diff --git a/parsernaam/notebooks/pos_name_parser.ipynb b/parsernaam/notebooks/pos_name_parser.ipynb new file mode 100644 index 0000000..14c44fe --- /dev/null +++ b/parsernaam/notebooks/pos_name_parser.ipynb @@ -0,0 +1,452 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5f1fd291-b290-403e-a284-f9fb0288c233", + "metadata": {}, + "source": [ + "### Name Parsing With Context" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "40aa5bc2-797a-4bcb-b9ea-c3c0dbbaa125", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import random\n", + "import time\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.utils.rnn as rnn_utils\n", + "from sklearn.model_selection import train_test_split\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.nn.utils.rnn import pad_sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d7035459-09b6-466f-9b40-a09899c88334", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8909970, 3)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv(\"../data/fl_reg_name_race_2022.csv.gz\")\n", + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6d4ecbb7-fe73-47b3-a8e4-36de4e1e318f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
last_namefirst_namerace
0BinkleyKathryn5.0
1BrockLakaya3.0
2FontaineCharles5.0
3PosseltSuzanne5.0
4HaeselerBala5.0
\n", + "
" + ], + "text/plain": [ + " last_name first_name race\n", + "0 Binkley Kathryn 5.0\n", + "1 Brock Lakaya 3.0\n", + "2 Fontaine Charles 5.0\n", + "3 Posselt Suzanne 5.0\n", + "4 Haeseler Bala 5.0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cf8282c2-d002-4f4e-a08e-71d25cc1b9c8", + "metadata": {}, + "outputs": [], + "source": [ + "df.drop_duplicates(subset=['last_name', 'first_name'], inplace = True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aae5230f-6eee-40ac-b95d-34550876879b", + "metadata": {}, + "outputs": [], + "source": [ + "tokenized_sentences = []\n", + "pos_labels = []\n", + "\n", + "random.seed(42)\n", + "\n", + "for _, row in df.sample(n = 1000000).iterrows():\n", + " if random.random() < 0.5:\n", + " tokenized_sentences.append([row['last_name'], row['first_name']])\n", + " pos_labels.append(['last_name', 'first_name'])\n", + " else:\n", + " tokenized_sentences.append([row['first_name'], row['last_name']])\n", + " pos_labels.append(['first_name', 'last_name'])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bb0a6dd8-6471-440e-be1c-baee3c706c08", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[['Borinquen', 'Medina-Torres'], ['Uddin', 'Asma'], ['Mecoli', 'Donald'], ['Ibargoyen', 'Yansari'], ['Judy', 'Bales'], ['Marlyn', 'Diaz Alvarado'], ['Holly', 'Korman'], ['Masso', 'Justina'], ['Drake', 'Brandon'], ['Swabowicz', 'Michael']]\n", + "[['first_name', 'last_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['first_name', 'last_name'], ['first_name', 'last_name'], ['first_name', 'last_name'], ['last_name', 'first_name'], ['last_name', 'first_name'], ['last_name', 'first_name']]\n" + ] + } + ], + "source": [ + "print(tokenized_sentences[:10])\n", + "print(pos_labels[:10])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "46036d2c-0e01-4d12-888e-e5047780ee33", + "metadata": {}, + "outputs": [], + "source": [ + "vocab = {word: idx for idx, word in enumerate(set(word for sentence in tokenized_sentences for word in sentence))}\n", + "pos_tags = {tag: idx for idx, tag in enumerate(set(tag for label_set in pos_labels for tag in label_set))}" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8a3bf61a-38ca-4642-aed6-c0dce1fcf627", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert words and tags to indices\n", + "tokenized_sentences_idx = [[vocab[word] for word in sentence] for sentence in tokenized_sentences]\n", + "pos_labels_idx = [[pos_tags[tag] for tag in label_set] for label_set in pos_labels]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9ab88886-6b78-4183-91e3-83a6615a7757", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training set shapes: 900000 900000\n", + "Testing set shapes: 100000 100000\n" + ] + } + ], + "source": [ + "# Split the data into training and testing sets (90% train, 10% test)\n", + "train_sentences, test_sentences, train_pos_labels, test_pos_labels = train_test_split(\n", + " tokenized_sentences_idx, pos_labels_idx, test_size=0.1, random_state=42\n", + ")\n", + "\n", + "# Check the shapes of the sets\n", + "print(\"Training set shapes:\", len(train_sentences), len(train_pos_labels))\n", + "print(\"Testing set shapes:\", len(test_sentences), len(test_pos_labels))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8a327f45-e6c3-4e9a-af7a-0935cc37ebfb", + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "embedding_dim = 100\n", + "hidden_dim = 128\n", + "vocab_size = len(vocab)\n", + "output_size = len(pos_tags)\n", + "\n", + "# Simple BiLSTM model for POS tagging\n", + "class BiLSTMPOSTagger(nn.Module):\n", + " def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size):\n", + " super(BiLSTMPOSTagger, self).__init__()\n", + " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", + " self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)\n", + " self.hidden2pos = nn.Linear(hidden_dim * 2, output_size)\n", + " \n", + " def forward(self, sentence):\n", + " embeds = self.embedding(sentence)\n", + " lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))\n", + " pos_space = self.hidden2pos(lstm_out.view(len(sentence), -1))\n", + " pos_scores = F.log_softmax(pos_space, dim=1)\n", + " return pos_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3c555faf-230c-44b7-91a2-5cb69c0a94f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/3, Avg. Training Loss: 0.2509, Avg. Test Loss: 0.1945\n", + "Epoch 2/3, Avg. Training Loss: 0.1574, Avg. Test Loss: 0.1728\n", + "Epoch 3/3, Avg. Training Loss: 0.1187, Avg. Test Loss: 0.1708\n" + ] + } + ], + "source": [ + "# Initialize the model, loss function, and optimizer\n", + "model = BiLSTMPOSTagger(embedding_dim, hidden_dim, vocab_size, output_size)\n", + "loss_function = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", + "log_interval = 100 # Log every 100 iterations\n", + "\n", + "# Training loop\n", + "num_epochs = 3\n", + "best_test_loss = float('inf')\n", + "patience = 2 # Number of epochs to wait for improvement\n", + "\n", + "for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " model.train() # Set the model to training mode\n", + " total_loss = 0.0\n", + "\n", + " # Training\n", + " for sentence, tags in zip(train_sentences, train_pos_labels):\n", + " model.zero_grad()\n", + " sentence_in = torch.tensor(sentence, dtype=torch.long)\n", + " targets = torch.tensor(tags, dtype=torch.long)\n", + "\n", + " tag_scores = model(sentence_in)\n", + " loss = loss_function(tag_scores, targets)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + "\n", + " # Log at specified intervals\n", + " if i % log_interval == 0 and i > 0:\n", + " avg_loss = total_loss / log_interval\n", + " print(f\"Iteration {i}/{len(train_sentences)}, Avg. Loss: {avg_loss:.4f}\")\n", + " total_loss = 0.0\n", + " \n", + " # Calculate average training loss for this epoch\n", + " avg_train_loss = total_loss / len(train_sentences)\n", + " print(f\"Epoch {epoch + 1}/{num_epochs}, Time: {epoch_time:.2f} seconds\")\n", + "\n", + " # Evaluate on the test set\n", + " model.eval() # Set the model to evaluation mode\n", + " total_test_loss = 0.0\n", + "\n", + " with torch.no_grad():\n", + " for sentence, tags in zip(test_sentences, test_pos_labels):\n", + " sentence_in = torch.tensor(sentence, dtype=torch.long)\n", + " targets = torch.tensor(tags, dtype=torch.long)\n", + "\n", + " tag_scores = model(sentence_in)\n", + " loss = loss_function(tag_scores, targets)\n", + " total_test_loss += loss.item()\n", + "\n", + " # Calculate average test loss for this epoch\n", + " avg_test_loss = total_test_loss / len(test_sentences)\n", + "\n", + " # Print the losses for this epoch\n", + " print(f\"Epoch {epoch + 1}/{num_epochs}, Avg. Training Loss: {avg_train_loss:.4f}, Avg. Test Loss: {avg_test_loss:.4f}\")\n", + "\n", + " # Check for early stopping\n", + " if avg_test_loss < best_test_loss:\n", + " best_test_loss = avg_test_loss\n", + " patience_counter = 0\n", + " else:\n", + " patience_counter += 1\n", + " if patience_counter >= patience:\n", + " print(\"Early stopping. Test loss hasn't improved for\", patience, \"epochs.\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2e681409-ae64-48b5-8a0c-05153beb38d8", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), 'naamparser_pos_model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "cb678d97-19b4-43cb-933d-8880848126f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted POS tags:\n", + "Sentence: Jon Smith\n", + "POS tags: first_name last_name\n", + "\n", + "Sentence: Rodriguez Julia\n", + "POS tags: last_name first_name\n", + "\n" + ] + } + ], + "source": [ + "def predict_pos_tags(model, vocab, pos_tags, sentences):\n", + " \"\"\"\n", + " Predicts POS tags for tokenized sentences.\n", + "\n", + " Args:\n", + " model (nn.Module): Trained POS tagging model.\n", + " vocab (dict): Vocabulary mapping words to indices.\n", + " pos_tags (dict): POS tag mapping.\n", + " sentences (list of list): Tokenized sentences.\n", + "\n", + " Returns:\n", + " list of list: Predicted POS tags for each word in each sentence.\n", + " \"\"\"\n", + " # Convert words to indices\n", + " sentences_idx = [[vocab.get(word, 0) for word in sentence] for sentence in sentences]\n", + "\n", + " # Predict POS tags\n", + " predicted_tags = []\n", + " with torch.no_grad():\n", + " model.eval() # Set the model to evaluation mode\n", + " for sentence_idx in sentences_idx:\n", + " sentence_in = torch.tensor(sentence_idx, dtype=torch.long)\n", + " tag_scores = model(sentence_in)\n", + " _, predicted = torch.max(tag_scores, dim=1)\n", + " predicted_tags.append([list(pos_tags.keys())[list(pos_tags.values()).index(tag)] for tag in predicted.numpy()])\n", + "\n", + " return predicted_tags\n", + "\n", + "# Example usage\n", + "sample_sentences = [['Jon', 'Smith'], ['Rodriguez', 'Julia']]\n", + "predicted_tags = predict_pos_tags(model, vocab, pos_tags, sample_sentences)\n", + "\n", + "print(\"Predicted POS tags:\")\n", + "for i in range(len(sample_sentences)):\n", + " print(\"Sentence:\", ' '.join(sample_sentences[i]))\n", + " print(\"POS tags:\", ' '.join(predicted_tags[i]))\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "340fee1c-00b9-4fef-96e1-2101884ecd86", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}