Skip to content

Commit

Permalink
add data model on api
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Oct 29, 2024
1 parent 6264a65 commit de18d40
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 11 deletions.
1 change: 1 addition & 0 deletions app/workflows/detect_entity_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_intro():
async def create(sv: rn_variables.SessionVariables, workflow=None):
sv_home = SessionVariables("home")
ui_components.check_ai_configuration()
den = sv.workflow_object.value

intro_tab, uploader_tab, process_tab, view_tab, report_tab, examples_tab = st.tabs(
[
Expand Down
136 changes: 134 additions & 2 deletions example_notebooks/detect_entity_networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,137 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")\n",
"import os"
"import os\n",
"from toolkit.detect_entity_networks.api import DetectEntityNetworks\n",
"from toolkit.AI.openai_configuration import OpenAIConfiguration\n",
"import polars as pl"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded data\n"
]
}
],
"source": [
"# Create the workflow object\n",
"den = DetectEntityNetworks()\n",
"# Set the AI configuration\n",
"ai_configuration = OpenAIConfiguration(\n",
" {\n",
" \"api_type\": \"OpenAI\",\n",
" \"api_key\": os.environ[\"OPENAI_API_KEY\"],\n",
" \"model\": \"gpt-4o\",\n",
" }\n",
")\n",
"den.set_ai_configuration(ai_configuration)\n",
"\n",
"data_path = \"../example_outputs/detect_entity_networks/company_grievances/company_grievances_input.csv\"\n",
"entity_df = pl.read_csv(data_path)\n",
"\n",
"print(\"Loaded data\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summary\n",
"Number of entities: 3602, Number of attributes: 18549, Number of flags: 0, Number of groups: 0, Number of links: 41727\n"
]
}
],
"source": [
"# set entity-attributes\n",
"from toolkit.detect_entity_networks.prepare_model import format_data_columns\n",
"\n",
"\n",
"entity_id_column = \"name\"\n",
"columns_to_link = [\"address\", \"city\", \"email\", \"phone\", \"owner\"]\n",
"entity_df = format_data_columns(entity_df, columns_to_link, entity_id_column)\n",
"den.add_attribute_links(entity_df, entity_id_column, columns_to_link)\n",
"\n",
"summary = den.get_model_summary_value()\n",
"print(\"Summary\")\n",
"print(summary)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summary\n",
"Number of entities: 3687, Number of attributes: 18663, Number of flags: 8108, Number of groups: 0, Number of links: 41971\n"
]
}
],
"source": [
"# set flags\n",
"from toolkit.detect_entity_networks.classes import FlagAggregatorType\n",
"\n",
"\n",
"entity_id_column = \"name\"\n",
"columns_to_link = [\n",
" \"safety_grievances\",\n",
" \"pay_grievances\",\n",
" \"conditions_grievances\",\n",
" \"treatment_grievances\",\n",
" \"workload_grievances\",\n",
"]\n",
"flag_format = FlagAggregatorType.Count\n",
"den.add_flag_links(entity_df, entity_id_column, columns_to_link, flag_format)\n",
"summary = den.get_model_summary_value()\n",
"print(\"Summary\")\n",
"print(summary)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summary\n",
"Number of entities: 3687, Number of attributes: 18663, Number of flags: 8108, Number of groups: 634, Number of links: 41971\n"
]
}
],
"source": [
"# set groups\n",
"entity_id_column = \"name\"\n",
"columns_to_link = [\"sector\", \"country\"]\n",
"den.add_groups(entity_df, entity_id_column, columns_to_link)\n",
"\n",
"summary = den.get_model_summary_value()\n",
"print(\"Summary\")\n",
"print(summary)"
]
}
],
Expand All @@ -19,7 +143,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
Expand Down
99 changes: 98 additions & 1 deletion toolkit/detect_entity_networks/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,105 @@
# Licensed under the MIT license. See LICENSE file in the project.
#

import networkx as nx
import polars as pl

from toolkit.detect_entity_networks.classes import FlagAggregatorType
from toolkit.detect_entity_networks.config import ENTITY_LABEL
from toolkit.detect_entity_networks.prepare_model import (
build_flag_links,
build_flags,
build_groups,
build_main_graph,
generate_attribute_links,
)
from toolkit.helpers.classes import IntelligenceWorkflow
from toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR


class DetectEntityNetworks(IntelligenceWorkflow):
attribute_links = []
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.attribute_links = []
self.attributes_list = []
self.flag_links = []
self.group_links = []
self.graph = None
self.integrated_flags = pl.DataFrame()

def _build_graph(self) -> nx.Graph:
self.graph = build_main_graph(self.attribute_links)

def add_attribute_links(
self, data_df: pl.DataFrame, entity_id_column: str, columns_to_link: list[str]
):
links = generate_attribute_links(data_df, entity_id_column, columns_to_link)
self.attribute_links.extend(links)
self._build_graph()
return self.graph

def add_flag_links(
self,
data_df: pl.DataFrame,
entity_id_column: str,
flag_column: str,
flag_format: FlagAggregatorType,
) -> None:
links = build_flag_links(
data_df,
entity_id_column,
flag_format,
flag_column,
)
self.flag_links.extend(links)

(
self.integrated_flags,
self.max_entity_flags,
self.mean_flagged_flags,
) = build_flags(self.flag_links)

def add_groups(
self,
data_df: pl.DataFrame,
entity_id_column: str,
value_cols: list[str],
) -> None:
self.group_links = build_groups(
value_cols,
data_df,
entity_id_column,
)

def get_model_summary_data(self) -> str:
num_entities = 0
num_attributes = 0
num_flags = 0
groups = set()
for link_list in self.group_links:
for link in link_list:
groups.add(f"{link[1]}{ATTRIBUTE_VALUE_SEPARATOR}{link[2]}")

if self.graph is not None:
all_nodes = self.graph.nodes()
entity_nodes = [node for node in all_nodes if node.startswith(ENTITY_LABEL)]
self.attributes_list = [
node for node in all_nodes if not node.startswith(ENTITY_LABEL)
]
num_entities = len(entity_nodes)
num_attributes = len(all_nodes) - num_entities

if len(self.integrated_flags) > 0:
num_flags = self.integrated_flags["count"].sum()

return {
"entities": num_entities,
"attributes": num_attributes,
"flags": num_flags,
"groups": len(groups),
"links": len(self.graph.edges()),
}

def get_model_summary_value(self):
summary = self.get_model_summary_data()
return f"Number of entities: {summary['entities']}, Number of attributes: {summary['attributes']}, Number of flags: {summary['flags']}, Number of groups: {summary['groups']}, Number of links: {summary['links']}"
13 changes: 5 additions & 8 deletions toolkit/detect_entity_networks/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def generate_attribute_links(
data_df: pl.DataFrame,
entity_id_column: str,
columns_to_link: list[str],
existing_links: list | None = None,
) -> list:
"""
Generate attribute links for the given entity and columns.
Expand All @@ -61,7 +60,7 @@ def generate_attribute_links(
Returns:
list: A list of attribute links.
"""
attribute_links = existing_links or []
attribute_links = []

for value_col in columns_to_link:
data_df = data_df.with_columns([pl.lit(value_col).alias("attribute_col")])
Expand Down Expand Up @@ -107,9 +106,8 @@ def build_flag_links(
entity_col: str,
flag_agg: FlagAggregatorType,
flag_columns: list[str],
existing_flag_links: list | None = None,
) -> list[Any]:
flag_links = existing_flag_links or []
flag_links = []

if entity_col not in df_flag.columns:
msg = f"Column {entity_col} not found in the DataFrame."
Expand All @@ -131,10 +129,10 @@ def build_flag_links(
.to_numpy()
.tolist()
)
if flag_agg == FlagAggregatorType.Instance.value:
if flag_agg == FlagAggregatorType.Instance:
gdf = gdf.with_columns([pl.lit(1).alias("count_col")])
flag_links.extend([[val[0], value_col, val[1], 1] for val in vals])
elif flag_agg == FlagAggregatorType.Count.value:
elif flag_agg == FlagAggregatorType.Count:
flag_links.extend([[val[0], value_col, value_col, val[1]] for val in vals])

return flag_links
Expand Down Expand Up @@ -175,9 +173,8 @@ def build_groups(
value_cols: list[str],
df_groups: pl.DataFrame,
entity_col: str,
existing_group_links: list | None = None,
) -> list[Any]:
group_links = existing_group_links or []
group_links = []

if df_groups.is_empty():
return group_links
Expand Down

0 comments on commit de18d40

Please sign in to comment.