diff --git a/parsernaam/notebooks/pos_name_parser.ipynb b/parsernaam/notebooks/pos_name_parser.ipynb
new file mode 100644
index 0000000..14c44fe
--- /dev/null
+++ b/parsernaam/notebooks/pos_name_parser.ipynb
@@ -0,0 +1,452 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "5f1fd291-b290-403e-a284-f9fb0288c233",
+ "metadata": {},
+ "source": [
+ "### Name Parsing With Context"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "40aa5bc2-797a-4bcb-b9ea-c3c0dbbaa125",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import random\n",
+ "import time\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "import torch.nn.utils.rnn as rnn_utils\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torch.nn.utils.rnn import pad_sequence"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "d7035459-09b6-466f-9b40-a09899c88334",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(8909970, 3)"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = pd.read_csv(\"../data/fl_reg_name_race_2022.csv.gz\")\n",
+ "df.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "6d4ecbb7-fe73-47b3-a8e4-36de4e1e318f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " last_name | \n",
+ " first_name | \n",
+ " race | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Binkley | \n",
+ " Kathryn | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Brock | \n",
+ " Lakaya | \n",
+ " 3.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Fontaine | \n",
+ " Charles | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Posselt | \n",
+ " Suzanne | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Haeseler | \n",
+ " Bala | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " last_name first_name race\n",
+ "0 Binkley Kathryn 5.0\n",
+ "1 Brock Lakaya 3.0\n",
+ "2 Fontaine Charles 5.0\n",
+ "3 Posselt Suzanne 5.0\n",
+ "4 Haeseler Bala 5.0"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "cf8282c2-d002-4f4e-a08e-71d25cc1b9c8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df.drop_duplicates(subset=['last_name', 'first_name'], inplace = True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "aae5230f-6eee-40ac-b95d-34550876879b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenized_sentences = []\n",
+ "pos_labels = []\n",
+ "\n",
+ "random.seed(42)\n",
+ "\n",
+ "for _, row in df.sample(n = 1000000).iterrows():\n",
+ " if random.random() < 0.5:\n",
+ " tokenized_sentences.append([row['last_name'], row['first_name']])\n",
+ " pos_labels.append(['last_name', 'first_name'])\n",
+ " else:\n",
+ " tokenized_sentences.append([row['first_name'], row['last_name']])\n",
+ " pos_labels.append(['first_name', 'last_name'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "bb0a6dd8-6471-440e-be1c-baee3c706c08",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[['Borinquen', 'Medina-Torres'], ['Uddin', 'Asma'], ['Mecoli', 'Donald'], ['Ibargoyen', 'Yansari'], ['Judy', 'Bales'], ['Marlyn', 'Diaz Alvarado'], ['Holly', 'Korman'], ['Masso', 'Justina'], ['Drake', 'Brandon'], ['Swabowicz', 'Michael']]\n",
+ "[['first_name', 'last_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['first_name', 'last_name'], ['first_name', 'last_name'], ['first_name', 'last_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['last_name', 'first_name']]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokenized_sentences[:10])\n",
+ "print(pos_labels[:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "46036d2c-0e01-4d12-888e-e5047780ee33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vocab = {word: idx for idx, word in enumerate(set(word for sentence in tokenized_sentences for word in sentence))}\n",
+ "pos_tags = {tag: idx for idx, tag in enumerate(set(tag for label_set in pos_labels for tag in label_set))}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8a3bf61a-38ca-4642-aed6-c0dce1fcf627",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert words and tags to indices\n",
+ "tokenized_sentences_idx = [[vocab[word] for word in sentence] for sentence in tokenized_sentences]\n",
+ "pos_labels_idx = [[pos_tags[tag] for tag in label_set] for label_set in pos_labels]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "9ab88886-6b78-4183-91e3-83a6615a7757",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training set shapes: 900000 900000\n",
+ "Testing set shapes: 100000 100000\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Split the data into training and testing sets (90% train, 10% test)\n",
+ "train_sentences, test_sentences, train_pos_labels, test_pos_labels = train_test_split(\n",
+ " tokenized_sentences_idx, pos_labels_idx, test_size=0.1, random_state=42\n",
+ ")\n",
+ "\n",
+ "# Check the shapes of the sets\n",
+ "print(\"Training set shapes:\", len(train_sentences), len(train_pos_labels))\n",
+ "print(\"Testing set shapes:\", len(test_sentences), len(test_pos_labels))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "8a327f45-e6c3-4e9a-af7a-0935cc37ebfb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Hyperparameters\n",
+ "embedding_dim = 100\n",
+ "hidden_dim = 128\n",
+ "vocab_size = len(vocab)\n",
+ "output_size = len(pos_tags)\n",
+ "\n",
+ "# Simple BiLSTM model for POS tagging\n",
+ "class BiLSTMPOSTagger(nn.Module):\n",
+ " def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size):\n",
+ " super(BiLSTMPOSTagger, self).__init__()\n",
+ " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
+ " self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)\n",
+ " self.hidden2pos = nn.Linear(hidden_dim * 2, output_size)\n",
+ " \n",
+ " def forward(self, sentence):\n",
+ " embeds = self.embedding(sentence)\n",
+ " lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))\n",
+ " pos_space = self.hidden2pos(lstm_out.view(len(sentence), -1))\n",
+ " pos_scores = F.log_softmax(pos_space, dim=1)\n",
+ " return pos_scores"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "3c555faf-230c-44b7-91a2-5cb69c0a94f8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/3, Avg. Training Loss: 0.2509, Avg. Test Loss: 0.1945\n",
+ "Epoch 2/3, Avg. Training Loss: 0.1574, Avg. Test Loss: 0.1728\n",
+ "Epoch 3/3, Avg. Training Loss: 0.1187, Avg. Test Loss: 0.1708\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialize the model, loss function, and optimizer\n",
+ "model = BiLSTMPOSTagger(embedding_dim, hidden_dim, vocab_size, output_size)\n",
+ "loss_function = nn.CrossEntropyLoss()\n",
+ "optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
+ "log_interval = 100 # Log every 100 iterations\n",
+ "\n",
+ "# Training loop\n",
+ "num_epochs = 3\n",
+ "best_test_loss = float('inf')\n",
+ "patience = 2 # Number of epochs to wait for improvement\n",
+ "\n",
+ "for epoch in range(num_epochs):\n",
+ " start_time = time.time()\n",
+ " model.train() # Set the model to training mode\n",
+ " total_loss = 0.0\n",
+ "\n",
+ " # Training\n",
+ " for sentence, tags in zip(train_sentences, train_pos_labels):\n",
+ " model.zero_grad()\n",
+ " sentence_in = torch.tensor(sentence, dtype=torch.long)\n",
+ " targets = torch.tensor(tags, dtype=torch.long)\n",
+ "\n",
+ " tag_scores = model(sentence_in)\n",
+ " loss = loss_function(tag_scores, targets)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " total_loss += loss.item()\n",
+ "\n",
+ " # Log at specified intervals\n",
+ " if i % log_interval == 0 and i > 0:\n",
+ " avg_loss = total_loss / log_interval\n",
+ " print(f\"Iteration {i}/{len(train_sentences)}, Avg. Loss: {avg_loss:.4f}\")\n",
+ " total_loss = 0.0\n",
+ " \n",
+ " # Calculate average training loss for this epoch\n",
+ " avg_train_loss = total_loss / len(train_sentences)\n",
+ " print(f\"Epoch {epoch + 1}/{num_epochs}, Time: {epoch_time:.2f} seconds\")\n",
+ "\n",
+ " # Evaluate on the test set\n",
+ " model.eval() # Set the model to evaluation mode\n",
+ " total_test_loss = 0.0\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for sentence, tags in zip(test_sentences, test_pos_labels):\n",
+ " sentence_in = torch.tensor(sentence, dtype=torch.long)\n",
+ " targets = torch.tensor(tags, dtype=torch.long)\n",
+ "\n",
+ " tag_scores = model(sentence_in)\n",
+ " loss = loss_function(tag_scores, targets)\n",
+ " total_test_loss += loss.item()\n",
+ "\n",
+ " # Calculate average test loss for this epoch\n",
+ " avg_test_loss = total_test_loss / len(test_sentences)\n",
+ "\n",
+ " # Print the losses for this epoch\n",
+ " print(f\"Epoch {epoch + 1}/{num_epochs}, Avg. Training Loss: {avg_train_loss:.4f}, Avg. Test Loss: {avg_test_loss:.4f}\")\n",
+ "\n",
+ " # Check for early stopping\n",
+ " if avg_test_loss < best_test_loss:\n",
+ " best_test_loss = avg_test_loss\n",
+ " patience_counter = 0\n",
+ " else:\n",
+ " patience_counter += 1\n",
+ " if patience_counter >= patience:\n",
+ " print(\"Early stopping. Test loss hasn't improved for\", patience, \"epochs.\")\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "2e681409-ae64-48b5-8a0c-05153beb38d8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.save(model.state_dict(), 'naamparser_pos_model.pth')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "cb678d97-19b4-43cb-933d-8880848126f0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Predicted POS tags:\n",
+ "Sentence: Jon Smith\n",
+ "POS tags: first_name last_name\n",
+ "\n",
+ "Sentence: Rodriguez Julia\n",
+ "POS tags: last_name first_name\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "def predict_pos_tags(model, vocab, pos_tags, sentences):\n",
+ " \"\"\"\n",
+ " Predicts POS tags for tokenized sentences.\n",
+ "\n",
+ " Args:\n",
+ " model (nn.Module): Trained POS tagging model.\n",
+ " vocab (dict): Vocabulary mapping words to indices.\n",
+ " pos_tags (dict): POS tag mapping.\n",
+ " sentences (list of list): Tokenized sentences.\n",
+ "\n",
+ " Returns:\n",
+ " list of list: Predicted POS tags for each word in each sentence.\n",
+ " \"\"\"\n",
+ " # Convert words to indices\n",
+ " sentences_idx = [[vocab.get(word, 0) for word in sentence] for sentence in sentences]\n",
+ "\n",
+ " # Predict POS tags\n",
+ " predicted_tags = []\n",
+ " with torch.no_grad():\n",
+ " model.eval() # Set the model to evaluation mode\n",
+ " for sentence_idx in sentences_idx:\n",
+ " sentence_in = torch.tensor(sentence_idx, dtype=torch.long)\n",
+ " tag_scores = model(sentence_in)\n",
+ " _, predicted = torch.max(tag_scores, dim=1)\n",
+ " predicted_tags.append([list(pos_tags.keys())[list(pos_tags.values()).index(tag)] for tag in predicted.numpy()])\n",
+ "\n",
+ " return predicted_tags\n",
+ "\n",
+ "# Example usage\n",
+ "sample_sentences = [['Jon', 'Smith'], ['Rodriguez', 'Julia']]\n",
+ "predicted_tags = predict_pos_tags(model, vocab, pos_tags, sample_sentences)\n",
+ "\n",
+ "print(\"Predicted POS tags:\")\n",
+ "for i in range(len(sample_sentences)):\n",
+ " print(\"Sentence:\", ' '.join(sample_sentences[i]))\n",
+ " print(\"POS tags:\", ' '.join(predicted_tags[i]))\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "340fee1c-00b9-4fef-96e1-2101884ecd86",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}