Skip to content

Commit

Permalink
Merge pull request #35 from SWM-SMART/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
minseok-oh authored Nov 19, 2023
2 parents df052b7 + 6079438 commit c87d9e1
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 28 deletions.
6 changes: 1 addition & 5 deletions app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,12 @@
stt_controller: Optional[STTController] = None

def init_model() -> None:
global summary_model, summary_tokenizer, summary_controller, llm_controller, mindmap_controller, keyword_controller, s3_controller, stt_controller
global summary_model, summary_tokenizer, summary_controller, s3_controller, stt_controller

nltk.download('punkt')
summary_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARY_MODEL_PATH)
summary_tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL_PATH)
summary_controller = SummaryController(summary_model, summary_tokenizer)

llm_controller = LLMController()
mindmap_controller = MindMapController(llm_controller)
keyword_controller = KeywordsController(llm_controller)

s3_controller = boto3.client(
service_name="s3",
Expand Down
5 changes: 2 additions & 3 deletions app/api/v1/endpoints/keywords.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException

from app.schemas.database import DatabaseInfo
from app.schemas.context import Keywords
Expand All @@ -7,12 +7,10 @@
from app.api.deps import get_s3_controller
import botocore
from langchain.document_loaders import PyPDFLoader, TextLoader

from app.core.config import (
S3_PREFIX,
S3_BUCKET_NAME
)

router = APIRouter()

@router.post("/keywords", response_model=None)
Expand All @@ -32,4 +30,5 @@ def get_keywords(
loader = TextLoader(f'app/static/{db.key}')

doc = loader.load_and_split()
if len(doc) == 0: return HTTPException(status_code=404, detail="Error")
return keyword_controller.get_keywords(doc)
3 changes: 2 additions & 1 deletion app/api/v1/endpoints/mindmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException

from app.controller.llm import LLMController
from app.controller.mindmap import MindMapController
Expand Down Expand Up @@ -34,6 +34,7 @@ def get_mind_map(
else:
loader = TextLoader(f'app/static/{document.key}')
doc = loader.load_and_split()
if len(doc) == 0: return HTTPException(status_code=404, detail="Error")
ret = mindmap_controller.get_mindmap(doc, document.keywords)
print(ret)
return ret
8 changes: 4 additions & 4 deletions app/api/v1/endpoints/question.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from fastapi import APIRouter, Depends

from app.schemas.context import Text, Context
from app.schemas.context import QuestionText, Context
from app.controller.llm import LLMController

router = APIRouter()

@router.post("/question", response_model=Text)
def get_answer(context: Context) -> Text:
@router.post("/question", response_model=QuestionText)
def get_answer(context: Context) -> QuestionText:
controller = LLMController()
question = f"'{context.summary}'에서 {context.keyword}의 의미를 알려줘"
return Text(text=controller.request_base(question))
return QuestionText(text=controller.request_base(question))
10 changes: 7 additions & 3 deletions app/controller/keywords.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from fastapi import Depends
from app.controller.llm import LLMController
from langchain.schema.document import Document
from langchain.text_splitter import CharacterTextSplitter

from app.schemas.context import Keywords

class KeywordsController:
prompt = "Question: 키워드들을 뽑아서 '(', ',', ')'로 정리해서 알려줘 \\nAnswer: ("
def __init__(self, llm: LLMController = Depends(LLMController)):
def __init__(self, llm: LLMController):
self.llm = llm

def get_keywords(self, document: Document) -> Keywords:
self.llm.set_document(document)
answer = self.llm.request(self.prompt).content
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
texts = text_splitter.split_documents(document)
prompt = f"#Context: {texts} \\n#Question: 한글 키워드들을 20개 이내로 뽑아서 '(', ',', ')'로 정리해서 알려줘 \\n#Answer: ("

answer = self.llm.request_base(prompt)
return Keywords(keywords=list(map(lambda word: word.strip(), list(answer[1:-1].split(',')))))
31 changes: 21 additions & 10 deletions app/controller/mindmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import Depends
from app.controller.llm import LLMController
from langchain.schema.document import Document
from langchain.text_splitter import CharacterTextSplitter

from app.schemas.mindmap import MindMap
from app.schemas.context import Keywords
Expand All @@ -10,7 +11,6 @@
from typing import List

class MindMapController:
prompt = "Question: 문맥 내에서 %s들의 계층 구조를 MarkDown의 '-'로 알려줘 \nAnswer: -"
def __init__(self, llm: LLMController = Depends(LLMController)):
self.llm = llm

Expand All @@ -23,23 +23,31 @@ def delete_stopwords(self, html: str) -> str:

def parse_html(self, markdown: str, keywords: List[str]) -> MindMap:
mindmap = MindMap()
mindmap.keywords = keywords
keyword2index = {v: i for i, v in enumerate(keywords)}
mindmap.keywords = []
keyword2index = {}
lines = markdown.split('\n')

for index, line in enumerate(lines):
if len(line.split('- ')) == 1: continue
word = '- '.join(line.split('- ')[1:])
keyword2index[word] = index-1
mindmap.keywords.append(word)
print(mindmap.keywords)
print(keyword2index)

current = 0
stack = []
lines = markdown.split('\n')
for line in lines:
if len(line.split('- ')) == 1: continue
sep = line.split('- ')[0]
word = line.split('- ')[-1]
if word in keywords:
word = '- '.join(line.split('- ')[1:])
if word in mindmap.keywords:
sep = len(sep)
mindmap.graph[str(keyword2index[word])] = []
print(stack)

if sep == 0:
mindmap.root = sep

if current == sep:
if len(stack) != 0: stack.pop()
if len(stack) != 0: mindmap.graph[str(stack[-1])].append(keyword2index[word])
Expand All @@ -54,8 +62,11 @@ def parse_html(self, markdown: str, keywords: List[str]) -> MindMap:
return mindmap

def get_mindmap(self, document: Document, keywords: List[str]) -> MindMap:
self.llm.set_document(document)
answer = self.llm.request(self.prompt % keywords).content
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
texts = text_splitter.split_documents(document)

prompt = f"#Context: {texts} \\nQuestion: 문맥 내에서 {keywords}들의 계층 구조를 MarkDown의 '-'로 알려줘 \nAnswer: -"
answer = self.llm.request_base(prompt)
print(answer)
answer = self.delete_stopwords(answer)
return self.parse_html(answer, keywords)
2 changes: 1 addition & 1 deletion app/controller/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self):

def convert_to_wav(self, prefix):
output_path = f'app/static/{prefix}.wav'
y, sr = librosa.load(f'app/static/{prefix}.m4a', sr=16000)
y, sr = librosa.load(f'app/static/{prefix}.mp3', sr=16000)
sf.write(output_path, y, sr)

def speech_to_text(self, prefix):
Expand Down
2 changes: 1 addition & 1 deletion app/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from app.schemas.mindmap import MindMap
from app.schemas.database import DatabaseInfo
from app.schemas.document import Document
from app.schemas.context import Keywords, Text, Context
from app.schemas.context import Keywords, Text, Context, QuestionText
3 changes: 3 additions & 0 deletions app/schemas/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ class Segment(BaseModel):

class SpeechText(BaseModel):
segments: List[Segment]
text: str

class QuestionText(BaseModel):
text: str

0 comments on commit c87d9e1

Please sign in to comment.