-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
53 lines (44 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from Parser import Parser
from Processor import Processor
from Speech import TTS, STT
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--tts", action="store_true", help="Whether to include TTS and STT functionality.")
parser.add_argument("-qa", "--qa_save", action="store_true", help="Whether to save parser-generated question/answer data.")
parser.add_argument("-g", "--google", action="store_true", help="Whether to parse google corpus instead of parse simple corpus.")
parser.add_argument("-c", "--corpus", type=str, help="Path to the corpus for the parser to parse.")
parser.add_argument("-m", "--model_save", action="store_true", help="Whether to save processor-generated models.")
parser.add_argument("-l", "--load", action="store_true", help="Whether to load model files into the bot.")
parser.add_argument("-e", "--encoder", type=str, default="encoder.h5", help="Path to encoder file to load.")
parser.add_argument("-d", "--decoder", type=str, default="decoder.h5", help="Path to decoder file to load.")
parser.add_argument("-t", "--tokenizer", type=str, default="tokenizer.pickle", help="Path to tokenizer file to load.")
args = parser.parse_args()
return vars(args)
if __name__ == "__main__":
args = parse_args()
if not args["load"] and "corpus" not in args:
raise ValueError("Need a corpus for the parser to parse.")
bot = Processor()
if args["load"]:
bot.load_all(args["encoder"], args["decoder"], args["tokenizer"])
else:
p = Parser(args["corpus"])
q, a = p.main(args["google"], args["qa_save"])
bot = Processor()
bot.main(q, a)
if args["model_save"]:
bot.save_model(bot.encoder, name="google_enc.h5")
bot.save_model(bot.decoder, name="google_dec.h5")
bot.save_tokenizer(bot.tokenizer, name="google.token.pickle")
if args["tts"]:
tts = TTS()
stt = STT()
while True:
inp = stt.speech_to_text()
ans = bot.ask_question(inp)
print("The bot said: " + ans)
tts.text_to_speech(ans)
else:
inp = input("Input: ")
print(bot.ask_question(inp))