diff --git a/.gitignore b/.gitignore
index a4ea5ff..a78a9e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,9 @@ wheels/
# venv
.venv
example_inputs
+.idea
+*.log
+.vscode
# IDE files
.idea/
@@ -19,3 +22,4 @@ src/.DS_Store
# in progress / trials
scripts/
+/.vscode/settings.json
diff --git a/README.md b/README.md
index d54485a..c6589c1 100644
--- a/README.md
+++ b/README.md
@@ -86,33 +86,55 @@ visualization:
## Models
-CitySeg currently supports OneFormer models. The verified models include:
+CitySeg currently supports Mask2Former and BEIT models. The verified models include:
-- `shi-labs/oneformer_ade20k_swin_large`
-- `shi-labs/oneformer_cityscapes_swin_large`
-- `shi-labs/oneformer_ade20k_dinat_large`
-- `shi-labs/oneformer_cityscapes_dinat_large`
+- "facebook/mask2former-swin-large-cityscapes-semantic"
+- "facebook/mask2former-swin-large-mapillary-vistas-semantic"
+- "facebook/maskformer-swin-small-ade" (sort of, this often leads to segfaults. Recommend using `disable_tqdm` in the config.)
+- "microsoft/beit-large-finetuned-ade-640-640"
+- "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
+- "zoheb/mit-b5-finetuned-sidewalk-semantic"
+- "nickmuchi/segformer-b4-finetuned-segments-sidewalk"
+
+Mask2Former are by far the most stable.
+
+Some models which seem to load correctly but continually produce segfault errors on my machine are:
+
+- "facebook/maskformer-swin-large-ade"
+- "nvidia/segformer-b5-finetuned-ade-640-640"
+- "nvidia/segformer-b0-finetuned-cityscapes-1024-1024"
+- "zoheb/mit-b5-finetuned-sidewalk-semantic" (use `model_type: segformer` in the config)
+
+Confirmed not to work due to issues with the Hugging Face pipeline:
+
+- "shi-labs/oneformer_ade20k_dinat_large"
**Note on `dinat` models:** The `dinat` backbone models require the `natten` package, which may have installation issues on some systems. These models are also significantly slower than the `swin` backbone models, especially when forced to run on CPU. However, they may produce better quality outputs in some cases.
## Project Structure
-The project is organized into several Python modules:
+The project is organized into several Python modules, each serving a specific purpose within the CitySeg pipeline:
+
+- `main.py`: Entry point of the application, responsible for initializing and running the segmentation pipeline.
+- `config.py`: Defines configuration classes and handles loading and validating configuration settings.
+- `pipeline.py`: Implements the core segmentation pipeline, including model loading and inference.
+- `processors.py`: Contains classes for processing images, videos, and directories, managing the segmentation workflow.
+- `segmentation_analyzer.py`: Provides functionality for analyzing segmentation results, including computing statistics and generating reports.
+- `video_file_iterator.py`: Implements an iterator for efficiently processing multiple video files in a directory.
+- `visualization_handler.py`: Handles the visualization of segmentation results using color palettes.
+- `file_handler.py`: Manages file operations related to saving and loading segmentation data and metadata.
+- `utils.py`: Provides utility functions for various tasks, including data handling and logging.
+- `palettes.py`: Defines color palettes for different datasets used in segmentation.
+- `exceptions.py`: Custom exception classes for error handling throughout the pipeline.
-- `main.py`: Entry point of the application
-- `config.py`: Defines configuration classes for the pipeline
-- `pipeline.py`: Implements the core segmentation pipeline
-- `processors.py`: Contains classes for processing images, videos, and directories
-- `utils.py`: Provides utility functions for analysis, file operations, and logging
-- `palettes.py`: Defines color palettes for different datasets
-- `exceptions.py`: Custom exception classes for error handling
+This modular structure allows for easy maintenance and extension of the CitySeg pipeline, facilitating the addition of new features and models.
## Logging
The pipeline uses the `loguru` library for flexible and configurable logging. You can set the log level and enable verbose output using command-line arguments:
```
-python main.py --config path/to/your/config.yaml --log-level INFO --verbose
+python main.py --config path/to/your/config.yaml --log-level INFO # or DEBUG, WARNING, ERROR, CRITICAL or --verbose
```
Logs are output to both the console and a file (`segmentation.log`). The file log is in JSON format for easy parsing and analysis.
@@ -136,4 +158,4 @@ CitySeg is released under the BSD 3-Clause License. See the `LICENSE` file for d
## Contact
-For support or inquiries, please open an issue on the GitHub repository or contact [Your Name/Email].
\ No newline at end of file
+For support or inquiries, please open an issue on the GitHub repository or contact Andrew Mitchell.
\ No newline at end of file
diff --git a/docs/api/handlers.md b/docs/api/handlers.md
new file mode 100644
index 0000000..1f5884f
--- /dev/null
+++ b/docs/api/handlers.md
@@ -0,0 +1,10 @@
+# Handlers
+
+::: cityseg.file_handler
+
+
+::: cityseg.visualization_handler
+ options:
+ show_root_heading: true
+ members: true
+ parameter_headings: true
diff --git a/docs/api/processors.md b/docs/api/processors.md
index 10f3e05..adb361c 100644
--- a/docs/api/processors.md
+++ b/docs/api/processors.md
@@ -1,6 +1,13 @@
# Processors Module
::: cityseg.processors
- options:
- members: true
- parameter_headings: true
+
+::: cityseg.processing_plan.ProcessingPlan
+ options:
+ show_root_heading: true
+ show_root_full_path: false
+
+::: cityseg.video_file_iterator.VideoFileIterator
+ options:
+ show_root_heading: true
+ show_root_full_path: false
\ No newline at end of file
diff --git a/docs/api/segments_analysis.md b/docs/api/segments_analysis.md
new file mode 100644
index 0000000..a81a25a
--- /dev/null
+++ b/docs/api/segments_analysis.md
@@ -0,0 +1 @@
+::: cityseg.segmentation_analyzer
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index 76abae4..7dfb1a3 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -32,12 +32,19 @@ For more detailed information on how to use CitySeg, check out our [Getting Star
## Project Structure
-CitySeg is organized into several Python modules:
+The project is organized into several Python modules, each serving a specific purpose within the CitySeg pipeline:
+
+- `main.py`: Entry point of the application, responsible for initializing and running the segmentation pipeline.
+- `config.py`: Defines configuration classes and handles loading and validating configuration settings.
+- `pipeline.py`: Implements the core segmentation pipeline, including model loading and inference.
+- `processors.py`: Contains classes for processing images, videos, and directories, managing the segmentation workflow.
+- `segmentation_analyzer.py`: Provides functionality for analyzing segmentation results, including computing statistics and generating reports.
+- `video_file_iterator.py`: Implements an iterator for efficiently processing multiple video files in a directory.
+- `visualization_handler.py`: Handles the visualization of segmentation results using color palettes.
+- `file_handler.py`: Manages file operations related to saving and loading segmentation data and metadata.
+- `utils.py`: Provides utility functions for various tasks, including data handling and logging.
+- `palettes.py`: Defines color palettes for different datasets used in segmentation.
+- `exceptions.py`: Custom exception classes for error handling throughout the pipeline.
-- `config.py`: Configuration classes for the pipeline
-- `pipeline.py`: Core segmentation pipeline implementation
-- `processors.py`: Classes for processing images, videos, and directories
-- `utils.py`: Utility functions for analysis, file operations, and logging
-- `exceptions.py`: Custom exception classes for error handling
For detailed API documentation, visit our [API Reference](api/config.md) section.
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index 4445aa2..88122bb 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -42,6 +42,8 @@ nav:
- Config: api/config.md
- Pipeline: api/pipeline.md
- Processors: api/processors.md
+ - Segments Analysis: api/segments_analysis.md
+ - Handlers: api/handlers.md
- Utils: api/utils.md
- Palettes: api/palettes.md
- Exceptions: api/exceptions.md
@@ -104,6 +106,7 @@ plugins:
merge_init_into_class: true
ignore_init_summary: true
show_labels: false
+ parameter_headings: true
show_if_no_docstring: false
docstring_section_style: spacy
diff --git a/pyproject.toml b/pyproject.toml
index e110bff..e0c1b74 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "cityseg"
-version = "0.2.11"
+version = "0.3.0"
description = "A flexible and efficient semantic segmentation pipeline for processing images and videos"
authors = [
{ name = "Andrew Mitchell", email = "mitchellacoustics15@gmail.com" }
@@ -42,9 +42,8 @@ dev-dependencies = [
"jupyter>=1.0.0",
"pytest>=8.3.2",
"rich[jupyter]>=13.7.1",
- "natten==0.17.1",
+ # "natten==0.17.1",
"matplotlib>=3.9.1",
- "yappi>=1.6.0",
]
[tool.hatch.metadata]
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 546a618..da34caa 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -230,7 +230,6 @@ mpmath==1.3.0
# via sympy
mypy-extensions==1.0.0
# via black
-natten==0.17.1
nbclient==0.10.0
# via nbconvert
nbconvert==7.16.4
@@ -280,7 +279,6 @@ packaging==24.1
# via jupyterlab-server
# via matplotlib
# via mkdocs
- # via natten
# via nbconvert
# via pytest
# via qtconsole
@@ -422,7 +420,6 @@ tokenizers==0.19.1
# via transformers
torch==2.4.0
# via cityseg
- # via natten
# via torchvision
torchvision==0.19.0
# via cityseg
@@ -479,4 +476,3 @@ websocket-client==1.8.0
# via jupyter-server
widgetsnbextension==4.0.11
# via ipywidgets
-yappi==1.6.0
diff --git a/src/cityseg/SemanticSidewalk_id2label.json b/src/cityseg/SemanticSidewalk_id2label.json
new file mode 100644
index 0000000..7c6f74f
--- /dev/null
+++ b/src/cityseg/SemanticSidewalk_id2label.json
@@ -0,0 +1 @@
+{"0": "unlabeled", "1": "flat-road", "2": "flat-sidewalk", "3": "flat-crosswalk", "4": "flat-cyclinglane", "5": "flat-parkingdriveway", "6": "flat-railtrack", "7": "flat-curb", "8": "human-person", "9": "human-rider", "10": "vehicle-car", "11": "vehicle-truck", "12": "vehicle-bus", "13": "vehicle-tramtrain", "14": "vehicle-motorcycle", "15": "vehicle-bicycle", "16": "vehicle-caravan", "17": "vehicle-cartrailer", "18": "construction-building", "19": "construction-door", "20": "construction-wall", "21": "construction-fenceguardrail", "22": "construction-bridge", "23": "construction-tunnel", "24": "construction-stairs", "25": "object-pole", "26": "object-trafficsign", "27": "object-trafficlight", "28": "nature-vegetation", "29": "nature-terrain", "30": "sky", "31": "void-ground", "32": "void-dynamic", "33": "void-static", "34": "void-unclear"}
\ No newline at end of file
diff --git a/src/cityseg/__init__.py b/src/cityseg/__init__.py
index 5ef8d97..74437b4 100644
--- a/src/cityseg/__init__.py
+++ b/src/cityseg/__init__.py
@@ -3,8 +3,7 @@
This package provides a flexible and efficient semantic segmentation pipeline
for processing images and videos. It supports multiple segmentation models
-and datasets, with capabilities for tiling large images, mixed-precision
-processing, and comprehensive result analysis.
+and datasets.
Main components:
- Config: Configuration class for the pipeline
@@ -20,27 +19,36 @@
For detailed usage instructions, please refer to the package documentation.
"""
-__version__ = "0.2.0"
+__version__ = "0.2.12"
from . import palettes
from .config import Config
from .exceptions import ConfigurationError, InputError, ModelError, ProcessingError
+from .file_handler import FileHandler
from .pipeline import SegmentationPipeline, create_segmentation_pipeline
+from .processing_plan import ProcessingPlan
from .processors import DirectoryProcessor, SegmentationProcessor, create_processor
-from .utils import analyze_segmentation_map, setup_logging
+from .segmentation_analyzer import SegmentationAnalyzer
+from .utils import setup_logging
+from .video_file_iterator import VideoFileIterator
+from .visualization_handler import VisualizationHandler
__all__ = [
"Config",
"SegmentationPipeline",
"create_segmentation_pipeline",
"SegmentationProcessor",
+ "SegmentationAnalyzer",
"DirectoryProcessor",
"create_processor",
"ConfigurationError",
"InputError",
"ModelError",
"ProcessingError",
- "analyze_segmentation_map",
"setup_logging",
"palettes",
+ "FileHandler",
+ "VisualizationHandler",
+ "ProcessingPlan",
+ "VideoFileIterator",
]
diff --git a/src/cityseg/config.py b/src/cityseg/config.py
index 8cb08df..b465c01 100644
--- a/src/cityseg/config.py
+++ b/src/cityseg/config.py
@@ -13,6 +13,7 @@
from typing import Any, Dict, List, Optional, Union
import yaml
+from loguru import logger
class InputType(Enum):
@@ -36,14 +37,45 @@ class ModelConfig:
"""
name: str
- model_type: Optional[str] = (
- None # Can be 'oneformer', 'mask2former', or None for auto-detection
- )
+ model_type: Optional[str] = None
max_size: Optional[int] = None
device: Optional[str] = None
+ dataset: Optional[str] = None
+ num_workers: Optional[int] = 8
+ pipe_batch: Optional[int] = 1
- # TODO: impelement model_type auto-detection
- # TODO: implement device auto-detection
+ def __post_init__(self):
+ """
+ Post-initialization method to set up the model type if not provided.
+ """
+ self.auto_detect_model_type()
+ if self.device == "mps" and self.num_workers > 0 or self.num_workers is None:
+ logger.warning(
+ "MPS is not compatible with multiple workers in pytorch. Setting num_workers to 0."
+ )
+ self.num_workers = 0
+
+ def auto_detect_model_type(self):
+ """
+ Automatically detect the model type from the model name if not provided.
+ """
+
+ def auto_model_type(model_name: str) -> str:
+ return model_name.split("/")[-1].split("-")[0]
+
+ if self.model_type is None:
+ try:
+ self.model_type = auto_model_type(self.name)
+ except IndexError:
+ logger.warning(
+ "Unable to auto-detect model type from the model name and none provided."
+ )
+ return
+ logger.info(f"Auto-detected model type: {self.model_type}")
+ elif self.model_type != auto_model_type(self.name):
+ logger.warning(
+ f"Model type does not match auto-detected model type. Using provided model type: {self.model_type}"
+ )
@dataclass
@@ -83,6 +115,7 @@ class Config:
visualization (VisualizationConfig): The visualization configuration.
input_type (InputType): The type of input (automatically determined).
force_reprocess (bool): Whether to force reprocessing of existing results.
+ disable_tqdm (bool): Whether to disable the progress bar display.
"""
input: Union[Path, str]
@@ -100,6 +133,7 @@ class Config:
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
input_type: InputType = field(init=False)
force_reprocess: bool = False
+ disable_tqdm: bool = False
def __post_init__(self):
"""
@@ -232,6 +266,7 @@ def from_yaml(cls, config_path: Path) -> "Config":
analyze_results=config_dict.get("analyze_results", True),
visualization=vis_config,
force_reprocess=config_dict.get("force_reprocess", False),
+ disable_tqdm=config_dict.get("disable_tqdm", False),
)
def to_dict(self) -> Dict[str, Any]:
@@ -257,6 +292,7 @@ def to_dict(self) -> Dict[str, Any]:
"visualization": asdict(self.visualization),
"input_type": self.input_type.value,
"force_reprocess": self.force_reprocess,
+ "disable_tqdm": self.disable_tqdm,
}
diff --git a/src/cityseg/config.yaml b/src/cityseg/config.yaml
index 40fa552..197857b 100644
--- a/src/cityseg/config.yaml
+++ b/src/cityseg/config.yaml
@@ -6,21 +6,24 @@ ignore_files: null # Optional: list of file names to ignore
# Model configuration
model:
- name: "facebook/mask2former-swin-large-mapillary-vistas-semantic"
- model_type: "mask2former" # Optional: can be 'oneformer', 'mask2former', or null for auto-detection
+ name: "facebook/mask2former-swin-large-cityscapes-semantic"
+ model_type: null # Optional: can be 'beit', 'mask2former', or null for auto-detection
max_size: null # Optional: maximum size for input images/frames
device: "mps" # Options: "cuda", "cpu", "mps", or null for auto-detection
+ dataset: "semantic-sidewalk" # Optional: dataset name for model-specific postprocessing
+ num_workers: 0 # Number of workers for data loading
+ pipe_batch: 5 # Number of frames to process in each batch. Recommend setting this equal to batch_size below.
# Processing configuration
-frame_step: 10 # Process every 5th frame
+frame_step: 1 # Process every 5th frame
batch_size: 5 # Number of frames to process in each batch
output_fps: null # Optional: FPS for output video (if different from input)
# Output options
-save_raw_segmentation: false
+save_raw_segmentation: true
save_colored_segmentation: true
save_overlay: true
-analyze_results: false
+analyze_results: true
# Visualization configuration
visualization:
@@ -29,3 +32,4 @@ visualization:
# Advanced options
force_reprocess: false # Set to true to reprocess even if output files exist
+disable_tqdm: false # Set to true to disable progress bars. In some cases, tqdm seems to lead to segfaults.
\ No newline at end of file
diff --git a/src/cityseg/file_handler.py b/src/cityseg/file_handler.py
new file mode 100644
index 0000000..1317020
--- /dev/null
+++ b/src/cityseg/file_handler.py
@@ -0,0 +1,186 @@
+"""
+This module provides a class for handling file operations related to segmentation data.
+
+It includes functionalities for saving and loading segmentation data in HDF files,
+verifying the integrity of HDF and video files, and checking the validity of analysis files.
+
+Classes:
+ FileHandler: A class for handling file operations related to segmentation data and metadata.
+"""
+
+import json
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+import cv2
+import h5py
+import numpy as np
+from loguru import logger
+
+from .config import Config
+
+
+class FileHandler:
+ """
+ A class for handling file operations related to segmentation data and metadata.
+
+ This class provides methods for saving and loading segmentation data in HDF files,
+ verifying the integrity of HDF and video files, and checking analysis files.
+
+ Methods:
+ save_hdf_file: Saves segmentation data and metadata to an HDF file.
+ load_hdf_file: Loads segmentation data and metadata from an HDF file.
+ verify_hdf_file: Verifies the integrity of an HDF file.
+ verify_video_file: Verifies the integrity of a video file.
+ verify_analysis_files: Verifies the analysis files for counts and percentages.
+ """
+
+ @staticmethod
+ def save_hdf_file(
+ file_path: Path, segmentation_data: np.ndarray, metadata: Dict[str, Any]
+ ) -> None:
+ """
+ Saves segmentation data and metadata to an HDF file.
+
+ Args:
+ file_path (Path): Path to the HDF file.
+ segmentation_data (np.ndarray): Segmentation data to be saved.
+ metadata (Dict[str, Any]): Metadata associated with the segmentation data.
+ """
+ with h5py.File(file_path, "w") as f:
+ f.create_dataset("segmentation", data=segmentation_data, compression="gzip")
+ if "palette" in metadata and isinstance(metadata["palette"], np.ndarray):
+ metadata["palette"] = metadata["palette"].tolist()
+ json_metadata = json.dumps(metadata)
+ f.create_dataset("metadata", data=json_metadata)
+
+ @staticmethod
+ def load_hdf_file(file_path: Path) -> Tuple[h5py.File, Dict[str, Any]]:
+ """
+ Loads segmentation data and metadata from an HDF file.
+
+ Args:
+ file_path (Path): Path to the HDF file.
+
+ Returns:
+ Tuple[h5py.File, Dict[str, Any]]: Loaded HDF file and metadata.
+ """
+ hdf_file = h5py.File(file_path, "r")
+ json_metadata = hdf_file["metadata"][()]
+ metadata = json.loads(json_metadata)
+ if "palette" in metadata and isinstance(metadata["palette"], list):
+ metadata["palette"] = np.array(metadata["palette"], np.uint8)
+ return hdf_file, metadata
+
+ @staticmethod
+ def verify_hdf_file(file_path: Path, config: Config) -> bool:
+ """
+ Verifies the integrity of an HDF file.
+
+ Args:
+ file_path (Path): Path to the HDF file.
+ config (Config): Configuration object for comparison.
+
+ Returns:
+ bool: True if the HDF file is valid and up-to-date, False otherwise.
+ """
+ try:
+ with h5py.File(file_path, "r") as f:
+ if "segmentation" not in f or "metadata" not in f:
+ logger.warning(
+ f"HDF file at {file_path} is missing required datasets"
+ )
+ return False
+
+ json_metadata = f["metadata"][()]
+ metadata = json.loads(json_metadata)
+
+ if metadata.get("frame_step") != config.frame_step:
+ logger.warning(
+ f"HDF file frame step ({metadata.get('frame_step')}) does not match current config ({config.frame_step})"
+ )
+ return False
+
+ segmentation_data = f["segmentation"]
+ if len(segmentation_data) == 0:
+ logger.warning(
+ f"HDF file at {file_path} contains no segmentation data"
+ )
+ return False
+
+ first_frame = segmentation_data[0]
+ last_frame = segmentation_data[-1]
+ if first_frame.shape != last_frame.shape:
+ logger.warning(
+ f"Inconsistent frame shapes in HDF file at {file_path}"
+ )
+ return False
+
+ logger.debug(f"HDF file at {file_path} is valid and up-to-date")
+ return True
+ except Exception as e:
+ logger.error(f"Error verifying HDF file at {file_path}: {str(e)}")
+ return False
+
+ @staticmethod
+ def verify_video_file(file_path: Path) -> bool:
+ """
+ Verifies the integrity of a video file.
+
+ Args:
+ file_path (Path): Path to the video file.
+
+ Returns:
+ bool: True if the video file is valid and up-to-date, False otherwise.
+ """
+ try:
+ cap = cv2.VideoCapture(str(file_path))
+ if not cap.isOpened():
+ logger.warning(f"Unable to open video file at {file_path}")
+ return False
+
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ ret, first_frame = cap.read()
+ if not ret:
+ logger.warning(
+ f"Unable to read first frame from video file at {file_path}"
+ )
+ return False
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
+ ret, last_frame = cap.read()
+ if not ret:
+ logger.warning(
+ f"Unable to read last frame from video file at {file_path}"
+ )
+ return False
+
+ cap.release()
+ logger.debug(f"Video file at {file_path} is valid and up-to-date")
+ return True
+ except Exception as e:
+ logger.error(f"Error verifying video file at {file_path}: {str(e)}")
+ return False
+
+ @staticmethod
+ def verify_analysis_files(counts_file: Path, percentages_file: Path) -> bool:
+ """
+ Verifies the analysis files for counts and percentages.
+
+ Args:
+ counts_file (Path): Path to the counts file.
+ percentages_file (Path): Path to the percentages file.
+
+ Returns:
+ bool: True if the analysis files are valid, False otherwise.
+ """
+ try:
+ if counts_file.stat().st_size == 0 or percentages_file.stat().st_size == 0:
+ logger.info("One or both analysis files are empty")
+ return False
+ return True
+ except Exception as e:
+ logger.error(f"Error verifying analysis files: {str(e)}")
+ return False
diff --git a/src/cityseg/main.py b/src/cityseg/main.py
index d51d475..6f2321a 100644
--- a/src/cityseg/main.py
+++ b/src/cityseg/main.py
@@ -75,5 +75,7 @@ def main() -> None:
import warnings
warnings.filterwarnings("ignore")
+
+ # import sys
# sys.argv = ["", "config.yaml", "--verbose"]
main()
diff --git a/src/cityseg/pipeline.py b/src/cityseg/pipeline.py
index 4a23074..f6093a1 100644
--- a/src/cityseg/pipeline.py
+++ b/src/cityseg/pipeline.py
@@ -6,19 +6,29 @@
create detailed segmentation maps with associated metadata.
"""
+import json
+import logging
+import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
+from loguru import logger
from PIL.Image import Image
from transformers import (
AutoImageProcessor,
AutoModelForSemanticSegmentation,
+ AutoProcessor,
+ BeitForSemanticSegmentation,
ImageSegmentationPipeline,
Mask2FormerForUniversalSegmentation,
- OneFormerProcessor,
+ MaskFormerForInstanceSegmentation,
+ OneFormerForUniversalSegmentation,
+ SegformerForSemanticSegmentation,
)
+from .config import ModelConfig
+
class SegmentationPipeline(ImageSegmentationPipeline):
"""
@@ -86,8 +96,9 @@ def create_single_segmentation_map(
"palette": self.palette,
}
+ @staticmethod
def _is_single_image_result(
- self, result: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
+ result: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
) -> bool:
"""
Determine if the result is for a single image or multiple images.
@@ -127,8 +138,9 @@ def __call__(
Returns:
List[Dict[str, Any]]: A list of dictionaries containing segmentation maps and metadata.
"""
- result = super().__call__(images, **kwargs)
-
+ # logger.debug("Pass image(s) up to HF pipeline...")
+ result = super().__call__(images, subtask="semantic", **kwargs)
+ # logger.debug("Received result from HF pipeline")
if self._is_single_image_result(result):
return [
self.create_single_segmentation_map(
@@ -144,8 +156,9 @@ def __call__(
]
+@logger.catch
def create_segmentation_pipeline(
- model_name: str, device: Optional[str] = None, **kwargs: Any
+ config: ModelConfig, **kwargs: Any
) -> SegmentationPipeline:
"""
Create and return a SegmentationPipeline instance based on the specified model.
@@ -154,23 +167,63 @@ def create_segmentation_pipeline(
model name, and creates a SegmentationPipeline instance with these components.
Args:
- model_name (str): The name or path of the pre-trained model to use.
- device (Optional[str]): The device to use for processing (e.g., "cpu", "cuda"). If None, it will be automatically determined.
+ config:
**kwargs: Additional keyword arguments to pass to the SegmentationPipeline constructor.
Returns:
SegmentationPipeline: An instance of the SegmentationPipeline class.
"""
+ model_name = config.name
+ model_type = config.model_type
+ device = config.device
+ dataset = config.dataset
+
if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = (
+ "cuda"
+ if torch.cuda.is_available()
+ else "mps"
+ if torch.backends.mps.is_available()
+ else "cpu"
+ )
# Initialize the appropriate model and image processor based on the model name
- if "oneformer" in model_name.lower():
- model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
- image_processor = OneFormerProcessor.from_pretrained(model_name)
- elif "mask2former" in model_name.lower():
+ if "oneformer" == model_type:
+ warnings.warn(
+ "OneFormer models are experimental and may not be fully supported"
+ )
+ try:
+ model = OneFormerForUniversalSegmentation.from_pretrained(model_name)
+ image_processor = AutoProcessor.from_pretrained(model_name)
+ except ValueError as e:
+ logger.error(f"Error loading model: {e}")
+
+ elif "mask2former" == model_type:
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)
+
+ elif "maskformer" == model_type:
+ model = MaskFormerForInstanceSegmentation.from_pretrained(model_name)
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
+
+ elif "beit" == model_type:
+ if device != "cpu":
+ logger.warning(
+ "Beit models are not supported on GPU and will be loaded on CPU"
+ )
+ device = "cpu"
+ model = BeitForSemanticSegmentation.from_pretrained(model_name)
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
+
+ elif "segformer" == model_type:
+ model = SegformerForSemanticSegmentation.from_pretrained(model_name)
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
+
+ if dataset == "sidewalk-semantic":
+ logging.debug("Loading Sidewalk Semantic dataset label mappings...")
+ with open("SemanticSidewalk_id2label.json") as f:
+ id2label = json.load(f)
+ model.config.id2label = id2label
else:
model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)
@@ -180,5 +233,6 @@ def create_segmentation_pipeline(
image_processor=image_processor,
device=device,
subtask="semantic",
+ num_workers=config.num_workers,
**kwargs,
)
diff --git a/src/cityseg/processing_plan.py b/src/cityseg/processing_plan.py
new file mode 100644
index 0000000..c5f0d55
--- /dev/null
+++ b/src/cityseg/processing_plan.py
@@ -0,0 +1,123 @@
+"""
+This module defines a class for creating and managing the processing plan for video segmentation tasks.
+
+It determines which processing steps need to be executed based on the configuration
+and the existence of previously generated outputs.
+
+Classes:
+ ProcessingPlan: A class to create and manage the processing plan for video segmentation tasks.
+"""
+
+from typing import Dict
+
+from loguru import logger
+
+from .config import Config
+from .file_handler import FileHandler
+
+
+class ProcessingPlan:
+ """
+ A class to create and manage the processing plan for video segmentation tasks.
+
+ This class determines which processing steps need to be executed based on the
+ configuration and the existence of previously generated outputs.
+
+ Attributes:
+ config (Config): Configuration object containing processing parameters.
+ plan (Dict[str, bool]): A dictionary indicating which processing steps to execute.
+ """
+
+ def __init__(self, config: Config):
+ """
+ Initializes the ProcessingPlan with the given configuration.
+
+ Args:
+ config (Config): Configuration object for the processing plan.
+ """
+ self.config = config
+ self.plan = self._create_processing_plan()
+
+ def _create_processing_plan(self) -> Dict[str, bool]:
+ """
+ Creates a processing plan based on the configuration and existing outputs.
+
+ This method checks if force reprocessing is enabled or if existing outputs
+ are valid to determine which processing steps should be executed.
+
+ Returns:
+ Dict[str, bool]: A dictionary indicating which processing steps to execute.
+ """
+ if self.config.force_reprocess:
+ logger.info("Force reprocessing enabled. All steps will be executed.")
+ return {
+ "process_video": True,
+ "generate_hdf": True,
+ "generate_colored_video": self.config.save_colored_segmentation,
+ "generate_overlay_video": self.config.save_overlay,
+ "analyze_results": self.config.analyze_results,
+ }
+
+ existing_outputs = self._check_existing_outputs()
+
+ plan = {
+ "process_video": not existing_outputs["hdf_file_valid"],
+ "generate_hdf": not existing_outputs["hdf_file_valid"],
+ "generate_colored_video": self.config.save_colored_segmentation
+ and not existing_outputs["colored_video_valid"],
+ "generate_overlay_video": self.config.save_overlay
+ and not existing_outputs["overlay_video_valid"],
+ "analyze_results": self.config.analyze_results
+ and not existing_outputs["analysis_files_valid"],
+ }
+
+ logger.debug(f"Created processing plan: {plan}")
+ return plan
+
+ def _check_existing_outputs(self) -> Dict[str, bool]:
+ """
+ Checks the validity of existing output files.
+
+ This method verifies the existence and validity of the HDF file, colored video,
+ overlay video, and analysis files.
+
+ Returns:
+ Dict[str, bool]: A dictionary indicating the validity of existing outputs.
+ """
+ output_path = self.config.get_output_path()
+ hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5")
+ colored_video_path = output_path.with_name(f"{output_path.stem}_colored.mp4")
+ overlay_video_path = output_path.with_name(f"{output_path.stem}_overlay.mp4")
+ counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv")
+ percentages_file = output_path.with_name(
+ f"{output_path.stem}_category_percentages.csv"
+ )
+
+ results = {
+ "hdf_file_valid": False,
+ "colored_video_valid": False,
+ "overlay_video_valid": False,
+ "analysis_files_valid": False,
+ }
+
+ if hdf_path.exists():
+ results["hdf_file_valid"] = FileHandler.verify_hdf_file(
+ hdf_path, self.config
+ )
+
+ if colored_video_path.exists():
+ results["colored_video_valid"] = FileHandler.verify_video_file(
+ colored_video_path
+ )
+
+ if overlay_video_path.exists():
+ results["overlay_video_valid"] = FileHandler.verify_video_file(
+ overlay_video_path
+ )
+
+ if counts_file.exists() and percentages_file.exists():
+ results["analysis_files_valid"] = FileHandler.verify_analysis_files(
+ counts_file, percentages_file
+ )
+
+ return results
diff --git a/src/cityseg/processors.py b/src/cityseg/processors.py
index fbf2753..f165073 100644
--- a/src/cityseg/processors.py
+++ b/src/cityseg/processors.py
@@ -13,465 +13,248 @@
import time
from datetime import datetime
from pathlib import Path
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Tuple, Union
import cv2
import h5py
import numpy as np
-import pandas as pd
from loguru import logger
from PIL import Image
from .config import Config, ConfigHasher, InputType
from .exceptions import InputError, ProcessingError
+from .file_handler import FileHandler
from .pipeline import create_segmentation_pipeline
-from .utils import (
- get_video_files,
-)
+from .processing_plan import ProcessingPlan
+from .segmentation_analyzer import SegmentationAnalyzer
+from .utils import get_segmentation_data_batch, tqdm_context
+from .video_file_iterator import VideoFileIterator
+from .visualization_handler import VisualizationHandler
-class ProcessingHistory:
+class ImageProcessor:
"""
- A class to manage and persist the processing history of video segmentation tasks.
+ Processes individual images using semantic segmentation models.
- This class keeps track of individual processing runs, including timestamps,
- configuration hashes, and which outputs were generated in each run.
+ This class handles the segmentation of single images, including saving results
+ and analyzing the segmentation output.
Attributes:
- runs (List[Dict]): A list of dictionaries, each representing a processing run.
+ config (Config): Configuration object containing processing parameters.
+ pipeline: Segmentation pipeline for processing images.
+ file_handler (FileHandler): Handles file operations.
+ visualizer (VisualizationHandler): Handles visualization of segmentation results.
+ analyzer (SegmentationAnalyzer): Analyzes segmentation results.
"""
- def __init__(self):
- """Initialize an empty ProcessingHistory."""
- self.runs = []
-
- def add_run(
- self, timestamp: str, config_hash: str, outputs_generated: Dict[str, bool]
- ) -> None:
- """
- Add a new processing run to the history.
-
- Args:
- timestamp (str): The timestamp of the processing run.
- config_hash (str): A hash of the relevant configuration used for the run.
- outputs_generated (Dict[str, bool]): A dictionary indicating which outputs were generated.
- """
- self.runs.append(
- {
- "timestamp": timestamp,
- "config_hash": config_hash,
- "outputs_generated": outputs_generated,
- }
- )
- logger.debug(f"Added new processing run to history. Timestamp: {timestamp}")
-
- def save(self, file_path: Path) -> None:
+ def __init__(self, config: Config):
"""
- Save the processing history to a JSON file.
+ Initializes the ImageProcessor with the given configuration.
Args:
- file_path (Path): The path where the history file will be saved.
+ config (Config): Configuration object for the processor.
"""
- with file_path.open("w") as f:
- json.dump({"runs": self.runs}, f)
- logger.debug(f"Saved processing history to {file_path}")
+ self.config = config
+ self.pipeline = create_segmentation_pipeline(config.model)
+ self.file_handler = FileHandler()
+ self.visualizer = VisualizationHandler()
+ self.analyzer = SegmentationAnalyzer()
- @classmethod
- def load(cls, file_path: Path) -> "ProcessingHistory":
+ def process(self) -> None:
"""
- Load a processing history from a JSON file.
+ Processes the input image according to the configuration.
- Args:
- file_path (Path): The path to the history file.
+ This method handles the entire image processing pipeline, including
+ segmentation, result saving, and analysis.
- Returns:
- ProcessingHistory: A ProcessingHistory object populated with the loaded data.
+ Raises:
+ ProcessingError: If an error occurs during image processing.
"""
- history = cls()
- if file_path.exists():
- with file_path.open("r") as f:
- data = json.load(f)
- history.runs = data["runs"]
- logger.debug(f"Loaded processing history from {file_path}")
- else:
- logger.info(f"No processing history file found at {file_path}")
- return history
-
+ logger.info(f"Processing image: {self.config.input}")
+ try:
+ image = Image.open(self.config.input).convert("RGB")
+ if self.config.model.max_size:
+ image.thumbnail(
+ (self.config.model.max_size, self.config.model.max_size)
+ )
-class SegmentationProcessor:
- """
- A processor for semantic segmentation of images and videos.
+ result = self.pipeline([image])[0]
- This class handles the segmentation process for individual image and video files,
- including caching of results, visualization generation, and statistical analysis.
+ self._save_results(image, result)
+ self._analyze_results(result["seg_map"])
- Attributes:
- config (Config): Configuration object containing processing parameters.
- pipeline (SegmentationPipeline): The segmentation pipeline used for processing.
- logger (Logger): Logger instance for tracking processing events.
- processing_history (ProcessingHistory): Object to track processing history.
- processing_plan (Dict[str, bool]): Plan determining which processing steps to execute.
- """
+ logger.info("Image processing complete")
+ except Exception as e:
+ logger.exception(f"Error during image processing: {str(e)}")
+ raise ProcessingError(f"Error during image processing: {str(e)}")
- def __init__(self, config: Config):
+ def _save_results(self, image: Image.Image, result: Dict[str, Any]) -> None:
"""
- Initialize the SegmentationProcessor.
+ Saves the segmentation results based on the configuration.
Args:
- config (Config): Configuration object containing processing parameters.
+ image (Image.Image): The original input image.
+ result (Dict[str, Any]): The segmentation result dictionary.
"""
- self.config = config
- self.pipeline = create_segmentation_pipeline(
- model_name=config.model.name,
- device=config.model.device,
- )
- self.palette = (
- self.pipeline.palette
- if self.pipeline.palette is not None
- else self._generate_palette(255)
- ) # Pre-compute palette for up to 256 classes
- self.logger = logger.bind(
- processor_type=self.__class__.__name__,
- input_type=self.config.input_type.value,
- )
- self.processing_history = self._load_processing_history()
- self.processing_plan = self._create_processing_plan()
- self.logger.debug(
- "SegmentationProcessor initialized for input video.",
- video_input=str(self.config.input),
- )
+ output_path = self.config.get_output_path()
- def _load_processing_history(self) -> ProcessingHistory:
- """
- Load the processing history from a file or create a new one if not found.
+ # Save raw segmentation
+ if self.config.save_raw_segmentation:
+ raw_seg_path = output_path.with_name(
+ f"{output_path.stem}_raw_segmentation.png"
+ )
+ Image.fromarray(result["seg_map"].astype(np.uint8)).save(raw_seg_path)
+ logger.info(f"Raw segmentation saved to {raw_seg_path}")
- Returns:
- ProcessingHistory: The loaded or newly created processing history.
- """
- history_file = self._get_history_file_path()
- try:
- history = ProcessingHistory.load(history_file)
- self.logger.debug("Processing history loaded successfully")
- return history
- except Exception as e:
- self.logger.info(
- f"Failed to load processing history: {str(e)}. Starting with a new history."
+ # Save colored segmentation
+ if self.config.save_colored_segmentation:
+ colored_seg_path = output_path.with_name(
+ f"{output_path.stem}_colored_segmentation.png"
+ )
+ colored_seg = self.visualizer.visualize_segmentation(
+ np.array(image), result["seg_map"], result["palette"], colored_only=True
)
- return ProcessingHistory()
+ Image.fromarray(colored_seg).save(colored_seg_path)
+ logger.info(f"Colored segmentation saved to {colored_seg_path}")
+
+ # Save overlay
+ if self.config.save_overlay:
+ overlay_path = output_path.with_name(f"{output_path.stem}_overlay.png")
+ overlay = self.visualizer.visualize_segmentation(
+ np.array(image),
+ result["seg_map"],
+ result["palette"],
+ colored_only=False,
+ )
+ Image.fromarray(overlay).save(overlay_path)
+ logger.info(f"Overlay saved to {overlay_path}")
- def _get_history_file_path(self) -> Path:
+ def _analyze_results(self, seg_map: np.ndarray) -> None:
"""
- Get the file path for the processing history JSON file.
+ Analyzes the segmentation results and saves the analysis.
- Returns:
- Path: The path to the processing history file.
+ Args:
+ seg_map (np.ndarray): The segmentation map to analyze.
"""
output_path = self.config.get_output_path()
- return output_path.with_name(f"{output_path.stem}_processing_history.json")
-
- def _create_processing_plan(self) -> Dict[str, bool]:
- """
- Create a processing plan based on the current configuration and existing outputs.
-
- Returns:
- Dict[str, bool]: A dictionary representing the processing plan.
- """
- if self.config.force_reprocess:
- self.logger.info("Force reprocessing enabled. All steps will be executed.")
- return {
- "process_video": True,
- "generate_hdf": True,
- "generate_colored_video": self.config.save_colored_segmentation,
- "generate_overlay_video": self.config.save_overlay,
- "analyze_results": self.config.analyze_results,
- }
-
- existing_outputs = self._check_existing_outputs()
-
- plan = {
- "process_video": not existing_outputs["hdf_file_valid"],
- "generate_hdf": not existing_outputs["hdf_file_valid"],
- "generate_colored_video": self.config.save_colored_segmentation
- and not existing_outputs["colored_video_valid"],
- "generate_overlay_video": self.config.save_overlay
- and not existing_outputs["overlay_video_valid"],
- "analyze_results": self.config.analyze_results
- and not existing_outputs["analysis_files_valid"],
- }
+ num_categories = self.config.model.num_classes
- self.logger.debug(f"Created processing plan: {plan}")
- return plan
+ analysis = self.analyzer.analyze_segmentation_map(seg_map, num_categories)
- def _check_existing_outputs(self) -> Dict[str, bool]:
- """
- Check the validity of existing output files.
-
- Returns:
- Dict[str, bool]: A dictionary indicating the validity of each output type.
- """
- output_path = self.config.get_output_path()
- hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5")
- colored_video_path = output_path.with_name(f"{output_path.stem}_colored.mp4")
- overlay_video_path = output_path.with_name(f"{output_path.stem}_overlay.mp4")
counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv")
percentages_file = output_path.with_name(
f"{output_path.stem}_category_percentages.csv"
)
- results = {
- "hdf_file_valid": False,
- "colored_video_valid": False,
- "overlay_video_valid": False,
- "analysis_files_valid": False,
- }
+ with open(counts_file, "w", newline="") as f:
+ writer = csv.writer(f)
+ writer.writerow(["category_id", "pixel_count"])
+ for category_id, (pixel_count, _) in analysis.items():
+ writer.writerow([category_id, pixel_count])
- if hdf_path.exists():
- results["hdf_file_valid"] = self._verify_hdf_file(hdf_path)
- self.logger.debug(f"HDF file validity: {results['hdf_file_valid']}")
+ with open(percentages_file, "w", newline="") as f:
+ writer = csv.writer(f)
+ writer.writerow(["category_id", "percentage"])
+ for category_id, (_, percentage) in analysis.items():
+ writer.writerow([category_id, percentage])
- if colored_video_path.exists():
- results["colored_video_valid"] = self._verify_video_file(colored_video_path)
- self.logger.debug(
- f"Colored video validity: {results['colored_video_valid']}"
- )
+ logger.info(f"Category counts saved to {counts_file}")
+ logger.info(f"Category percentages saved to {percentages_file}")
- if overlay_video_path.exists():
- results["overlay_video_valid"] = self._verify_video_file(overlay_video_path)
- self.logger.debug(
- f"Overlay video validity: {results['overlay_video_valid']}"
- )
- if counts_file.exists() and percentages_file.exists():
- results["analysis_files_valid"] = self._verify_analysis_files(
- counts_file, percentages_file
- )
- self.logger.debug(
- f"Analysis files validity: {results['analysis_files_valid']}"
- )
+class VideoProcessor:
+ """
+ Processes video files using semantic segmentation models.
+
+ This class handles the segmentation of video frames, including saving results,
+ generating output videos, and analyzing the segmentation output.
- return results
+ Attributes:
+ config (Config): Configuration object containing processing parameters.
+ pipeline: Segmentation pipeline for processing video frames.
+ processing_plan (ProcessingPlan): Plan for video processing steps.
+ file_handler (FileHandler): Handles file operations.
+ visualizer (VisualizationHandler): Handles visualization of segmentation results.
+ analyzer (SegmentationAnalyzer): Analyzes segmentation results.
+ """
- def _verify_analysis_files(self, counts_file: Path, percentages_file: Path) -> bool:
+ def __init__(self, config: Config):
"""
- Verify the integrity of analysis files.
+ Initializes the VideoProcessor with the given configuration.
Args:
- counts_file (Path): Path to the category counts CSV file.
- percentages_file (Path): Path to the category percentages CSV file.
-
- Returns:
- bool: True if both files are valid, False otherwise.
+ config (Config): Configuration object for the processor.
"""
- try:
- # Perform basic checks on the files
- if counts_file.stat().st_size == 0 or percentages_file.stat().st_size == 0:
- self.logger.warning("One or both analysis files are empty")
- return False
-
- # You could add more sophisticated checks here, such as:
- # - Verifying the number of rows matches the expected frame count
- # - Checking that the headers are correct
- # - Validating that the data is within expected ranges
-
- return True
- except Exception as e:
- self.logger.error(f"Error verifying analysis files: {str(e)}")
- return False
+ self.config = config
+ self.pipeline = create_segmentation_pipeline(config.model)
+ self.processing_plan = ProcessingPlan(config)
+ self.file_handler = FileHandler()
+ self.visualizer = VisualizationHandler()
+ self.analyzer = SegmentationAnalyzer()
+ logger.debug(f"VideoProcessor initialized with config: {config}")
def process(self) -> None:
"""
- Process the input based on its type (image or video).
-
- Raises:
- ValueError: If the input type is not supported.
- """
- if self.config.input_type == InputType.SINGLE_IMAGE:
- self.process_image()
- elif self.config.input_type == InputType.SINGLE_VIDEO:
- self.process_video()
- else:
- raise ValueError(f"Unsupported input type: {self.config.input_type}")
-
- def process_image(self) -> None:
- """
- Process a single image file.
+ Processes the input video according to the configuration and processing plan.
- This method handles loading the image, running it through the segmentation pipeline,
- saving the results, and analyzing the segmentation map.
+ This method handles the entire video processing pipeline, including
+ frame segmentation, result saving, video generation, and analysis.
Raises:
- ProcessingError: If an error occurs during image processing.
- """
- self.logger.info(f"Processing image: {self.config.input}")
- try:
- # Load and preprocess the image
- image = Image.open(self.config.input).convert("RGB")
- if self.config.model.max_size:
- image.thumbnail(
- (self.config.model.max_size, self.config.model.max_size)
- )
-
- # Run segmentation
- result = self.pipeline([image])[0]
-
- # Save and analyze results
- self.save_results(image, result)
- self.analyze_results(result["seg_map"])
-
- self.logger.info("Image processing complete")
- except Exception as e:
- self.logger.exception(f"Error during image processing: {str(e)}")
- raise ProcessingError(f"Error during image processing: {str(e)}")
-
- def process_video(self) -> None:
- """
- Process a single video file according to the current processing plan.
+ ProcessingError: If an error occurs during video processing.
"""
- self.logger.info(f"Processing video: {self.config.input.name}")
+ logger.info(f"Processing video: {self.config.input.name}")
try:
output_path = self.config.get_output_path()
hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5")
- if self.processing_plan["process_video"]:
- self.logger.debug("Executing video frame processing")
- segmentation_data, metadata = self.process_video_frames()
- if self.processing_plan["generate_hdf"]:
- self.logger.debug(
- f"Saving segmentation data to HDF file: {hdf_path}"
+ if self.processing_plan.plan["process_video"]:
+ # logger.debug("Executing video frame processing")
+ segmentation_data, metadata = self._process_video_frames()
+ if self.processing_plan.plan["generate_hdf"]:
+ logger.debug(f"Saving segmentation data to HDF file: {hdf_path}")
+ self.file_handler.save_hdf_file(
+ hdf_path, segmentation_data, metadata
)
- self.save_hdf_file(hdf_path, segmentation_data, metadata)
else:
- self.logger.info(
+ logger.info(
f"Loading existing segmentation data from HDF file: {hdf_path.name}"
)
- hdf_file, metadata = self.load_hdf_file(hdf_path)
+ hdf_file, metadata = self.file_handler.load_hdf_file(hdf_path)
segmentation_data = hdf_file[
"segmentation"
] # This is now a h5py.Dataset
# Generate videos based on the processing plan
if (
- self.processing_plan["generate_colored_video"]
- or self.processing_plan["generate_overlay_video"]
+ self.processing_plan.plan["generate_colored_video"]
+ or self.processing_plan.plan["generate_overlay_video"]
):
self.generate_videos(segmentation_data, metadata)
- # Analyze results if needed
- if self.processing_plan["analyze_results"]:
- self.logger.debug("Analyzing segmentation results")
- self.analyze_results(segmentation_data, metadata)
+ if self.processing_plan.plan["analyze_results"]:
+ logger.debug("Analyzing segmentation results")
+ SegmentationAnalyzer.analyze_results(
+ segmentation_data, metadata, output_path
+ )
- # Update processing history
self._update_processing_history()
- self.logger.info("Video processing complete")
+ logger.info("Video processing complete")
except Exception as e:
- self.logger.exception(f"Error during video processing: {str(e)}")
+ logger.exception(f"Error during video processing: {str(e)}")
raise ProcessingError(f"Error during video processing: {str(e)}")
finally:
if hasattr(self, "hdf_file"):
self.hdf_file.close()
- def _verify_hdf_file(self, file_path: Path) -> bool:
- """
- Verify the integrity and relevance of an existing HDF file.
-
- Args:
- file_path (Path): Path to the HDF file.
-
- Returns:
- bool: True if the file is valid and up-to-date, False otherwise.
- """
- try:
- with h5py.File(file_path, "r") as f:
- if "segmentation" not in f or "metadata" not in f:
- self.logger.warning(
- f"HDF file at {file_path} is missing required datasets"
- )
- return False
-
- json_metadata = f["metadata"][()]
- metadata = json.loads(json_metadata)
-
- if metadata.get("frame_step") != self.config.frame_step:
- self.logger.warning(
- f"HDF file frame step ({metadata.get('frame_step')}) does not match current config ({self.config.frame_step})"
- )
- return False
-
- segmentation_data = f["segmentation"]
- if len(segmentation_data) == 0:
- self.logger.warning(
- f"HDF file at {file_path} contains no segmentation data"
- )
- return False
-
- first_frame = segmentation_data[0]
- last_frame = segmentation_data[-1]
- if first_frame.shape != last_frame.shape:
- self.logger.warning(
- f"Inconsistent frame shapes in HDF file at {file_path}"
- )
- return False
-
- self.logger.debug(f"HDF file at {file_path} is valid and up-to-date")
- return True
- except Exception as e:
- self.logger.error(f"Error verifying HDF file at {file_path}: {str(e)}")
- return False
-
- def _verify_video_file(self, file_path: Path) -> bool:
+ def _process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]:
"""
- Verify the integrity and relevance of an existing video file.
-
- Args:
- file_path (Path): Path to the video file.
+ Processes video frames in batches and returns segmentation data and metadata.
Returns:
- bool: True if the file is valid and up-to-date, False otherwise.
+ Tuple[np.ndarray, Dict[str, Any]]: Segmentation data and metadata.
"""
- try:
- cap = cv2.VideoCapture(str(file_path))
- if not cap.isOpened():
- self.logger.warning(f"Unable to open video file at {file_path}")
- return False
-
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
- ret, first_frame = cap.read()
- if not ret:
- self.logger.warning(
- f"Unable to read first frame from video file at {file_path}"
- )
- return False
-
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
- ret, last_frame = cap.read()
- if not ret:
- self.logger.warning(
- f"Unable to read last frame from video file at {file_path}"
- )
- return False
-
- cap.release()
- self.logger.debug(f"Video file at {file_path} is valid and up-to-date")
- return True
- except Exception as e:
- self.logger.error(f"Error verifying video file at {file_path}: {str(e)}")
- return False
-
- def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]:
- """
- Process video frames to generate segmentation maps.
-
- Returns:
- Tuple[np.ndarray, Dict[str, Any]]: A tuple containing the segmentation data
- and metadata.
- """
- # Lazy import of tqdm_context to avoid circular imports
- from .utils import tqdm_context
-
cap = cv2.VideoCapture(str(self.config.input))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
@@ -479,22 +262,28 @@ def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]:
segmentation_data = []
+ logger.info(f"Processing video frames in batches of {self.config.batch_size}")
with tqdm_context(
- total=total_frames // self.config.frame_step, desc="Processing frames"
+ total=total_frames // self.config.frame_step,
+ desc="Processing frames",
+ disable=self.config.disable_tqdm,
) as pbar:
for batch in self._frame_generator(
cv2.VideoCapture(str(self.config.input))
):
- batch_results = self._process_batch(batch)
+ # logger.debug("Loading batch into pipeline...")
+ batch_results = self.pipeline(batch)
+ # logger.debug("Adding batch results to segmentation data...")
segmentation_data.extend(
[result["seg_map"] for result in batch_results]
)
pbar.update(len(batch))
+ cap.release()
metadata = {
"model_name": self.config.model.name,
"original_video": str(self.config.input.name),
- "palette": self.pipeline.palette.tolist()
+ "palette": np.array(self.pipeline.palette.tolist(), np.uint8)
if self.pipeline.palette is not None
else None,
"label_ids": self.pipeline.model.config.id2label,
@@ -506,159 +295,47 @@ def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]:
return np.array(segmentation_data), metadata
- def _initialize_video_capture(
- self,
- ) -> Tuple[cv2.VideoCapture, int, float, int, int]:
- """
- Initialize video capture and retrieve video properties.
-
- Returns:
- Tuple[cv2.VideoCapture, int, float, int, int]: A tuple containing the video capture object,
- frame count, FPS, width, and height of the video.
- """
- cap = cv2.VideoCapture(str(self.config.input))
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- original_fps = cap.get(cv2.CAP_PROP_FPS)
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- return cap, frame_count, original_fps, width, height
-
- def _initialize_video_writers(
- self, width: int, height: int, fps: float
- ) -> Dict[str, cv2.VideoWriter]:
- """
- Initialize video writers for output videos.
-
- Args:
- width (int): Width of the video frame.
- height (int): Height of the video frame.
- fps (float): Frames per second of the output video.
-
- Returns:
- Dict[str, cv2.VideoWriter]: A dictionary of initialized video writers.
- """
- writers = {}
- output_base = self.config.get_output_path()
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
-
- if self.processing_plan.get("generate_colored_video", False):
- colored_path = output_base.with_name(f"{output_base.stem}_colored.mp4")
- writers["colored"] = cv2.VideoWriter(
- str(colored_path), fourcc, fps, (width, height)
- )
-
- if self.processing_plan.get("generate_overlay_video", False):
- overlay_path = output_base.with_name(f"{output_base.stem}_overlay.mp4")
- writers["overlay"] = cv2.VideoWriter(
- str(overlay_path), fourcc, fps, (width, height)
- )
-
- return writers
-
- def _process_batch(self, batch: List[Image.Image]) -> List[Dict[str, Any]]:
- """
- Process a batch of images through the segmentation pipeline.
-
- Args:
- batch (List[Image.Image]): A list of PIL Image objects to process.
-
- Returns:
- List[Dict[str, Any]]: A list of dictionaries containing segmentation results.
- """
- return self.pipeline(batch)
-
- def _write_output_frames(
- self,
- batch: List[Image.Image],
- batch_results: List[Dict[str, Any]],
- video_writers: Dict[str, cv2.VideoWriter],
- ) -> None:
- """
- Write processed frames to output video files.
-
- Args:
- batch (List[Image.Image]): A list of original PIL Image objects.
- batch_results (List[Dict[str, Any]]): A list of segmentation results.
- video_writers (Dict[str, cv2.VideoWriter]): A dictionary of video writers.
- """
- for pil_image, result in zip(batch, batch_results):
- if "colored" in video_writers:
- colored_seg = self.visualize_segmentation(
- pil_image, result["seg_map"], result["palette"], colored_only=True
- )
- video_writers["colored"].write(
- cv2.cvtColor(np.array(colored_seg), cv2.COLOR_RGB2BGR)
- )
- if "overlay" in video_writers:
- overlay = self.visualize_segmentation(
- pil_image, result["seg_map"], result["palette"], colored_only=False
- )
- video_writers["overlay"].write(
- cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR)
- )
-
- def save_hdf_file(
- self, file_path: Path, segmentation_data: np.ndarray, metadata: Dict[str, Any]
- ) -> None:
- """
- Save segmentation data and metadata to an HDF5 file.
-
- Args:
- file_path (Path): Path to save the HDF5 file.
- segmentation_data (np.ndarray): Array of segmentation maps.
- metadata (Dict[str, Any]): Metadata dictionary.
- """
- with h5py.File(file_path, "w") as f:
- f.create_dataset("segmentation", data=segmentation_data, compression="gzip")
-
- # Convert all metadata to JSON-compatible format
- json_metadata = json.dumps(metadata)
- f.create_dataset("metadata", data=json_metadata)
-
- def load_hdf_file(self, file_path: Path) -> Tuple[h5py.File, Dict[str, Any]]:
- """
- Load and return the HDF5 file handle and metadata.
-
- Args:
- file_path (Path): Path to the HDF5 file.
-
- Returns:
- Tuple[h5py.File, Dict[str, Any]]: A tuple containing the HDF5 file handle
- and metadata.
- """
- self.hdf_file = h5py.File(file_path, "r")
- json_metadata = self.hdf_file["metadata"][()]
- metadata = json.loads(json_metadata)
- if "palette" in metadata:
- metadata["palette"] = np.array(metadata["palette"], np.uint8)
- return self.hdf_file, metadata
-
- def get_segmentation_data_batch(self, start: int, end: int) -> np.ndarray:
+ def _frame_generator(self, cap: cv2.VideoCapture) -> Iterator[List[Image.Image]]:
"""
- Get a batch of segmentation data from the HDF5 file.
+ Generates batches of frames from a video capture object.
Args:
- start (int): Start index of the batch.
- end (int): End index of the batch.
+ cap (cv2.VideoCapture): The video capture object.
- Returns:
- np.ndarray: A batch of segmentation data.
+ Yields:
+ Iterator[List[Image.Image]]: Batches of frames as PIL Image objects.
"""
- return self.hdf_file["segmentation"][start:end]
+ while True:
+ frames = []
+ for _ in range(self.config.batch_size):
+ for _ in range(self.config.frame_step):
+ ret, frame = cap.read()
+ if not ret:
+ break
+ if not ret:
+ break
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(rgb_frame)
+ frames.append(pil_image)
+ if not frames:
+ break
+ yield frames
def generate_videos(
self, segmentation_data: h5py.Dataset, metadata: Dict[str, Any]
) -> None:
"""
- Generate output videos based on the processing plan, using batched processing.
+ Generates output videos based on the processing plan, using batched processing.
+
+ Args:
+ segmentation_data (h5py.Dataset): The segmentation data for all frames.
+ metadata (Dict[str, Any]): Metadata about the video and segmentation.
"""
if not (
- self.processing_plan.get("generate_colored_video", False)
- or self.processing_plan.get("generate_overlay_video", False)
+ self.processing_plan.plan.get("generate_colored_video", False)
+ or self.processing_plan.plan.get("generate_overlay_video", False)
):
- self.logger.info(
- "No video generation required according to the processing plan."
- )
+ logger.info("No video generation required according to the processing plan")
return
start_time = time.time()
@@ -670,47 +347,82 @@ def generate_videos(
output_base = self.config.get_output_path()
video_writers = self._initialize_video_writers(width, height, fps)
- chunk_size = 100 # Adjust this value based on your memory constraints and performance needs
+ chunk_size = 100 # Adjust this value based on available memory
for chunk_start in range(0, len(segmentation_data), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(segmentation_data))
- seg_chunk = self.get_segmentation_data_batch(chunk_start, chunk_end)
+ seg_chunk = get_segmentation_data_batch(
+ segmentation_data, chunk_start, chunk_end
+ )
frames = self._get_video_frames_batch(
cap, chunk_start, chunk_end, metadata["frame_step"]
)
- if self.processing_plan.get("generate_colored_video", False):
- colored_frames = self.visualize_segmentation(
- frames, seg_chunk, metadata["palette"], colored_only=True
+ if self.processing_plan.plan.get("generate_colored_video", False):
+ colored_frames = self.visualizer.visualize_segmentation(
+ frames, seg_chunk, metadata["palette"], colored_only=True
+ )
+ for colored_frame in colored_frames:
+ video_writers["colored"].write(
+ cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
)
- for colored_frame in colored_frames:
- video_writers["colored"].write(
- cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
- )
- if self.processing_plan.get("generate_overlay_video", False):
- overlay_frames = self.visualize_segmentation(
- frames, seg_chunk, metadata["palette"], colored_only=False
+ if self.processing_plan.plan.get("generate_overlay_video", False):
+ overlay_frames = self.visualizer.visualize_segmentation(
+ frames, seg_chunk, metadata["palette"], colored_only=False
+ )
+ for overlay_frame in overlay_frames:
+ video_writers["overlay"].write(
+ cv2.cvtColor(overlay_frame, cv2.COLOR_RGB2BGR)
)
- for overlay_frame in overlay_frames:
- video_writers["overlay"].write(
- cv2.cvtColor(overlay_frame, cv2.COLOR_RGB2BGR)
- )
for writer in video_writers.values():
writer.release()
- cap.release()
- self.logger.debug(
- f"Video generation took {time.time() - start_time:.4f} seconds"
+ cap.release()
+ logger.debug(
+ f"Video generation completed in {time.time() - start_time:.2f} seconds"
)
- self.logger.debug(f"Videos saved to: {output_base}")
+ logger.debug(f"Videos saved to: {output_base}")
+
+ def _initialize_video_writers(
+ self, width: int, height: int, fps: float
+ ) -> Dict[str, cv2.VideoWriter]:
+ """
+ Initializes video writers for output videos.
+
+ Args:
+ width (int): Width of the video frame.
+ height (int): Height of the video frame.
+ fps (float): Frames per second of the output video.
+
+ Returns:
+ Dict[str, cv2.VideoWriter]: A dictionary of initialized video writers.
+ """
+ writers = {}
+ output_base = self.config.get_output_path()
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+
+ if self.processing_plan.plan.get("generate_colored_video", False):
+ colored_path = output_base.with_name(f"{output_base.stem}_colored.mp4")
+ writers["colored"] = cv2.VideoWriter(
+ str(colored_path), fourcc, fps, (width, height)
+ )
+
+ if self.processing_plan.plan.get("generate_overlay_video", False):
+ overlay_path = output_base.with_name(f"{output_base.stem}_overlay.mp4")
+ writers["overlay"] = cv2.VideoWriter(
+ str(overlay_path), fourcc, fps, (width, height)
+ )
+
+ return writers
+ @staticmethod
def _get_video_frames_batch(
- self, cap: cv2.VideoCapture, start: int, end: int, frame_step: int
+ cap: cv2.VideoCapture, start: int, end: int, frame_step: int
) -> List[np.ndarray]:
"""
- Get a batch of video frames.
+ Gets a batch of video frames.
Args:
cap (cv2.VideoCapture): Video capture object.
@@ -730,231 +442,159 @@ def _get_video_frames_batch(
frames.append(frame)
return frames
- def _update_processing_history(self) -> None:
- """
- Update the processing history with the current run's information.
- """
- config_hash = ConfigHasher.calculate_hash(self.config)
- self.processing_history.add_run(
- timestamp=datetime.now().isoformat(),
- config_hash=config_hash,
- outputs_generated=self.processing_plan,
- )
- self.processing_history.save(self._get_history_file_path())
-
- def analyze_results(
- self, segmentation_data: h5py.Dataset, metadata: Dict[str, Any]
+ def _create_video(
+ self,
+ cap: cv2.VideoCapture,
+ segmentation_data: h5py.Dataset,
+ metadata: Dict[str, Any],
+ output_path: Path,
+ colored_only: bool,
) -> None:
"""
- Analyze segmentation results and generate statistics using chunked processing.
+ Creates a video from segmentation data.
Args:
- segmentation_data (h5py.Dataset): Memory-mapped segmentation data.
- metadata (Dict[str, Any]): Metadata dictionary.
+ cap (cv2.VideoCapture): Video capture object of the original video.
+ segmentation_data (h5py.Dataset): Segmentation data for all frames.
+ metadata (Dict[str, Any]): Metadata about the video and segmentation.
+ output_path (Path): Path to save the output video.
+ colored_only (bool): If True, create colored segmentation; if False, create overlay.
"""
- output_path = self.config.get_output_path()
- counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv")
- percentages_file = output_path.with_name(
- f"{output_path.stem}_category_percentages.csv"
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ out = cv2.VideoWriter(
+ str(output_path),
+ fourcc,
+ metadata["fps"],
+ (metadata["width"], metadata["height"]),
)
- id2label = metadata["label_ids"]
- headers = ["Frame"] + [id2label[i] for i in sorted(id2label.keys())]
-
- chunk_size = 100 # Adjust based on memory constraints
+ frame_index = 0
+ seg_index = 0
+ palette = np.array(metadata["palette"], dtype=np.uint8)
- with open(counts_file, "w", newline="") as cf, open(
- percentages_file, "w", newline=""
- ) as pf:
- counts_writer = csv.writer(cf)
- percentages_writer = csv.writer(pf)
- counts_writer.writerow(headers)
- percentages_writer.writerow(headers)
-
- for chunk_start in range(0, len(segmentation_data), chunk_size):
- chunk_end = min(chunk_start + chunk_size, len(segmentation_data))
- seg_chunk = self.get_segmentation_data_batch(chunk_start, chunk_end)
-
- for frame_idx, seg_map in enumerate(seg_chunk, start=chunk_start):
- analysis = self.analyze_segmentation_map(seg_map, len(id2label))
- frame_number = frame_idx * metadata["frame_step"]
+ with tqdm_context(
+ total=metadata["frame_count"],
+ desc=f"Generating {'colored' if colored_only else 'overlay'} video",
+ disable=self.config.disable_tqdm,
+ ) as pbar:
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
- counts_row = [frame_number] + [
- analysis[i][0] for i in sorted(analysis.keys())
- ]
- percentages_row = [frame_number] + [
- analysis[i][1] for i in sorted(analysis.keys())
- ]
+ if frame_index % metadata["frame_step"] == 0:
+ seg_map = segmentation_data[seg_index]
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ visualized = self.visualizer.visualize_segmentation(
+ frame_rgb, seg_map, palette, colored_only=colored_only
+ )
+ out.write(cv2.cvtColor(visualized, cv2.COLOR_RGB2BGR))
+ seg_index += 1
+ else:
+ out.write(frame)
- counts_writer.writerow(counts_row)
- percentages_writer.writerow(percentages_row)
+ frame_index += 1
+ pbar.update(1)
- self._generate_category_stats(
- counts_file, output_path.with_name(f"{output_path.stem}_counts_stats.csv")
- )
- self._generate_category_stats(
- percentages_file,
- output_path.with_name(f"{output_path.stem}_percentages_stats.csv"),
+ out.release()
+ logger.info(
+ f"{'Colored' if colored_only else 'Overlay'} video saved to {output_path}"
)
- @staticmethod
- def analyze_segmentation_map(
- seg_map: np.ndarray, num_categories: int
- ) -> Dict[int, tuple[int, float]]:
- """
- Analyze a segmentation map to compute pixel counts and percentages for each category.
-
- Args:
- seg_map (np.ndarray): The segmentation map to analyze.
- num_categories (int): The total number of categories in the segmentation.
-
- Returns:
- Dict[int, tuple[int, float]]: A dictionary where keys are category IDs and values
- are tuples of (pixel count, percentage) for each category.
+ def _update_processing_history(self) -> None:
"""
- unique, counts = np.unique(seg_map, return_counts=True)
- total_pixels = seg_map.size
- category_analysis = {i: (0, 0.0) for i in range(num_categories)}
-
- for category_id, pixel_count in zip(unique, counts):
- percentage = (pixel_count / total_pixels) * 100
- category_analysis[int(category_id)] = (int(pixel_count), float(percentage))
-
- return category_analysis
-
- def _generate_category_stats(self, input_file: Path, output_file: Path) -> None:
+ Updates the processing history JSON file with the current processing information.
"""
- Generate category statistics from input CSV file.
+ output_path = self.config.get_output_path()
+ history_file = output_path.with_name(
+ f"{output_path.stem}_processing_history.json"
+ )
- Args:
- input_file (Path): Path to the input CSV file (counts or percentages).
- output_file (Path): Path to save the output statistics CSV file.
- """
try:
- # Read the entire CSV file
- df = pd.read_csv(input_file)
+ if history_file.exists():
+ with open(history_file, "r") as f:
+ history = json.load(f)
+ else:
+ history = []
- # Exclude the 'Frame' column and calculate statistics
- category_columns = df.columns[1:]
- stats = df[category_columns].agg(["mean", "median", "std", "min", "max"])
+ current_entry = {
+ "timestamp": datetime.now().isoformat(),
+ "config_hash": ConfigHasher.calculate_hash(self.config),
+ "input_file": str(self.config.input),
+ "output_file": str(output_path),
+ }
- # Transpose the results for a more readable output
- stats = stats.transpose()
+ history.append(current_entry)
- # Save the statistics to the output file
- stats.to_csv(output_file)
+ with open(history_file, "w") as f:
+ json.dump(history, f, indent=2)
- self.logger.info(f"Category statistics saved to {output_file}")
+ logger.info(f"Processing history updated in {history_file}")
except Exception as e:
- self.logger.error(f"Error generating category stats: {str(e)}")
- raise
-
- def visualize_segmentation(
- self,
- images: Union[np.ndarray, List[np.ndarray]],
- seg_maps: Union[np.ndarray, List[np.ndarray]],
- palette: Optional[np.ndarray],
- colored_only: bool = False,
- ) -> Union[np.ndarray, List[np.ndarray]]:
- """
- Visualize the segmentation maps for multiple images.
-
- Args:
- images (Union[np.ndarray, List[np.ndarray]]): The original images or a single image.
- seg_maps (Union[np.ndarray, List[np.ndarray]]): The segmentation maps or a single segmentation map.
- palette (Optional[np.ndarray]): The color palette for visualization.
- colored_only (bool): If True, return only the colored segmentation maps.
-
- Returns:
- Union[np.ndarray, List[np.ndarray]]: The visualized segmentation maps or overlays.
- """
- if palette is None:
- palette = self._generate_palette(256) # Assuming max 256 classes
+ logger.error(f"Error updating processing history: {str(e)}")
- # Convert single image/seg_map to list for uniform processing
- if isinstance(images, np.ndarray) and images.ndim == 3:
- images = [images]
- seg_maps = [seg_maps]
- results = []
- for image, seg_map in zip(images, seg_maps):
- # Vectorized color application
- color_seg = palette[seg_map]
+class SegmentationProcessor:
+ """
+ Handles segmentation processing for both images and videos.
- if colored_only:
- results.append(color_seg)
- else:
- img = image * 0.5 + color_seg * 0.5
- results.append(img.astype(np.uint8))
+ This class serves as a facade for ImageProcessor and VideoProcessor,
+ delegating the processing based on the input type.
- return results[0] if len(results) == 1 else results
+ Attributes:
+ config (Config): Configuration object containing processing parameters.
+ image_processor (ImageProcessor): Processor for handling image inputs.
+ video_processor (VideoProcessor): Processor for handling video inputs.
+ """
- def _frame_generator(self, cap: cv2.VideoCapture) -> Iterator[List[Image.Image]]:
+ def __init__(self, config: Config):
"""
- Generate batches of frames from a video capture object.
+ Initializes the SegmentationProcessor with the given configuration.
Args:
- cap (cv2.VideoCapture): The video capture object.
-
- Yields:
- Iterator[List[Image.Image]]: Batches of frames as PIL Image objects.
+ config (Config): Configuration object for the processor.
"""
- while True:
- frames = []
- for _ in range(self.config.batch_size):
- for _ in range(self.config.frame_step):
- ret, frame = cap.read()
- if not ret:
- break
- if not ret:
- break
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- pil_image = Image.fromarray(rgb_frame)
- frames.append(pil_image)
- if not frames:
- break
- yield frames
+ self.config = config
+ self.image_processor = ImageProcessor(config)
+ self.video_processor = VideoProcessor(config)
+ logger.debug(f"SegmentationProcessor initialized with config: {config}")
- def _generate_palette(self, num_colors: int) -> np.ndarray:
+ def process(self):
"""
- Generate a color palette for visualization.
-
- Args:
- num_colors (int): The number of colors to generate.
+ Processes the input based on its type (image or video).
- Returns:
- np.ndarray: The generated color palette.
+ Raises:
+ ValueError: If the input type is not supported.
"""
-
- return np.array(
- [
- [(i * 100) % 255, (i * 150) % 255, (i * 200) % 255]
- for i in range(num_colors)
- ],
- dtype=np.uint8,
- )
+ if self.config.input_type == InputType.SINGLE_IMAGE:
+ self.image_processor.process()
+ elif self.config.input_type == InputType.SINGLE_VIDEO:
+ self.video_processor.process()
+ else:
+ raise ValueError(f"Unsupported input type: {self.config.input_type}")
class DirectoryProcessor:
"""
- A processor for handling multiple video files in a directory.
+ Processes multiple video files in a directory.
- This class manages the processing of multiple video files within a specified
- directory, utilizing the SegmentationProcessor for individual video processing.
+ This class handles the batch processing of video files found in a specified directory.
Attributes:
config (Config): Configuration object containing processing parameters.
- logger (Logger): Logger instance for tracking processing events.
+ video_iterator (VideoFileIterator): Iterator for video files in the directory.
+ logger: Logger instance for this processor.
"""
def __init__(self, config: Config):
"""
- Initialize the DirectoryProcessor.
+ Initializes the DirectoryProcessor with the given configuration.
Args:
- config (Config): Configuration object containing processing parameters.
+ config (Config): Configuration object for the processor.
"""
self.config = config
+ self.video_iterator = VideoFileIterator(config.input)
self.logger = logger.bind(
processor_type=self.__class__.__name__,
input_type=self.config.input_type.value,
@@ -965,24 +605,19 @@ def __init__(self, config: Config):
def process(self) -> None:
"""
- Process all video files in the specified directory.
+ Processes all video files in the specified directory.
- This method identifies all video files in the input directory,
- processes each video using SegmentationProcessor, and handles any errors
- that occur during processing.
+ This method iterates through all video files, processing each one
+ according to the configuration.
Raises:
- InputError: If no video files are found in the specified directory.
+ InputError: If no video files are found in the directory.
"""
- # Lazy import of tqdm_context
- from .utils import tqdm_context
-
self.logger.debug(
"Starting directory processing", input_path=str(self.config.input)
)
- video_files = self.get_video_files()
- if not video_files:
+ if not self.video_iterator.video_files:
self.logger.error("No video files found")
raise InputError(f"No video files found in directory: {self.config.input}")
@@ -991,8 +626,12 @@ def process(self) -> None:
f"Output directory set: {str(output_dir)}", output_dir=str(output_dir)
)
- with tqdm_context(total=len(video_files), desc="Processing videos") as pbar:
- for video_file in video_files:
+ with tqdm_context(
+ total=len(self.video_iterator.video_files),
+ desc="Processing videos",
+ disable=self.config.disable_tqdm,
+ ) as pbar:
+ for video_file in self.video_iterator:
if video_file.name in self.config.ignore_files:
self.logger.info(
f"Ignoring video file: {str(video_file.name)}",
@@ -1001,7 +640,7 @@ def process(self) -> None:
pbar.update(1)
continue
try:
- self.process_single_video(video_file, output_dir)
+ self._process_single_video(video_file, output_dir)
except Exception as e:
self.logger.error(
"Error processing video",
@@ -1016,20 +655,9 @@ def process(self) -> None:
"Finished processing all videos", input_directory=str(self.config.input)
)
- def get_video_files(self) -> List[Path]:
- """
- Get a list of video files in the input directory.
-
- Returns:
- List[Path]: A list of paths to video files.
- """
- video_files = get_video_files(self.config.input)
- self.logger.info(f"Found {len(video_files)} video files in {self.config.input}")
- return video_files
-
- def process_single_video(self, video_file: Path, output_dir: Path) -> None:
+ def _process_single_video(self, video_file: Path, output_dir: Path) -> None:
"""
- Process a single video file.
+ Processes a single video file.
Args:
video_file (Path): Path to the video file to process.
@@ -1038,7 +666,9 @@ def process_single_video(self, video_file: Path, output_dir: Path) -> None:
Raises:
ProcessingError: If an error occurs during video processing.
"""
- video_config = self.create_video_config(video_file, output_dir)
+ logger.debug("Creating video config...", video_file=str(video_file))
+ video_config = self._create_video_config(video_file, output_dir)
+ logger.debug("Video config created", video_config=video_config)
try:
processor = SegmentationProcessor(video_config)
@@ -1049,16 +679,16 @@ def process_single_video(self, video_file: Path, output_dir: Path) -> None:
)
raise ProcessingError(f"Error processing video {video_file}: {str(e)}")
- def create_video_config(self, video_file: Path, output_dir: Path) -> Config:
+ def _create_video_config(self, video_file: Path, output_dir: Path) -> Config:
"""
- Create a configuration object for processing a single video.
+ Creates a configuration object for processing a single video.
Args:
video_file (Path): Path to the video file.
output_dir (Path): Directory to save the processing results.
Returns:
- Config: A configuration object for the video processing.
+ Config: Configuration object for the video processor.
"""
return Config(
input=video_file,
@@ -1066,11 +696,13 @@ def create_video_config(self, video_file: Path, output_dir: Path) -> Config:
output_prefix=None,
model=self.config.model,
frame_step=self.config.frame_step,
+ batch_size=self.config.batch_size,
save_raw_segmentation=self.config.save_raw_segmentation,
save_colored_segmentation=self.config.save_colored_segmentation,
save_overlay=self.config.save_overlay,
visualization=self.config.visualization,
force_reprocess=self.config.force_reprocess,
+ disable_tqdm=self.config.disable_tqdm,
)
@@ -1078,14 +710,13 @@ def create_processor(
config: Config,
) -> Union[SegmentationProcessor, DirectoryProcessor]:
"""
- Create and return the appropriate processor based on the input type.
+ Creates and returns the appropriate processor based on the input type.
Args:
config (Config): Configuration object containing processing parameters.
Returns:
- Union[SegmentationProcessor, DirectoryProcessor]: An instance of either
- SegmentationProcessor or DirectoryProcessor, depending on the input type.
+ Union[SegmentationProcessor, DirectoryProcessor]: The appropriate processor instance.
"""
if config.input_type == InputType.DIRECTORY:
return DirectoryProcessor(config)
diff --git a/src/cityseg/segmentation_analyzer.py b/src/cityseg/segmentation_analyzer.py
new file mode 100644
index 0000000..46e3a2b
--- /dev/null
+++ b/src/cityseg/segmentation_analyzer.py
@@ -0,0 +1,148 @@
+"""
+This module provides a class for analyzing segmentation results.
+
+It includes methods to analyze segmentation maps, compute pixel counts and percentages
+for each category, and generate statistics for the analysis results.
+
+Classes:
+ SegmentationAnalyzer: A class for analyzing segmentation results.
+"""
+
+import csv
+from pathlib import Path
+from typing import Any, Dict
+
+import h5py
+import numpy as np
+import pandas as pd
+from loguru import logger
+
+from cityseg.utils import get_segmentation_data_batch
+
+
+class SegmentationAnalyzer:
+ """
+ A class for analyzing segmentation results.
+
+ This class provides methods to analyze segmentation maps, compute pixel counts and
+ percentages for each category, and generate statistics for the analysis results.
+
+ Methods:
+ analyze_segmentation_map: Analyzes a segmentation map to compute pixel counts and percentages.
+ analyze_results: Analyzes segmentation data and saves counts and percentages to CSV files.
+ generate_category_stats: Generates statistics for category counts or percentages.
+ """
+
+ @staticmethod
+ def analyze_segmentation_map(
+ seg_map: np.ndarray, num_categories: int
+ ) -> Dict[int, tuple[int, float]]:
+ """
+ Analyzes a segmentation map to compute pixel counts and percentages for each category.
+
+ Args:
+ seg_map (np.ndarray): The segmentation map to analyze.
+ num_categories (int): The total number of categories in the segmentation.
+
+ Returns:
+ Dict[int, Tuple[int, float]]: A dictionary where keys are category IDs and values
+ are tuples of (pixel count, percentage) for each category.
+ """
+ unique, counts = np.unique(seg_map, return_counts=True)
+ total_pixels = seg_map.size
+ category_analysis = {i: (0, 0.0) for i in range(num_categories)}
+
+ for category_id, pixel_count in zip(unique, counts):
+ percentage = (pixel_count / total_pixels) * 100
+ category_analysis[int(category_id)] = (int(pixel_count), float(percentage))
+
+ return category_analysis
+
+ @staticmethod
+ def analyze_results(
+ segmentation_data: h5py.Dataset, metadata: Dict[str, Any], output_path: Path
+ ) -> None:
+ """
+ Analyzes segmentation data and saves counts and percentages to CSV files.
+
+ This method processes the segmentation data in chunks, computes the analysis for each
+ frame, and writes the results to separate CSV files for counts and percentages.
+
+ Args:
+ segmentation_data (h5py.Dataset): The segmentation data to analyze.
+ metadata (Dict[str, Any]): Metadata containing label IDs and frame step.
+ output_path (Path): The path where the output CSV files will be saved.
+ """
+ counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv")
+ percentages_file = output_path.with_name(
+ f"{output_path.stem}_category_percentages.csv"
+ )
+
+ id2label = metadata["label_ids"]
+ headers = ["Frame"] + [id2label[i] for i in sorted(id2label.keys())]
+
+ chunk_size = 100 # Adjust based on memory constraints
+
+ with open(counts_file, "w", newline="") as cf, open(
+ percentages_file, "w", newline=""
+ ) as pf:
+ counts_writer = csv.writer(cf)
+ percentages_writer = csv.writer(pf)
+ counts_writer.writerow(headers)
+ percentages_writer.writerow(headers)
+
+ for chunk_start in range(0, len(segmentation_data), chunk_size):
+ chunk_end = min(chunk_start + chunk_size, len(segmentation_data))
+ seg_chunk = get_segmentation_data_batch(
+ segmentation_data, chunk_start, chunk_end
+ )
+
+ for frame_idx, seg_map in enumerate(seg_chunk, start=chunk_start):
+ analysis = SegmentationAnalyzer.analyze_segmentation_map(
+ seg_map, len(id2label)
+ )
+ frame_number = frame_idx * metadata["frame_step"]
+
+ counts_row = [frame_number] + [
+ analysis[i][0] for i in sorted(analysis.keys())
+ ]
+ percentages_row = [frame_number] + [
+ analysis[i][1] for i in sorted(analysis.keys())
+ ]
+
+ counts_writer.writerow(counts_row)
+ percentages_writer.writerow(percentages_row)
+
+ logger.info(f"Category counts saved to {counts_file}")
+ logger.info(f"Category percentages saved to {percentages_file}")
+
+ SegmentationAnalyzer.generate_category_stats(
+ counts_file, output_path.with_name(f"{output_path.stem}_counts_stats.csv")
+ )
+ SegmentationAnalyzer.generate_category_stats(
+ percentages_file,
+ output_path.with_name(f"{output_path.stem}_percentages_stats.csv"),
+ )
+
+ @staticmethod
+ def generate_category_stats(input_file: Path, output_file: Path) -> None:
+ """
+ Generates statistics for category counts or percentages.
+
+ This method reads the input CSV file, computes statistics (mean, median, std, min, max)
+ for each category, and saves the results to the specified output file.
+
+ Args:
+ input_file (Path): Path to the input CSV file containing category data.
+ output_file (Path): Path to save the generated statistics.
+ """
+ try:
+ df = pd.read_csv(input_file)
+ category_columns = df.columns[1:]
+ stats = df[category_columns].agg(["mean", "median", "std", "min", "max"])
+ stats = stats.transpose()
+ stats.to_csv(output_file)
+ logger.info(f"Category statistics saved to {output_file}")
+ except Exception as e:
+ logger.error(f"Error generating category stats: {str(e)}")
+ raise
diff --git a/src/cityseg/utils.py b/src/cityseg/utils.py
index 806c0fd..fa627cd 100644
--- a/src/cityseg/utils.py
+++ b/src/cityseg/utils.py
@@ -8,184 +8,29 @@
import sys
from contextlib import contextmanager
-from pathlib import Path
-from typing import Any, Dict, Iterator, List, Optional, Tuple
+from typing import Any, Iterator
-import cv2
+import h5py
import numpy as np
-import pandas as pd
from loguru import logger
-from tqdm.asyncio import tqdm
+from tqdm.auto import tqdm
-def analyze_segmentation_map(
- seg_map: np.ndarray, num_categories: int
-) -> Dict[int, Tuple[int, float]]:
+def get_segmentation_data_batch(
+ segmentation_data: h5py.Dataset, start: int, end: int
+) -> np.ndarray:
"""
- Analyze a segmentation map to compute pixel counts and percentages for each category.
+ Get a batch of segmentation data from the HDF5 file.
Args:
- seg_map (np.ndarray): The segmentation map to analyze.
- num_categories (int): The total number of categories in the segmentation.
+ segmentation_data:
+ start (int): Start index of the batch.
+ end (int): End index of the batch.
Returns:
- Dict[int, Tuple[int, float]]: A dictionary where keys are category IDs and values
- are tuples of (pixel count, percentage) for each category.
+ np.ndarray: A batch of segmentation data.
"""
- unique, counts = np.unique(seg_map, return_counts=True)
- total_pixels = seg_map.size
- category_analysis = {i: (0, 0.0) for i in range(num_categories)}
-
- for category_id, pixel_count in zip(unique, counts):
- percentage = (pixel_count / total_pixels) * 100
- category_analysis[int(category_id)] = (int(pixel_count), float(percentage))
-
- return category_analysis
-
-
-def generate_category_stats(csv_file_path: Path) -> pd.DataFrame:
- """
- Generate statistical summaries for each category from a CSV file.
-
- Args:
- csv_file_path (Path): Path to the CSV file containing category data.
-
- Returns:
- pd.DataFrame: A DataFrame containing statistical summaries for each category.
- """
- df = pd.read_csv(csv_file_path)
- category_columns = df.columns[1:]
-
- stats = []
- for category in category_columns:
- category_data = df[category]
- stats.append(
- {
- "Category": category,
- "Mean": category_data.mean(),
- "Median": category_data.median(),
- "Std Dev": category_data.std(),
- "Min": category_data.min(),
- "Max": category_data.max(),
- }
- )
-
- return pd.DataFrame(stats)
-
-
-def initialize_csv_files(
- output_prefix: Path, category_names: Dict[int, str]
-) -> Tuple[Path, Path]:
- """
- Initialize CSV files for storing category counts and percentages.
-
- Args:
- output_prefix (Path): The prefix for output file names.
- category_names (Dict[int, str]): A dictionary mapping category IDs to names.
-
- Returns:
- Tuple[Path, Path]: Paths to the created counts and percentages CSV files.
- """
- counts_file = output_prefix.with_name(f"{output_prefix.stem}_category_counts.csv")
- percentages_file = output_prefix.with_name(
- f"{output_prefix.stem}_category_percentages.csv"
- )
-
- headers = ["Frame"] + [category_names[i] for i in sorted(category_names.keys())]
-
- pd.DataFrame(columns=headers).to_csv(counts_file, index=False)
- pd.DataFrame(columns=headers).to_csv(percentages_file, index=False)
-
- return counts_file, percentages_file
-
-
-def append_to_csv_files(
- counts_file: Path,
- percentages_file: Path,
- frame_count: int,
- analysis: Dict[int, Tuple[int, float]],
-) -> None:
- """
- Append analysis results to the counts and percentages CSV files.
-
- Args:
- counts_file (Path): Path to the counts CSV file.
- percentages_file (Path): Path to the percentages CSV file.
- frame_count (int): The current frame count.
- analysis (Dict[int, Tuple[int, float]]): Analysis results for the current frame.
- """
- counts_row = [frame_count] + [analysis[i][0] for i in sorted(analysis.keys())]
- percentages_row = [frame_count] + [analysis[i][1] for i in sorted(analysis.keys())]
-
- pd.DataFrame([counts_row]).to_csv(counts_file, mode="a", header=False, index=False)
- pd.DataFrame([percentages_row]).to_csv(
- percentages_file, mode="a", header=False, index=False
- )
-
-
-def get_video_files(directory: Path) -> List[Path]:
- """
- Get a list of video files in the specified directory.
-
- Args:
- directory (Path): The directory to search for video files.
-
- Returns:
- List[Path]: A list of paths to video files found in the directory.
- """
- video_extensions = [".mp4", ".avi", ".mov"]
- video_files = []
- for ext in video_extensions:
- video_files.extend(directory.glob(f"*{ext}"))
- return video_files
-
-
-def save_segmentation_map(
- seg_map: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None
-) -> None:
- """
- Save a segmentation map as a NumPy array file.
-
- Args:
- seg_map (np.ndarray): The segmentation map to save.
- output_prefix (Path): The prefix for the output file name.
- frame_count (Optional[int]): The frame count, if applicable.
- """
- filename = f"{output_prefix.stem}_segmap{'_' + str(frame_count) if frame_count is not None else ''}.npy"
- np.save(output_prefix.with_name(filename), seg_map)
-
-
-def save_colored_segmentation(
- colored_seg: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None
-) -> None:
- """
- Save a colored segmentation map as an image file.
-
- Args:
- colored_seg (np.ndarray): The colored segmentation map to save.
- output_prefix (Path): The prefix for the output file name.
- frame_count (Optional[int]): The frame count, if applicable.
- """
- filename = f"{output_prefix.stem}_colored{'_' + str(frame_count) if frame_count is not None else ''}.png"
- cv2.imwrite(
- str(output_prefix.with_name(filename)),
- cv2.cvtColor(colored_seg, cv2.COLOR_RGB2BGR),
- )
-
-
-def save_overlay(
- overlay: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None
-) -> None:
- """
- Save an overlay image as a file.
-
- Args:
- overlay (np.ndarray): The overlay image to save.
- output_prefix (Path): The prefix for the output file name.
- frame_count (Optional[int]): The frame count, if applicable.
- """
- filename = f"{output_prefix.stem}_overlay{'_' + str(frame_count) if frame_count is not None else ''}.png"
- cv2.imwrite(str(output_prefix.with_name(filename)), overlay)
+ return segmentation_data[start:end]
def setup_logging(log_level: str, verbose: bool = False) -> None:
diff --git a/src/cityseg/video_file_iterator.py b/src/cityseg/video_file_iterator.py
new file mode 100644
index 0000000..a3300eb
--- /dev/null
+++ b/src/cityseg/video_file_iterator.py
@@ -0,0 +1,66 @@
+"""
+This module provides an iterator class for iterating over video files in a specified directory.
+
+It retrieves and stores video files from the given input path and provides an iterator
+interface to access these files.
+
+Classes:
+ VideoFileIterator: An iterator class for iterating over video files in a specified directory.
+"""
+
+from pathlib import Path
+from typing import Iterator, List
+
+from loguru import logger
+
+
+class VideoFileIterator:
+ """
+ An iterator class for iterating over video files in a specified directory.
+
+ This class retrieves and stores video files from the given input path and
+ provides an iterator interface to access these files.
+
+ Attributes:
+ input_path (Path): The path to the directory containing video files.
+ video_files (List[Path]): A list of video file paths found in the input directory.
+ """
+
+ def __init__(self, input_path: Path):
+ """
+ Initializes the VideoFileIterator with the specified input path.
+
+ Args:
+ input_path (Path): The path to the directory containing video files.
+ """
+ self.input_path = input_path
+ self.video_files = self._get_video_files()
+
+ def _get_video_files(self) -> List[Path]:
+ """
+ Retrieves a list of video files from the input directory.
+
+ This method uses the utility function `get_video_files` to find all video files
+ in the specified input path and logs the number of files found.
+
+ Returns:
+ List[Path]: A list of paths to the video files found in the input directory.
+ """
+ video_extensions = [".mp4", ".avi", ".mov"]
+ video_files = [
+ f for f in self.input_path.glob("*") if f.suffix.lower() in video_extensions
+ ]
+ logger.info(f"Found {len(video_files)} video files in {self.input_path}")
+ return list(video_files)
+
+ def __iter__(self) -> Iterator[Path]:
+ """
+ Returns an iterator over the video files.
+
+ This method allows the VideoFileIterator to be used in a for-loop or any context
+ that requires an iterable.
+
+ Returns:
+ Iterator[Path]: An iterator over the video file paths.
+ """
+ return iter(self.video_files)
diff --git a/src/cityseg/visualization_handler.py b/src/cityseg/visualization_handler.py
new file mode 100644
index 0000000..6da0af0
--- /dev/null
+++ b/src/cityseg/visualization_handler.py
@@ -0,0 +1,102 @@
+"""
+This module provides a class for visualizing segmentation results using color palettes.
+
+It includes methods to visualize segmentation maps with color palettes and options for
+displaying colored or blended results.
+
+Classes:
+ VisualizationHandler: A class for visualizing segmentation results using color palettes.
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+from loguru import logger
+
+
+class VisualizationHandler:
+ """
+ A class for visualizing segmentation results using color palettes.
+
+ This class provides methods to visualize segmentation maps with color palettes
+ and options for displaying colored or blended results.
+
+ Methods:
+ visualize_segmentation: Visualizes segmentation results with color palettes.
+ _generate_palette: Generates a color palette for visualization.
+ """
+
+ @staticmethod
+ def visualize_segmentation(
+ images: Union[np.ndarray, List[np.ndarray]],
+ seg_maps: Union[np.ndarray, List[np.ndarray]],
+ palette: Optional[np.ndarray] = None,
+ colored_only: bool = False,
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """
+ Visualizes segmentation results using color palettes.
+
+ This method takes input images and their corresponding segmentation maps,
+ applies the specified color palette, and returns the visualized results.
+
+ Args:
+ images (Union[np.ndarray, List[np.ndarray]]): Input images or a list of images.
+ seg_maps (Union[np.ndarray, List[np.ndarray]]): Segmentation maps or a list of maps.
+ palette (Optional[np.ndarray]): Color palette for visualization. If None, a default palette is generated.
+ colored_only (bool): Flag to indicate if only colored results are desired (True) or blended with the original images (False).
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: Visualized segmentation results, either as a single array or a list of arrays.
+ """
+ logger.debug(
+ f"Visualizing segmentation for {len(images) if isinstance(images, list) else 1} images"
+ )
+ if palette is None:
+ palette = VisualizationHandler._generate_palette(256)
+ if isinstance(palette, list):
+ palette = np.array(palette, dtype=np.uint8)
+
+ if isinstance(images, np.ndarray) and images.ndim == 3:
+ images = [images]
+ seg_maps = [seg_maps]
+
+ results = []
+ for image, seg_map in zip(images, seg_maps):
+ color_seg = palette[seg_map]
+
+ if colored_only:
+ results.append(color_seg)
+ else:
+ img = image * 0.5 + color_seg * 0.5
+ results.append(img.astype(np.uint8))
+
+ return results[0] if len(results) == 1 else results
+
+ @staticmethod
+ def _generate_palette(num_colors: int) -> np.ndarray:
+ """
+ Generates a color palette for visualization.
+
+ This method creates a color palette with a specified number of colors,
+ which can be used to visualize segmentation results.
+
+ Args:
+ num_colors (int): Number of colors to generate in the palette.
+
+ Returns:
+ np.ndarray: Color palette array for visualization, with shape (num_colors, 3).
+ """
+ from .palettes import ADE20K_PALETTE
+
+ if num_colors < len(ADE20K_PALETTE):
+ logger.debug(f"Using ADE20K palette with {num_colors} colors")
+ return np.array(ADE20K_PALETTE[:num_colors], dtype=np.uint8)
+ else:
+ logger.debug(f"Generating custom palette for {num_colors} colors")
+ return np.array(
+ [
+ [(i * 100) % 255, (i * 150) % 255, (i * 200) % 255]
+ for i in range(num_colors)
+ ],
+ dtype=np.uint8,
+ )
diff --git a/tests/test_file_handler.py b/tests/test_file_handler.py
new file mode 100644
index 0000000..1af1c29
--- /dev/null
+++ b/tests/test_file_handler.py
@@ -0,0 +1,105 @@
+import json
+from unittest.mock import MagicMock, patch
+
+import h5py
+import numpy as np
+import pytest
+
+from cityseg.config import Config
+from cityseg.file_handler import FileHandler
+
+
+@pytest.fixture
+def temp_hdf_file(tmp_path):
+ file_path = tmp_path / "test.hdf5"
+ yield file_path
+ if file_path.exists():
+ file_path.unlink()
+
+
+@pytest.fixture
+def temp_video_file(tmp_path):
+ file_path = tmp_path / "test.mp4"
+ file_path.touch()
+ yield file_path
+ if file_path.exists():
+ file_path.unlink()
+
+
+def test_saves_hdf_file_correctly(temp_hdf_file):
+ segmentation_data = np.random.rand(10, 10)
+ metadata = {"frame_step": 1, "palette": np.array([1, 2, 3])}
+ FileHandler.save_hdf_file(temp_hdf_file, segmentation_data, metadata)
+ with h5py.File(temp_hdf_file, "r") as f:
+ assert "segmentation" in f
+ assert "metadata" in f
+ assert np.array_equal(f["segmentation"], segmentation_data)
+ loaded_metadata = json.loads(f["metadata"][()])
+ assert loaded_metadata["frame_step"] == 1
+ assert loaded_metadata["palette"] == [1, 2, 3]
+
+
+def test_loads_hdf_file_correctly(temp_hdf_file):
+ segmentation_data = np.random.rand(10, 10)
+ metadata = {"frame_step": 1, "palette": [1, 2, 3]}
+ with h5py.File(temp_hdf_file, "w") as f:
+ f.create_dataset("segmentation", data=segmentation_data)
+ f.create_dataset("metadata", data=json.dumps(metadata))
+ hdf_file, loaded_metadata = FileHandler.load_hdf_file(temp_hdf_file)
+ assert np.array_equal(hdf_file["segmentation"], segmentation_data)
+ assert loaded_metadata["frame_step"] == 1
+ assert np.array_equal(loaded_metadata["palette"], np.array([1, 2, 3]))
+
+
+def test_verifies_hdf_file_correctly(temp_hdf_file):
+ segmentation_data = np.random.rand(10, 10)
+ metadata = {"frame_step": 1}
+ with h5py.File(temp_hdf_file, "w") as f:
+ f.create_dataset("segmentation", data=segmentation_data)
+ f.create_dataset("metadata", data=json.dumps(metadata))
+ mock_config = MagicMock(spec=Config)
+ mock_config.frame_step = 1
+ assert FileHandler.verify_hdf_file(temp_hdf_file, mock_config) is True
+
+
+def test_fails_verification_for_invalid_hdf_file(temp_hdf_file):
+ segmentation_data = np.random.rand(10, 10)
+ metadata = {"frame_step": 2}
+ with h5py.File(temp_hdf_file, "w") as f:
+ f.create_dataset("segmentation", data=segmentation_data)
+ f.create_dataset("metadata", data=json.dumps(metadata))
+ mock_config = MagicMock(spec=Config)
+ mock_config.frame_step = 1
+ assert FileHandler.verify_hdf_file(temp_hdf_file, mock_config) is False
+
+
+def test_verifies_video_file_correctly(temp_video_file):
+ with patch("cv2.VideoCapture") as mock_capture:
+ mock_capture.return_value.isOpened.return_value = True
+ mock_capture.return_value.read.side_effect = [
+ (True, np.zeros((10, 10, 3))),
+ (True, np.zeros((10, 10, 3))),
+ ]
+ assert FileHandler.verify_video_file(temp_video_file) is True
+
+
+def test_fails_verification_for_invalid_video_file(temp_video_file):
+ with patch("cv2.VideoCapture") as mock_capture:
+ mock_capture.return_value.isOpened.return_value = False
+ assert FileHandler.verify_video_file(temp_video_file) is False
+
+
+def test_verifies_analysis_files_correctly(tmp_path):
+ counts_file = tmp_path / "counts.txt"
+ percentages_file = tmp_path / "percentages.txt"
+ counts_file.write_text("data")
+ percentages_file.write_text("data")
+ assert FileHandler.verify_analysis_files(counts_file, percentages_file) is True
+
+
+def test_fails_verification_for_empty_analysis_files(tmp_path):
+ counts_file = tmp_path / "counts.txt"
+ percentages_file = tmp_path / "percentages.txt"
+ counts_file.touch()
+ percentages_file.touch()
+ assert FileHandler.verify_analysis_files(counts_file, percentages_file) is False
diff --git a/tests/test_processing_plan.py b/tests/test_processing_plan.py
new file mode 100644
index 0000000..b1bc16c
--- /dev/null
+++ b/tests/test_processing_plan.py
@@ -0,0 +1,101 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from cityseg.config import Config
+from cityseg.processing_plan import ProcessingPlan
+
+
+@pytest.fixture
+def mock_config():
+ config = MagicMock(spec=Config)
+ config.force_reprocess = False
+ config.save_colored_segmentation = True
+ config.save_overlay = True
+ config.analyze_results = True
+ config.get_output_path.return_value = MagicMock()
+ return config
+
+
+def test_force_reprocess_enabled(mock_config):
+ """
+ Given force reprocessing is enabled
+ When creating a processing plan
+ Then all processing steps should be executed
+ """
+ mock_config.force_reprocess = True
+ processing_plan = ProcessingPlan(mock_config)
+
+ expected_plan = {
+ "process_video": True,
+ "generate_hdf": True,
+ "generate_colored_video": True,
+ "generate_overlay_video": True,
+ "analyze_results": True,
+ }
+
+ assert processing_plan.plan == expected_plan
+
+
+def test_existing_outputs_invalid(mock_config):
+ """
+ Given force reprocessing is disabled
+ And existing outputs are invalid
+ When creating a processing plan
+ Then all processing steps should be executed
+ """
+ mock_config.force_reprocess = False
+
+ # Mock the _check_existing_outputs method before creating the ProcessingPlan object
+ ProcessingPlan._check_existing_outputs = MagicMock(
+ return_value={
+ "hdf_file_valid": False,
+ "colored_video_valid": False,
+ "overlay_video_valid": False,
+ "analysis_files_valid": False,
+ }
+ )
+
+ processing_plan = ProcessingPlan(mock_config)
+
+ expected_plan = {
+ "process_video": True,
+ "generate_hdf": True,
+ "generate_colored_video": True,
+ "generate_overlay_video": True,
+ "analyze_results": True,
+ }
+
+ assert processing_plan.plan == expected_plan
+
+
+def test_existing_outputs_valid(mock_config):
+ """
+ Given force reprocessing is disabled
+ And existing outputs are valid
+ When creating a processing plan
+ Then no processing steps should be executed
+ """
+ mock_config.force_reprocess = False
+
+ # Mock the _check_existing_outputs method before creating the ProcessingPlan object
+ ProcessingPlan._check_existing_outputs = MagicMock(
+ return_value={
+ "hdf_file_valid": True,
+ "colored_video_valid": True,
+ "overlay_video_valid": True,
+ "analysis_files_valid": True,
+ }
+ )
+
+ processing_plan = ProcessingPlan(mock_config)
+
+ expected_plan = {
+ "process_video": False,
+ "generate_hdf": False,
+ "generate_colored_video": False,
+ "generate_overlay_video": False,
+ "analyze_results": False,
+ }
+
+ assert processing_plan.plan == expected_plan
diff --git a/tests/test_segmentation_analyzer.py b/tests/test_segmentation_analyzer.py
new file mode 100644
index 0000000..f1e4bd0
--- /dev/null
+++ b/tests/test_segmentation_analyzer.py
@@ -0,0 +1,172 @@
+from typing import Any, Dict
+from unittest.mock import mock_open, patch
+
+import numpy as np
+import pandas as pd
+import pytest
+from cityseg.segmentation_analyzer import SegmentationAnalyzer
+
+
+class TestSegmentationAnalyzer:
+ """
+ Test suite for the SegmentationAnalyzer class.
+
+ This class contains tests to verify the correct behavior of the SegmentationAnalyzer
+ methods, including segmentation map analysis, results analysis, and statistics generation.
+ """
+
+ @pytest.fixture
+ def sample_segmentation_map(self) -> np.ndarray:
+ """
+ Fixture that provides a sample segmentation map for testing.
+
+ Returns:
+ np.ndarray: A 2D numpy array representing a sample segmentation map.
+ """
+ return np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
+
+ @pytest.fixture
+ def sample_metadata(self) -> Dict[str, Any]:
+ """
+ Fixture that provides sample metadata for testing.
+
+ Returns:
+ Dict[str, Any]: A dictionary containing sample metadata.
+ """
+ return {
+ "label_ids": {0: "background", 1: "foreground", 2: "edge"},
+ "frame_step": 5,
+ }
+
+ def test_analyze_segmentation_map(self, sample_segmentation_map):
+ """
+ Test the analyze_segmentation_map method.
+
+ This test verifies that the method correctly computes pixel counts
+ and percentages for each category in the segmentation map.
+
+ Args:
+ sample_segmentation_map (np.ndarray): The sample segmentation map fixture.
+ """
+ result = SegmentationAnalyzer.analyze_segmentation_map(
+ sample_segmentation_map, 3
+ )
+ expected = {0: (3, 100 / 3), 1: (3, 100 / 3), 2: (3, 100 / 3)}
+
+ assert len(result) == len(expected), "Number of categories does not match"
+ for category in expected:
+ assert category in result, f"Category {category} is missing from the result"
+ assert (
+ result[category][0] == expected[category][0]
+ ), f"Pixel count for category {category} does not match"
+ assert np.isclose(
+ result[category][1], expected[category][1], rtol=1e-9
+ ), f"Percentage for category {category} is not close enough"
+
+ @patch("cityseg.segmentation_analyzer.get_segmentation_data_batch")
+ @patch("cityseg.segmentation_analyzer.open", new_callable=mock_open)
+ @patch("cityseg.segmentation_analyzer.csv.writer")
+ @patch("cityseg.segmentation_analyzer.logger")
+ @patch("cityseg.segmentation_analyzer.pd.read_csv")
+ @patch("cityseg.segmentation_analyzer.SegmentationAnalyzer.generate_category_stats")
+ def test_analyze_results(
+ self,
+ mock_generate_stats,
+ mock_read_csv,
+ mock_logger,
+ mock_csv_writer,
+ mock_open,
+ mock_get_data,
+ sample_metadata,
+ tmp_path,
+ ):
+ """
+ Test the analyze_results method.
+
+ This test verifies that the method correctly processes segmentation data,
+ writes results to CSV files, and generates statistics.
+
+ Args:
+ mock_generate_stats: Mocked generate_category_stats method.
+ mock_read_csv: Mocked pandas read_csv function.
+ mock_logger: Mocked logger object.
+ mock_csv_writer: Mocked CSV writer object.
+ mock_open: Mocked open function.
+ mock_get_data: Mocked get_segmentation_data_batch function.
+ sample_metadata (Dict[str, Any]): The sample metadata fixture.
+ tmp_path (Path): Pytest fixture for a temporary directory path.
+ """
+ mock_segmentation_data = mock_get_data.return_value
+ mock_segmentation_data.__len__.return_value = 10
+ mock_get_data.return_value = np.array([[[0, 1], [1, 2]]] * 10)
+
+ # Mock the DataFrame that would be read from the CSV
+ mock_df = pd.DataFrame(
+ {
+ "Frame": [0, 5],
+ "background": [50, 60],
+ "foreground": [30, 25],
+ "edge": [20, 15],
+ }
+ )
+ mock_read_csv.return_value = mock_df
+
+ output_path = tmp_path / "test_output.h5"
+ SegmentationAnalyzer.analyze_results(
+ mock_segmentation_data, sample_metadata, output_path
+ )
+
+ assert mock_open.call_count == 2, "Should open two files for writing"
+ assert mock_csv_writer.call_count == 2, "Should create two CSV writers"
+ assert (
+ mock_generate_stats.call_count == 2
+ ), "Should call generate_category_stats twice"
+
+ @patch("cityseg.segmentation_analyzer.pd.read_csv")
+ @patch("cityseg.segmentation_analyzer.logger")
+ def test_generate_category_stats(self, mock_logger, mock_read_csv, tmp_path):
+ """
+ Test the generate_category_stats method.
+
+ This test verifies that the method correctly reads input data,
+ computes statistics, and saves the results.
+
+ Args:
+ mock_logger: Mocked logger object.
+ mock_read_csv: Mocked pandas read_csv function.
+ tmp_path (Path): Pytest fixture for a temporary directory path.
+ """
+ mock_df = pd.DataFrame(
+ {
+ "Frame": [0, 5, 10],
+ "background": [50, 60, 70],
+ "foreground": [30, 25, 20],
+ "edge": [20, 15, 10],
+ }
+ )
+ mock_read_csv.return_value = mock_df
+
+ input_file = tmp_path / "input.csv"
+ output_file = tmp_path / "output.csv"
+
+ SegmentationAnalyzer.generate_category_stats(input_file, output_file)
+
+ mock_read_csv.assert_called_once_with(input_file)
+ assert output_file.exists(), "Output file should be created"
+ mock_logger.info.assert_called_once()
+
+ def test_generate_category_stats_error_handling(self, tmp_path):
+ """
+ Test error handling in the generate_category_stats method.
+
+ This test verifies that the method properly handles and logs errors
+ when processing fails.
+
+ Args:
+ tmp_path (Path): Pytest fixture for a temporary directory path.
+ """
+ non_existent_file = tmp_path / "non_existent.csv"
+ output_file = tmp_path / "output.csv"
+
+ with pytest.raises(Exception):
+ SegmentationAnalyzer.generate_category_stats(non_existent_file, output_file)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..6506187
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,317 @@
+import os
+import sys
+import tempfile
+from typing import Generator
+from unittest.mock import patch
+
+import h5py
+import numpy as np
+import pytest
+from cityseg.utils import get_segmentation_data_batch, setup_logging, tqdm_context
+from tqdm.auto import tqdm
+
+
+@pytest.fixture
+def temp_hdf5_file() -> Generator[str, None, None]:
+ """
+ Fixture to create a temporary HDF5 file for testing.
+
+ Yields:
+ str: Path to the temporary HDF5 file.
+ """
+ with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp:
+ tmp_path = tmp.name
+ yield tmp_path
+ os.unlink(tmp_path)
+
+
+class TestGetSegmentationDataBatch:
+ """Tests for the get_segmentation_data_batch function."""
+
+ def test_retrieves_correct_batch_of_segmentation_data(
+ self, temp_hdf5_file: str
+ ) -> None:
+ """
+ Test that the function retrieves the correct batch of segmentation data.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(100, 100)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("segmentation", data=data)
+ batch = get_segmentation_data_batch(dset, 10, 20)
+ assert batch.shape == (10, 100)
+ assert np.array_equal(batch, data[10:20])
+
+ def test_handles_empty_segmentation_data_batch(self, temp_hdf5_file: str) -> None:
+ """
+ Test that the function correctly handles an empty batch request.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(100, 100)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("segmentation", data=data)
+ batch = get_segmentation_data_batch(dset, 0, 0)
+ assert batch.shape == (0, 100)
+
+ def test_handles_out_of_bounds_segmentation_data_batch(
+ self, temp_hdf5_file: str
+ ) -> None:
+ """
+ Test that the function correctly handles out-of-bounds batch requests.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(100, 100)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("segmentation", data=data)
+ batch = get_segmentation_data_batch(dset, 90, 110)
+ assert batch.shape == (10, 100)
+ assert np.array_equal(batch, data[90:100])
+
+ def test_with_different_data_types(self, temp_hdf5_file: str) -> None:
+ """
+ Test the function with different data types (int and float).
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ int_data = np.random.randint(0, 100, (100, 100))
+ float_data = np.random.rand(100, 100)
+
+ with h5py.File(temp_hdf5_file, "w") as f:
+ int_dset = f.create_dataset("int_segmentation", data=int_data)
+ float_dset = f.create_dataset("float_segmentation", data=float_data)
+
+ int_batch = get_segmentation_data_batch(int_dset, 10, 20)
+ float_batch = get_segmentation_data_batch(float_dset, 10, 20)
+
+ assert int_batch.dtype == np.int64
+ assert float_batch.dtype == np.float64
+ assert np.array_equal(int_batch, int_data[10:20])
+ assert np.array_equal(float_batch, float_data[10:20])
+
+ def test_with_multi_dimensional_data(self, temp_hdf5_file: str) -> None:
+ """
+ Test the function with multi-dimensional data.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(
+ 100, 50, 50, 3
+ ) # 4D data (frames, height, width, channels)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("multi_dim_segmentation", data=data)
+ batch = get_segmentation_data_batch(dset, 10, 20)
+ assert batch.shape == (10, 50, 50, 3)
+ assert np.array_equal(batch, data[10:20])
+
+ def test_single_element_batch(self, temp_hdf5_file: str) -> None:
+ """
+ Test the function with a single-element batch.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(100, 100)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("segmentation", data=data)
+ batch = get_segmentation_data_batch(dset, 10, 11)
+ assert batch.shape == (1, 100)
+ assert np.array_equal(batch, data[10:11])
+
+
+class TestTqdmContext:
+ """Tests for the tqdm_context function."""
+
+ def test_handles_empty_progress_bar(self) -> None:
+ """Test that the function correctly handles an empty progress bar."""
+ with tqdm_context(total=0) as progress_bar:
+ assert isinstance(progress_bar, tqdm)
+ assert progress_bar.total == 0
+
+ def test_handles_non_empty_progress_bar(self) -> None:
+ """Test that the function correctly handles a non-empty progress bar."""
+ with tqdm_context(total=100) as progress_bar:
+ assert isinstance(progress_bar, tqdm)
+ assert progress_bar.total == 100
+
+ def test_handles_progress_bar_with_updates(self) -> None:
+ """Test that the function correctly handles progress bar updates."""
+ with tqdm_context(total=100) as progress_bar:
+ progress_bar.update(10)
+ assert progress_bar.n == 10
+ progress_bar.update(20)
+ assert progress_bar.n == 30
+
+ def test_handles_progress_bar_with_exception(self) -> None:
+ """Test that the function correctly handles exceptions within the context."""
+ with pytest.raises(ValueError, match="Test exception"):
+ with tqdm_context(total=100):
+ raise ValueError("Test exception")
+
+ def test_with_large_total_value(self) -> None:
+ """Test that the function handles a very large total value."""
+ large_total = 10**10
+ with tqdm_context(total=large_total) as progress_bar:
+ assert progress_bar.total == large_total
+ progress_bar.update(large_total // 2)
+ assert progress_bar.n == large_total // 2
+
+
+class TestSetupLogging:
+ """
+ Test suite for the setup_logging function in utils.py.
+
+ This class contains tests to verify the correct behavior of the logging setup,
+ including console and file logging configurations, different log levels,
+ and formatting options.
+ """
+
+ @pytest.fixture
+ def mock_logger(self):
+ """
+ Fixture that provides a mock logger for testing.
+
+ Returns:
+ MockLogger: A mock object that simulates the behavior of the loguru logger.
+ """
+
+ class MockLogger:
+ def __init__(self):
+ self.handlers = []
+
+ def remove(self, handler_id=None):
+ print("MockLogger: remove() called")
+
+ def add(self, sink, **kwargs):
+ print(f"MockLogger: add() called with sink={sink}, kwargs={kwargs}")
+ self.handlers.append((sink, kwargs))
+
+ def info(self, message):
+ print(f"MockLogger: info() called with message: {message}")
+
+ return MockLogger()
+
+ @patch("cityseg.utils.logger")
+ def test_basic_setup(self, mock_logger):
+ """
+ Test the basic setup of logging with default parameters.
+
+ Args:
+ mock_logger: The mocked logger object.
+ """
+ setup_logging("INFO")
+ assert mock_logger.remove.called, "logger.remove() should be called"
+ assert (
+ len(mock_logger.add.call_args_list) == 2
+ ), "Should add console and file handlers"
+
+ # Verify console handler
+ console_call = mock_logger.add.call_args_list[0]
+ assert console_call[0][0] == sys.stderr, "Console handler should use sys.stderr"
+ assert console_call[1]["level"] == "INFO", "Console level should be INFO"
+
+ # Verify file handler
+ file_call = mock_logger.add.call_args_list[1]
+ assert (
+ file_call[0][0] == "segmentation.log"
+ ), "File handler should use 'segmentation.log'"
+ assert file_call[1]["level"] == "INFO", "File level should be INFO"
+
+ @patch("cityseg.utils.logger")
+ def test_verbose_mode(self, mock_logger):
+ """
+ Test the logging setup in verbose mode.
+
+ Args:
+ mock_logger: The mocked logger object.
+ """
+ setup_logging("INFO", verbose=True)
+ console_call = mock_logger.add.call_args_list[0]
+ assert (
+ console_call[1]["level"] == "DEBUG"
+ ), "Console level should be DEBUG in verbose mode"
+
+ @patch("cityseg.utils.logger")
+ def test_different_log_levels(self, mock_logger):
+ """
+ Test the logging setup with different log levels.
+
+ Args:
+ mock_logger: The mocked logger object.
+ """
+ setup_logging("DEBUG")
+ console_call = mock_logger.add.call_args_list[0]
+ file_call = mock_logger.add.call_args_list[1]
+ assert console_call[1]["level"] == "DEBUG", "Console level should be DEBUG"
+ assert (
+ file_call[1]["level"] == "DEBUG"
+ ), "File level should be DEBUG (minimum of input and INFO)"
+
+ setup_logging("WARNING")
+ console_call = mock_logger.add.call_args_list[2]
+ file_call = mock_logger.add.call_args_list[3]
+ assert console_call[1]["level"] == "WARNING", "Console level should be WARNING"
+ assert (
+ file_call[1]["level"] == "INFO"
+ ), "File level should be INFO (minimum of input and INFO)"
+
+ @patch("cityseg.utils.logger")
+ def test_file_logging_config(self, mock_logger):
+ """
+ Test the file logging configuration.
+
+ Args:
+ mock_logger: The mocked logger object.
+ """
+ setup_logging("INFO")
+ file_call = mock_logger.add.call_args_list[1]
+ assert file_call[1]["rotation"] == "100 MB", "File should rotate at 100 MB"
+ assert (
+ file_call[1]["retention"] == "1 week"
+ ), "File should be retained for 1 week"
+ assert file_call[1]["serialize"] is True, "File logging should be serialized"
+
+ @patch("cityseg.utils.logger")
+ def test_console_logging_format(self, mock_logger):
+ """
+ Test the console logging format.
+
+ Args:
+ mock_logger: The mocked logger object.
+ """
+ setup_logging("INFO")
+ console_call = mock_logger.add.call_args_list[0]
+ format_string = console_call[1]["format"]
+ assert "{time:YYYY-MM-DD HH:mm:ss}" in format_string
+ assert "{level: <8}" in format_string
+ assert (
+ "{name}:{function}:{line}"
+ in format_string
+ )
+
+
+def test_integration_segmentation_with_tqdm(temp_hdf5_file: str) -> None:
+ """
+ Integration test for using get_segmentation_data_batch within a tqdm_context.
+
+ Args:
+ temp_hdf5_file (str): Path to temporary HDF5 file.
+ """
+ data = np.random.rand(100, 100)
+ with h5py.File(temp_hdf5_file, "w") as f:
+ dset = f.create_dataset("segmentation", data=data)
+
+ with tqdm_context(total=len(data), desc="Processing") as pbar:
+ for i in range(0, len(data), 10):
+ batch = get_segmentation_data_batch(dset, i, min(i + 10, len(data)))
+ assert batch.shape[0] <= 10
+ pbar.update(len(batch))
+
+ assert pbar.n == len(data)
diff --git a/tests/test_video_file_iterator.py b/tests/test_video_file_iterator.py
new file mode 100644
index 0000000..48e9a23
--- /dev/null
+++ b/tests/test_video_file_iterator.py
@@ -0,0 +1,56 @@
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+
+from cityseg.video_file_iterator import VideoFileIterator
+
+
+@pytest.fixture
+def mock_video_files():
+ # Create mock Path objects for video files
+ return [Path("video1.mp4"), Path("video2.mp4")]
+
+
+@pytest.fixture
+def mock_directory():
+ # Create a mock Path object for the directory
+ return Path("/mock/directory")
+
+
+def test_video_file_iterator_initializes_with_video_files(
+ mock_video_files, mock_directory
+):
+ with patch.object(Path, "glob", return_value=mock_video_files):
+ iterator = VideoFileIterator(mock_directory)
+ assert iterator.video_files == mock_video_files
+
+
+def test_video_file_iterator_yields_video_files(mock_video_files, mock_directory):
+ with patch.object(Path, "glob", return_value=mock_video_files):
+ iterator = VideoFileIterator(mock_directory)
+ video_files_list = list(iterator)
+ assert video_files_list == mock_video_files
+
+
+def test_video_file_iterator_handles_empty_directory(mock_directory):
+ with patch.object(Path, "glob", return_value=[]):
+ iterator = VideoFileIterator(mock_directory)
+ assert iterator.video_files == []
+ assert list(iterator) == []
+
+
+def test_video_file_iterator_handles_non_mp4_files(mock_directory):
+ non_video_files = [Path("file1.txt"), Path("file2.doc")]
+ with patch.object(Path, "glob", return_value=non_video_files):
+ iterator = VideoFileIterator(mock_directory)
+ assert iterator.video_files == []
+ assert list(iterator) == []
+
+
+def test_video_file_iterator_handles_mixed_files(mock_directory):
+ mixed_files = [Path("video1.mp4"), Path("file1.txt"), Path("video2.avi")]
+ with patch.object(Path, "glob", return_value=mixed_files):
+ iterator = VideoFileIterator(mock_directory)
+ assert iterator.video_files == [Path("video1.mp4"), Path("video2.avi")]
+ assert list(iterator) == [Path("video1.mp4"), Path("video2.avi")]
diff --git a/tests/test_visualization_handler.py b/tests/test_visualization_handler.py
new file mode 100644
index 0000000..16ebdcd
--- /dev/null
+++ b/tests/test_visualization_handler.py
@@ -0,0 +1,73 @@
+import numpy as np
+
+from cityseg.visualization_handler import VisualizationHandler
+
+
+def test_visualize_single_image_with_default_palette():
+ image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
+ seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
+ result = VisualizationHandler.visualize_segmentation(image, seg_map)
+ assert result.shape == image.shape
+
+
+def test_visualize_single_image_with_custom_palette():
+ image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
+ seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
+ palette = np.random.randint(0, 255, (256, 3), dtype=np.uint8)
+ result = VisualizationHandler.visualize_segmentation(image, seg_map, palette)
+ assert result.shape == image.shape
+
+
+def test_visualize_single_image_colored_only():
+ image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
+ seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
+ result = VisualizationHandler.visualize_segmentation(
+ image, seg_map, colored_only=True
+ )
+ assert result.shape == image.shape
+
+
+def test_visualize_multiple_images_with_default_palette():
+ images = [
+ np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3)
+ ]
+ seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)]
+ results = VisualizationHandler.visualize_segmentation(images, seg_maps)
+ assert len(results) == 3
+ for result, image in zip(results, images):
+ assert result.shape == image.shape
+
+
+def test_visualize_multiple_images_with_custom_palette():
+ images = [
+ np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3)
+ ]
+ seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)]
+ palette = np.random.randint(0, 255, (256, 3), dtype=np.uint8)
+ results = VisualizationHandler.visualize_segmentation(images, seg_maps, palette)
+ assert len(results) == 3
+ for result, image in zip(results, images):
+ assert result.shape == image.shape
+
+
+def test_visualize_multiple_images_colored_only():
+ images = [
+ np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3)
+ ]
+ seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)]
+ results = VisualizationHandler.visualize_segmentation(
+ images, seg_maps, colored_only=True
+ )
+ assert len(results) == 3
+ for result, image in zip(results, images):
+ assert result.shape == image.shape
+
+
+def test_generate_palette_with_less_colors_than_default():
+ palette = VisualizationHandler._generate_palette(10)
+ assert palette.shape == (10, 3)
+
+
+def test_generate_palette_with_more_colors_than_default():
+ palette = VisualizationHandler._generate_palette(300)
+ assert palette.shape == (300, 3)