Shaking off the Rust is a series of exercises with the Rust programing language. The purpose of the series is to improve both my and my dear reader’s abilities with Rust by building things. Plus, by actually building stuff, we'll learn about an array of technological concepts in the process. In this installment, we’re going to implement a classic machine learning algorithm.
After reading this installment, you'll have experience with:
if let
syntaxClone
traitwhere
clauseThis installment’s Github repo: https://github.com/josht-jpg/k-nearust-neighbors
The nearest neighbors (KNN) algorithm is simple.
It’s a good algorithm for classifying data [1]. Suppose we have a labeled dataset, which we'll denote , and an unlabeled data point, which we'll denote , and we want to predict the correct label for . We can do that with KNN.
KNN works like this:
For some integer , we find the data points in nearest to — the nearest neighbors. For an example of what I mean by nearest: in the graph below, the blue data points are the 3 nearest neighbors of the red data point.
For our computer to find data points nearby , we need a way to measure distance between data points. We can use the Pythagorean formula to do that [2]:
The distance between two data points and with features and is .
And to predict the label for , we pick the most common label from its nearest labeled data points.
Finally, we have to handle the possible scenario of a tie for the most common label. There are a few ways to handle this. Our approach will be to decrement until there’s no longer a tie. That is, we remove the furthest label from our nearest labels and recount the most common labels.
We, the implementors of KNN, specify the value for . Choosing a good value for is usually a process of trying and testing several values [3].
If you feel like watching a quick video on KNN, here’s a great one from Stat Quest with Josh Starmer (and I can’t say enough good things about that youtube channel. Thank you, Josh):
To get started, we’ll create a new library called k-nearust-neighbors
.
cargo new k-nearust-neighbors --lib
cd k-nearust-neighbors
We’ll also bring in the following crates as dev-dependencies.
// Cargo.toml
/*...*/
[dev-dependencies]
rand = "0.8.4"
reqwest = "0.11.10"
tokio-test = "*"
rand
is a crate for random number generation. You can read more about it here: https://crates.io/crates/rand. reqwest
and tokio_test
will be used to get some data to test our KNN classifier on. You can read about the reqwest
crate here and the tokio_test
crate here.
And we’ll toss this use
declaration into our lib.rs
file.
// lib.rs
use std::{
collections::HashMap,
ops::{Add, Sub}, ➀
}
We’ll use a struct
to represent a data point. We’ll name it LabeledPoint
. Each LabeledPoint
will have fields:
label
: a string slice that categorizes the data,point
: a vector containing the data point’s features (which will be represented as f64
s).Here is LabeledPoint
in Rust:
// lib.rs
#[derive(Clone)] ➀
struct LabeledPoint<'a> { ➁
label: &'a str, ➂
point: Vec<f64>,
}
➀ - We’re using the derive attribute to implement the Clone
trait. This means we’re asking Rust to generate code for the Clone
trait’s default implementation and apply it to LabeledPoint
. The Clone
trait will let us explicitly create deep copies of LabeledPoint
instances [5]. This will allow us to call the to_vec
method on a slice of LabelPoint
s (which we’ll do in our KNN implementation).
➁ - We declare that our struct is generic over the lifetime parameter 'a
. In Rust, a lifetime is the scope for which a reference is valid [5]. Rust’s lifetimes are part of what makes the language special. If you’d like to learn about them, I recommend reading section 10.3 of The Rust Programming Language book or watching this stream from Ryan Levick.
➂ - For a struct
to hold a reference, it must have a lifetime annotation for that reference [5]. label
is a string slice, and string slices are references. Thus, we must add a lifetime annotation to label
, which we do by putting 'a
after &
in label: &'a str
. Awesome.
To implement KNN, we need to compare distances between data points. This calls for some linear algebra. Fun.
(But if linear algebra does not sound like fun to you, feel free to copy and paste this code into your lib.rs
and skip to the next section - I’ll forgive you).
A quick review of vectors may be helpful (here I’m talking about vectors from linear algebra, not vectors from Rust).
A vector can be thought of as a list of numbers.
There are a few ways that a vector can be interpreted. One common interpretation is a point in space [6]. Take the vector , for example. This is what it looks like as a point in space:
This is how we will interpret vectors: as points in space. Cool.
Sidenote: the above plot was created with Rust! It was made with the plotters crate. Feel free to take a look at the code I slapped together to generate this plot: https://github.com/josht-jpg/vector_plot
We’re going to define a trait called LinearAlg
:
trait LinearAlg<T>
where
T: Add + Sub, ➀
{
fn dot(&self, w: &[T]) -> T;
fn subtract(&self, w: &[T]) -> Vec<T>;
fn sum_of_squares(&self) -> T;
fn distance(&self, w: &[T]) -> f64;
}
➀ - Rust’s where clause lets us specify that the generic type T
must implement the Add
and Sub
traits [7].
And we’ll make LinearAlg
an extension trait, implementing it for the standard library’s Vec<f64>
type.
impl LinearAlg<f64> for Vec<f64> { /*...*/ }
We’re going to take a test-first approach here. For each method in this implemetation of LinearAlg
, I’ll go over the math behind the operation, then provide a test for the method, and then some code that implements the method.
It’s a good exercise to try writing each method yourself and running the test before looking at my implementation. There’s a good chance you’ll like your implementation more (and if you do, please share it with me at joshtaylor361@gmail.com).
[8].
Here’s our test for dot
:
// lib.rs
/*...*/
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_alg() {
let v = vec![1., 5., -3.];
let w = vec![0.5, 2., 3.];
assert_eq!(v.dot(&w), 1.5)
}
}
Run that test with the command cargo test linear_alg
in the root of your k_nearust_neighbors
folder. Congratulations if that passes for you.
Here is my implementation of dot
:
fn dot(&self, w: &[f64]) -> f64 {
assert_eq!(self.len(), w.len());
self.iter().zip(w).map(|(v_i, w_i)| v_i * w_i).sum()
}
To test subtract
, add the following assertion to your linear_alg
test function:
#[test]
fn linear_alg() {
/*...*/
assert_eq!(v.subtract(&w), vec![0.5, 3., -6.]);
}
Nice. I hope you were able to make that test pass (but of course no worries if you weren’t). Here’s an implementation of subtract
:
fn subtract(&self, w: &[f64]) -> Vec<f64> {
assert_eq!(self.len(), w.len());
self.iter().zip(w).map(|(v_i, w_i)| v_i - w_i).collect()
}
For some vector with elements ,
Here’s a test for sum_of_squares
:
#[test]
fn linear_alg() {
/*...*/
assert_eq!(v.sum_of_squares(), 35.);
}
And here’s my implementation:
fn sum_of_squares(&self) -> f64 {
self.dot(&self)
}
As usual, here’s a test for distance
:
assert_eq!(v.distance(&w), 45.25f64.sqrt())
Hallelujah. Here’s some Rust:
fn distance(&self, w: &[f64]) -> f64 {
assert_eq!(self.len(), w.len());
self.subtract(w).sum_of_squares().sqrt()
}
Implementing LinearAlgbra
for the Vec<f64>
type is all we need for our KNN implementation, so we’ll leave it there. Great.
I’d like to continue our test-first approach. So before we get to the most important function of this installment, which will be called knn_classify
, we’ll write a test for it.
But we’re going to need some data to test on. A classic dataset to test KNN on is the iris flower dataset. This data set contains 150 rows, where each row contains a flower’s petal length, petal width, sepal length, sepal width, and type. The dataset has three types of iris: Setosa, Versicolor, and Virginica.
Here’s some code for you to put in lib.rs
. It gets the iris dataset and converts it to a format we can work with. Look through the code if you’re curious, but I won’t be explaining any of it. I’d like to save time for more interesting stuff.
#[cfg(test)]
mod tests {
use super::*;
/*...*/
macro_rules! await_fn {
($arg:expr) => {{
tokio_test::block_on($arg)
}};
}
async fn get_iris_data() -> Result<String, reqwest::Error> {
let body = reqwest::get(
"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
)
.await?
.text()
.await?;
Ok(body)
}
type GenericResult<T> = Result<T, Box<dyn std::error::Error>>;
fn process_iris_data(body: &str) -> GenericResult<Vec<LabeledPoint>> {
body.split("\n")
.filter(|data_point| data_point.len() > 0)
.map(|data_point| -> GenericResult<LabeledPoint> {
let columns = data_point.split(",").collect::<Vec<&str>>();
let (label, point) = columns.split_last().ok_or("Cannot split last")?;
let point = point
.iter()
.map(|feature| feature.parse::<f64>())
.collect::<Result<Vec<f64>, std::num::ParseFloatError>>()?;
Ok(LabeledPoint { label, point })
})
.collect::<GenericResult<Vec<LabeledPoint>>>()
}
}
Sweet.
Next, we need to split that data into a training set and a testing set. Here’s a function to do just that:
mod tests {
/*...*/
use rand::{seq::SliceRandom, thread_rng};
fn split_data<T>(data: &[T], prob: f64) -> (Vec<T>, Vec<T>)
where
T: Clone, ➀
{
let mut data_copy = data.to_vec(); ➁
data_copy.shuffle(&mut thread_rng()); ➂
let split_index = ((data.len() as f64) * prob).round() as usize;
(
data_copy[..split_index].to_vec(),
data_copy[split_index..].to_vec(),
)
}
}
➀ - Using the where
clause to specify that T
must implement the Clone
trait.
➁ - data.to_vec()
copies the data
slice into a new Vec
[9]. This allows us to shuffle our data without taking a mutable reference to the data.
➂ - The shuffle
method will shuffle up a mutable slice in place [10]. shuffle
is from the rand
crate’s SliceRandom
trait, which is an extension trait on slices. So we get to use shuffle
on mutable slices after importing the trait.
The thread_rng
function, which also comes from rand
, is a random number generator.
Great. We’ve got everything we need to test our (currently unimplemented) knn_classify
function.
We’re going to set to 5; we’ll classify new data points based on their 5 nearest neighbors. In a real application of KNN, it would be a good idea to test out a few more values of .
Here’s the test:
fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
todo!()
}
#[cfg(test)]
mod tests {
/*...*/
fn count_correct_classifications(
train_set: &[LabeledPoint],
test_set: &[LabeledPoint],
k: u8,
) -> u32 {
let mut num_correct: u32 = 0;
for iris in test_set.iter() {
let predicted = knn_classify(k, &train_set, &iris.point);
let actual = iris.label;
if let Some(predicted) = predicted { ➀
if predicted == actual {
num_correct += 1;
}
}
}
num_correct
}
#[test]
fn iris() -> GenericResult<()> {
let raw_iris_data = await_fn!(get_iris_data())?;
let iris_data = process_iris_data(&raw_iris_data)?;
let (train_set, test_set) = split_data(&iris_data, 0.70); ➁
assert_eq!(train_set.len(), 105);
assert_eq!(test_set.len(), 45);
let k = 5;
let num_correct = count_correct_classifications(&train_set, &test_set, k);
let percent_corrent = num_correct as f32 / test_set.len() as f32;
assert!(percent_corrent > 0.9); ➂
Ok(())
}
}
➀ - The if let
syntax is a lovely way for us to match one pattern and ignore all other patterns [5]. So an alternative (but less Rustic) way to write this block is:
match predicted {
Some(predicted) => {
if predicted == actual {
num_correct += 1;
}
}
_ => (),
}
➁ - Splitting 70% of the data into a set for training the classifier, and 30% into a set for testing.
➂ - If our classifier is working, it should correctly classify at least 90% of the testing set.
I’ll start with pseudocode for our KNN classifier. Try to write your own Rust implementation based on this pseudocode, and run the test we wrote to see if your implementation works.
function knn_classify(k, data_points, new_point)
arguments {
k: number of neighbors we use to classify our new data point
data_points: our labeled data points
new_point: the data point we want to classify
}
returning: predicted label for new_point
{
sorted_data_points = sort_by_distance_from(data_points, new_point)
k_nearest_labels = empty list
for i from 0 to k {
k_nearest_labels.append(data_points[i].label)
}
predicted_label = find_most_common_label(k_nearest_labels)
return predicted_label
}
function find_most_common_label(labels)
arguments {
labels: a list of labels
}
returning: most common value in the passed in list of labels
{
label_counts = new Hash Map
for label in labels {
if label is a key in label_counts {
label_counts[label] += 1
} else {
label_counts.add_key_value_pair((label, 1))
}
}
if there are no ties for most common label in label_counts {
return key with highest value in label_counts
} else {
new_labels = all elements in labels but the last
return find_most_common_label(new_labels)
}
}
Great. Now here is some Rust for you.
fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
let mut data_points_copy = data_points.to_vec();
data_points_copy.sort_unstable_by(|a, b| { ➀
let dist_a = a.point.distance(new_point);
let dist_b = b.point.distance(new_point);
dist_a
.partial_cmp(&dist_b)
.expect("Cannot compare floating point numbers, encoutered a NAN") ➁
});
let k_nearest_labels = &data_points_copy[..(k as usize)]
.iter()
.map(|a| a.label)
.collect::<Vec<&str>>();
let predicted_label = find_most_common_label(&k_nearest_labels);
predicted_label
}
fn find_most_common_label(labels: &[&str]) -> Option<String> {
let mut label_counts: HashMap<&str, u32> = HashMap::new(); ➂
for label in labels.iter() {
let current_label_count = if let Some(current_label_count) = label_counts.get(label) { ➃
*current_label_count
} else {
0
};
label_counts.insert(label, current_label_count + 1);
}
let most_common = label_counts
.iter()
.max_by(|(_label_a, count_a), (_label_b, count_b)| count_a.cmp(&count_b)); ➄
if let Some((most_common_label, most_common_label_count)) = most_common {
let is_tie_for_most_common = label_counts
.iter()
.any(|(label, count)| count == most_common_label_count && label != most_common_label); ➅
if !is_tie_for_most_common {
return Some((*most_common_label).to_string());
} else {
let (_last, labels) = labels.split_last()?; ➆
return find_most_common_label(&labels);
}
}
None
}
➀ - sort_unstable_by
allows us to specify the way we want our vector sorted. We’re using this to sort the data points in data_copy
by their distance from the data point we’re classifying.
Digression on stable vs. unstable sorting:
As the name suggests, sort_unstable_by
is unstable [11]. This means that when the sorting algorithm comes across two equal elements, it is allowed to swap them - whereas a stable sort will not swap equal elements [12].
Unstable sorting is generally faster and requires less memory than stable sorting.
Rust’s sort_by
method is stable. So in cases like ours where you don’t care if equal elements are swapped, it’s preferable to use sort_unstable_by
over sort_by
.
➁ - partial_ord
returns a value specifying whether dist_a
is greater to, less than, or equal to dist_b
, if such a comparison can be made.
In more gory details, partial_ord
returns an Option
containing a variant of the Ordering
enum, if an ordering exists (if not it will return None
).
The creators of Rust have intentionally not implemented Ord
for f64
. This is because an f64
can be a NAN
(not a number), and in Rust NAN != NAN
. So f64
does not form a total order.
If some or all of that didn’t make sense, do not worry. Just note that we have to use partial_cmp
for f64
types rather than cmp
.
➂ - We create a HashMap
with &str
keys and u32
values. The keys are the distinct elements of the labels
slice. The values are the number of times each label shows up in labels
.
➃ - Rust is an expression language [13]. This means that things like if
and match
will produce values. So we can assign a variable to the result of an if-else
block. Cool.
➄ - Here we want to find the most common label in label_counts
. We iterate through each key-value pair of label_counts
and use max_by
to find the key-value pair with the highest value. max_by
returns the element with the maximum value with respect to a custom comparison function [14].
➅ - The any
method checks if the provided predicate is true
for any elements in the iterator [15]. We use any
to see if there are any ties for most common label in label_counts
.
➆ - split_last
returns an Option
containing a tuple of the last element of a vector, and a slice containing the rest of that vector.
And that’s our K Nearust Neighbors classifier. Amazing.
As always, questions and feedback are welcome and appreciated: joshtaylor361@gmail.com.
Have a great rest of your day.
[1] - Joel Grus. (2019). Data Science from Scratch: First Principles with Python, 2nd edition. O’Reilly Media.
[2] - Aditya Y. Bhargava. (2015). Grokking Algorithms. Manning Publications.
[3] - Josh Starmer. (2017). StatQuest: K-nearest neighbors, Clearly Explained. StatQuest with Josh Starmer.
[4] - Rust Documentation for Trait std::ops::Add.
[5] - Steve Klabnik and Carol Nichols. (2018). The Rust Programming Language. No Starch Press.
[6] - Shin Takahashi and Iroha Inoue. (2012). The Manga Guide to Linear Algebra. No Starch Press.
[7] - Rust documentation for the where clause.
[8] - Sheldon Axler. (1995). Linear Algebra Done Right. Springer.
[9] - Rust documentation for to_vec.
[10] - rand crate documentation of shuffle.
[11] - Rust documentation for sort_unstable_by.
[12] - Jon Gjengset. (2020). Crust of Rust: Sorting Algorithms.
[13] - Jason Orendorff, Jim Blandy, and Leonora F .S. Tindall. (2021). Programming Rust. O’Reilly Media.
[14] - Rust documentation for max_by.
[15] - Rust documentation for any.
Creating and running Shaking off the Rust is one of the most fulfilling things I do. But it's exhausting. By supporting me, even if it's just a dollar, you'll allow me to put more time into building this series. I really appreciate any support.
The only way to support me right now is by sponsoring me on Github. I'll probably also set up Patreon and Donorbox pages soon.
Thank you so much!
No spam. Unsubscribe anytime.