Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Edge committed Nov 11, 2024
2 parents 7be689d + 9aaa808 commit 001c183
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 27 deletions.
69 changes: 46 additions & 23 deletions app/workflows/query_text_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,28 @@ async def create(sv: SessionVariables, workflow=None):
)
with uploader_tab:
st.markdown("##### Upload data for processing")
files = st.file_uploader(
"Upload PDF text files",
type=["pdf", "txt", "json", "csv"],
accept_multiple_files=True,
key="qtd_uploader_" + st.session_state[f"{workflow}_uploader_index"],

upload_type = st.radio(
"Upload type",
options=["Raw files", "Processed data"],
key=f"{workflow}_data_source",
)
# window_size = st.selectbox(
# "Analysis time window",
# key=sv.analysis_window_size.key,
# options=[str(x) for x in input_processor.PeriodOption._member_names_],
# )
# window_period = input_processor.PeriodOption[window_size]
# window_period = input_processor.PeriodOption.NONE
files = None
file_chunks = None
if upload_type == "Raw files":
files = st.file_uploader(
"Upload PDF text files",
type=["pdf", "txt", "json", "csv"],
accept_multiple_files=True,
key="qtd_uploader_" + st.session_state[f"{workflow}_uploader_index"],
)
else:
file_chunks = st.file_uploader(
"Upload processed chunks",
type=["csv"],
key="chunk_uploader_" + st.session_state[f"{workflow}_uploader_index"],
)

local_embedding = st.toggle(
"Use local embeddings",
key=sv.answer_local_embedding_enabled.key,
Expand All @@ -84,28 +93,32 @@ async def create(sv: SessionVariables, workflow=None):
)
qtd.set_embedder(embedder.create_embedder(local_embedding))

if files is not None and st.button("Process files"):
if st.button("Process files") and (
files is not None or file_chunks is not None
):
qtd.reset_workflow()

file_pb, file_callback = functions.create_progress_callback(
"Loaded {} of {} files..."
)
qtd.process_data_from_files(
input_file_bytes={file.name: file.getvalue() for file in files},
callbacks=[file_callback],
)
if upload_type == "Raw files":
file_pb, file_callback = functions.create_progress_callback(
"Loaded {} of {} files..."
)
qtd.process_data_from_files(
input_file_bytes={file.name: file.getvalue() for file in files},
callbacks=[file_callback],
)
file_pb.empty()
else:
qtd.import_chunks_from_str(file_chunks)

chunk_pb, chunk_callback = functions.create_progress_callback(
"Processed {} of {} chunks..."
)
qtd.process_text_chunks(callbacks=[chunk_callback])

embed_pb, embed_callback = functions.create_progress_callback(
"Embedded {} of {} text chunks..."
)
await qtd.embed_text_chunks(callbacks=[embed_callback])
chunk_pb.empty()
file_pb.empty()
embed_pb.empty()
st.rerun()

Expand All @@ -121,6 +134,16 @@ async def create(sv: SessionVariables, workflow=None):
message += "."
message = message.replace("**1** periods", "**1** period")
st.success(message)

if qtd.label_to_chunks and upload_type == "Raw files":
st.download_button(
label="Download chunk data",
help="Export chunk data as CSV",
data=qtd.get_chunks_as_df().to_csv(),
file_name=f"processed_data_{len(qtd.label_to_chunks)}_query_text.csv",
mime="text/csv",
)

with graph_tab:
if qtd.stage.value < QueryTextDataStage.CHUNKS_PROCESSED.value:
st.warning("Process files to continue.")
Expand Down
2 changes: 1 addition & 1 deletion intelligence_toolkit/compare_case_groups/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_filter_options(self, input_df: pl.DataFrame) -> list[str]:
return sorted_atts

def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None:
columns = self.groups
columns = self.groups.copy()
default_columns = [
"group_count",
"group_rank",
Expand Down
24 changes: 23 additions & 1 deletion intelligence_toolkit/query_text_data/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project.

from collections import defaultdict
from enum import Enum

import networkx as nx
Expand Down Expand Up @@ -316,5 +316,27 @@ def prepare_for_new_answer(self) -> None:
self.answer_object = None
self.stage = QueryTextDataStage.CHUNKS_MINED

def get_chunks_as_df(self) -> pd.DataFrame:
flat_data = []
for key, json_list in self.label_to_chunks.items():
for json_str in json_list:
item_data = {
"file_name": key,
"text_to_label_str": json_str,
}
flat_data.append(item_data)

return pd.DataFrame(flat_data)

def import_chunks_from_str(self, data: str) -> None:
chunks_df = pd.read_csv(data)
data_imported = defaultdict(list)
for _, row in chunks_df.iterrows():
key = row["file_name"]
row_data = row["text_to_label_str"]
data_imported[key].append(row_data)

self.label_to_chunks = data_imported

def __repr__(self):
return f"QueryTextData()"
4 changes: 2 additions & 2 deletions intelligence_toolkit/query_text_data/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def convert_file_bytes_to_chunks(

if file_name.endswith(".csv"):
df = pd.read_csv(io.BytesIO(bytes))
text_to_chunks = convert_df_to_chunks(df, file_name)
text_chunks = convert_df_to_chunks(df, file_name)
else:
if file_name.endswith(".pdf"):
page_texts = []
Expand All @@ -89,7 +89,7 @@ def convert_file_bytes_to_chunks(
}
text_chunks[index] = dumps(chunk, indent=2, ensure_ascii=False)

text_to_chunks[file_name] = text_chunks
text_to_chunks[file_name] = text_chunks
return text_to_chunks


Expand Down

0 comments on commit 001c183

Please sign in to comment.