Nearest neighbor classification CSE 250B
Nearest neighbor classification
Nearest neighbor classification
CSE 250B
The problem we’ll solve today
Given an image of a handwritten digit, say which digit it is.
=⇒ 3
Some more examples:
The problem we’ll solve today
Given an image of a handwritten digit, say which digit it is.
=⇒ 3
Some more examples:
The machine learning approach
Assemble a data set:
The MNIST data set of handwritten digits:
• Training set of 60,000 images and their labels.
• Test set of 10,000 images and their labels.
And let the machine figure out the underlying patterns.
Nearest neighbor classification
Training images x (1), x (2), x (3), . . . , x (60000)
Labels y (1), y (2), y (3), . . . , y (60000) are numbers in the range 0− 9
How to classify a new image x?
• Find its nearest neighbor amongst the x (i)
• Return y (i)
The data space
How to measure the distance between images?
MNIST images:
• Size 28× 28 (total: 784 pixels)
• Each pixel is grayscale: 0-255
Stretch each image into a vector with 784 coordinates:
• Data space X = R784
• Label space Y = {0, 1, . . . , 9}
The data space
How to measure the distance between images?
MNIST images:
• Size 28× 28 (total: 784 pixels)
• Each pixel is grayscale: 0-255
Stretch each image into a vector with 784 coordinates:
• Data space X = R784
• Label space Y = {0, 1, . . . , 9}
The distance function
Remember Euclidean distance in two dimensions?
x = (1, 2)
z = (3, 5)
Euclidean distance in higher dimension
Euclidean distance between 784-dimensional vectors x , z is
‖x − z‖ =
√√√√ 784∑
i=1
(xi − zi )2
Here xi is the ith coordinate of x .
Nearest neighbor classification
Training images x (1), . . . , x (60000), labels y (1), . . . , y (60000)
To classify a new image x :
• Find its nearest neighbor amongst the x (i) using
Euclidean distance in R784
• Return y (i)
How accurate is this classifier?
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points?
Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?) 90%.
• Test error of nearest neighbor: 3.09%.
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points? Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?) 90%.
• Test error of nearest neighbor: 3.09%.
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points? Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?) 90%.
• Test error of nearest neighbor: 3.09%.
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points? Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?)
90%.
• Test error of nearest neighbor: 3.09%.
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points? Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?) 90%.
• Test error of nearest neighbor: 3.09%.
Accuracy of nearest neighbor on MNIST
Training set of 60,000 points.
• What is the error rate on training points? Zero.
In general, training error is an overly optimistic predictor of future
performance.
• A better gauge: separate test set of 10,000 points.
Test error = fraction of test points incorrectly classified.
• What test error would we expect for a random classifier?
(One that picks a label 0− 9 at random?) 90%.
• Test error of nearest neighbor: 3.09%.
Examples of errors
Test set of 10,000 points:
• 309 are misclassified
• Error rate 3.09%
Examples of errors:
Query
NN
Ideas for improvement: (1) k-NN (2) better distance function.
Examples of errors
Test set of 10,000 points:
• 309 are misclassified
• Error rate 3.09%
Examples of errors:
Query
NN
Ideas for improvement: (1) k-NN (2) better distance function.
K -nearest neighbor classification
Classify a point using the labels of its k nearest neighbors among the
training points.
MNIST:
k 1 3 5 7 9 11
Test error (%) 3.09 2.94 3.13 3.10 3.43 3.34
In real life, there’s no test set. How to decide which k is best?
1 Hold-out set.
• Let S be the training set.
• Choose a subset V ⊂ S as a validation set.
• What fraction of V is misclassified by finding the k-nearest
neighbors in S \ V ?
2 Leave-one-out cross-validation.
• For each point x ∈ S , find the k-nearest neighbors in S \ {x}.
• What fraction are misclassified?
K -nearest neighbor classification
Classify a point using the labels of its k nearest neighbors among the
training points.
MNIST:
k 1 3 5 7 9 11
Test error (%) 3.09 2.94 3.13 3.10 3.43 3.34
In real life, there’s no test set. How to decide which k is best?
1 Hold-out set.
• Let S be the training set.
• Choose a subset V ⊂ S as a validation set.
• What fraction of V is misclassified by finding the k-nearest
neighbors in S \ V ?
2 Leave-one-out cross-validation.
• For each point x ∈ S , find the k-nearest neighbors in S \ {x}.
• What fraction are misclassified?
K -nearest neighbor classification
Classify a point using the labels of its k nearest neighbors among the
training points.
MNIST:
k 1 3 5 7 9 11
Test error (%) 3.09 2.94 3.13 3.10 3.43 3.34
In real life, there’s no test set. How to decide which k is best?
1 Hold-out set.
• Let S be the training set.
• Choose a subset V ⊂ S as a validation set.
• What fraction of V is misclassified by finding the k-nearest
neighbors in S \ V ?
2 Leave-one-out cross-validation.
• For each point x ∈ S , find the k-nearest neighbors in S \ {x}.
• What fraction are misclassified?
K -nearest neighbor classification
Classify a point using the labels of its k nearest neighbors among the
training points.
MNIST:
k 1 3 5 7 9 11
Test error (%) 3.09 2.94 3.13 3.10 3.43 3.34
In real life, there’s no test set. How to decide which k is best?
1 Hold-out set.
• Let S be the training set.
• Choose a subset V ⊂ S as a validation set.
• What fraction of V is misclassified by finding the k-nearest
neighbors in S \ V ?
2 Leave-one-out cross-validation.
• For each point x ∈ S , find the k-nearest neighbors in S \ {x}.
• What fraction are misclassified?
K -nearest neighbor classification
Classify a point using the labels of its k nearest neighbors among the
training points.
MNIST:
k 1 3 5 7 9 11
Test error (%) 3.09 2.94 3.13 3.10 3.43 3.34
In real life, there’s no test set. How to decide which k is best?
1 Hold-out set.
• Let S be the training set.
• Choose a subset V ⊂ S as a validation set.
• What fraction of V is misclassified by finding the k-nearest
neighbors in S \ V ?
2 Leave-one-out cross-validation.
• For each point x ∈ S , find the k-nearest neighbors in S \ {x}.
• What fraction are misclassified?
Cross-validation
How to estimate the error of k-NN for a particular k?
10-fold cross-validation
• Divide the training set into 10 equal pieces.
Training set (call it S): 60,000 points
Call the pieces S1,S2, . . . ,S10: 6,000 points each.
• For each piece Si :
• Classify each point in Si using k-NN with training set S − Si
• Let i = fraction of Si that is incorrectly classified
• Take the average of these 10 numbers:
estimated error with k-NN =
1 + · · ·+ 10
10
Another improvement: better distance functions
The Euclidean (`2) distance between these two images is very high!
Much better idea: distance measures that are invariant under:
• Small translations and rotations. e.g. tangent distance.
• A broader family of natural deformations. e.g. shape context.
Test error rates:
`2 tangent distance shape context
3.09 1.10 0.63
Another improvement: better distance functions
The Euclidean (`2) distance between these two images is very high!
Much better idea: distance measures that are invariant under:
• Small translations and rotations. e.g. tangent distance.
• A broader family of natural deformations. e.g. shape context.
Test error rates:
`2 tangent distance shape context
3.09 1.10 0.63
Another improvement: better distance functions
The Euclidean (`2) distance between these two images is very high!
Much better idea: distance measures that are invariant under:
• Small translations and rotations. e.g. tangent distance.
• A broader family of natural deformations. e.g. shape context.
Test error rates:
`2 tangent distance shape context
3.09 1.10 0.63
Related problem: feature selection
Feature selection/reweighting is part of picking a distance function.
And, one noisy feature can wreak havoc with nearest neighbor!
versus
Related problem: feature selection
Feature selection/reweighting is part of picking a distance function.
And, one noisy feature can wreak havoc with nearest neighbor!
versus
Algorithmic issue: speeding up NN search
Naive search takes time O(n) for training set of size n: slow!
There are data structures for speeding up nearest neighbor search, like:
1 Locality sensitive hashing
2 Ball trees
3 K -d trees
These are part of standard Python libraries for NN, and help a lot.
Algorithmic issue: speeding up NN search
Naive search takes time O(n) for training set of size n: slow!
There are data structures for speeding up nearest neighbor search, like:
1 Locality sensitive hashing
2 Ball trees
3 K -d trees
These are part of standard Python libraries for NN, and help a lot.
Example: k-d trees for NN search
A hierarchical, rectilinear spatial partition.
For data set S ⊂ Rd :
• Pick a coordinate 1 ≤ i ≤ d .
• Compute v = median({xi : x ∈ S}).
• Split S into two halves:
SL = {x ∈ S : xi < v}
SR = {x ∈ S : xi ≥ v}