Skip to content

Commit

Permalink
continue refactoring the notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjanovsky committed Nov 14, 2023
1 parent ab55c4b commit 0e4a68f
Show file tree
Hide file tree
Showing 6 changed files with 3,859 additions and 285 deletions.
4,072 changes: 3,821 additions & 251 deletions notebooks/cc/references.ipynb

Large diffs are not rendered by default.

62 changes: 31 additions & 31 deletions notebooks/fixed_sankey_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# type: ignore

# ruff: noqa: UP007
"""
This is a fork of https://github.com/anazalea/pySankey/blob/master/pysankey/sankey.py.
We've had some problems with the plot, mostly related to resizing (likely, I don't remember now).
Expand All @@ -9,7 +9,7 @@
import logging
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -35,7 +35,7 @@ class LabelMismatch(PySankeyException):
LOGGER = logging.getLogger(__name__)


def check_data_matches_labels(labels: Union[List[str], Set[str]], data: Series, side: str) -> None:
def check_data_matches_labels(labels: Union[list[str], set[str]], data: Series, side: str) -> None:
"""Check whether data matches labels.
Raise a LabelMismatch Exception if not."""
if len(labels) > 0:
Expand All @@ -55,19 +55,19 @@ def check_data_matches_labels(labels: Union[List[str], Set[str]], data: Series,


def sankey(
left: Union[List, ndarray, Series],
left: Union[list, ndarray, Series],
right: Union[ndarray, Series],
leftWeight: Optional[ndarray] = None,
rightWeight: Optional[ndarray] = None,
colorDict: Optional[Dict[str, str]] = None,
leftLabels: Optional[List[str]] = None,
rightLabels: Optional[List[str]] = None,
colorDict: Optional[dict[str, str]] = None,
leftLabels: Optional[list[str]] = None,
rightLabels: Optional[list[str]] = None,
aspect: int = 4,
rightColor: bool = False,
fontsize: int = 14,
figureName: Optional[str] = None,
closePlot: bool = False,
figSize: Optional[Tuple[int, int]] = None,
figSize: Optional[tuple[int, int]] = None,
ax: Optional[Any] = None,
) -> Any:
"""
Expand Down Expand Up @@ -158,7 +158,7 @@ def save_image(figureName: Optional[str]) -> None:
LOGGER.info("Sankey diagram generated in '%s'", file_name)


def identify_labels(dataFrame: DataFrame, leftLabels: List[str], rightLabels: List[str]) -> Tuple[ndarray, ndarray]:
def identify_labels(dataFrame: DataFrame, leftLabels: list[str], rightLabels: list[str]) -> tuple[ndarray, ndarray]:
# Identify left labels
if len(leftLabels) == 0:
leftLabels = pd.Series(dataFrame.left.unique()).unique()
Expand All @@ -175,14 +175,14 @@ def identify_labels(dataFrame: DataFrame, leftLabels: List[str], rightLabels: Li
def init_values(
ax: Optional[Any],
closePlot: bool,
figSize: Optional[Tuple[int, int]],
figSize: Optional[tuple[int, int]],
figureName: Optional[str],
left: Union[List, ndarray, Series],
leftLabels: Optional[List[str]],
left: Union[list, ndarray, Series],
leftLabels: Optional[list[str]],
leftWeight: Optional[ndarray],
rightLabels: Optional[List[str]],
rightLabels: Optional[list[str]],
rightWeight: Optional[ndarray],
) -> Tuple[Any, List[str], ndarray, List[str], ndarray]:
) -> tuple[Any, list[str], ndarray, list[str], ndarray]:
deprecation_warnings(closePlot, figSize, figureName)
if ax is None:
ax = plt.gca()
Expand All @@ -202,7 +202,7 @@ def init_values(
return ax, leftLabels, leftWeight, rightLabels, rightWeight


def deprecation_warnings(closePlot: bool, figSize: Optional[Tuple[int, int]], figureName: Optional[str]) -> None:
def deprecation_warnings(closePlot: bool, figSize: Optional[tuple[int, int]], figureName: Optional[str]) -> None:
warn = []
if figureName is not None:
msg = "use of figureName in sankey() is deprecated"
Expand All @@ -223,10 +223,10 @@ def deprecation_warnings(closePlot: bool, figSize: Optional[Tuple[int, int]], fi
)


def determine_widths(dataFrame: DataFrame, leftLabels: ndarray, rightLabels: ndarray) -> Tuple[Dict, Dict]:
def determine_widths(dataFrame: DataFrame, leftLabels: ndarray, rightLabels: ndarray) -> tuple[dict, dict]:
# Determine widths of individual strips
ns_l: Dict = defaultdict()
ns_r: Dict = defaultdict()
ns_l: dict = defaultdict()
ns_r: dict = defaultdict()
for leftLabel in leftLabels:
left_dict = {}
right_dict = {}
Expand All @@ -244,12 +244,12 @@ def determine_widths(dataFrame: DataFrame, leftLabels: ndarray, rightLabels: nda

def draw_vertical_bars(
ax: Any,
colorDict: Union[Dict[str, Tuple[float, float, float]], Dict[str, str]],
colorDict: Union[dict[str, tuple[float, float, float]], dict[str, str]],
fontsize: int,
leftLabels: ndarray,
leftWidths: Dict,
leftWidths: dict,
rightLabels: ndarray,
rightWidths: Dict,
rightWidths: dict,
xMax: float64,
) -> None:
# Draw vertical bars on left and right of each label's section & print label
Expand Down Expand Up @@ -286,8 +286,8 @@ def draw_vertical_bars(


def create_colors(
allLabels: ndarray, colorDict: Optional[Dict[str, str]]
) -> Union[Dict[str, Tuple[float, float, float]], Dict[str, str]]:
allLabels: ndarray, colorDict: Optional[dict[str, str]]
) -> Union[dict[str, tuple[float, float, float]], dict[str, str]]:
# If no colorDict given, make one
if colorDict is None:
colorDict = {}
Expand All @@ -306,7 +306,7 @@ def create_colors(


def _create_dataframe(
left: Union[List, ndarray, Series],
left: Union[list, ndarray, Series],
leftWeight: Union[ndarray, Series],
right: Union[ndarray, Series],
rightWeight: Union[ndarray, Series],
Expand Down Expand Up @@ -336,15 +336,15 @@ def _create_dataframe(

def plot_strips(
ax: Any,
colorDict: Union[Dict[str, Tuple[float, float, float]], Dict[str, str]],
colorDict: Union[dict[str, tuple[float, float, float]], dict[str, str]],
dataFrame: DataFrame,
leftLabels: ndarray,
leftWidths: Dict,
ns_l: Dict,
ns_r: Dict,
leftWidths: dict,
ns_l: dict,
ns_r: dict,
rightColor: bool,
rightLabels: ndarray,
rightWidths: Dict,
rightWidths: dict,
xMax: float64,
) -> None:
# Plot strips
Expand Down Expand Up @@ -380,9 +380,9 @@ def plot_strips(
ax.axis("off")


def _get_positions_and_total_widths(df: DataFrame, labels: ndarray, side: str) -> Tuple[Dict, float64]:
def _get_positions_and_total_widths(df: DataFrame, labels: ndarray, side: str) -> tuple[dict, float64]:
"""Determine positions of label patches and total widths"""
widths: Dict = defaultdict()
widths: dict = defaultdict()
for i, label in enumerate(labels):
label_widths = {}
label_widths[side] = df[df[side] == label][side + "Weight"].sum()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"scipy>=1.9.0",
"networkx",
"pydantic",
"pydantic-settings",
"psutil",
"pytesseract",
]
Expand Down
3 changes: 2 additions & 1 deletion src/sec_certs/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import Literal, Optional

import yaml
from pydantic import AnyHttpUrl, BaseSettings, Field
from pydantic import AnyHttpUrl, Field
from pydantic_settings import BaseSettings


class Configuration(BaseSettings):
Expand Down
3 changes: 2 additions & 1 deletion src/sec_certs/model/references_nlp/annotator_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

Check warning on line 1 in src/sec_certs/model/references_nlp/annotator_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/sec_certs/model/references_nlp/annotator_trainer.py#L1

Added line #L1 was not covered by tests

import logging
from collections.abc import Callable
from functools import partial
from typing import Callable, Final, Literal
from typing import Final, Literal

Check warning on line 6 in src/sec_certs/model/references_nlp/annotator_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/sec_certs/model/references_nlp/annotator_trainer.py#L3-L6

Added lines #L3 - L6 were not covered by tests

import pandas as pd
from datasets import ClassLabel, Dataset, Features, NamedSplit, Value
Expand Down
3 changes: 2 additions & 1 deletion src/sec_certs/model/references_nlp/segment_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import json
import logging
import re
from collections.abc import Iterable
from dataclasses import dataclass
from importlib.resources import files
from pathlib import Path
from typing import Any, Iterable, Literal
from typing import Any, Literal

Check warning on line 11 in src/sec_certs/model/references_nlp/segment_extractor.py

View check run for this annotation

Codecov / codecov/patch

src/sec_certs/model/references_nlp/segment_extractor.py#L3-L11

Added lines #L3 - L11 were not covered by tests

# import langdetect
import numpy as np
Expand Down

0 comments on commit 0e4a68f

Please sign in to comment.