Shaking off the RustRust Mascot

K Nearust Neighbors

May 1st, 2022

Thumbnail for K Nearust Neighbors
Difficulty: Intermediate

Shaking off the Rust is a series of exercises with the Rust programing language. The purpose of the series is to improve both my and my dear reader’s abilities with Rust by building things. Plus, by actually building stuff, we'll learn about an array of technological concepts in the process. In this installment, we’re going to implement a classic machine learning algorithm.

After reading this installment, you'll have experience with:

  • Rust’s if let syntax
  • Rust’s Clone trait
  • Extension traits
  • A little bit of lifetime parameters
  • Rust’s where clause
  • The kk nearest neighbors algorithm
  • Much more!

This installment’s Github repo: https://github.com/josht-jpg/k-nearust-neighbors

K Nearust Neighbors


The kk nearest neighbors (KNN) algorithm is simple.

It’s a good algorithm for classifying data [1]. Suppose we have a labeled dataset, which we'll denote DD, and an unlabeled data point, which we'll denote dd, and we want to predict the correct label for dd. We can do that with KNN.

KNN works like this:

For some integer kk, we find the kk data points in DD nearest to dd — the kk nearest neighbors. For an example of what I mean by nearest: in the graph below, the blue data points are the 3 nearest neighbors of the red data point.

knn example

For our computer to find data points nearby dd, we need a way to measure distance between data points. We can use the Pythagorean formula to do that [2]:

The distance between two data points xx and yy with features x1,...,xnx_1, ..., x_n and y1,...,yny_1, ..., y_n is (x1y1)2+...+(xnyn)2\sqrt{(x_1 - y_1)^2 + ... + (x_n - y_n)^2}.

And to predict the label for dd, we pick the most common label from its kk nearest labeled data points.

Finally, we have to handle the possible scenario of a tie for the most common label. There are a few ways to handle this. Our approach will be to decrement kk until there’s no longer a tie. That is, we remove the furthest label from our nearest labels and recount the most common labels.

We, the implementors of KNN, specify the value for kk. Choosing a good value for kk is usually a process of trying and testing several values [3].

If you feel like watching a quick video on KNN, here’s a great one from Stat Quest with Josh Starmer (and I can’t say enough good things about that youtube channel. Thank you, Josh):

Getting Started


To get started, we’ll create a new library called k-nearust-neighbors.

cargo new k-nearust-neighbors --lib
cd k-nearust-neighbors

We’ll also bring in the following crates as dev-dependencies.

// Cargo.toml

/*...*/

[dev-dependencies]
rand = "0.8.4"
reqwest = "0.11.10"
tokio-test = "*"

rand is a crate for random number generation. You can read more about it here: https://crates.io/crates/rand. reqwest and tokio_test will be used to get some data to test our KNN classifier on. You can read about the reqwest crate here and the tokio_test crate here.

And we’ll toss this use declaration into our lib.rs file.

// lib.rs

use std::{
    collections::HashMap,
    ops::{Add, Sub}, 
}

    - Add and Sub are traits that specify how the addition and subtraction operators work. If some type T implements Add, then for two values a and b of type T, we can write a + b. Same idea for Sub [4].

Representing a Data Point in our Code


We’ll use a struct to represent a data point. We’ll name it LabeledPoint. Each LabeledPoint will have fields:

  • label: a string slice that categorizes the data,
  • point: a vector containing the data point’s features (which will be represented as f64s).

Here is LabeledPoint in Rust:

// lib.rs

#[derive(Clone)]  
struct LabeledPoint<'a> { 
	label: &'a str,     
	point: Vec<f64>,
}

    ➀ - We’re using the derive attribute to implement the Clone trait. This means we’re asking Rust to generate code for the Clone trait’s default implementation and apply it to LabeledPoint. The Clone trait will let us explicitly create deep copies of LabeledPoint instances [5]. This will allow us to call the to_vec method on a slice of LabelPoints (which we’ll do in our KNN implementation).

    - We declare that our struct is generic over the lifetime parameter 'a. In Rust, a lifetime is the scope for which a reference is valid [5]. Rust’s lifetimes are part of what makes the language special. If you’d like to learn about them, I recommend reading section 10.3 of The Rust Programming Language book or watching this stream from Ryan Levick.

    • Sidenote: I highly recommend all of Ryan Levick’s content. Thanks for doing what you do, Ryan.

    - For a struct to hold a reference, it must have a lifetime annotation for that reference [5]. label is a string slice, and string slices are references. Thus, we must add a lifetime annotation to label, which we do by putting 'a after & in label: &'a str. Awesome.

Smells like Linear Algebra


To implement KNN, we need to compare distances between data points. This calls for some linear algebra. Fun.

(But if linear algebra does not sound like fun to you, feel free to copy and paste this code into your lib.rs and skip to the next section - I’ll forgive you).

A quick review of vectors may be helpful (here I’m talking about vectors from linear algebra, not vectors from Rust).


A vector can be thought of as a list of numbers.

There are a few ways that a vector can be interpreted. One common interpretation is a point in space [6]. Take the vector [2.5,3,4][2.5, 3, 4], for example. This is what it looks like as a point in space:

Vector example

This is how we will interpret vectors: as points in space. Cool.

Sidenote: the above plot was created with Rust! It was made with the plotters crate. Feel free to take a look at the code I slapped together to generate this plot: https://github.com/josht-jpg/vector_plot


We’re going to define a trait called LinearAlg:

trait LinearAlg<T>
where
    T: Add + Sub,  
{
    fn dot(&self, w: &[T]) -> T;

    fn subtract(&self, w: &[T]) -> Vec<T>;

    fn sum_of_squares(&self) -> T;

    fn distance(&self, w: &[T]) -> f64;
}

    - Rust’s where clause lets us specify that the generic type T must implement the Add and Sub traits [7].

And we’ll make LinearAlg an extension trait, implementing it for the standard library’s Vec<f64> type.

impl LinearAlg<f64> for Vec<f64> { /*...*/ }

We’re going to take a test-first approach here. For each method in this implemetation of LinearAlg, I’ll go over the math behind the operation, then provide a test for the method, and then some code that implements the method.

It’s a good exercise to try writing each method yourself and running the test before looking at my implementation. There’s a good chance you’ll like your implementation more (and if you do, please share it with me at joshtaylor361@gmail.com).

  • Dot Product - For two vectors v\bold{v} and w\bold{w} of the same length, the dot product of v\bold{v} and w\bold{w} is the result of coupling up each corresponding element in v\bold{v} and w\bold{w}, multiplying those two elements together, and adding each result.

    vw=[v1v2vn][w1w2wn]=v1w1+v2w2+...+vnwn\bold{v} \cdot \bold{w}= \begin{bmatrix} \bold{v}_{1} \\ \bold{v}_{2} \\ \vdots \\ \bold{v}_{n}\end{bmatrix} \cdot \begin{bmatrix} \bold{w}_{1} \\ \bold{w}_{2} \\ \vdots \\ \bold{w}_{n}\end{bmatrix} = \bold{v}_1 \cdot \bold{w}_1 + \bold{v}_2 \cdot \bold{w}_2 + ... + \bold{v}_n \cdot \bold{w}_n [8].

    Here’s our test for dot:

    // lib.rs
    
    /*...*/
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn linear_alg() {
            let v = vec![1., 5., -3.];
            let w = vec![0.5, 2., 3.];
    
            assert_eq!(v.dot(&w), 1.5)
        }
    }

    Run that test with the command cargo test linear_alg in the root of your k_nearust_neighbors folder. Congratulations if that passes for you.

    Here is my implementation of dot:

    fn dot(&self, w: &[f64]) -> f64 {
    	assert_eq!(self.len(), w.len());
    	self.iter().zip(w).map(|(v_i, w_i)| v_i * w_i).sum()
    }

  • Subtract - This one’s simpler. For two vectors of the same length v\bold{v} and w\bold{w},

    vw=[v1v2vn][w1w2wn]=[v1w1v2w2vnwn]\bold{v} - \bold{w} = \begin{bmatrix} \bold{v}_{1} \\ \bold{v}_{2} \\ \vdots \\ \bold{v}_{m}\end{bmatrix} - \begin{bmatrix} \bold{w}_{1} \\ \bold{w}_{2} \\ \vdots \\ \bold{w}_{m}\end{bmatrix} = \begin{bmatrix} \bold{v}_{1} - \bold{w}_{1} \\ \bold{v}_{2} - \bold{w}_{2} \\ \vdots \\ \bold{v}_{n} - \bold{w}_{n}\end{bmatrix}

    To test subtract, add the following assertion to your linear_alg test function:

    #[test]
    fn linear_alg() {
    	/*...*/
    	assert_eq!(v.subtract(&w), vec![0.5, 3., -6.]);
    }

    Nice. I hope you were able to make that test pass (but of course no worries if you weren’t). Here’s an implementation of subtract:

    fn subtract(&self, w: &[f64]) -> Vec<f64> {
    	assert_eq!(self.len(), w.len());
    	self.iter().zip(w).map(|(v_i, w_i)| v_i - w_i).collect()
    }
  • Sum of Squares - A vector’s sum of squares is the result of squaring each of its elements and adding everything up:

    For some vector v\bold{v} with elements v1,...,vn\bold{v}_1, ..., \bold{v}_n, sumofsquares(v)=(v1)2+...+(vn)2sum \hspace{1mm} of \hspace{1mm} squares(\bold{v}) = (\bold{v}_1)^2 + ... + (\bold{v}_n)^2

    Here’s a test for sum_of_squares:

    #[test]
    fn linear_alg() {
    	/*...*/
    	assert_eq!(v.sum_of_squares(), 35.);
    }

    And here’s my implementation:

    fn sum_of_squares(&self) -> f64 {
    	self.dot(&self)
    }
  • Distance - The distance between two vectors v\bold{v} and w\bold{w} is defined as

    (v1w1)2+...+(vnwn)2\sqrt{(\bold{v}_1 - \bold{w}_1)^2 + ... + (\bold{v}_n - \bold{w}_n)^2} 

    As usual, here’s a test for distance:

    assert_eq!(v.distance(&w), 45.25f64.sqrt())

    Hallelujah. Here’s some Rust:

    fn distance(&self, w: &[f64]) -> f64 {
    	assert_eq!(self.len(), w.len());
    	self.subtract(w).sum_of_squares().sqrt()
    }

Implementing LinearAlgbra for the Vec<f64> type is all we need for our KNN implementation, so we’ll leave it there. Great.

Loading Data, Processing Data, and Splitting Data. Gnarly.


I’d like to continue our test-first approach. So before we get to the most important function of this installment, which will be called knn_classify, we’ll write a test for it.

But we’re going to need some data to test on. A classic dataset to test KNN on is the iris flower dataset. This data set contains 150 rows, where each row contains a flower’s petal length, petal width, sepal length, sepal width, and type. The dataset has three types of iris: Setosa, Versicolor, and Virginica.

Here’s some code for you to put in lib.rs. It gets the iris dataset and converts it to a format we can work with. Look through the code if you’re curious, but I won’t be explaining any of it. I’d like to save time for more interesting stuff.

#[cfg(test)]
mod tests {
    use super::*;

		/*...*/

    macro_rules! await_fn {
        ($arg:expr) => {{
            tokio_test::block_on($arg)
        }};
    }

    async fn get_iris_data() -> Result<String, reqwest::Error> {
        let body = reqwest::get(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
        )
        .await?
        .text()
        .await?;

        Ok(body)
    }

    type GenericResult<T> = Result<T, Box<dyn std::error::Error>>;

    fn process_iris_data(body: &str) -> GenericResult<Vec<LabeledPoint>> {
        body.split("\n")
            .filter(|data_point| data_point.len() > 0)
            .map(|data_point| -> GenericResult<LabeledPoint> {
                let columns = data_point.split(",").collect::<Vec<&str>>();
                let (label, point) = columns.split_last().ok_or("Cannot split last")?;
                let point = point
                    .iter()
                    .map(|feature| feature.parse::<f64>())
                    .collect::<Result<Vec<f64>, std::num::ParseFloatError>>()?;

                Ok(LabeledPoint { label, point })
            })
            .collect::<GenericResult<Vec<LabeledPoint>>>()
    }
}

Sweet.

Next, we need to split that data into a training set and a testing set. Here’s a function to do just that:

mod tests {

	/*...*/

	use rand::{seq::SliceRandom, thread_rng};

	fn split_data<T>(data: &[T], prob: f64) -> (Vec<T>, Vec<T>)
  where
      T: Clone, 
  {
      let mut data_copy = data.to_vec(); 
      data_copy.shuffle(&mut thread_rng()); 
      let split_index = ((data.len() as f64) * prob).round() as usize;

      (
          data_copy[..split_index].to_vec(),
          data_copy[split_index..].to_vec(),
      )
  }

}

    - Using the where clause to specify that T must implement the Clone trait.

    - data.to_vec() copies the data slice into a new Vec [9]. This allows us to shuffle our data without taking a mutable reference to the data.

    - The shuffle method will shuffle up a mutable slice in place [10]. shuffle is from the rand crate’s SliceRandom trait, which is an extension trait on slices. So we get to use shuffle on mutable slices after importing the trait.

    The thread_rng function, which also comes from rand, is a random number generator.

Great. We’ve got everything we need to test our (currently unimplemented) knn_classify function.

We’re going to set kk to 5; we’ll classify new data points based on their 5 nearest neighbors. In a real application of KNN, it would be a good idea to test out a few more values of kk.

Here’s the test:

fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
	todo!()
}

#[cfg(test)]
mod tests {
		/*...*/

		fn count_correct_classifications(
        train_set: &[LabeledPoint],
        test_set: &[LabeledPoint],
        k: u8,
    ) -> u32 {
        let mut num_correct: u32 = 0;

        for iris in test_set.iter() {
            let predicted = knn_classify(k, &train_set, &iris.point);
            let actual = iris.label;

            if let Some(predicted) = predicted { 
                if predicted == actual {
                    num_correct += 1;
                }
            }
        }

        num_correct
    }

		#[test]
    fn iris() -> GenericResult<()> {
        let raw_iris_data = await_fn!(get_iris_data())?;
        let iris_data = process_iris_data(&raw_iris_data)?;

        let (train_set, test_set) = split_data(&iris_data, 0.70); 
        assert_eq!(train_set.len(), 105);
        assert_eq!(test_set.len(), 45);

        let k = 5;
        let num_correct = count_correct_classifications(&train_set, &test_set, k);
        let percent_corrent = num_correct as f32 / test_set.len() as f32;

        assert!(percent_corrent > 0.9); 

        Ok(())
    }

}

    - The if let syntax is a lovely way for us to match one pattern and ignore all other patterns [5]. So an alternative (but less Rustic) way to write this block is:

    match predicted {
    	Some(predicted) => {
    		if predicted == actual {
    			num_correct += 1;
    		}
    	}
    	_ => (),
    }

    - Splitting 70% of the data into a set for training the classifier, and 30% into a set for testing.

    - If our classifier is working, it should correctly classify at least 90% of the testing set.

Won’t you be my Neighbor? Implementing KNN


I’ll start with pseudocode for our KNN classifier. Try to write your own Rust implementation based on this pseudocode, and run the test we wrote to see if your implementation works.

function knn_classify(k, data_points, new_point) 
arguments {
	k: number of neighbors we use to classify our new data point
	data_points: our labeled data points
	new_point: the data point we want to classify	
}
returning: predicted label for new_point
{ 
	sorted_data_points = sort_by_distance_from(data_points, new_point)	
  
	k_nearest_labels = empty list

	for i from 0 to k {
		k_nearest_labels.append(data_points[i].label)
	}

	predicted_label = find_most_common_label(k_nearest_labels)
	return predicted_label
}

function find_most_common_label(labels) 
arguments {
	labels: a list of labels
}
returning: most common value in the passed in list of labels
{
  label_counts = new Hash Map

  for label in labels {
		if label is a key in label_counts {
			label_counts[label] += 1
		} else {
			label_counts.add_key_value_pair((label, 1))
		}
  }

	if there are no ties for most common label in label_counts {
		return key with highest value in label_counts
	} else {
		new_labels = all elements in labels but the last
		return find_most_common_label(new_labels)
	}
}

Great. Now here is some Rust for you.

fn knn_classify(k: u8, data_points: &[LabeledPoint], new_point: &[f64]) -> Option<String> {
    let mut data_points_copy = data_points.to_vec();

    data_points_copy.sort_unstable_by(|a, b| { 
        let dist_a = a.point.distance(new_point);
        let dist_b = b.point.distance(new_point);

        dist_a
            .partial_cmp(&dist_b)
            .expect("Cannot compare floating point numbers, encoutered a NAN") 
    });

    let k_nearest_labels = &data_points_copy[..(k as usize)]
        .iter()
        .map(|a| a.label)
        .collect::<Vec<&str>>();

    let predicted_label = find_most_common_label(&k_nearest_labels);
    predicted_label
}

fn find_most_common_label(labels: &[&str]) -> Option<String> {
    let mut label_counts: HashMap<&str, u32> = HashMap::new(); 

    for label in labels.iter() {
        let current_label_count = if let Some(current_label_count) = label_counts.get(label) { 
            *current_label_count
        } else {
            0
        };
        label_counts.insert(label, current_label_count + 1);
    }

    let most_common = label_counts
        .iter()
        .max_by(|(_label_a, count_a), (_label_b, count_b)| count_a.cmp(&count_b)); 

    if let Some((most_common_label, most_common_label_count)) = most_common {
        let is_tie_for_most_common = label_counts
            .iter()
            .any(|(label, count)| count == most_common_label_count && label != most_common_label); 

        if !is_tie_for_most_common {
            return Some((*most_common_label).to_string());
        } else {
            let (_last, labels) = labels.split_last()?; 
            return find_most_common_label(&labels);
        }
    }

    None
}

    - sort_unstable_by allows us to specify the way we want our vector sorted. We’re using this to sort the data points in data_copy by their distance from the data point we’re classifying.

    Digression on stable vs. unstable sorting:


    As the name suggests, sort_unstable_by is unstable [11]. This means that when the sorting algorithm comes across two equal elements, it is allowed to swap them - whereas a stable sort will not swap equal elements [12].

    Unstable sorting is generally faster and requires less memory than stable sorting.

    Rust’s sort_by method is stable. So in cases like ours where you don’t care if equal elements are swapped, it’s preferable to use sort_unstable_by over sort_by.


    - partial_ord returns a value specifying whether dist_a is greater to, less than, or equal to dist_b, if such a comparison can be made.

    In more gory details, partial_ord returns an Option containing a variant of the Ordering enum, if an ordering exists (if not it will return None).

    The creators of Rust have intentionally not implemented Ord for f64. This is because an f64 can be a NAN (not a number), and in Rust NAN != NAN. So f64 does not form a total order.

    If some or all of that didn’t make sense, do not worry. Just note that we have to use partial_cmp for f64 types rather than cmp.

    - We create a HashMap with &str keys and u32 values. The keys are the distinct elements of the labels slice. The values are the number of times each label shows up in labels.

    - Rust is an expression language [13]. This means that things like if and match will produce values. So we can assign a variable to the result of an if-else block. Cool.

    - Here we want to find the most common label in label_counts. We iterate through each key-value pair of label_counts and use max_by to find the key-value pair with the highest value. max_by returns the element with the maximum value with respect to a custom comparison function [14].

    - The any method checks if the provided predicate is true for any elements in the iterator [15]. We use any to see if there are any ties for most common label in label_counts.

    - split_last returns an Option containing a tuple of the last element of a vector, and a slice containing the rest of that vector.

And that’s our K Nearust Neighbors classifier. Amazing.

As always, questions and feedback are welcome and appreciated: joshtaylor361@gmail.com.

Have a great rest of your day.

Mr Rogers Wave

References


[1] - Joel Grus. (2019). Data Science from Scratch: First Principles with Python, 2nd edition. O’Reilly Media.

[2] - Aditya Y. Bhargava. (2015). Grokking Algorithms. Manning Publications.

[3] - Josh Starmer. (2017). StatQuest: K-nearest neighbors, Clearly Explained. StatQuest with Josh Starmer.

[4] - Rust Documentation for Trait std::ops::Add.

[5] - Steve Klabnik and Carol Nichols. (2018). The Rust Programming Language. No Starch Press.

[6] - Shin Takahashi and Iroha Inoue. (2012). The Manga Guide to Linear Algebra. No Starch Press.

[7] - Rust documentation for the where clause.

[8] - Sheldon Axler. (1995). Linear Algebra Done Right. Springer.

[9] - Rust documentation for to_vec.

[10] - rand crate documentation of shuffle.

[11] - Rust documentation for sort_unstable_by.

[12] - Jon Gjengset. (2020). Crust of Rust: Sorting Algorithms.

[13] - Jason Orendorff, Jim Blandy, and Leonora F .S. Tindall. (2021). Programming Rust. O’Reilly Media.

[14] - Rust documentation for max_by.

[15] - Rust documentation for any.


Support Me

Creating and running Shaking off the Rust is one of the most fulfilling things I do. But it's exhausting. By supporting me, even if it's just a dollar, you'll allow me to put more time into building this series. I really appreciate any support.

The only way to support me right now is by sponsoring me on Github. I'll probably also set up Patreon and Donorbox pages soon.

Thank you so much!

Rust up your inbox!

Subscribe

No spam. Unsubscribe anytime.