-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
omame
committed
May 20, 2021
0 parents
commit 218bb25
Showing
11 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.idea/ | ||
venv/ | ||
token.txt | ||
data.json | ||
config.json | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# bob v2 | ||
bob is a chatbot designed for use in TFT | ||
|
||
## how to use | ||
1. install the requirements | ||
2. copy `config.example.json` to `config.json` and edit it | ||
2. run main.py | ||
|
||
is it hard/?????? | ||
|
||
## config options | ||
### `scrape_channel` | ||
channel which bob will scrape for data | ||
|
||
### `target_channel` | ||
channel in which bob will be talking in | ||
|
||
### `debug_channel` | ||
channel in which bob will be talking in when in debug mode | ||
|
||
### `scrape_amount` | ||
amount of messages bob will go through for its dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"scrape_channel": 0, | ||
"target_channel": 1, | ||
"debug_channel": 2, | ||
"scrape_amount": 50000 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import json | ||
import os.path | ||
|
||
default_config = { | ||
"scrape_channel": 0, | ||
"target_channel": 0, | ||
"debug_channel": 0, | ||
"scrape_amount": 50000 | ||
} | ||
|
||
if not os.path.exists("config.json"): | ||
print("WARNING: Config file not found! Using default config.") | ||
config = default_config | ||
else: | ||
with open("config.json") as file: | ||
config = json.load(file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from qna import helpers, json | ||
import argparse | ||
import discord | ||
from configreader import config | ||
|
||
version = 'v2.1.0 stable' | ||
|
||
parser = argparse.ArgumentParser(description=f'bob {version}') | ||
parser.add_argument('--debug', '-d', action='store_true', help='enable debug mode') | ||
parser.add_argument('--version', '-v', action='version', version=version) | ||
args = parser.parse_args() | ||
|
||
channel_id = config["debug_channel"] if args.debug else config["target_channel"] | ||
client = discord.Client() | ||
questions = [] | ||
|
||
|
||
def debug(*arg, **kwargs): | ||
if args.debug: | ||
print(*arg, **kwargs) | ||
|
||
|
||
async def debug_send(message: discord.Message, *arg, **kwargs): | ||
if args.debug: | ||
await message.channel.send(*arg, **kwargs) | ||
|
||
|
||
@client.event | ||
async def on_ready(): | ||
print(f'bob {version} is ready!') | ||
channel = client.get_channel(channel_id) | ||
await channel.send(f'ℹ️ **bob {version} is ready** with {len(questions)} questions in the dataset') | ||
await channel.edit(topic=f'**bob {version}** | trained on {len(questions)} questions') | ||
|
||
|
||
@client.event | ||
async def on_message(message: discord.Message): | ||
if message.channel.id != channel_id: | ||
return | ||
if message.author.bot: | ||
return | ||
content = helpers.sanitize_question(message.clean_content) | ||
if message.is_system(): | ||
content = helpers.sanitize_question(message.system_content) | ||
question = helpers.get_closest_question(questions, content) | ||
response = helpers.pick_response(question) | ||
debug(question) | ||
debug(content, '->', response.text) | ||
embed = discord.Embed(title="Details", color=0x2273e6) | ||
embed.add_field(name="Matched Question", value=question.text) | ||
embed.add_field(name="# of Responses", value=len(question.responses)) | ||
try: | ||
await message.reply(response.text, embed=embed) | ||
except discord.errors.HTTPException: | ||
await message.channel.send(response.text, embed=embed) | ||
|
||
|
||
if __name__ == '__main__': | ||
debug('initializing...') | ||
with open('data.json') as file: | ||
questions = json.json_to_questions(file.read()) | ||
debug('loaded', len(questions), 'questions') | ||
|
||
with open('token.txt') as file: | ||
client.run(file.readline()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import classes, helpers, json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
class Response: | ||
def __init__(self, response: str): | ||
self.text = response | ||
self.count = 1 | ||
|
||
def __repr__(self): | ||
attrs = [ | ||
('text', self.text), | ||
('count', self.count) | ||
] | ||
joined = ' '.join('%s=%r' % t for t in attrs) | ||
return f'<{self.__class__.__name__} {joined}>' | ||
|
||
|
||
class Question: | ||
def __init__(self, question: str): | ||
self.text = question | ||
self.responses = [] | ||
|
||
def __repr__(self): | ||
attrs = [ | ||
('text', self.text), | ||
('responses', self.responses) | ||
] | ||
joined = ' '.join('%s=%r' % t for t in attrs) | ||
return f'<{self.__class__.__name__} {joined}>' | ||
|
||
def add_response(self, resp: Response): | ||
for response in self.responses: | ||
if response.text == resp.text: | ||
response.count += resp.count | ||
return | ||
self.responses.append(resp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import string | ||
|
||
import Levenshtein | ||
|
||
from .classes import Question, Response | ||
import typing | ||
import random | ||
|
||
|
||
def calculate_popularity(question: Question) -> typing.Dict[Response, float]: | ||
popularity = {} | ||
total = 0 | ||
for response in question.responses: | ||
total += response.count | ||
for response in question.responses: | ||
popularity.update({response: response.count / total}) | ||
return popularity | ||
|
||
|
||
def pick_response(question: Question) -> Response: | ||
popularity = calculate_popularity(question) | ||
|
||
keys = list(popularity.keys()) | ||
random_prob = random.random() | ||
pop = 0 | ||
for response in keys: | ||
pop = pop + popularity[response] | ||
if pop > random_prob: | ||
return response | ||
|
||
|
||
def get_closest_question(questions: typing.List[Question], message: str) -> Question: | ||
lowest = None | ||
target = None | ||
for question in questions: | ||
dist = Levenshtein.distance(question.text, message) | ||
if lowest is None or lowest > dist: | ||
lowest = dist | ||
target = question | ||
|
||
return target | ||
|
||
|
||
def sanitize_question(content: str) -> str: | ||
return content.lower().translate(str.maketrans('', '', string.punctuation)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import json | ||
import typing | ||
from qna.classes import Question, Response | ||
|
||
|
||
def response_to_dict(response: Response) -> dict: | ||
data = { | ||
'text': response.text, | ||
'count': response.count | ||
} | ||
return data | ||
|
||
|
||
def dict_to_response(response_dict: dict) -> Response: | ||
out = Response(response_dict['text']) | ||
out.count = response_dict['count'] | ||
return out | ||
|
||
|
||
def question_to_dict(question: Question) -> dict: | ||
data = { | ||
'text': question.text, | ||
'responses': [] | ||
} | ||
for response in question.responses: | ||
data['responses'].append(response_to_dict(response)) | ||
|
||
return data | ||
|
||
|
||
def dict_to_question(question_dict: dict) -> Question: | ||
out = Question(question_dict['text']) | ||
for response in question_dict['responses']: | ||
out.add_response(dict_to_response(response)) | ||
return out | ||
|
||
|
||
def questions_to_json(questions: typing.List[Question]) -> str: | ||
out = [] | ||
for question in questions: | ||
out.append(question_to_dict(question)) | ||
|
||
return json.dumps(out, indent='\t') | ||
|
||
|
||
def json_to_questions(questions_json: str) -> typing.List[Question]: | ||
questions_list = json.loads(questions_json) | ||
out = [] | ||
for question in questions_list: | ||
out.append(dict_to_question(question)) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
discord.py | ||
python-Levenshtein |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import discord | ||
from qna.classes import Question, Response | ||
from qna import json | ||
from qna.helpers import sanitize_question | ||
from configreader import config | ||
|
||
client = discord.Client() | ||
limit = config["scrape_amount"] | ||
|
||
|
||
@client.event | ||
async def on_ready(): | ||
questions = {} | ||
chat = client.get_channel(config["scrape_channel"]) | ||
print('> Scraping...') | ||
count = 0 # this is lowkey better ngl | ||
async for message in chat.history(limit=None): | ||
if message.is_system(): | ||
continue # we don't want to deal with these | ||
if count >= limit: | ||
with open('scraped-data.json', 'w+') as file: | ||
file.write(json.questions_to_json(list(questions.values()))) | ||
await client.close() | ||
print('\n> Done.') | ||
return | ||
if message.reference: | ||
try: | ||
reply: discord.Message = await chat.fetch_message(message.reference.message_id) | ||
except discord.errors.NotFound: | ||
continue # too bad!!! | ||
content = sanitize_question(reply.clean_content) | ||
if content not in questions.keys(): | ||
questions.update({content: Question(content)}) | ||
questions[content].add_response(Response(message.clean_content)) | ||
print(content, '->', message.clean_content) | ||
else: | ||
continue # worthless junk | ||
count += 1 | ||
print(count, '/', limit) | ||
# print(questions) | ||
|
||
|
||
if __name__ == '__main__': | ||
with open('token.txt') as token: | ||
client.run(token.readline()) |