Skip to content

Latest commit

 

History

History
46 lines (28 loc) · 3.52 KB

README.md

File metadata and controls

46 lines (28 loc) · 3.52 KB

MBTI chatbot

This repository contains the code to develop a chatbot that can take one of the 16 Myers-Briggs personality types.

Scraping data

The chatbot is fine-tuned for each personality using posts and comments belonging to the corresponding subreddit (for example, r/infj for the INFJ type).

reddit_scraper.py contains the script to scrape a given subreddit. To execute it, you first need an instance of a MySQL database to connect to. You also need some parameters associated to your reddit account and to the MySQL database: all needs to be inserted in a config.py file, following the schema of config.example.py.

The script is gonna first load all posts in a table called posts , and then their comments in a table called comments. Although parallelization has been applied, this second part is gonna take many hours. That's why, once you have downloaded the posts you are interested in through the main script (~ 20min), you can use the script comments_scraper.py to download the associated comments. If you interrupt it, the next time you run it the script is gonna start from where you left.

Training

Preparing data

To train the model, I first reported data into the conversational dataset format, i.e. a CSV table with the following structure.

id response context context/0 ... context/n
s892nn I'm fine It's ok. What about you? How's life? ... Hi!

Here, context/n represents the beginning of the conversation, going to the most recent exchange (showed in context/0, context and response, which is the latest sentence in the conversation). It is possibile to change the cardinality of contexts by overriding the NUMBER_OF_CONTEXTS parameter in the config.py file.

The script create_conversational_dataset.py generates the CSV starting from the SQL tables created during the scraping phase, saving it into a pickle file in the data folder. A conversation is built either from a post and one of its direct comments or from a post, a comment and its comment chain.

The execution of the script is parallelized, so it writes on N different CSVs - N depending on the NUMBER_OF_PROCESSES parameter - finally concatenated to create the resulting pickle file.

Model

The notebook training.py contains the fine-tuning of the DialoGPT-medium language model on the conversational data, and is mainly an adaptation of the code you can find in this notebook.

Executing demo.py will start the conversation.

Running

To run all the code in the respository, you can create a virtual environment and run the following commands.

virtualenv venv 
source ./venv/bin/activate
pip install -r requirements.txt