Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Nov 14, 2024
1 parent bb787ca commit 8ee0905
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
14 changes: 7 additions & 7 deletions src/scportrait/data/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def dataset_1() -> Path:
"""Download and extract the example dataset 1 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_1_images"
Expand All @@ -25,7 +25,7 @@ def dataset_2() -> Path:
"""Download and extract the example dataset 2 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_2_images"
Expand All @@ -42,7 +42,7 @@ def dataset_3() -> Path:
"""Download and extract the example dataset 3 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_3_images"
Expand All @@ -59,7 +59,7 @@ def dataset_4() -> Path:
"""Download and extract the example dataset 4 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_4_images"
Expand All @@ -76,7 +76,7 @@ def dataset_5() -> Path:
"""Download and extract the example dataset 5 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_5_images"
Expand All @@ -93,7 +93,7 @@ def dataset_6() -> Path:
"""Download and extract the example dataset 6 images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "example_6_images"
Expand All @@ -110,7 +110,7 @@ def dataset_stitching_example() -> Path:
"""Download and extract the example dataset for stitching images.
Returns:
Path: Path to the downloaded and extracted images.
Path to the downloaded and extracted images.
"""
data_dir = Path(get_data_dir())
save_path = data_dir / "stitching_example"
Expand Down
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/_utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _return_edge_labels_2d(input_map: NDArray) -> list[int]:
.union(set(last_column.flatten()))
)

full_union = {np.uint64(i) for i in full_union}
full_union = set([np.uint64(i) for i in full_union]) # noqa: C403
full_union.discard(0)

return list(full_union)
Expand Down
4 changes: 2 additions & 2 deletions src/scportrait/pipeline/_utils/spatialdata_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_memory(item: xarray.DataArray) -> bool:


def generate_region_annotation_lookuptable(sdata: SpatialData) -> dict[str, list[tuple[str, TableModel]]]:
"""Generate lookup table for region annotations.
"""Generate a lookup table for the region annotation tables contained in a SpatialData object ordered according to the region they annotate.
Args:
sdata: SpatialData object to process
Expand Down Expand Up @@ -75,7 +75,7 @@ def remap_region_annotation_table(table: TableModel, region_name: str) -> TableM
table.obs["region"] = table.obs["region"].astype("category")

if "spatialdata_attrs" in table.uns:
del table.uns["spatialdata_attrs"]
del table.uns["spatialdata_attrs"] # remove the spatialdata attributes so that the table can be re-written

return TableModel.parse(table, region_key="region", region=region_name, instance_key="cell_id")

Expand Down
6 changes: 2 additions & 4 deletions src/scportrait/tools/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def combine_datasets_balanced(
elements = [len(el) for el in list_of_datasets]
rows = np.arange(len(list_of_datasets))

# create dataset fraction array
# create dataset fraction array of len(list_of_datasets)
mat = csr_matrix((elements, (rows, class_labels))).toarray()
cells_per_class = np.sum(mat, axis=0)
normalized = mat / cells_per_class
Expand Down Expand Up @@ -168,9 +168,7 @@ def split_dataset_fractions(
else:
residual_size = len(dataset) - train_size - test_size - val_size
if residual_size < 0:
raise ValueError(
f"Dataset with length {len(dataset)} is too small to be split into " f"requested sizes"
)
raise ValueError(f"Dataset with length {len(dataset)} is too small to be split into requested sizes")

if seed is not None:
gen = torch.Generator()
Expand Down

0 comments on commit 8ee0905

Please sign in to comment.