Difficulty: Intermediate

* Shaking off the Rust* is a series of exercises with the Rust programing language. The purpose of the series is to improve both my and my dear reader’s abilities with Rust by building things. Plus, by actually building stuff, we'll learn about an array of technological concepts in the process. In this installment, we’re going to implement a classic machine learning algorithm.

After reading this installment, you'll have experience with:

- Rust’s
`if let`

syntax

- Rust’s
`Clone`

trait

- Extension traits

- A little bit of lifetime parameters

- Rust’s
`where`

clause

- The $k$ nearest neighbors algorithm

- Much more!

This installment’s Github repo: https://github.com/josht-jpg/k-nearust-neighbors

The $k$ nearest neighbors (KNN) algorithm is simple.

It’s a good algorithm for classifying data [1]. Suppose we have a labeled dataset, which we'll denote $D$, and an unlabeled data point, which we'll denote $d$, and we want to predict the correct label for $d$. We can do that with KNN.

KNN works like this:

For some integer $k$, we find the $k$ data points in $D$ nearest to $d$ — the $k$ nearest neighbors. For an example of what I mean by *nearest*: in the graph below, the blue data points are the 3 nearest neighbors of the red data point.

For our computer to find data points nearby $d$, we need a way to measure distance between data points. We can use the Pythagorean formula to do that [2]:

The distance between two data points $x$ and $y$ with features $x_1, ..., x_n$ and $y_1, ..., y_n$ is $\sqrt{(x_1 - y_1)^2 + ... + (x_n - y_n)^2}$.

And to predict the label for $d$, we pick the most common label from its $k$ nearest labeled data points.

Finally, we have to handle the possible scenario of a tie for the most common label. There are a few ways to handle this. Our approach will be to decrement $k$ until there’s no longer a tie. That is, we remove the furthest label from our nearest labels and recount the most common labels.

We, the implementors of KNN, specify the value for $k$. Choosing a good value for $k$ is usually a process of trying and testing several values [3].

If you feel like watching a quick video on KNN, here’s a great one from Stat Quest with Josh Starmer (and I can’t say enough good things about that youtube channel. Thank you, Josh):

To get started, we’ll create a new library called `k-nearust-neighbors`

.

```
cargo new k-nearust-neighbors --lib
cd k-nearust-neighbors
```

We’ll also bring in the following crates as dev-dependencies.

```
// Cargo.toml
/*...*/
[dev-dependencies]
rand = "0.8.4"
reqwest = "0.11.10"
tokio-test = "*"
```

`rand`

is a crate for random number generation. You can read more about it here: https://crates.io/crates/rand. `reqwest`

and `tokio_test`

will be used to get some data to test our KNN classifier on. You can read about the `reqwest`

crate here and the `tokio_test`

crate here.

And we’ll toss this `use`

declaration into our `lib.rs`

file.

```
// lib.rs
use std::{
collections::HashMap,
ops::{Add, Sub}, ➀
}
```

We’ll use a `struct`

to represent a data point. We’ll name it `LabeledPoint`

. Each `LabeledPoint`

will have fields:

`label`

: a string slice that categorizes the data,

`point`

: a vector containing the data point’s features (which will be represented as`f64`

s).

Here is `LabeledPoint`

in Rust:

```
// lib.rs
#[derive(Clone)] ➀
struct LabeledPoint<'a> { ➁
label: &'a str, ➂
point: Vec<f64>,
}
```

*Sidenote: I highly recommend all of Ryan Levick’s content. Thanks for doing what you do, Ryan*.

**➀ - **We’re using the derive attribute to implement the `Clone`

trait. This means we’re asking Rust to generate code for the `Clone`

trait’s default implementation and apply it to `LabeledPoint`

. The `Clone`

trait will let us explicitly create deep copies of `LabeledPoint`

instances [5]. This will allow us to call the `to_vec`

method on a slice of `LabelPoint`

s (which we’ll do in our KNN implementation).

**➁** - We declare that our struct is generic over the *lifetime* parameter `'a`

. In Rust, a lifetime is the scope for which a reference is valid [5]. Rust’s lifetimes are part of what makes the language special. If you’d like to learn about them, I recommend reading section 10.3 of The Rust Programming Language book or watching this stream from Ryan Levick.

**➂** - For a `struct`

to hold a reference, it must have a lifetime annotation for that reference [5]. `label`

is a string slice, and string slices are references. Thus, we must add a lifetime annotation to `label`

, which we do by putting `'a`

after `&`

in `label: &'a str`

. Awesome.

To implement KNN, we need to compare distances between data points. This calls for some linear algebra. Fun.

*(But if linear algebra does not sound like fun to you, feel free to copy and paste this code into your lib.rs and skip to the next section - I’ll forgive you). *

A quick review of vectors may be helpful (here I’m talking about vectors from linear algebra, not vectors from Rust).

A vector can be thought of as a list of numbers.

There are a few ways that a vector can be interpreted. One common interpretation is a point in space [6]. Take the vector $[2.5, 3, 4]$, for example. This is what it looks like as a point in space:

This is how we will interpret vectors: as points in space. Cool.

*Sidenote: the above plot was created with Rust! It was made with the **plotters crate**. Feel free to take a look at the code I slapped together to generate this plot: *https://github.com/josht-jpg/vector_plot

We’re going to define a trait called `LinearAlg`

:

```
trait LinearAlg<T>
where
T: Add + Sub, ➀
{
fn dot(&self, w: &[T]) -> T;
fn subtract(&self, w: &[T]) -> Vec<T>;
fn sum_of_squares(&self) -> T;
fn distance(&self, w: &[T]) -> f64;
}
```

**➀** - Rust’s where clause lets us specify that the generic type `T`

must implement the `Add`

and `Sub`

traits [7].

And we’ll make `LinearAlg`

an extension trait, implementing it for the standard library’s `Vec<f64>`

type.

`impl LinearAlg<f64> for Vec<f64> { /*...*/ }`

We’re going to take a test-first approach here. For each method in this implemetation of `LinearAlg`

, I’ll go over the math behind the operation, then provide a test for the method, and then some code that implements the method.

It’s a good exercise to try writing each method yourself and running the test before looking at my implementation. There’s a good chance you’ll like your implementation more (and if you do, please share it with me at joshtaylor361@gmail.com).

- Dot Product - For two vectors $\bold{v}$ and $\bold{w}$ of the same length, the dot product of $\bold{v}$ and $\bold{w}$ is the result of coupling up each corresponding element in $\bold{v}$ and $\bold{w}$, multiplying those two elements together, and adding each result.
$\bold{v} \cdot \bold{w}= \begin{bmatrix} \bold{v}_{1} \\ \bold{v}_{2} \\ \vdots \\ \bold{v}_{n}\end{bmatrix} \cdot \begin{bmatrix} \bold{w}_{1} \\ \bold{w}_{2} \\ \vdots \\ \bold{w}_{n}\end{bmatrix} = \bold{v}_1 \cdot \bold{w}_1 + \bold{v}_2 \cdot \bold{w}_2 + ... + \bold{v}_n \cdot \bold{w}_n$ [8].

Here’s our test for

`dot`

:`// lib.rs /*...*/ #[cfg(test)] mod tests { use super::*; #[test] fn linear_alg() { let v = vec![1., 5., -3.]; let w = vec![0.5, 2., 3.]; assert_eq!(v.dot(&w), 1.5) } }`

Run that test with the command

`cargo test linear_alg`

in the root of your`k_nearust_neighbors`

folder. Congratulations if that passes for you.Here is my implementation of

`dot`

:`fn dot(&self, w: &[f64]) -> f64 { assert_eq!(self.len(), w.len()); self.iter().zip(w).map(|(v_i, w_i)| v_i * w_i).sum() }`

- Subtract - This one’s simpler. For two vectors of the same length $\bold{v}$ and $\bold{w}$,
$\bold{v} - \bold{w} = \begin{bmatrix} \bold{v}_{1} \\ \bold{v}_{2} \\ \vdots \\ \bold{v}_{m}\end{bmatrix} - \begin{bmatrix} \bold{w}_{1} \\ \bold{w}_{2} \\ \vdots \\ \bold{w}_{m}\end{bmatrix} = \begin{bmatrix} \bold{v}_{1} - \bold{w}_{1} \\ \bold{v}_{2} - \bold{w}_{2} \\ \vdots \\ \bold{v}_{n} - \bold{w}_{n}\end{bmatrix}$

To test

`subtract`

, add the following assertion to your`linear_alg`

test function:`#[test] fn linear_alg() { /*...*/ assert_eq!(v.subtract(&w), vec![0.5, 3., -6.]); }`

Nice. I hope you were able to make that test pass (but of course no worries if you weren’t). Here’s an implementation of

`subtract`

:`fn subtract(&self, w: &[f64]) -> Vec<f64> { assert_eq!(self.len(), w.len()); self.iter().zip(w).map(|(v_i, w_i)| v_i - w_i).collect() }`

- Sum of Squares - A vector’s
*sum of squares*is the result of squaring each of its elements and adding everything up:For some vector $\bold{v}$ with elements $\bold{v}_1, ..., \bold{v}_n$, $sum \hspace{1mm} of \hspace{1mm} squares(\bold{v}) = (\bold{v}_1)^2 + ... + (\bold{v}_n)^2$

Here’s a test for

`sum_of_squares`

:`#[test] fn linear_alg() { /*...*/ assert_eq!(v.sum_of_squares(), 35.); }`

And here’s my implementation:

`fn sum_of_squares(&self) -> f64 { self.dot(&self) }`

- Distance - The distance between two vectors $\bold{v}$ and $\bold{w}$ is defined as
$\sqrt{(\bold{v}_1 - \bold{w}_1)^2 + ... + (\bold{v}_n - \bold{w}_n)^2}$

As usual, here’s a test for

`distance`

:`assert_eq!(v.distance(&w), 45.25f64.sqrt())`

Hallelujah. Here’s some Rust:

`fn distance(&self, w: &[f64]) -> f64 { assert_eq!(self.len(), w.len()); self.subtract(w).sum_of_squares().sqrt() }`

Implementing `LinearAlgbra`

for the `Vec<f64>`

type is all we need for our KNN implementation, so we’ll leave it there. Great.

I’d like to continue our test-first approach. So before we get to the most important function of this installment, which will be called `knn_classify`

, we’ll write a test for it.

But we’re going to need some data to test on. A classic dataset to test KNN on is the iris flower dataset. This data set contains 150 rows, where each row contains a flower’s petal length, petal width, sepal length, sepal width, and type. The dataset has three types of iris: Setosa, Versicolor, and Virginica.

Here’s some code for you to put in `lib.rs`

. It gets the iris dataset and converts it to a format we can work with. Look through the code if you’re curious, but I won’t be explaining any of it. I’d like to save time for more interesting stuff.

```
#[cfg(test)]
mod tests {
use super::*;
/*...*/
macro_rules! await_fn {
($arg:expr) => {{
tokio_test::block_on($arg)
}};
}
async fn get_iris_data() -> Result<String, reqwest::Error> {
let body = reqwest::get(
"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
)
.await?
.text()
.await?;
Ok(body)
}
type GenericResult<T> = Result<T, Box<dyn std::error::Error>>;
fn process_iris_data(body: &str) -> GenericResult<Vec<LabeledPoint>> {
body.split("\n")
.filter(|data_point| data_point.len() > 0)
.map(|data_point| -> GenericResult<LabeledPoint> {
let columns = data_point.split(",").collect::<Vec<&str>>();
let (label, point) = columns.split_last().ok_or("Cannot split last")?;
let point = point
.iter()
.map(|feature| feature.parse::<f64>())
.collect::<Result<Vec<f64>, std::num::ParseFloatError>>()?;
Ok(LabeledPoint { label, point })
})
.collect::<GenericResult<Vec<LabeledPoint>>>()
}
}
```

Sweet.

Next, we need to split that data into a training set and a testing set. Here’s a function to do just that:

```
mod tests {
/*...*/
use rand::{seq::SliceRandom, thread_rng};
fn split_data<T>(data: &[T], prob: f64) -> (Vec<T>, Vec<T>)
where
T: Clone, ➀
{
let mut data_copy = data.to_vec(); ➁
data_copy.shuffle(&mut thread_rng()); ➂
let split_index = ((data.len() as f64) * prob).round() as usize;
(
data_copy[..split_index].to_vec(),
data_copy[split_index..].to_vec(),
)
}
}
```

**➀** - Using the `where`

clause to specify that `T`

must implement the `Clone`

trait.

**➁** - `data.to_vec()`

copies the `data`

slice into a new `Vec`

[9]. This allows us to shuffle our data without taking a mutable reference to the data.

**➂** - The `shuffle`

method will shuffle up a mutable slice in place [10]. `shuffle`

is from the `rand`

crate’s `SliceRandom`

trait, which is an extension trait on slices. So we get to use `shuffle`

on mutable slices after importing the trait.

The `thread_rng`

function, which also comes from `rand`

, is a random number generator.

Great. We’ve got everything we need to test our (currently unimplemented) `knn_classify`

function.

We’re going to set $k$ to 5; we’ll classify new data points based on their 5 nearest neighbors. In a real application of KNN, it would be a good idea to test out a few more values of $k$.

Here’s the test:

```
fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
todo!()
}
#[cfg(test)]
mod tests {
/*...*/
fn count_correct_classifications(
train_set: &[LabeledPoint],
test_set: &[LabeledPoint],
k: u8,
) -> u32 {
let mut num_correct: u32 = 0;
for iris in test_set.iter() {
let predicted = knn_classify(k, &train_set, &iris.point);
let actual = iris.label;
if let Some(predicted) = predicted { ➀
if predicted == actual {
num_correct += 1;
}
}
}
num_correct
}
#[test]
fn iris() -> GenericResult<()> {
let raw_iris_data = await_fn!(get_iris_data())?;
let iris_data = process_iris_data(&raw_iris_data)?;
let (train_set, test_set) = split_data(&iris_data, 0.70); ➁
assert_eq!(train_set.len(), 105);
assert_eq!(test_set.len(), 45);
let k = 5;
let num_correct = count_correct_classifications(&train_set, &test_set, k);
let percent_corrent = num_correct as f32 / test_set.len() as f32;
assert!(percent_corrent > 0.9); ➂
Ok(())
}
}
```

**➀** - The `if let`

syntax is a lovely way for us to match one pattern and ignore all other patterns [5]. So an alternative (but less Rustic) way to write this block is:

```
match predicted {
Some(predicted) => {
if predicted == actual {
num_correct += 1;
}
}
_ => (),
}
```

**➁** - Splitting 70% of the data into a set for training the classifier, and 30% into a set for testing.

**➂** - If our classifier is working, it should correctly classify at least 90% of the testing set.

I’ll start with pseudocode for our KNN classifier. Try to write your own Rust implementation based on this pseudocode, and run the test we wrote to see if your implementation works.

```
function knn_classify(k, data_points, new_point)
arguments {
k: number of neighbors we use to classify our new data point
data_points: our labeled data points
new_point: the data point we want to classify
}
returning: predicted label for new_point
{
sorted_data_points = sort_by_distance_from(data_points, new_point)
k_nearest_labels = empty list
for i from 0 to k {
k_nearest_labels.append(data_points[i].label)
}
predicted_label = find_most_common_label(k_nearest_labels)
return predicted_label
}
function find_most_common_label(labels)
arguments {
labels: a list of labels
}
returning: most common value in the passed in list of labels
{
label_counts = new Hash Map
for label in labels {
if label is a key in label_counts {
label_counts[label] += 1
} else {
label_counts.add_key_value_pair((label, 1))
}
}
if there are no ties for most common label in label_counts {
return key with highest value in label_counts
} else {
new_labels = all elements in labels but the last
return find_most_common_label(new_labels)
}
}
```

Great. Now here is some Rust for you.

```
fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
let mut data_points_copy = data_points.to_vec();
data_points_copy.sort_unstable_by(|a, b| { ➀
let dist_a = a.point.distance(new_point);
let dist_b = b.point.distance(new_point);
dist_a
.partial_cmp(&dist_b)
.expect("Cannot compare floating point numbers, encoutered a NAN") ➁
});
let k_nearest_labels = &data_points_copy[..(k as usize)]
.iter()
.map(|a| a.label)
.collect::<Vec<&str>>();
let predicted_label = find_most_common_label(&k_nearest_labels);
predicted_label
}
fn find_most_common_label(labels: &[&str]) -> Option<String> {
let mut label_counts: HashMap<&str, u32> = HashMap::new(); ➂
for label in labels.iter() {
let current_label_count = if let Some(current_label_count) = label_counts.get(label) { ➃
*current_label_count
} else {
0
};
label_counts.insert(label, current_label_count + 1);
}
let most_common = label_counts
.iter()
.max_by(|(_label_a, count_a), (_label_b, count_b)| count_a.cmp(&count_b)); ➄
if let Some((most_common_label, most_common_label_count)) = most_common {
let is_tie_for_most_common = label_counts
.iter()
.any(|(label, count)| count == most_common_label_count && label != most_common_label); ➅
if !is_tie_for_most_common {
return Some((*most_common_label).to_string());
} else {
let (_last, labels) = labels.split_last()?; ➆
return find_most_common_label(&labels);
}
}
None
}
```

**➀** - `sort_unstable_by`

allows us to specify the way we want our vector sorted. We’re using this to sort the data points in `data_copy`

by their distance from the data point we’re classifying.

*Digression on stable vs. unstable sorting:*

As the name suggests, `sort_unstable_by`

is unstable [11]. This means that when the sorting algorithm comes across two equal elements, it is allowed to swap them - whereas a stable sort will not swap equal elements [12].

Unstable sorting is generally faster and requires less memory than stable sorting.

Rust’s `sort_by`

method is stable. So in cases like ours where you don’t care if equal elements are swapped, it’s preferable to use `sort_unstable_by`

over `sort_by`

.

**➁** - `partial_ord`

returns a value specifying whether `dist_a`

is greater to, less than, or equal to `dist_b`

, if such a comparison can be made.

In more gory details, `partial_ord`

returns an `Option`

containing a variant of the `Ordering`

enum, if an ordering exists (if not it will return `None`

).

The creators of Rust have intentionally not implemented `Ord`

for `f64`

. This is because an `f64`

can be a `NAN`

(not a number), and in Rust `NAN != NAN`

. So `f64`

does not form a total order.

*If some or all of that didn’t make sense, do not worry. Just note that we have to use *`partial_cmp`

* for *`f64`

* types rather than *`cmp`

*.*

**➂** - We create a `HashMap`

with `&str`

keys and `u32`

values. The keys are the distinct elements of the `labels`

slice. The values are the number of times each label shows up in `labels`

.

**➃** - Rust is an *expression language *[13]. This means that things like `if`

and `match`

will produce values. So we can assign a variable to the result of an `if-else`

block. Cool.

**➄** - Here we want to find the most common label in `label_counts`

. We iterate through each key-value pair of `label_counts`

and use `max_by`

to find the key-value pair with the highest value. `max_by`

returns the element with the maximum value with respect to a custom comparison function [14].

**➅** - The `any`

method checks if the provided predicate is `true`

for any elements in the iterator [15]. We use `any`

to see if there are any ties for most common label in `label_counts`

.

**➆** - `split_last`

returns an `Option`

containing a tuple of the last element of a vector, and a slice containing the rest of that vector.

And that’s our K Nea*rust* Neighbors classifier. Amazing.

As always, questions and feedback are welcome and appreciated: joshtaylor361@gmail.com.

Have a great rest of your day.

[1] - Joel Grus. (2019). *Data Science from Scratch: First Principles with Python, 2nd edition. *O’Reilly Media.

[2] - Aditya Y. Bhargava. (2015). *Grokking Algorithms. *Manning Publications.

[3] - Josh Starmer. (2017). *StatQuest: K-nearest neighbors, Clearly Explained. *StatQuest with Josh Starmer.

[4] - Rust Documentation for Trait std::ops::Add.

[5] - Steve Klabnik and Carol Nichols. (2018). *The Rust Programming Language.* No Starch Press.

[6] - Shin Takahashi and Iroha Inoue. (2012). *The Manga Guide to Linear Algebra. *No Starch Press.

[7] - Rust documentation for the where clause.

[8] - Sheldon Axler. (1995). *Linear Algebra Done Right. *Springer.

[9] - Rust documentation for to_vec.

[10] - rand crate documentation of shuffle.

[11] - Rust documentation for sort_unstable_by.

[12] - Jon Gjengset. (2020). *Crust of Rust: Sorting Algorithms.** *

[13] - Jason Orendorff, Jim Blandy, and Leonora F .S. Tindall. (2021). *Programming Rust. *O’Reilly Media.

[14] - Rust documentation for max_by.

[15] - Rust documentation for any.

Creating and running *Shaking off the Rust* is one of the most fulfilling things I do. But it's exhausting. By supporting me, even if it's just a dollar, you'll allow me to put more time into building this series. I really appreciate any support.

The only way to support me right now is by sponsoring me on Github. I'll probably also set up Patreon and Donorbox pages soon.

Thank you so much!

No spam. Unsubscribe anytime.