Shaking off the RustRust Mascot

Naive Bayes Classifier

January 1st, 2022

Thumbnail for Naive Bayes Classifier
Difficulty: Advanced

Shaking off the Rust is a series of exercises with the Rust programing language. Its purpose is to improve both my and my dear reader’s abilities with Rust by building interesting things. Plus, by actually building stuff, we get the bonus benefit of learning about an array of technological concepts in the process. In this installment, we’ll implement the Naive Bayes classifier.

You may encounter a few unfamiliar terms or concepts in this article. Don’t be discouraged. Look these up if you have the time. But either way, this article’s main ideas will not be lost on you.

What is a Naive Bayes Classifier?


The Naive Bayes classifier is a machine learning algorithm based on Bayes’ Theorem. Bayes’ Theorem gives us a way to update the probability of a hypothesis HH, given some data DD.

Expressed mathematically, we have: P(HD)=P(DH)P(H)P(D)P(H|D) = \frac{P(D|H)P(H)}{P(D)}. Where P(HD)=P(H|D) = the probability of HH given DD. If we accumulate more data, we can update P(HD)P(H|D) accordingly.

Naive Bayesian models rest on a big assumption: whether a data point is present or absent from the data set is independent from data already in that set [1]. That is, each piece of data conveys no information about any other data points.

We do not expect this assumption to be true – it is weak. But it’s still useful, allowing us to create efficient classifiers that work quite well [2].

We’ll leave our description of Naive Bayes there. A lot more could be said, but the main point of this article is to practice Rust.

If you’d like to learn more about the algorithm, here are some resources:

The canonical application of the Naive Bayes classifier is a spam classifier. That is what we’ll build. You can find all the code here: https://github.com/josht-jpg/shaking-off-the-rust

Getting Started


We’ll begin by creating a new library with Cargo.

cargo new naive_bayes --lib
cd naive_bayes

Tokenization


Our classifier will take in a message as input and return a classification of spam or not spam.

To work with the message we’re given, we’ll want to tokenized it. Our tokenized representation will be a set of words in lower case where order and repeat entries are disregarded. Rust’s std::collections::HashSet structure is a great way to achieve this.

The function we’ll write to perform tokenization will require the use of the regex crate. Make sure you include this dependency in your Cargo.toml file:

[dependencies]
regex = "^1.5.4"

And here’s the tokenize function:

// lib.rs

// We'll need HashMap later
use std::collections::{HashMap, HashSet};

extern crate regex;
use regex::Regex;

pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
    Regex::new(r"[a-z0-9']+")
        .unwrap()
        .find_iter(lower_case_text)
        .map(|mat| mat.as_str())
        .collect()
}

This function uses a regular expression to match all numbers and lowercase letters. Whenever we come across a different type of symbol (often whitespace or punctuation), we split the input and group together all numbers and letters encountered since the last split (you can read more about regex in Rust here) [4]. That is, we're identifying and isolating words in the input text.

Some Handy Structures


Using a struct to represent a message will be helpful. This struct will contain a string slicefor the message’s text, and a Boolean value to indicate whether or not the message is spam:

struct Message<'a> {
    pub text: &'a str,
    pub is_spam: bool,
}

The 'a is a lifetime parameter annotation. If you’re unfamiliar with lifetimes, and want to learn about them, I recommend reading section 10.3 of The Rust Programming Language Book.

A struct will also be useful to represent our classifier. Before creating the struct, we need a short digression on Laplacian Smoothing.

Digression on Laplace Smoothing:


Assume that, in our training data, the word fubar appears in some non-spam messages, but does not appear in any spam messages. Then, the Naive Bayes classifier will assign a probability 0 of spam to any message that contains the word fubar [5].

Unless we’re talking about my success with online dating, it’s not smart to assign a probability of 0 to an event just because it hasn’t happened yet.

Enter Laplace Smoothing. This is the technique of adding α\alpha to the number of observations of each token [5]. Let’s see this mathematically: without Laplace Smoothing, the probability of seeing a word w in a spam message is:

P(wS)=numberofspammessagescontainingwtotalnumberofspamsP(w|S) = \frac{number\hspace{1mm}of\hspace{1mm}spam\hspace{1mm}messages\hspace{1mm}containing\hspace{1mm}w\hspace{1mm}}{\hspace{1mm}total\hspace{1mm}number\hspace{1mm}of\hspace{1mm}spams} 

With Laplace Smoothing, it’s:

P(wS)=(α+numberofspammessagescontainingw)(2α+totalnumberofspams)P(w|S) = \frac{ (\alpha + number\hspace{1mm}of\hspace{1mm}spam\hspace{1mm}messages\hspace{1mm}containing\hspace{1mm}w\hspace{1mm})}{(2\alpha + \hspace{1mm}total\hspace{1mm}number\hspace{1mm}of\hspace{1mm}spams)} 


Back to our classifier struct:

pub struct NaiveBayesClassifier {
    pub alpha: f64,
    pub tokens: HashSet<String>,
    pub token_ham_counts: HashMap<String, i32>,
    pub token_spam_counts: HashMap<String, i32>,
    pub spam_messages_count: i32,
    pub ham_messages_count: i32,
}

The implementation block for NaiveBayesClassifier will center around a train method and a predict method.

Training our classifier


The train method will take in a slice of Messages and loop through each Message, doing the following:

  • Checks whether the message is spam and updates spam_messages_count or ham_messages_count accordingly. We’ll create the helper function increment_message_classifications_count for this.
  • Tokenizes the message’s contents with our tokenize function
  • Loops through each token in the message and:
    • Inserts the token into the tokens HashSet
    • Updates the token_spam_counts or token_ham_counts. We’ll create the hepler function increment_word_counts for this.

Here’s the pseudocode for our train method. If you feel like it, try to convert the pseudocode into Rust before looking at my implementation below. Don't hesitate to send me your implementation, I’d love to see it!

implemention block for NaiveBayesClassifier {

	train(self, messages) {
		for each message in messages {
			self.increment_message_classifications_count(message)
			
			lowercase_text = to_lowercase( message.text )
			for each token in tokenize(lower_case_text) {
				self.tokens.insert(tokens)
				self.increment_word_counts(token, message.is_spam)
			}			
		}
	}

	increment_message_classifications_count(self, message) {
	  if message.is_spam {
	    self.spam_messages_count = self.spam_messages_count + 1
    } else {
	    self.ham_messages_count = self.ham_messages_count + 1
    }
  }

	increment_word_counts(&mut self, token, is_spam) {
	  if token is not a key of self.token_spam_counts {
			insert record with key=token and value=0 into self.token_spam_counts
    }

	  if token is not a key of self.token_ham_counts {
			insert record with key=token and value=0 into self.token_ham_counts
    }

    if is_spam {
	     self.token_spam_counts[token] = self.token_spam_counts[token] + 1
    } else {
       self.token_ham_counts[token] = self.token_ham_counts[token] + 1
    }
  }

}

And here’s my Rust implementation:

impl NaiveBayesClassifier {
    pub fn train(&mut self, messages: &[Message]) {
        for message in messages.iter() {
            self.increment_message_classifications_count(message);
            for token in tokenize(&message.text.to_lowercase()) {
                self.tokens.insert(token.to_string());
                self.increment_token_count(token, message.is_spam)
            }
        }
    }

    fn increment_message_classifications_count(&mut self, message: &Message) {
        if message.is_spam {
            self.spam_messages_count += 1;
        } else {
            self.ham_messages_count += 1;
        }
    }

    fn increment_token_count(&mut self, token: &str, is_spam: bool) {
        if !self.token_spam_counts.contains_key(token) {
            self.token_spam_counts.insert(token.to_string(), 0);
        }

        if !self.token_ham_counts.contains_key(token) {
            self.token_ham_counts.insert(token.to_string(), 0);
        }

        if is_spam {
            self.increment_spam_count(token);
        } else {
            self.increment_ham_count(token);
        }
    }

    fn increment_spam_count(&mut self, token: &str) {
        *self.token_spam_counts.get_mut(token).unwrap() += 1;
    }

    fn increment_ham_count(&mut self, token: &str) {
        *self.token_ham_counts.get_mut(token).unwrap() += 1;
    }
}

Notice that incrementing a value in a HashMap is pretty cumbersome. A novice Rust programmer would have difficulty understanding what self.token_spam_counts.get_mut(token).unwrap() += 1 is doing.

In an attempt to make the code more explicit, I’ve created the increment_spam_count and increment_ham_countfunctions. But I’m not happy with that – it still feels cumbersome. Reach out to me if you have suggestions for a better approach.

Predicting with our classifier


The predict method will take a string slice and return the model’s calculated probability of spam.

We’ll create two helper functions probabilities_of_message and probabilites_of_token to do the heavy lifting for predict.

probabilities_of_message returns P(MessageSpam)P(Message|Spam) and P(Messageham)P(Message|ham).

probabilities_of_token returns P(TokenSpam)P(Token|Spam) and P(Tokenham)P(Token|ham).

Calculating the probability that the input message is spam involves multiplying together each word’s probability of occurring in a spam message.

Since probabilities are floating point numbers between 0 and 1, multiplying many probabilities together can result in underflow [1]. This is when an operation results in a number smaller than what the computer can accurately store [6][7]. Thus, we’ll use logarithms and exponentials to transform the task into a series of additions:

Πi=0npi=exp(Σi=0nlog(pi))\Pi^n_{i = 0}p_i = exp(\Sigma^n_{i=0} \log(p_i))

Which we can do because for any a,bR,ab=exp(log(ab))=exp(log(a)+log(b))a, b \in \mathbb{R}, ab = exp(log(ab)) = exp(log(a) + \log(b)) (if you find that math physically painful to look at, don’t worry, it’s by no means crucial to this article).

Once again I’ll start with pseudocode for the predict method:

implementation block for NaiveBayesCalssifier {
	/*...*/

	predict(self, text) {
		lower_case_text = to_lowercase(text)
		message_tokens = tokenize(text)
		(prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
		return prob_if_spam / (prob_if_spam + prob_if_ham)
	}
	
	probabilities_of_message(self, message_tokens) {
		log_prob_if_spam = 0
		log_prob_if_ham = 0

		for each token in self.tokens {
			(prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)

			if message_tokens contains token {
				log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
			} else {
				log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
			}
		}

		prob_if_spam = exp(log_prob_if_spam)
		prob_if_ham = exp(log_prob_if_ham)

		return (prob_if_spam, prob_if_ham)
	}

	probabilites_of_token(self, token) {
		prob_of_token_spam = (self.token_spam_counts[token] + self.alpha) 
						/ (self.spam_messages_count + 2 * self.alpha)
        
		prob_of_token_ham = (self.token_ham_counts[token] + self.alpha) 
						/ (self.ham_messages_count + 2 * self.alpha)

		return (prob_of_token_spam, prob_of_token_ham)
	}
	
	
}

And here’s the Rust code:

impl NaiveBayesClassifier {

		/*...*/

	pub fn predict(&self, text: &str) -> f64 {
        let lower_case_text = text.to_lowercase();
        let message_tokens = tokenize(&lower_case_text);
        let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);

        return prob_if_spam / (prob_if_spam + prob_if_ham);
    }

    fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
        let mut log_prob_if_spam = 0.;
        let mut log_prob_if_ham = 0.;

        for token in self.tokens.iter() {
            let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);

            if message_tokens.contains(token.as_str()) {
                log_prob_if_spam += prob_if_spam.ln();
                log_prob_if_ham += prob_if_ham.ln();
            } else {
                log_prob_if_spam += (1. - prob_if_spam).ln();
                log_prob_if_ham += (1. - prob_if_ham).ln();
            }
        }

        let prob_if_spam = log_prob_if_spam.exp();
        let prob_if_ham = log_prob_if_ham.exp();

        return (prob_if_spam, prob_if_ham);
    }

    fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
        let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
            / (self.spam_messages_count as f64 + 2. * self.alpha);

        let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
            / (self.ham_messages_count as f64 + 2. * self.alpha);

        return (prob_of_token_spam, prob_of_token_ham);
    }
}

Testing our Classifier


Let’s give our model a test. The test below goes through Naive Bayes manually, then checks that our model gives the same result.

You may find it worthwhile to go through the test’s logic, or you may just want to paste the code to the bottom of your lib.rs file to check that your code works.

// ...lib.rs

pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
    return NaiveBayesClassifier {
        alpha,
        tokens: HashSet::new(),
        token_ham_counts: HashMap::new(),
        token_spam_counts: HashMap::new(),
        spam_messages_count: 0,
        ham_messages_count: 0,
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn naive_bayes() {
        let train_messages = [
            Message {
                text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
                is_spam: true,
            },
            Message {
                text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
                is_spam: false,
            },
            Message {
                text: "Here in my garage...",
                is_spam: true,
            },
        ];

        let alpha = 1.;
        let num_spam_messages = 2.;
        let num_ham_messages = 1.;

        let mut model = new_classifier(alpha);
        model.train(&train_messages);

        let mut expected_tokens: HashSet<String> = HashSet::new();
        for message in train_messages.iter() {
            for token in tokenize(&message.text.to_lowercase()) {
                expected_tokens.insert(token.to_string());
            }
        }

        let input_text = "Bitcoin crypto academy Christmas deals";

        let probs_if_spam = [
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"Free"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"viagra"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"XXX"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "christmas"  (present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"my"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"dear"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"granddaughter"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"please"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"explain"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"over"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), //"dinner"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"here"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"in"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), //"garage"  (not present)
        ];

        let probs_if_ham = [
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"Free"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"viagra"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"XXX"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "christmas"  (present)
            (0. + alpha) / (num_ham_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"my"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"dear"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"granddaughter"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"please"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"explain"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"over"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), //"dinner"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"here"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"in"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), //"garage"  (not present)
        ];

        let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
        let p_if_spam = p_if_spam_log.exp();

        let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
        let p_if_ham = p_if_ham_log.exp();

        // P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
        assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
    }
}

Now run cargo test. If that passes for you, well done, you’ve implemented a Naive Bayes classifier in Rust!

Thank you for coding with me, friends. Feel free to reach out if you have any feedback or questions: joshtaylor361@gmail.com. Cheers!

References


[1] Grus, J. (2019). Data Science from Scratch: First Principles with Python, 2nd edition. O’Reilly Media.

[2] Downey, A. (2021). Think Bayes: Bayesian Statistics in Python, 2nd edition.O’Reilly Media.

[3] Murphy, K. (2012). Machine Learning: A Probabilistic Perspective. MIT Press.

[4] Dhinakaran, V. (2017). Rust Cookbook. Packt.

[5] Ng, A. (2018). Stanford CS229: Lecture 5 - GDA & Naive Bayes.

[6] Burden, R. Faires, J. Burden, A. (2015). Numerical Analysis, 10th edition. Brooks Cole.

[7] Underflow. Technopedia.


Support Me

Creating and running Shaking off the Rust is one of the most fulfilling things I do. But it's exhausting. By supporting me, even if it's just a dollar, you'll allow me to put more time into building this series. I really appreciate any support.

The only way to support me right now is by sponsoring me on Github. I'll probably also set up Patreon and Donorbox pages soon.

Thank you so much!

Rust up your inbox!

Subscribe

No spam. Unsubscribe anytime.