Using kNN in scikit-learn: Score and Update k
In a previous lesson, you created
y data from a dataset about abalones.
X contains the physical measurements of the abalones, while
y stores their rings, which you tried to predict using kNN.
00:36 You fit that model to your training data and then made predictions with it for the features of both the training and test sets. Now, let’s determine how good those predictions are by scoring them against the actual values.
MSE can be difficult to understand at first glance because it reports error in squared units. Let’s switch over to RMSE, or root mean squared error. So
rmse_train will be equal to the square root of
mse. You can just use NumPy’s square root function to calculate it.
np, which is for NumPy,
sqrt(), which is a square root function, and then we’ll take the square root of
rmse_train is about 1.65, which is in the same units as your original target.
02:27 But for a more realistic result, you should see how your model performs on data that it hasn’t actually ever seen, and that’s the test set. So let’s find the mean squared error of the actual target values in your test set compared to the predictions of your kNN model.
The k is a so-called hyperparameter of kNN, and it needs to be adjusted to an appropriate value each time you apply kNN to a new dataset. Imagine you set
k equal to
1, so you’d only consider the closest neighbor when making a prediction.
03:48 Your predictions probably wouldn’t be very good because they would vary a lot from one point to another. That’s called high variance. However, if you set k to be very high, say the size of the entire dataset, you might be using neighbors that are very far away to make predictions, and you’d lose out on the nuances of your dataset. That’s called high bias.
Okay, so the predictions have been made, and you can go ahead and score these against the actual rings of your test abalones. Set
mse_test_25 equal to the
mean_squared_error() of your
y_test and your
Let’s take a look at
rmse_test_25, and you can see that this value has now decreased from 2.38 to 2.17. By considering twenty-five neighbors when making a prediction, your kNN model has less variability from point to point. That makes your test error lower and means that your kNN generalizes better to unseen data.
The choice of 25 neighbors was somewhat arbitrary in this lesson, but in actuality, you would likely use a validation set or a cross-validation method like
.GridSearchCV() to select the best hyperparameter k for your particular situation. While this process is outside the scope of this lesson, you can learn more about validation or cross-validation elsewhere on the Real Python platform.
06:56 You’ve now completed building a kNN model in Python’s scikit-learn. Coming up next, you’ll conclude this course by reviewing all you’ve learned about kNN, including its primary attributes, the main steps of the algorithm, and the code you used to make kNN predictions in Python.
Become a Member to join the conversation.