From f52f043bdf6ebdefaf8e1e1b11e34b7a22b5a1b9 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 23 Aug 2024 05:53:47 +0200 Subject: [PATCH] update streamlit-pdf-viewer, fix chat wobbling --- requirements.txt | 4 ++-- streamlit_app.py | 56 ++++++++++++++++++++++++++++++------------------ 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/requirements.txt b/requirements.txt index e8db7d3..9576f1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ grobid_tei_xml==0.1.3 tqdm==4.66.2 pyyaml==6.0.1 pytest==8.1.1 -streamlit==1.36.0 +streamlit==1.37.0 lxml Beautifulsoup4 python-dotenv @@ -24,6 +24,6 @@ typing-inspect==0.9.0 typing_extensions==4.11.0 pydantic==2.6.4 sentence_transformers==2.6.1 -streamlit-pdf-viewer==0.0.14 +streamlit-pdf-viewer==0.0.17 umap-learn plotly \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py index a59a02b..edbcea9 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -31,8 +31,8 @@ ] OPEN_MODELS = { - 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.2', - # 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct", + 'Mistral-Nemo-Instruct-2407': 'mistralai/Mistral-Nemo-Instruct-2407', + 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.3', 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" } @@ -109,6 +109,20 @@ } ) +st.markdown( + """ + + """, + unsafe_allow_html=True +) + def new_file(): st.session_state['loaded_embeddings'] = None @@ -154,8 +168,8 @@ def init_qa(model, embeddings_name=None, api_key=None): chat = HuggingFaceEndpoint( repo_id=OPEN_MODELS[model], temperature=0.01, - max_new_tokens=2048, - model_kwargs={"max_length": 4096} + max_new_tokens=4092, + model_kwargs={"max_length": 8192} ) embeddings = HuggingFaceEmbeddings( model_name=OPEN_EMBEDDINGS[embeddings_name]) @@ -401,21 +415,21 @@ def generate_color_gradient(num_elements): with right_column: if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: + # messages.chat_message("user").markdown(question) + st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) + for message in st.session_state.messages: - with messages.chat_message(message["role"]): - if message['mode'] == "llm": - messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) - elif message['mode'] == "embeddings": - messages.chat_message(message["role"]).write(message["content"]) - if message['mode'] == "question_coefficient": - messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) + # with messages.chat_message(message["role"]): + if message['mode'] == "llm": + messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) + elif message['mode'] == "embeddings": + messages.chat_message(message["role"]).write(message["content"]) + elif message['mode'] == "question_coefficient": + messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) if model not in st.session_state['rqa']: st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") st.stop() - messages.chat_message("user").markdown(question) - st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) - text_response = None if mode == "embeddings": with placeholder: @@ -472,10 +486,10 @@ def generate_color_gradient(num_elements): with left_column: if st.session_state['binary']: - pdf_viewer( - input=st.session_state['binary'], - annotation_outline_size=2, - annotations=st.session_state['annotations'], - render_text=True, - height=600 - ) + with st.container(height=600): + pdf_viewer( + input=st.session_state['binary'], + annotation_outline_size=2, + annotations=st.session_state['annotations'], + render_text=True + )