{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Naive Bayes example\n", "Let's explore implementation of Naive Bayes!" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# create a simple function to tokenize messages into distinct words\n", "from typing import Set\n", "import re\n", "\n", "def tokenize(text: str) -> Set[str]:\n", " text = text.lower() # Convert to lowercase,\n", " all_words = re.findall(\"[a-z0-9']+\", text) # extract the words, and\n", " return set(all_words) # remove duplicates.\n", "\n", "assert tokenize(\"Data Science is science\") == {\"data\", \"science\", \"is\"}\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# define a type for our training data\n", "from typing import NamedTuple\n", "\n", "class Message(NamedTuple):\n", " text: str\n", " is_spam: bool" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# As our classifier needs to keep track of tokens, counts, and labels from the training data, we’ll make it a class.\n", "from typing import List, Tuple, Dict, Iterable\n", "import math\n", "from collections import defaultdict\n", "\n", "class NaiveBayesClassifier:\n", " def __init__(self, k: float = 0.5) -> None:\n", " self.k = k # smoothing factor\n", "\n", " self.tokens: Set[str] = set()\n", " self.token_spam_counts: Dict[str, int] = defaultdict(int)\n", " self.token_ham_counts: Dict[str, int] = defaultdict(int) # we refer to nonspam emails as ham emails\n", " self.spam_messages = self.ham_messages = 0\n", "\n", " # Next, we’ll give it a method to train it on a bunch of messages\n", " def train(self, messages: Iterable[Message]) -> None:\n", " for message in messages:\n", " # Increment message counts\n", " if message.is_spam:\n", " self.spam_messages += 1\n", " else:\n", " self.ham_messages += 1\n", "\n", " # Increment word counts\n", " for token in tokenize(message.text):\n", " self.tokens.add(token)\n", " if message.is_spam:\n", " self.token_spam_counts[token] += 1\n", " else:\n", " self.token_ham_counts[token] += 1\n", " \n", " # Ultimately we’ll want to predict P(spam | token). \n", " #As we saw earlier, to apply Bayes’s theorem we need to know P(token | spam) and P(token | ham) \n", " #for each token in the vocabulary. So we’ll create a “private” helper function to compute those:\n", " \n", " def _probabilities(self, token: str) -> Tuple[float, float]:\n", " \"\"\"returns P(token | spam) and P(token | ham)\"\"\"\n", " spam = self.token_spam_counts[token]\n", " ham = self.token_ham_counts[token]\n", "\n", " p_token_spam = (spam + self.k) / (self.spam_messages + 2 * self.k)\n", " p_token_ham = (ham + self.k) / (self.ham_messages + 2 * self.k)\n", "\n", " return p_token_spam, p_token_ham \n", " \n", " # finally we have the predict function\n", " def predict(self, text: str) -> float:\n", " text_tokens = tokenize(text)\n", " log_prob_if_spam = log_prob_if_ham = 0.0\n", "\n", " # Iterate through each word in our vocabulary\n", " for token in self.tokens:\n", " prob_if_spam, prob_if_ham = self._probabilities(token)\n", "\n", " # If *token* appears in the message,\n", " # add the log probability of seeing it\n", " if token in text_tokens:\n", " log_prob_if_spam += math.log(prob_if_spam)\n", " log_prob_if_ham += math.log(prob_if_ham)\n", "\n", " # Otherwise add the log probability of _not_ seeing it,\n", " # which is log(1 - probability of seeing it)\n", " else:\n", " log_prob_if_spam += math.log(1.0 - prob_if_spam)\n", " log_prob_if_ham += math.log(1.0 - prob_if_ham)\n", "\n", " prob_if_spam = math.exp(log_prob_if_spam)\n", " prob_if_ham = math.exp(log_prob_if_ham)\n", " return prob_if_spam / (prob_if_spam + prob_if_ham)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing Our Model\n", "Let’s make sure our model works by writing some unit tests for it." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "messages = [Message(\"spam rules\", is_spam=True),\n", " Message(\"ham rules\", is_spam=False),\n", " Message(\"hello ham\", is_spam=False)]\n", "\n", "model = NaiveBayesClassifier(k=0.5)\n", "model.train(messages)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### First, let’s check that it got the counts right:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "assert model.tokens == {\"spam\", \"ham\", \"rules\", \"hello\"}\n", "assert model.spam_messages == 1\n", "assert model.ham_messages == 2\n", "assert model.token_spam_counts == {\"spam\": 1, \"rules\": 1}\n", "assert model.token_ham_counts == {\"ham\": 2, \"rules\": 1, \"hello\": 1}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now let’s make a prediction. \n", "We’ll also (laboriously) go through our Naive Bayes logic by hand, and make sure that we get the same result" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "text = \"hello spam\"\n", "\n", "probs_if_spam = [\n", " (1 + 0.5) / (1 + 2 * 0.5), # \"spam\" (present in the text)\n", " 1 - (0 + 0.5) / (1 + 2 * 0.5), # \"ham\" (not present in the text)\n", " 1 - (1 + 0.5) / (1 + 2 * 0.5), # \"rules\" (not present in the text)\n", " (0 + 0.5) / (1 + 2 * 0.5) # \"hello\" (present in the text)\n", "]\n", "\n", "probs_if_ham = [\n", " (0 + 0.5) / (2 + 2 * 0.5), # \"spam\" (present in the text)\n", " 1 - (2 + 0.5) / (2 + 2 * 0.5), # \"ham\" (not present in the text)\n", " 1 - (1 + 0.5) / (2 + 2 * 0.5), # \"rules\" (not present in the text)\n", " (1 + 0.5) / (2 + 2 * 0.5), # \"hello\" (present in the text)\n", "]\n", "\n", "p_if_spam = math.exp(sum(math.log(p) for p in probs_if_spam))\n", "p_if_ham = math.exp(sum(math.log(p) for p in probs_if_ham))\n", "\n", "\n", "# Should be about 0.83\n", "assert model.predict(text) - p_if_spam / (p_if_spam + p_if_ham) <= 0.0001 \n", "# If this test passes, it seems like our model is doing what we think it is." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Our Model\n", "A popular (if somewhat old) dataset is the SpamAssassin public corpus. We’ll look at the files prefixed with 20021010.\n", "Here is a script that will download and unpack them to the directory of your choice (or you can do it manually):" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from io import BytesIO # So we can treat bytes as a file.\n", "import requests # To download the files, which\n", "import tarfile # are in .tar.bz format.\n", "\n", "BASE_URL = \"https://spamassassin.apache.org/old/publiccorpus\"\n", "FILES = [\"20021010_easy_ham.tar.bz2\",\n", " \"20021010_hard_ham.tar.bz2\",\n", " \"20021010_spam.tar.bz2\"]\n", "\n", "# This is where the data will end up,\n", "# in /spam, /easy_ham, and /hard_ham subdirectories.\n", "# Change this to where you want the data.\n", "OUTPUT_DIR = 'spam_data'\n", "\n", "for filename in FILES:\n", " # Use requests to get the file contents at each URL.\n", " content = requests.get(f\"{BASE_URL}/{filename}\").content\n", "\n", " # Wrap the in-memory bytes so we can use them as a \"file.\"\n", " fin = BytesIO(content)\n", "\n", " # And extract all the files to the specified output dir.\n", " with tarfile.open(fileobj=fin, mode='r:bz2') as tf:\n", " tf.extractall(OUTPUT_DIR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### After downloading the data you should have three folders: spam, easy_ham, and hard_ham. Check it by yourself\n", "To keep things really simple, we’ll just look at the subject lines of each email." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we identify the subject line? " ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import glob, re\n", "\n", "# modify the path to wherever you've put the files\n", "path = 'spam_data/*/*'\n", "\n", "data: List[Message] = []\n", "\n", "# glob.glob returns every filename that matches the wildcarded path\n", "for filename in glob.glob(path):\n", " is_spam = \"ham\" not in filename\n", "\n", " # There are some garbage characters in the emails; the errors='ignore'\n", " # skips them instead of raising an exception.\n", " with open(filename, errors='ignore') as email_file:\n", " for line in email_file:\n", " if line.startswith(\"Subject:\"):\n", " subject = line.lstrip(\"Subject: \")\n", " data.append(Message(subject, is_spam))\n", " break # done with this file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now we can split the data into training data and test data, and then we’re ready to build a classifier:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import random\n", "from typing import TypeVar, List, Tuple\n", "X = TypeVar('X') # generic type to represent a data point\n", "\n", "def split_data(data: List[X], prob: float) -> Tuple[List[X], List[X]]:\n", " \"\"\"Split data into fractions [prob, 1 - prob]\"\"\"\n", " data = data[:] # Make a shallow copy\n", " random.shuffle(data) # because shuffle modifies the list.\n", " cut = int(len(data) * prob) # Use prob to find a cutoff\n", " return data[:cut], data[cut:] # and split the shuffled list there.\n", "\n", "random.seed(0) # just so you get the same answers as me\n", "train_messages, test_messages = split_data(data, 0.75)\n", "\n", "model = NaiveBayesClassifier()\n", "model.train(train_messages)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let’s generate some predictions and check how our model does" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Counter({(False, False): 669, (True, True): 86, (True, False): 40, (False, True): 30})\n" ] } ], "source": [ "from collections import Counter\n", "\n", "predictions = [(message, model.predict(message.text))\n", " for message in test_messages]\n", "\n", "# Assume that spam_probability > 0.5 corresponds to spam prediction\n", "# and count the combinations of (actual is_spam, predicted is_spam)\n", "confusion_matrix = Counter((message.is_spam, spam_probability > 0.5)\n", " for message, spam_probability in predictions)\n", "\n", "print(confusion_matrix)\n", "\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6825396825396826\n", "0.7413793103448276\n" ] } ], "source": [ "recall = confusion_matrix[(True,True)] / (confusion_matrix[(True,True)] + confusion_matrix[(True,False)])\n", "print(recall)\n", "\n", "precision = confusion_matrix[(True,True)] / (confusion_matrix[(True,True)] + confusion_matrix[(False,True)])\n", "print(precision)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### We can also inspect the model’s innards to see which words are least and most indicative of spam:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "spammiest_words ['assistance', '95', 'attn', 'clearance', 'per', 'money', 'sale', 'rates', 'systemworks', 'adv']\n", "hammiest_words ['spambayes', 'users', 'razor', 'zzzzteana', 'sadev', 'apt', 'perl', 'ouch', 'spamassassin', 'bliss']\n" ] } ], "source": [ "def p_spam_given_token(token: str, model: NaiveBayesClassifier) -> float:\n", " # We probably shouldn't call private methods, but it's for a good cause.\n", " prob_if_spam, prob_if_ham = model._probabilities(token)\n", "\n", " return prob_if_spam / (prob_if_spam + prob_if_ham)\n", "\n", "words = sorted(model.tokens, key=lambda t: p_spam_given_token(t, model))\n", "\n", "print(\"spammiest_words\", words[-10:])\n", "print(\"hammiest_words\", words[:10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How could we get better performance? \n", "Get more data? Or other ideas? " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }