diff --git a/.gitignore b/.gitignore index a4ea5ff..a78a9e1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ wheels/ # venv .venv example_inputs +.idea +*.log +.vscode # IDE files .idea/ @@ -19,3 +22,4 @@ src/.DS_Store # in progress / trials scripts/ +/.vscode/settings.json diff --git a/README.md b/README.md index d54485a..c6589c1 100644 --- a/README.md +++ b/README.md @@ -86,33 +86,55 @@ visualization: ## Models -CitySeg currently supports OneFormer models. The verified models include: +CitySeg currently supports Mask2Former and BEIT models. The verified models include: -- `shi-labs/oneformer_ade20k_swin_large` -- `shi-labs/oneformer_cityscapes_swin_large` -- `shi-labs/oneformer_ade20k_dinat_large` -- `shi-labs/oneformer_cityscapes_dinat_large` +- "facebook/mask2former-swin-large-cityscapes-semantic" +- "facebook/mask2former-swin-large-mapillary-vistas-semantic" +- "facebook/maskformer-swin-small-ade" (sort of, this often leads to segfaults. Recommend using `disable_tqdm` in the config.) +- "microsoft/beit-large-finetuned-ade-640-640" +- "nvidia/segformer-b5-finetuned-cityscapes-1024-1024" +- "zoheb/mit-b5-finetuned-sidewalk-semantic" +- "nickmuchi/segformer-b4-finetuned-segments-sidewalk" + +Mask2Former are by far the most stable. + +Some models which seem to load correctly but continually produce segfault errors on my machine are: + +- "facebook/maskformer-swin-large-ade" +- "nvidia/segformer-b5-finetuned-ade-640-640" +- "nvidia/segformer-b0-finetuned-cityscapes-1024-1024" +- "zoheb/mit-b5-finetuned-sidewalk-semantic" (use `model_type: segformer` in the config) + +Confirmed not to work due to issues with the Hugging Face pipeline: + +- "shi-labs/oneformer_ade20k_dinat_large" **Note on `dinat` models:** The `dinat` backbone models require the `natten` package, which may have installation issues on some systems. These models are also significantly slower than the `swin` backbone models, especially when forced to run on CPU. However, they may produce better quality outputs in some cases. ## Project Structure -The project is organized into several Python modules: +The project is organized into several Python modules, each serving a specific purpose within the CitySeg pipeline: + +- `main.py`: Entry point of the application, responsible for initializing and running the segmentation pipeline. +- `config.py`: Defines configuration classes and handles loading and validating configuration settings. +- `pipeline.py`: Implements the core segmentation pipeline, including model loading and inference. +- `processors.py`: Contains classes for processing images, videos, and directories, managing the segmentation workflow. +- `segmentation_analyzer.py`: Provides functionality for analyzing segmentation results, including computing statistics and generating reports. +- `video_file_iterator.py`: Implements an iterator for efficiently processing multiple video files in a directory. +- `visualization_handler.py`: Handles the visualization of segmentation results using color palettes. +- `file_handler.py`: Manages file operations related to saving and loading segmentation data and metadata. +- `utils.py`: Provides utility functions for various tasks, including data handling and logging. +- `palettes.py`: Defines color palettes for different datasets used in segmentation. +- `exceptions.py`: Custom exception classes for error handling throughout the pipeline. -- `main.py`: Entry point of the application -- `config.py`: Defines configuration classes for the pipeline -- `pipeline.py`: Implements the core segmentation pipeline -- `processors.py`: Contains classes for processing images, videos, and directories -- `utils.py`: Provides utility functions for analysis, file operations, and logging -- `palettes.py`: Defines color palettes for different datasets -- `exceptions.py`: Custom exception classes for error handling +This modular structure allows for easy maintenance and extension of the CitySeg pipeline, facilitating the addition of new features and models. ## Logging The pipeline uses the `loguru` library for flexible and configurable logging. You can set the log level and enable verbose output using command-line arguments: ``` -python main.py --config path/to/your/config.yaml --log-level INFO --verbose +python main.py --config path/to/your/config.yaml --log-level INFO # or DEBUG, WARNING, ERROR, CRITICAL or --verbose ``` Logs are output to both the console and a file (`segmentation.log`). The file log is in JSON format for easy parsing and analysis. @@ -136,4 +158,4 @@ CitySeg is released under the BSD 3-Clause License. See the `LICENSE` file for d ## Contact -For support or inquiries, please open an issue on the GitHub repository or contact [Your Name/Email]. \ No newline at end of file +For support or inquiries, please open an issue on the GitHub repository or contact Andrew Mitchell. \ No newline at end of file diff --git a/docs/api/handlers.md b/docs/api/handlers.md new file mode 100644 index 0000000..1f5884f --- /dev/null +++ b/docs/api/handlers.md @@ -0,0 +1,10 @@ +# Handlers + +::: cityseg.file_handler + + +::: cityseg.visualization_handler + options: + show_root_heading: true + members: true + parameter_headings: true diff --git a/docs/api/processors.md b/docs/api/processors.md index 10f3e05..adb361c 100644 --- a/docs/api/processors.md +++ b/docs/api/processors.md @@ -1,6 +1,13 @@ # Processors Module ::: cityseg.processors - options: - members: true - parameter_headings: true + +::: cityseg.processing_plan.ProcessingPlan + options: + show_root_heading: true + show_root_full_path: false + +::: cityseg.video_file_iterator.VideoFileIterator + options: + show_root_heading: true + show_root_full_path: false \ No newline at end of file diff --git a/docs/api/segments_analysis.md b/docs/api/segments_analysis.md new file mode 100644 index 0000000..a81a25a --- /dev/null +++ b/docs/api/segments_analysis.md @@ -0,0 +1 @@ +::: cityseg.segmentation_analyzer \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 76abae4..7dfb1a3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,12 +32,19 @@ For more detailed information on how to use CitySeg, check out our [Getting Star ## Project Structure -CitySeg is organized into several Python modules: +The project is organized into several Python modules, each serving a specific purpose within the CitySeg pipeline: + +- `main.py`: Entry point of the application, responsible for initializing and running the segmentation pipeline. +- `config.py`: Defines configuration classes and handles loading and validating configuration settings. +- `pipeline.py`: Implements the core segmentation pipeline, including model loading and inference. +- `processors.py`: Contains classes for processing images, videos, and directories, managing the segmentation workflow. +- `segmentation_analyzer.py`: Provides functionality for analyzing segmentation results, including computing statistics and generating reports. +- `video_file_iterator.py`: Implements an iterator for efficiently processing multiple video files in a directory. +- `visualization_handler.py`: Handles the visualization of segmentation results using color palettes. +- `file_handler.py`: Manages file operations related to saving and loading segmentation data and metadata. +- `utils.py`: Provides utility functions for various tasks, including data handling and logging. +- `palettes.py`: Defines color palettes for different datasets used in segmentation. +- `exceptions.py`: Custom exception classes for error handling throughout the pipeline. -- `config.py`: Configuration classes for the pipeline -- `pipeline.py`: Core segmentation pipeline implementation -- `processors.py`: Classes for processing images, videos, and directories -- `utils.py`: Utility functions for analysis, file operations, and logging -- `exceptions.py`: Custom exception classes for error handling For detailed API documentation, visit our [API Reference](api/config.md) section. \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 4445aa2..88122bb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,6 +42,8 @@ nav: - Config: api/config.md - Pipeline: api/pipeline.md - Processors: api/processors.md + - Segments Analysis: api/segments_analysis.md + - Handlers: api/handlers.md - Utils: api/utils.md - Palettes: api/palettes.md - Exceptions: api/exceptions.md @@ -104,6 +106,7 @@ plugins: merge_init_into_class: true ignore_init_summary: true show_labels: false + parameter_headings: true show_if_no_docstring: false docstring_section_style: spacy diff --git a/pyproject.toml b/pyproject.toml index e110bff..e0c1b74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cityseg" -version = "0.2.11" +version = "0.3.0" description = "A flexible and efficient semantic segmentation pipeline for processing images and videos" authors = [ { name = "Andrew Mitchell", email = "mitchellacoustics15@gmail.com" } @@ -42,9 +42,8 @@ dev-dependencies = [ "jupyter>=1.0.0", "pytest>=8.3.2", "rich[jupyter]>=13.7.1", - "natten==0.17.1", + # "natten==0.17.1", "matplotlib>=3.9.1", - "yappi>=1.6.0", ] [tool.hatch.metadata] diff --git a/requirements-dev.lock b/requirements-dev.lock index 546a618..da34caa 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -230,7 +230,6 @@ mpmath==1.3.0 # via sympy mypy-extensions==1.0.0 # via black -natten==0.17.1 nbclient==0.10.0 # via nbconvert nbconvert==7.16.4 @@ -280,7 +279,6 @@ packaging==24.1 # via jupyterlab-server # via matplotlib # via mkdocs - # via natten # via nbconvert # via pytest # via qtconsole @@ -422,7 +420,6 @@ tokenizers==0.19.1 # via transformers torch==2.4.0 # via cityseg - # via natten # via torchvision torchvision==0.19.0 # via cityseg @@ -479,4 +476,3 @@ websocket-client==1.8.0 # via jupyter-server widgetsnbextension==4.0.11 # via ipywidgets -yappi==1.6.0 diff --git a/src/cityseg/SemanticSidewalk_id2label.json b/src/cityseg/SemanticSidewalk_id2label.json new file mode 100644 index 0000000..7c6f74f --- /dev/null +++ b/src/cityseg/SemanticSidewalk_id2label.json @@ -0,0 +1 @@ +{"0": "unlabeled", "1": "flat-road", "2": "flat-sidewalk", "3": "flat-crosswalk", "4": "flat-cyclinglane", "5": "flat-parkingdriveway", "6": "flat-railtrack", "7": "flat-curb", "8": "human-person", "9": "human-rider", "10": "vehicle-car", "11": "vehicle-truck", "12": "vehicle-bus", "13": "vehicle-tramtrain", "14": "vehicle-motorcycle", "15": "vehicle-bicycle", "16": "vehicle-caravan", "17": "vehicle-cartrailer", "18": "construction-building", "19": "construction-door", "20": "construction-wall", "21": "construction-fenceguardrail", "22": "construction-bridge", "23": "construction-tunnel", "24": "construction-stairs", "25": "object-pole", "26": "object-trafficsign", "27": "object-trafficlight", "28": "nature-vegetation", "29": "nature-terrain", "30": "sky", "31": "void-ground", "32": "void-dynamic", "33": "void-static", "34": "void-unclear"} \ No newline at end of file diff --git a/src/cityseg/__init__.py b/src/cityseg/__init__.py index 5ef8d97..74437b4 100644 --- a/src/cityseg/__init__.py +++ b/src/cityseg/__init__.py @@ -3,8 +3,7 @@ This package provides a flexible and efficient semantic segmentation pipeline for processing images and videos. It supports multiple segmentation models -and datasets, with capabilities for tiling large images, mixed-precision -processing, and comprehensive result analysis. +and datasets. Main components: - Config: Configuration class for the pipeline @@ -20,27 +19,36 @@ For detailed usage instructions, please refer to the package documentation. """ -__version__ = "0.2.0" +__version__ = "0.2.12" from . import palettes from .config import Config from .exceptions import ConfigurationError, InputError, ModelError, ProcessingError +from .file_handler import FileHandler from .pipeline import SegmentationPipeline, create_segmentation_pipeline +from .processing_plan import ProcessingPlan from .processors import DirectoryProcessor, SegmentationProcessor, create_processor -from .utils import analyze_segmentation_map, setup_logging +from .segmentation_analyzer import SegmentationAnalyzer +from .utils import setup_logging +from .video_file_iterator import VideoFileIterator +from .visualization_handler import VisualizationHandler __all__ = [ "Config", "SegmentationPipeline", "create_segmentation_pipeline", "SegmentationProcessor", + "SegmentationAnalyzer", "DirectoryProcessor", "create_processor", "ConfigurationError", "InputError", "ModelError", "ProcessingError", - "analyze_segmentation_map", "setup_logging", "palettes", + "FileHandler", + "VisualizationHandler", + "ProcessingPlan", + "VideoFileIterator", ] diff --git a/src/cityseg/config.py b/src/cityseg/config.py index 8cb08df..b465c01 100644 --- a/src/cityseg/config.py +++ b/src/cityseg/config.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Union import yaml +from loguru import logger class InputType(Enum): @@ -36,14 +37,45 @@ class ModelConfig: """ name: str - model_type: Optional[str] = ( - None # Can be 'oneformer', 'mask2former', or None for auto-detection - ) + model_type: Optional[str] = None max_size: Optional[int] = None device: Optional[str] = None + dataset: Optional[str] = None + num_workers: Optional[int] = 8 + pipe_batch: Optional[int] = 1 - # TODO: impelement model_type auto-detection - # TODO: implement device auto-detection + def __post_init__(self): + """ + Post-initialization method to set up the model type if not provided. + """ + self.auto_detect_model_type() + if self.device == "mps" and self.num_workers > 0 or self.num_workers is None: + logger.warning( + "MPS is not compatible with multiple workers in pytorch. Setting num_workers to 0." + ) + self.num_workers = 0 + + def auto_detect_model_type(self): + """ + Automatically detect the model type from the model name if not provided. + """ + + def auto_model_type(model_name: str) -> str: + return model_name.split("/")[-1].split("-")[0] + + if self.model_type is None: + try: + self.model_type = auto_model_type(self.name) + except IndexError: + logger.warning( + "Unable to auto-detect model type from the model name and none provided." + ) + return + logger.info(f"Auto-detected model type: {self.model_type}") + elif self.model_type != auto_model_type(self.name): + logger.warning( + f"Model type does not match auto-detected model type. Using provided model type: {self.model_type}" + ) @dataclass @@ -83,6 +115,7 @@ class Config: visualization (VisualizationConfig): The visualization configuration. input_type (InputType): The type of input (automatically determined). force_reprocess (bool): Whether to force reprocessing of existing results. + disable_tqdm (bool): Whether to disable the progress bar display. """ input: Union[Path, str] @@ -100,6 +133,7 @@ class Config: visualization: VisualizationConfig = field(default_factory=VisualizationConfig) input_type: InputType = field(init=False) force_reprocess: bool = False + disable_tqdm: bool = False def __post_init__(self): """ @@ -232,6 +266,7 @@ def from_yaml(cls, config_path: Path) -> "Config": analyze_results=config_dict.get("analyze_results", True), visualization=vis_config, force_reprocess=config_dict.get("force_reprocess", False), + disable_tqdm=config_dict.get("disable_tqdm", False), ) def to_dict(self) -> Dict[str, Any]: @@ -257,6 +292,7 @@ def to_dict(self) -> Dict[str, Any]: "visualization": asdict(self.visualization), "input_type": self.input_type.value, "force_reprocess": self.force_reprocess, + "disable_tqdm": self.disable_tqdm, } diff --git a/src/cityseg/config.yaml b/src/cityseg/config.yaml index 40fa552..197857b 100644 --- a/src/cityseg/config.yaml +++ b/src/cityseg/config.yaml @@ -6,21 +6,24 @@ ignore_files: null # Optional: list of file names to ignore # Model configuration model: - name: "facebook/mask2former-swin-large-mapillary-vistas-semantic" - model_type: "mask2former" # Optional: can be 'oneformer', 'mask2former', or null for auto-detection + name: "facebook/mask2former-swin-large-cityscapes-semantic" + model_type: null # Optional: can be 'beit', 'mask2former', or null for auto-detection max_size: null # Optional: maximum size for input images/frames device: "mps" # Options: "cuda", "cpu", "mps", or null for auto-detection + dataset: "semantic-sidewalk" # Optional: dataset name for model-specific postprocessing + num_workers: 0 # Number of workers for data loading + pipe_batch: 5 # Number of frames to process in each batch. Recommend setting this equal to batch_size below. # Processing configuration -frame_step: 10 # Process every 5th frame +frame_step: 1 # Process every 5th frame batch_size: 5 # Number of frames to process in each batch output_fps: null # Optional: FPS for output video (if different from input) # Output options -save_raw_segmentation: false +save_raw_segmentation: true save_colored_segmentation: true save_overlay: true -analyze_results: false +analyze_results: true # Visualization configuration visualization: @@ -29,3 +32,4 @@ visualization: # Advanced options force_reprocess: false # Set to true to reprocess even if output files exist +disable_tqdm: false # Set to true to disable progress bars. In some cases, tqdm seems to lead to segfaults. \ No newline at end of file diff --git a/src/cityseg/file_handler.py b/src/cityseg/file_handler.py new file mode 100644 index 0000000..1317020 --- /dev/null +++ b/src/cityseg/file_handler.py @@ -0,0 +1,186 @@ +""" +This module provides a class for handling file operations related to segmentation data. + +It includes functionalities for saving and loading segmentation data in HDF files, +verifying the integrity of HDF and video files, and checking the validity of analysis files. + +Classes: + FileHandler: A class for handling file operations related to segmentation data and metadata. +""" + +import json +from pathlib import Path +from typing import Any, Dict, Tuple + +import cv2 +import h5py +import numpy as np +from loguru import logger + +from .config import Config + + +class FileHandler: + """ + A class for handling file operations related to segmentation data and metadata. + + This class provides methods for saving and loading segmentation data in HDF files, + verifying the integrity of HDF and video files, and checking analysis files. + + Methods: + save_hdf_file: Saves segmentation data and metadata to an HDF file. + load_hdf_file: Loads segmentation data and metadata from an HDF file. + verify_hdf_file: Verifies the integrity of an HDF file. + verify_video_file: Verifies the integrity of a video file. + verify_analysis_files: Verifies the analysis files for counts and percentages. + """ + + @staticmethod + def save_hdf_file( + file_path: Path, segmentation_data: np.ndarray, metadata: Dict[str, Any] + ) -> None: + """ + Saves segmentation data and metadata to an HDF file. + + Args: + file_path (Path): Path to the HDF file. + segmentation_data (np.ndarray): Segmentation data to be saved. + metadata (Dict[str, Any]): Metadata associated with the segmentation data. + """ + with h5py.File(file_path, "w") as f: + f.create_dataset("segmentation", data=segmentation_data, compression="gzip") + if "palette" in metadata and isinstance(metadata["palette"], np.ndarray): + metadata["palette"] = metadata["palette"].tolist() + json_metadata = json.dumps(metadata) + f.create_dataset("metadata", data=json_metadata) + + @staticmethod + def load_hdf_file(file_path: Path) -> Tuple[h5py.File, Dict[str, Any]]: + """ + Loads segmentation data and metadata from an HDF file. + + Args: + file_path (Path): Path to the HDF file. + + Returns: + Tuple[h5py.File, Dict[str, Any]]: Loaded HDF file and metadata. + """ + hdf_file = h5py.File(file_path, "r") + json_metadata = hdf_file["metadata"][()] + metadata = json.loads(json_metadata) + if "palette" in metadata and isinstance(metadata["palette"], list): + metadata["palette"] = np.array(metadata["palette"], np.uint8) + return hdf_file, metadata + + @staticmethod + def verify_hdf_file(file_path: Path, config: Config) -> bool: + """ + Verifies the integrity of an HDF file. + + Args: + file_path (Path): Path to the HDF file. + config (Config): Configuration object for comparison. + + Returns: + bool: True if the HDF file is valid and up-to-date, False otherwise. + """ + try: + with h5py.File(file_path, "r") as f: + if "segmentation" not in f or "metadata" not in f: + logger.warning( + f"HDF file at {file_path} is missing required datasets" + ) + return False + + json_metadata = f["metadata"][()] + metadata = json.loads(json_metadata) + + if metadata.get("frame_step") != config.frame_step: + logger.warning( + f"HDF file frame step ({metadata.get('frame_step')}) does not match current config ({config.frame_step})" + ) + return False + + segmentation_data = f["segmentation"] + if len(segmentation_data) == 0: + logger.warning( + f"HDF file at {file_path} contains no segmentation data" + ) + return False + + first_frame = segmentation_data[0] + last_frame = segmentation_data[-1] + if first_frame.shape != last_frame.shape: + logger.warning( + f"Inconsistent frame shapes in HDF file at {file_path}" + ) + return False + + logger.debug(f"HDF file at {file_path} is valid and up-to-date") + return True + except Exception as e: + logger.error(f"Error verifying HDF file at {file_path}: {str(e)}") + return False + + @staticmethod + def verify_video_file(file_path: Path) -> bool: + """ + Verifies the integrity of a video file. + + Args: + file_path (Path): Path to the video file. + + Returns: + bool: True if the video file is valid and up-to-date, False otherwise. + """ + try: + cap = cv2.VideoCapture(str(file_path)) + if not cap.isOpened(): + logger.warning(f"Unable to open video file at {file_path}") + return False + + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret, first_frame = cap.read() + if not ret: + logger.warning( + f"Unable to read first frame from video file at {file_path}" + ) + return False + + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) + ret, last_frame = cap.read() + if not ret: + logger.warning( + f"Unable to read last frame from video file at {file_path}" + ) + return False + + cap.release() + logger.debug(f"Video file at {file_path} is valid and up-to-date") + return True + except Exception as e: + logger.error(f"Error verifying video file at {file_path}: {str(e)}") + return False + + @staticmethod + def verify_analysis_files(counts_file: Path, percentages_file: Path) -> bool: + """ + Verifies the analysis files for counts and percentages. + + Args: + counts_file (Path): Path to the counts file. + percentages_file (Path): Path to the percentages file. + + Returns: + bool: True if the analysis files are valid, False otherwise. + """ + try: + if counts_file.stat().st_size == 0 or percentages_file.stat().st_size == 0: + logger.info("One or both analysis files are empty") + return False + return True + except Exception as e: + logger.error(f"Error verifying analysis files: {str(e)}") + return False diff --git a/src/cityseg/main.py b/src/cityseg/main.py index d51d475..6f2321a 100644 --- a/src/cityseg/main.py +++ b/src/cityseg/main.py @@ -75,5 +75,7 @@ def main() -> None: import warnings warnings.filterwarnings("ignore") + + # import sys # sys.argv = ["", "config.yaml", "--verbose"] main() diff --git a/src/cityseg/pipeline.py b/src/cityseg/pipeline.py index 4a23074..f6093a1 100644 --- a/src/cityseg/pipeline.py +++ b/src/cityseg/pipeline.py @@ -6,19 +6,29 @@ create detailed segmentation maps with associated metadata. """ +import json +import logging +import warnings from typing import Any, Dict, List, Optional, Union import numpy as np import torch +from loguru import logger from PIL.Image import Image from transformers import ( AutoImageProcessor, AutoModelForSemanticSegmentation, + AutoProcessor, + BeitForSemanticSegmentation, ImageSegmentationPipeline, Mask2FormerForUniversalSegmentation, - OneFormerProcessor, + MaskFormerForInstanceSegmentation, + OneFormerForUniversalSegmentation, + SegformerForSemanticSegmentation, ) +from .config import ModelConfig + class SegmentationPipeline(ImageSegmentationPipeline): """ @@ -86,8 +96,9 @@ def create_single_segmentation_map( "palette": self.palette, } + @staticmethod def _is_single_image_result( - self, result: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]] + result: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], ) -> bool: """ Determine if the result is for a single image or multiple images. @@ -127,8 +138,9 @@ def __call__( Returns: List[Dict[str, Any]]: A list of dictionaries containing segmentation maps and metadata. """ - result = super().__call__(images, **kwargs) - + # logger.debug("Pass image(s) up to HF pipeline...") + result = super().__call__(images, subtask="semantic", **kwargs) + # logger.debug("Received result from HF pipeline") if self._is_single_image_result(result): return [ self.create_single_segmentation_map( @@ -144,8 +156,9 @@ def __call__( ] +@logger.catch def create_segmentation_pipeline( - model_name: str, device: Optional[str] = None, **kwargs: Any + config: ModelConfig, **kwargs: Any ) -> SegmentationPipeline: """ Create and return a SegmentationPipeline instance based on the specified model. @@ -154,23 +167,63 @@ def create_segmentation_pipeline( model name, and creates a SegmentationPipeline instance with these components. Args: - model_name (str): The name or path of the pre-trained model to use. - device (Optional[str]): The device to use for processing (e.g., "cpu", "cuda"). If None, it will be automatically determined. + config: **kwargs: Additional keyword arguments to pass to the SegmentationPipeline constructor. Returns: SegmentationPipeline: An instance of the SegmentationPipeline class. """ + model_name = config.name + model_type = config.model_type + device = config.device + dataset = config.dataset + if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) # Initialize the appropriate model and image processor based on the model name - if "oneformer" in model_name.lower(): - model = AutoModelForSemanticSegmentation.from_pretrained(model_name) - image_processor = OneFormerProcessor.from_pretrained(model_name) - elif "mask2former" in model_name.lower(): + if "oneformer" == model_type: + warnings.warn( + "OneFormer models are experimental and may not be fully supported" + ) + try: + model = OneFormerForUniversalSegmentation.from_pretrained(model_name) + image_processor = AutoProcessor.from_pretrained(model_name) + except ValueError as e: + logger.error(f"Error loading model: {e}") + + elif "mask2former" == model_type: model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name) image_processor = AutoImageProcessor.from_pretrained(model_name) + + elif "maskformer" == model_type: + model = MaskFormerForInstanceSegmentation.from_pretrained(model_name) + image_processor = AutoImageProcessor.from_pretrained(model_name) + + elif "beit" == model_type: + if device != "cpu": + logger.warning( + "Beit models are not supported on GPU and will be loaded on CPU" + ) + device = "cpu" + model = BeitForSemanticSegmentation.from_pretrained(model_name) + image_processor = AutoImageProcessor.from_pretrained(model_name) + + elif "segformer" == model_type: + model = SegformerForSemanticSegmentation.from_pretrained(model_name) + image_processor = AutoImageProcessor.from_pretrained(model_name) + + if dataset == "sidewalk-semantic": + logging.debug("Loading Sidewalk Semantic dataset label mappings...") + with open("SemanticSidewalk_id2label.json") as f: + id2label = json.load(f) + model.config.id2label = id2label else: model = AutoModelForSemanticSegmentation.from_pretrained(model_name) image_processor = AutoImageProcessor.from_pretrained(model_name) @@ -180,5 +233,6 @@ def create_segmentation_pipeline( image_processor=image_processor, device=device, subtask="semantic", + num_workers=config.num_workers, **kwargs, ) diff --git a/src/cityseg/processing_plan.py b/src/cityseg/processing_plan.py new file mode 100644 index 0000000..c5f0d55 --- /dev/null +++ b/src/cityseg/processing_plan.py @@ -0,0 +1,123 @@ +""" +This module defines a class for creating and managing the processing plan for video segmentation tasks. + +It determines which processing steps need to be executed based on the configuration +and the existence of previously generated outputs. + +Classes: + ProcessingPlan: A class to create and manage the processing plan for video segmentation tasks. +""" + +from typing import Dict + +from loguru import logger + +from .config import Config +from .file_handler import FileHandler + + +class ProcessingPlan: + """ + A class to create and manage the processing plan for video segmentation tasks. + + This class determines which processing steps need to be executed based on the + configuration and the existence of previously generated outputs. + + Attributes: + config (Config): Configuration object containing processing parameters. + plan (Dict[str, bool]): A dictionary indicating which processing steps to execute. + """ + + def __init__(self, config: Config): + """ + Initializes the ProcessingPlan with the given configuration. + + Args: + config (Config): Configuration object for the processing plan. + """ + self.config = config + self.plan = self._create_processing_plan() + + def _create_processing_plan(self) -> Dict[str, bool]: + """ + Creates a processing plan based on the configuration and existing outputs. + + This method checks if force reprocessing is enabled or if existing outputs + are valid to determine which processing steps should be executed. + + Returns: + Dict[str, bool]: A dictionary indicating which processing steps to execute. + """ + if self.config.force_reprocess: + logger.info("Force reprocessing enabled. All steps will be executed.") + return { + "process_video": True, + "generate_hdf": True, + "generate_colored_video": self.config.save_colored_segmentation, + "generate_overlay_video": self.config.save_overlay, + "analyze_results": self.config.analyze_results, + } + + existing_outputs = self._check_existing_outputs() + + plan = { + "process_video": not existing_outputs["hdf_file_valid"], + "generate_hdf": not existing_outputs["hdf_file_valid"], + "generate_colored_video": self.config.save_colored_segmentation + and not existing_outputs["colored_video_valid"], + "generate_overlay_video": self.config.save_overlay + and not existing_outputs["overlay_video_valid"], + "analyze_results": self.config.analyze_results + and not existing_outputs["analysis_files_valid"], + } + + logger.debug(f"Created processing plan: {plan}") + return plan + + def _check_existing_outputs(self) -> Dict[str, bool]: + """ + Checks the validity of existing output files. + + This method verifies the existence and validity of the HDF file, colored video, + overlay video, and analysis files. + + Returns: + Dict[str, bool]: A dictionary indicating the validity of existing outputs. + """ + output_path = self.config.get_output_path() + hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5") + colored_video_path = output_path.with_name(f"{output_path.stem}_colored.mp4") + overlay_video_path = output_path.with_name(f"{output_path.stem}_overlay.mp4") + counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv") + percentages_file = output_path.with_name( + f"{output_path.stem}_category_percentages.csv" + ) + + results = { + "hdf_file_valid": False, + "colored_video_valid": False, + "overlay_video_valid": False, + "analysis_files_valid": False, + } + + if hdf_path.exists(): + results["hdf_file_valid"] = FileHandler.verify_hdf_file( + hdf_path, self.config + ) + + if colored_video_path.exists(): + results["colored_video_valid"] = FileHandler.verify_video_file( + colored_video_path + ) + + if overlay_video_path.exists(): + results["overlay_video_valid"] = FileHandler.verify_video_file( + overlay_video_path + ) + + if counts_file.exists() and percentages_file.exists(): + results["analysis_files_valid"] = FileHandler.verify_analysis_files( + counts_file, percentages_file + ) + + return results diff --git a/src/cityseg/processors.py b/src/cityseg/processors.py index fbf2753..f165073 100644 --- a/src/cityseg/processors.py +++ b/src/cityseg/processors.py @@ -13,465 +13,248 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Tuple, Union import cv2 import h5py import numpy as np -import pandas as pd from loguru import logger from PIL import Image from .config import Config, ConfigHasher, InputType from .exceptions import InputError, ProcessingError +from .file_handler import FileHandler from .pipeline import create_segmentation_pipeline -from .utils import ( - get_video_files, -) +from .processing_plan import ProcessingPlan +from .segmentation_analyzer import SegmentationAnalyzer +from .utils import get_segmentation_data_batch, tqdm_context +from .video_file_iterator import VideoFileIterator +from .visualization_handler import VisualizationHandler -class ProcessingHistory: +class ImageProcessor: """ - A class to manage and persist the processing history of video segmentation tasks. + Processes individual images using semantic segmentation models. - This class keeps track of individual processing runs, including timestamps, - configuration hashes, and which outputs were generated in each run. + This class handles the segmentation of single images, including saving results + and analyzing the segmentation output. Attributes: - runs (List[Dict]): A list of dictionaries, each representing a processing run. + config (Config): Configuration object containing processing parameters. + pipeline: Segmentation pipeline for processing images. + file_handler (FileHandler): Handles file operations. + visualizer (VisualizationHandler): Handles visualization of segmentation results. + analyzer (SegmentationAnalyzer): Analyzes segmentation results. """ - def __init__(self): - """Initialize an empty ProcessingHistory.""" - self.runs = [] - - def add_run( - self, timestamp: str, config_hash: str, outputs_generated: Dict[str, bool] - ) -> None: - """ - Add a new processing run to the history. - - Args: - timestamp (str): The timestamp of the processing run. - config_hash (str): A hash of the relevant configuration used for the run. - outputs_generated (Dict[str, bool]): A dictionary indicating which outputs were generated. - """ - self.runs.append( - { - "timestamp": timestamp, - "config_hash": config_hash, - "outputs_generated": outputs_generated, - } - ) - logger.debug(f"Added new processing run to history. Timestamp: {timestamp}") - - def save(self, file_path: Path) -> None: + def __init__(self, config: Config): """ - Save the processing history to a JSON file. + Initializes the ImageProcessor with the given configuration. Args: - file_path (Path): The path where the history file will be saved. + config (Config): Configuration object for the processor. """ - with file_path.open("w") as f: - json.dump({"runs": self.runs}, f) - logger.debug(f"Saved processing history to {file_path}") + self.config = config + self.pipeline = create_segmentation_pipeline(config.model) + self.file_handler = FileHandler() + self.visualizer = VisualizationHandler() + self.analyzer = SegmentationAnalyzer() - @classmethod - def load(cls, file_path: Path) -> "ProcessingHistory": + def process(self) -> None: """ - Load a processing history from a JSON file. + Processes the input image according to the configuration. - Args: - file_path (Path): The path to the history file. + This method handles the entire image processing pipeline, including + segmentation, result saving, and analysis. - Returns: - ProcessingHistory: A ProcessingHistory object populated with the loaded data. + Raises: + ProcessingError: If an error occurs during image processing. """ - history = cls() - if file_path.exists(): - with file_path.open("r") as f: - data = json.load(f) - history.runs = data["runs"] - logger.debug(f"Loaded processing history from {file_path}") - else: - logger.info(f"No processing history file found at {file_path}") - return history - + logger.info(f"Processing image: {self.config.input}") + try: + image = Image.open(self.config.input).convert("RGB") + if self.config.model.max_size: + image.thumbnail( + (self.config.model.max_size, self.config.model.max_size) + ) -class SegmentationProcessor: - """ - A processor for semantic segmentation of images and videos. + result = self.pipeline([image])[0] - This class handles the segmentation process for individual image and video files, - including caching of results, visualization generation, and statistical analysis. + self._save_results(image, result) + self._analyze_results(result["seg_map"]) - Attributes: - config (Config): Configuration object containing processing parameters. - pipeline (SegmentationPipeline): The segmentation pipeline used for processing. - logger (Logger): Logger instance for tracking processing events. - processing_history (ProcessingHistory): Object to track processing history. - processing_plan (Dict[str, bool]): Plan determining which processing steps to execute. - """ + logger.info("Image processing complete") + except Exception as e: + logger.exception(f"Error during image processing: {str(e)}") + raise ProcessingError(f"Error during image processing: {str(e)}") - def __init__(self, config: Config): + def _save_results(self, image: Image.Image, result: Dict[str, Any]) -> None: """ - Initialize the SegmentationProcessor. + Saves the segmentation results based on the configuration. Args: - config (Config): Configuration object containing processing parameters. + image (Image.Image): The original input image. + result (Dict[str, Any]): The segmentation result dictionary. """ - self.config = config - self.pipeline = create_segmentation_pipeline( - model_name=config.model.name, - device=config.model.device, - ) - self.palette = ( - self.pipeline.palette - if self.pipeline.palette is not None - else self._generate_palette(255) - ) # Pre-compute palette for up to 256 classes - self.logger = logger.bind( - processor_type=self.__class__.__name__, - input_type=self.config.input_type.value, - ) - self.processing_history = self._load_processing_history() - self.processing_plan = self._create_processing_plan() - self.logger.debug( - "SegmentationProcessor initialized for input video.", - video_input=str(self.config.input), - ) + output_path = self.config.get_output_path() - def _load_processing_history(self) -> ProcessingHistory: - """ - Load the processing history from a file or create a new one if not found. + # Save raw segmentation + if self.config.save_raw_segmentation: + raw_seg_path = output_path.with_name( + f"{output_path.stem}_raw_segmentation.png" + ) + Image.fromarray(result["seg_map"].astype(np.uint8)).save(raw_seg_path) + logger.info(f"Raw segmentation saved to {raw_seg_path}") - Returns: - ProcessingHistory: The loaded or newly created processing history. - """ - history_file = self._get_history_file_path() - try: - history = ProcessingHistory.load(history_file) - self.logger.debug("Processing history loaded successfully") - return history - except Exception as e: - self.logger.info( - f"Failed to load processing history: {str(e)}. Starting with a new history." + # Save colored segmentation + if self.config.save_colored_segmentation: + colored_seg_path = output_path.with_name( + f"{output_path.stem}_colored_segmentation.png" + ) + colored_seg = self.visualizer.visualize_segmentation( + np.array(image), result["seg_map"], result["palette"], colored_only=True ) - return ProcessingHistory() + Image.fromarray(colored_seg).save(colored_seg_path) + logger.info(f"Colored segmentation saved to {colored_seg_path}") + + # Save overlay + if self.config.save_overlay: + overlay_path = output_path.with_name(f"{output_path.stem}_overlay.png") + overlay = self.visualizer.visualize_segmentation( + np.array(image), + result["seg_map"], + result["palette"], + colored_only=False, + ) + Image.fromarray(overlay).save(overlay_path) + logger.info(f"Overlay saved to {overlay_path}") - def _get_history_file_path(self) -> Path: + def _analyze_results(self, seg_map: np.ndarray) -> None: """ - Get the file path for the processing history JSON file. + Analyzes the segmentation results and saves the analysis. - Returns: - Path: The path to the processing history file. + Args: + seg_map (np.ndarray): The segmentation map to analyze. """ output_path = self.config.get_output_path() - return output_path.with_name(f"{output_path.stem}_processing_history.json") - - def _create_processing_plan(self) -> Dict[str, bool]: - """ - Create a processing plan based on the current configuration and existing outputs. - - Returns: - Dict[str, bool]: A dictionary representing the processing plan. - """ - if self.config.force_reprocess: - self.logger.info("Force reprocessing enabled. All steps will be executed.") - return { - "process_video": True, - "generate_hdf": True, - "generate_colored_video": self.config.save_colored_segmentation, - "generate_overlay_video": self.config.save_overlay, - "analyze_results": self.config.analyze_results, - } - - existing_outputs = self._check_existing_outputs() - - plan = { - "process_video": not existing_outputs["hdf_file_valid"], - "generate_hdf": not existing_outputs["hdf_file_valid"], - "generate_colored_video": self.config.save_colored_segmentation - and not existing_outputs["colored_video_valid"], - "generate_overlay_video": self.config.save_overlay - and not existing_outputs["overlay_video_valid"], - "analyze_results": self.config.analyze_results - and not existing_outputs["analysis_files_valid"], - } + num_categories = self.config.model.num_classes - self.logger.debug(f"Created processing plan: {plan}") - return plan + analysis = self.analyzer.analyze_segmentation_map(seg_map, num_categories) - def _check_existing_outputs(self) -> Dict[str, bool]: - """ - Check the validity of existing output files. - - Returns: - Dict[str, bool]: A dictionary indicating the validity of each output type. - """ - output_path = self.config.get_output_path() - hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5") - colored_video_path = output_path.with_name(f"{output_path.stem}_colored.mp4") - overlay_video_path = output_path.with_name(f"{output_path.stem}_overlay.mp4") counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv") percentages_file = output_path.with_name( f"{output_path.stem}_category_percentages.csv" ) - results = { - "hdf_file_valid": False, - "colored_video_valid": False, - "overlay_video_valid": False, - "analysis_files_valid": False, - } + with open(counts_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["category_id", "pixel_count"]) + for category_id, (pixel_count, _) in analysis.items(): + writer.writerow([category_id, pixel_count]) - if hdf_path.exists(): - results["hdf_file_valid"] = self._verify_hdf_file(hdf_path) - self.logger.debug(f"HDF file validity: {results['hdf_file_valid']}") + with open(percentages_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["category_id", "percentage"]) + for category_id, (_, percentage) in analysis.items(): + writer.writerow([category_id, percentage]) - if colored_video_path.exists(): - results["colored_video_valid"] = self._verify_video_file(colored_video_path) - self.logger.debug( - f"Colored video validity: {results['colored_video_valid']}" - ) + logger.info(f"Category counts saved to {counts_file}") + logger.info(f"Category percentages saved to {percentages_file}") - if overlay_video_path.exists(): - results["overlay_video_valid"] = self._verify_video_file(overlay_video_path) - self.logger.debug( - f"Overlay video validity: {results['overlay_video_valid']}" - ) - if counts_file.exists() and percentages_file.exists(): - results["analysis_files_valid"] = self._verify_analysis_files( - counts_file, percentages_file - ) - self.logger.debug( - f"Analysis files validity: {results['analysis_files_valid']}" - ) +class VideoProcessor: + """ + Processes video files using semantic segmentation models. + + This class handles the segmentation of video frames, including saving results, + generating output videos, and analyzing the segmentation output. - return results + Attributes: + config (Config): Configuration object containing processing parameters. + pipeline: Segmentation pipeline for processing video frames. + processing_plan (ProcessingPlan): Plan for video processing steps. + file_handler (FileHandler): Handles file operations. + visualizer (VisualizationHandler): Handles visualization of segmentation results. + analyzer (SegmentationAnalyzer): Analyzes segmentation results. + """ - def _verify_analysis_files(self, counts_file: Path, percentages_file: Path) -> bool: + def __init__(self, config: Config): """ - Verify the integrity of analysis files. + Initializes the VideoProcessor with the given configuration. Args: - counts_file (Path): Path to the category counts CSV file. - percentages_file (Path): Path to the category percentages CSV file. - - Returns: - bool: True if both files are valid, False otherwise. + config (Config): Configuration object for the processor. """ - try: - # Perform basic checks on the files - if counts_file.stat().st_size == 0 or percentages_file.stat().st_size == 0: - self.logger.warning("One or both analysis files are empty") - return False - - # You could add more sophisticated checks here, such as: - # - Verifying the number of rows matches the expected frame count - # - Checking that the headers are correct - # - Validating that the data is within expected ranges - - return True - except Exception as e: - self.logger.error(f"Error verifying analysis files: {str(e)}") - return False + self.config = config + self.pipeline = create_segmentation_pipeline(config.model) + self.processing_plan = ProcessingPlan(config) + self.file_handler = FileHandler() + self.visualizer = VisualizationHandler() + self.analyzer = SegmentationAnalyzer() + logger.debug(f"VideoProcessor initialized with config: {config}") def process(self) -> None: """ - Process the input based on its type (image or video). - - Raises: - ValueError: If the input type is not supported. - """ - if self.config.input_type == InputType.SINGLE_IMAGE: - self.process_image() - elif self.config.input_type == InputType.SINGLE_VIDEO: - self.process_video() - else: - raise ValueError(f"Unsupported input type: {self.config.input_type}") - - def process_image(self) -> None: - """ - Process a single image file. + Processes the input video according to the configuration and processing plan. - This method handles loading the image, running it through the segmentation pipeline, - saving the results, and analyzing the segmentation map. + This method handles the entire video processing pipeline, including + frame segmentation, result saving, video generation, and analysis. Raises: - ProcessingError: If an error occurs during image processing. - """ - self.logger.info(f"Processing image: {self.config.input}") - try: - # Load and preprocess the image - image = Image.open(self.config.input).convert("RGB") - if self.config.model.max_size: - image.thumbnail( - (self.config.model.max_size, self.config.model.max_size) - ) - - # Run segmentation - result = self.pipeline([image])[0] - - # Save and analyze results - self.save_results(image, result) - self.analyze_results(result["seg_map"]) - - self.logger.info("Image processing complete") - except Exception as e: - self.logger.exception(f"Error during image processing: {str(e)}") - raise ProcessingError(f"Error during image processing: {str(e)}") - - def process_video(self) -> None: - """ - Process a single video file according to the current processing plan. + ProcessingError: If an error occurs during video processing. """ - self.logger.info(f"Processing video: {self.config.input.name}") + logger.info(f"Processing video: {self.config.input.name}") try: output_path = self.config.get_output_path() hdf_path = output_path.with_name(f"{output_path.stem}_segmentation.h5") - if self.processing_plan["process_video"]: - self.logger.debug("Executing video frame processing") - segmentation_data, metadata = self.process_video_frames() - if self.processing_plan["generate_hdf"]: - self.logger.debug( - f"Saving segmentation data to HDF file: {hdf_path}" + if self.processing_plan.plan["process_video"]: + # logger.debug("Executing video frame processing") + segmentation_data, metadata = self._process_video_frames() + if self.processing_plan.plan["generate_hdf"]: + logger.debug(f"Saving segmentation data to HDF file: {hdf_path}") + self.file_handler.save_hdf_file( + hdf_path, segmentation_data, metadata ) - self.save_hdf_file(hdf_path, segmentation_data, metadata) else: - self.logger.info( + logger.info( f"Loading existing segmentation data from HDF file: {hdf_path.name}" ) - hdf_file, metadata = self.load_hdf_file(hdf_path) + hdf_file, metadata = self.file_handler.load_hdf_file(hdf_path) segmentation_data = hdf_file[ "segmentation" ] # This is now a h5py.Dataset # Generate videos based on the processing plan if ( - self.processing_plan["generate_colored_video"] - or self.processing_plan["generate_overlay_video"] + self.processing_plan.plan["generate_colored_video"] + or self.processing_plan.plan["generate_overlay_video"] ): self.generate_videos(segmentation_data, metadata) - # Analyze results if needed - if self.processing_plan["analyze_results"]: - self.logger.debug("Analyzing segmentation results") - self.analyze_results(segmentation_data, metadata) + if self.processing_plan.plan["analyze_results"]: + logger.debug("Analyzing segmentation results") + SegmentationAnalyzer.analyze_results( + segmentation_data, metadata, output_path + ) - # Update processing history self._update_processing_history() - self.logger.info("Video processing complete") + logger.info("Video processing complete") except Exception as e: - self.logger.exception(f"Error during video processing: {str(e)}") + logger.exception(f"Error during video processing: {str(e)}") raise ProcessingError(f"Error during video processing: {str(e)}") finally: if hasattr(self, "hdf_file"): self.hdf_file.close() - def _verify_hdf_file(self, file_path: Path) -> bool: - """ - Verify the integrity and relevance of an existing HDF file. - - Args: - file_path (Path): Path to the HDF file. - - Returns: - bool: True if the file is valid and up-to-date, False otherwise. - """ - try: - with h5py.File(file_path, "r") as f: - if "segmentation" not in f or "metadata" not in f: - self.logger.warning( - f"HDF file at {file_path} is missing required datasets" - ) - return False - - json_metadata = f["metadata"][()] - metadata = json.loads(json_metadata) - - if metadata.get("frame_step") != self.config.frame_step: - self.logger.warning( - f"HDF file frame step ({metadata.get('frame_step')}) does not match current config ({self.config.frame_step})" - ) - return False - - segmentation_data = f["segmentation"] - if len(segmentation_data) == 0: - self.logger.warning( - f"HDF file at {file_path} contains no segmentation data" - ) - return False - - first_frame = segmentation_data[0] - last_frame = segmentation_data[-1] - if first_frame.shape != last_frame.shape: - self.logger.warning( - f"Inconsistent frame shapes in HDF file at {file_path}" - ) - return False - - self.logger.debug(f"HDF file at {file_path} is valid and up-to-date") - return True - except Exception as e: - self.logger.error(f"Error verifying HDF file at {file_path}: {str(e)}") - return False - - def _verify_video_file(self, file_path: Path) -> bool: + def _process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]: """ - Verify the integrity and relevance of an existing video file. - - Args: - file_path (Path): Path to the video file. + Processes video frames in batches and returns segmentation data and metadata. Returns: - bool: True if the file is valid and up-to-date, False otherwise. + Tuple[np.ndarray, Dict[str, Any]]: Segmentation data and metadata. """ - try: - cap = cv2.VideoCapture(str(file_path)) - if not cap.isOpened(): - self.logger.warning(f"Unable to open video file at {file_path}") - return False - - frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - cap.set(cv2.CAP_PROP_POS_FRAMES, 0) - ret, first_frame = cap.read() - if not ret: - self.logger.warning( - f"Unable to read first frame from video file at {file_path}" - ) - return False - - cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) - ret, last_frame = cap.read() - if not ret: - self.logger.warning( - f"Unable to read last frame from video file at {file_path}" - ) - return False - - cap.release() - self.logger.debug(f"Video file at {file_path} is valid and up-to-date") - return True - except Exception as e: - self.logger.error(f"Error verifying video file at {file_path}: {str(e)}") - return False - - def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]: - """ - Process video frames to generate segmentation maps. - - Returns: - Tuple[np.ndarray, Dict[str, Any]]: A tuple containing the segmentation data - and metadata. - """ - # Lazy import of tqdm_context to avoid circular imports - from .utils import tqdm_context - cap = cv2.VideoCapture(str(self.config.input)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) @@ -479,22 +262,28 @@ def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]: segmentation_data = [] + logger.info(f"Processing video frames in batches of {self.config.batch_size}") with tqdm_context( - total=total_frames // self.config.frame_step, desc="Processing frames" + total=total_frames // self.config.frame_step, + desc="Processing frames", + disable=self.config.disable_tqdm, ) as pbar: for batch in self._frame_generator( cv2.VideoCapture(str(self.config.input)) ): - batch_results = self._process_batch(batch) + # logger.debug("Loading batch into pipeline...") + batch_results = self.pipeline(batch) + # logger.debug("Adding batch results to segmentation data...") segmentation_data.extend( [result["seg_map"] for result in batch_results] ) pbar.update(len(batch)) + cap.release() metadata = { "model_name": self.config.model.name, "original_video": str(self.config.input.name), - "palette": self.pipeline.palette.tolist() + "palette": np.array(self.pipeline.palette.tolist(), np.uint8) if self.pipeline.palette is not None else None, "label_ids": self.pipeline.model.config.id2label, @@ -506,159 +295,47 @@ def process_video_frames(self) -> Tuple[np.ndarray, Dict[str, Any]]: return np.array(segmentation_data), metadata - def _initialize_video_capture( - self, - ) -> Tuple[cv2.VideoCapture, int, float, int, int]: - """ - Initialize video capture and retrieve video properties. - - Returns: - Tuple[cv2.VideoCapture, int, float, int, int]: A tuple containing the video capture object, - frame count, FPS, width, and height of the video. - """ - cap = cv2.VideoCapture(str(self.config.input)) - frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - original_fps = cap.get(cv2.CAP_PROP_FPS) - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - return cap, frame_count, original_fps, width, height - - def _initialize_video_writers( - self, width: int, height: int, fps: float - ) -> Dict[str, cv2.VideoWriter]: - """ - Initialize video writers for output videos. - - Args: - width (int): Width of the video frame. - height (int): Height of the video frame. - fps (float): Frames per second of the output video. - - Returns: - Dict[str, cv2.VideoWriter]: A dictionary of initialized video writers. - """ - writers = {} - output_base = self.config.get_output_path() - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - - if self.processing_plan.get("generate_colored_video", False): - colored_path = output_base.with_name(f"{output_base.stem}_colored.mp4") - writers["colored"] = cv2.VideoWriter( - str(colored_path), fourcc, fps, (width, height) - ) - - if self.processing_plan.get("generate_overlay_video", False): - overlay_path = output_base.with_name(f"{output_base.stem}_overlay.mp4") - writers["overlay"] = cv2.VideoWriter( - str(overlay_path), fourcc, fps, (width, height) - ) - - return writers - - def _process_batch(self, batch: List[Image.Image]) -> List[Dict[str, Any]]: - """ - Process a batch of images through the segmentation pipeline. - - Args: - batch (List[Image.Image]): A list of PIL Image objects to process. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing segmentation results. - """ - return self.pipeline(batch) - - def _write_output_frames( - self, - batch: List[Image.Image], - batch_results: List[Dict[str, Any]], - video_writers: Dict[str, cv2.VideoWriter], - ) -> None: - """ - Write processed frames to output video files. - - Args: - batch (List[Image.Image]): A list of original PIL Image objects. - batch_results (List[Dict[str, Any]]): A list of segmentation results. - video_writers (Dict[str, cv2.VideoWriter]): A dictionary of video writers. - """ - for pil_image, result in zip(batch, batch_results): - if "colored" in video_writers: - colored_seg = self.visualize_segmentation( - pil_image, result["seg_map"], result["palette"], colored_only=True - ) - video_writers["colored"].write( - cv2.cvtColor(np.array(colored_seg), cv2.COLOR_RGB2BGR) - ) - if "overlay" in video_writers: - overlay = self.visualize_segmentation( - pil_image, result["seg_map"], result["palette"], colored_only=False - ) - video_writers["overlay"].write( - cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR) - ) - - def save_hdf_file( - self, file_path: Path, segmentation_data: np.ndarray, metadata: Dict[str, Any] - ) -> None: - """ - Save segmentation data and metadata to an HDF5 file. - - Args: - file_path (Path): Path to save the HDF5 file. - segmentation_data (np.ndarray): Array of segmentation maps. - metadata (Dict[str, Any]): Metadata dictionary. - """ - with h5py.File(file_path, "w") as f: - f.create_dataset("segmentation", data=segmentation_data, compression="gzip") - - # Convert all metadata to JSON-compatible format - json_metadata = json.dumps(metadata) - f.create_dataset("metadata", data=json_metadata) - - def load_hdf_file(self, file_path: Path) -> Tuple[h5py.File, Dict[str, Any]]: - """ - Load and return the HDF5 file handle and metadata. - - Args: - file_path (Path): Path to the HDF5 file. - - Returns: - Tuple[h5py.File, Dict[str, Any]]: A tuple containing the HDF5 file handle - and metadata. - """ - self.hdf_file = h5py.File(file_path, "r") - json_metadata = self.hdf_file["metadata"][()] - metadata = json.loads(json_metadata) - if "palette" in metadata: - metadata["palette"] = np.array(metadata["palette"], np.uint8) - return self.hdf_file, metadata - - def get_segmentation_data_batch(self, start: int, end: int) -> np.ndarray: + def _frame_generator(self, cap: cv2.VideoCapture) -> Iterator[List[Image.Image]]: """ - Get a batch of segmentation data from the HDF5 file. + Generates batches of frames from a video capture object. Args: - start (int): Start index of the batch. - end (int): End index of the batch. + cap (cv2.VideoCapture): The video capture object. - Returns: - np.ndarray: A batch of segmentation data. + Yields: + Iterator[List[Image.Image]]: Batches of frames as PIL Image objects. """ - return self.hdf_file["segmentation"][start:end] + while True: + frames = [] + for _ in range(self.config.batch_size): + for _ in range(self.config.frame_step): + ret, frame = cap.read() + if not ret: + break + if not ret: + break + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(rgb_frame) + frames.append(pil_image) + if not frames: + break + yield frames def generate_videos( self, segmentation_data: h5py.Dataset, metadata: Dict[str, Any] ) -> None: """ - Generate output videos based on the processing plan, using batched processing. + Generates output videos based on the processing plan, using batched processing. + + Args: + segmentation_data (h5py.Dataset): The segmentation data for all frames. + metadata (Dict[str, Any]): Metadata about the video and segmentation. """ if not ( - self.processing_plan.get("generate_colored_video", False) - or self.processing_plan.get("generate_overlay_video", False) + self.processing_plan.plan.get("generate_colored_video", False) + or self.processing_plan.plan.get("generate_overlay_video", False) ): - self.logger.info( - "No video generation required according to the processing plan." - ) + logger.info("No video generation required according to the processing plan") return start_time = time.time() @@ -670,47 +347,82 @@ def generate_videos( output_base = self.config.get_output_path() video_writers = self._initialize_video_writers(width, height, fps) - chunk_size = 100 # Adjust this value based on your memory constraints and performance needs + chunk_size = 100 # Adjust this value based on available memory for chunk_start in range(0, len(segmentation_data), chunk_size): chunk_end = min(chunk_start + chunk_size, len(segmentation_data)) - seg_chunk = self.get_segmentation_data_batch(chunk_start, chunk_end) + seg_chunk = get_segmentation_data_batch( + segmentation_data, chunk_start, chunk_end + ) frames = self._get_video_frames_batch( cap, chunk_start, chunk_end, metadata["frame_step"] ) - if self.processing_plan.get("generate_colored_video", False): - colored_frames = self.visualize_segmentation( - frames, seg_chunk, metadata["palette"], colored_only=True + if self.processing_plan.plan.get("generate_colored_video", False): + colored_frames = self.visualizer.visualize_segmentation( + frames, seg_chunk, metadata["palette"], colored_only=True + ) + for colored_frame in colored_frames: + video_writers["colored"].write( + cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR) ) - for colored_frame in colored_frames: - video_writers["colored"].write( - cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR) - ) - if self.processing_plan.get("generate_overlay_video", False): - overlay_frames = self.visualize_segmentation( - frames, seg_chunk, metadata["palette"], colored_only=False + if self.processing_plan.plan.get("generate_overlay_video", False): + overlay_frames = self.visualizer.visualize_segmentation( + frames, seg_chunk, metadata["palette"], colored_only=False + ) + for overlay_frame in overlay_frames: + video_writers["overlay"].write( + cv2.cvtColor(overlay_frame, cv2.COLOR_RGB2BGR) ) - for overlay_frame in overlay_frames: - video_writers["overlay"].write( - cv2.cvtColor(overlay_frame, cv2.COLOR_RGB2BGR) - ) for writer in video_writers.values(): writer.release() - cap.release() - self.logger.debug( - f"Video generation took {time.time() - start_time:.4f} seconds" + cap.release() + logger.debug( + f"Video generation completed in {time.time() - start_time:.2f} seconds" ) - self.logger.debug(f"Videos saved to: {output_base}") + logger.debug(f"Videos saved to: {output_base}") + + def _initialize_video_writers( + self, width: int, height: int, fps: float + ) -> Dict[str, cv2.VideoWriter]: + """ + Initializes video writers for output videos. + + Args: + width (int): Width of the video frame. + height (int): Height of the video frame. + fps (float): Frames per second of the output video. + + Returns: + Dict[str, cv2.VideoWriter]: A dictionary of initialized video writers. + """ + writers = {} + output_base = self.config.get_output_path() + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + + if self.processing_plan.plan.get("generate_colored_video", False): + colored_path = output_base.with_name(f"{output_base.stem}_colored.mp4") + writers["colored"] = cv2.VideoWriter( + str(colored_path), fourcc, fps, (width, height) + ) + + if self.processing_plan.plan.get("generate_overlay_video", False): + overlay_path = output_base.with_name(f"{output_base.stem}_overlay.mp4") + writers["overlay"] = cv2.VideoWriter( + str(overlay_path), fourcc, fps, (width, height) + ) + + return writers + @staticmethod def _get_video_frames_batch( - self, cap: cv2.VideoCapture, start: int, end: int, frame_step: int + cap: cv2.VideoCapture, start: int, end: int, frame_step: int ) -> List[np.ndarray]: """ - Get a batch of video frames. + Gets a batch of video frames. Args: cap (cv2.VideoCapture): Video capture object. @@ -730,231 +442,159 @@ def _get_video_frames_batch( frames.append(frame) return frames - def _update_processing_history(self) -> None: - """ - Update the processing history with the current run's information. - """ - config_hash = ConfigHasher.calculate_hash(self.config) - self.processing_history.add_run( - timestamp=datetime.now().isoformat(), - config_hash=config_hash, - outputs_generated=self.processing_plan, - ) - self.processing_history.save(self._get_history_file_path()) - - def analyze_results( - self, segmentation_data: h5py.Dataset, metadata: Dict[str, Any] + def _create_video( + self, + cap: cv2.VideoCapture, + segmentation_data: h5py.Dataset, + metadata: Dict[str, Any], + output_path: Path, + colored_only: bool, ) -> None: """ - Analyze segmentation results and generate statistics using chunked processing. + Creates a video from segmentation data. Args: - segmentation_data (h5py.Dataset): Memory-mapped segmentation data. - metadata (Dict[str, Any]): Metadata dictionary. + cap (cv2.VideoCapture): Video capture object of the original video. + segmentation_data (h5py.Dataset): Segmentation data for all frames. + metadata (Dict[str, Any]): Metadata about the video and segmentation. + output_path (Path): Path to save the output video. + colored_only (bool): If True, create colored segmentation; if False, create overlay. """ - output_path = self.config.get_output_path() - counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv") - percentages_file = output_path.with_name( - f"{output_path.stem}_category_percentages.csv" + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter( + str(output_path), + fourcc, + metadata["fps"], + (metadata["width"], metadata["height"]), ) - id2label = metadata["label_ids"] - headers = ["Frame"] + [id2label[i] for i in sorted(id2label.keys())] - - chunk_size = 100 # Adjust based on memory constraints + frame_index = 0 + seg_index = 0 + palette = np.array(metadata["palette"], dtype=np.uint8) - with open(counts_file, "w", newline="") as cf, open( - percentages_file, "w", newline="" - ) as pf: - counts_writer = csv.writer(cf) - percentages_writer = csv.writer(pf) - counts_writer.writerow(headers) - percentages_writer.writerow(headers) - - for chunk_start in range(0, len(segmentation_data), chunk_size): - chunk_end = min(chunk_start + chunk_size, len(segmentation_data)) - seg_chunk = self.get_segmentation_data_batch(chunk_start, chunk_end) - - for frame_idx, seg_map in enumerate(seg_chunk, start=chunk_start): - analysis = self.analyze_segmentation_map(seg_map, len(id2label)) - frame_number = frame_idx * metadata["frame_step"] + with tqdm_context( + total=metadata["frame_count"], + desc=f"Generating {'colored' if colored_only else 'overlay'} video", + disable=self.config.disable_tqdm, + ) as pbar: + while True: + ret, frame = cap.read() + if not ret: + break - counts_row = [frame_number] + [ - analysis[i][0] for i in sorted(analysis.keys()) - ] - percentages_row = [frame_number] + [ - analysis[i][1] for i in sorted(analysis.keys()) - ] + if frame_index % metadata["frame_step"] == 0: + seg_map = segmentation_data[seg_index] + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + visualized = self.visualizer.visualize_segmentation( + frame_rgb, seg_map, palette, colored_only=colored_only + ) + out.write(cv2.cvtColor(visualized, cv2.COLOR_RGB2BGR)) + seg_index += 1 + else: + out.write(frame) - counts_writer.writerow(counts_row) - percentages_writer.writerow(percentages_row) + frame_index += 1 + pbar.update(1) - self._generate_category_stats( - counts_file, output_path.with_name(f"{output_path.stem}_counts_stats.csv") - ) - self._generate_category_stats( - percentages_file, - output_path.with_name(f"{output_path.stem}_percentages_stats.csv"), + out.release() + logger.info( + f"{'Colored' if colored_only else 'Overlay'} video saved to {output_path}" ) - @staticmethod - def analyze_segmentation_map( - seg_map: np.ndarray, num_categories: int - ) -> Dict[int, tuple[int, float]]: - """ - Analyze a segmentation map to compute pixel counts and percentages for each category. - - Args: - seg_map (np.ndarray): The segmentation map to analyze. - num_categories (int): The total number of categories in the segmentation. - - Returns: - Dict[int, tuple[int, float]]: A dictionary where keys are category IDs and values - are tuples of (pixel count, percentage) for each category. + def _update_processing_history(self) -> None: """ - unique, counts = np.unique(seg_map, return_counts=True) - total_pixels = seg_map.size - category_analysis = {i: (0, 0.0) for i in range(num_categories)} - - for category_id, pixel_count in zip(unique, counts): - percentage = (pixel_count / total_pixels) * 100 - category_analysis[int(category_id)] = (int(pixel_count), float(percentage)) - - return category_analysis - - def _generate_category_stats(self, input_file: Path, output_file: Path) -> None: + Updates the processing history JSON file with the current processing information. """ - Generate category statistics from input CSV file. + output_path = self.config.get_output_path() + history_file = output_path.with_name( + f"{output_path.stem}_processing_history.json" + ) - Args: - input_file (Path): Path to the input CSV file (counts or percentages). - output_file (Path): Path to save the output statistics CSV file. - """ try: - # Read the entire CSV file - df = pd.read_csv(input_file) + if history_file.exists(): + with open(history_file, "r") as f: + history = json.load(f) + else: + history = [] - # Exclude the 'Frame' column and calculate statistics - category_columns = df.columns[1:] - stats = df[category_columns].agg(["mean", "median", "std", "min", "max"]) + current_entry = { + "timestamp": datetime.now().isoformat(), + "config_hash": ConfigHasher.calculate_hash(self.config), + "input_file": str(self.config.input), + "output_file": str(output_path), + } - # Transpose the results for a more readable output - stats = stats.transpose() + history.append(current_entry) - # Save the statistics to the output file - stats.to_csv(output_file) + with open(history_file, "w") as f: + json.dump(history, f, indent=2) - self.logger.info(f"Category statistics saved to {output_file}") + logger.info(f"Processing history updated in {history_file}") except Exception as e: - self.logger.error(f"Error generating category stats: {str(e)}") - raise - - def visualize_segmentation( - self, - images: Union[np.ndarray, List[np.ndarray]], - seg_maps: Union[np.ndarray, List[np.ndarray]], - palette: Optional[np.ndarray], - colored_only: bool = False, - ) -> Union[np.ndarray, List[np.ndarray]]: - """ - Visualize the segmentation maps for multiple images. - - Args: - images (Union[np.ndarray, List[np.ndarray]]): The original images or a single image. - seg_maps (Union[np.ndarray, List[np.ndarray]]): The segmentation maps or a single segmentation map. - palette (Optional[np.ndarray]): The color palette for visualization. - colored_only (bool): If True, return only the colored segmentation maps. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The visualized segmentation maps or overlays. - """ - if palette is None: - palette = self._generate_palette(256) # Assuming max 256 classes + logger.error(f"Error updating processing history: {str(e)}") - # Convert single image/seg_map to list for uniform processing - if isinstance(images, np.ndarray) and images.ndim == 3: - images = [images] - seg_maps = [seg_maps] - results = [] - for image, seg_map in zip(images, seg_maps): - # Vectorized color application - color_seg = palette[seg_map] +class SegmentationProcessor: + """ + Handles segmentation processing for both images and videos. - if colored_only: - results.append(color_seg) - else: - img = image * 0.5 + color_seg * 0.5 - results.append(img.astype(np.uint8)) + This class serves as a facade for ImageProcessor and VideoProcessor, + delegating the processing based on the input type. - return results[0] if len(results) == 1 else results + Attributes: + config (Config): Configuration object containing processing parameters. + image_processor (ImageProcessor): Processor for handling image inputs. + video_processor (VideoProcessor): Processor for handling video inputs. + """ - def _frame_generator(self, cap: cv2.VideoCapture) -> Iterator[List[Image.Image]]: + def __init__(self, config: Config): """ - Generate batches of frames from a video capture object. + Initializes the SegmentationProcessor with the given configuration. Args: - cap (cv2.VideoCapture): The video capture object. - - Yields: - Iterator[List[Image.Image]]: Batches of frames as PIL Image objects. + config (Config): Configuration object for the processor. """ - while True: - frames = [] - for _ in range(self.config.batch_size): - for _ in range(self.config.frame_step): - ret, frame = cap.read() - if not ret: - break - if not ret: - break - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - pil_image = Image.fromarray(rgb_frame) - frames.append(pil_image) - if not frames: - break - yield frames + self.config = config + self.image_processor = ImageProcessor(config) + self.video_processor = VideoProcessor(config) + logger.debug(f"SegmentationProcessor initialized with config: {config}") - def _generate_palette(self, num_colors: int) -> np.ndarray: + def process(self): """ - Generate a color palette for visualization. - - Args: - num_colors (int): The number of colors to generate. + Processes the input based on its type (image or video). - Returns: - np.ndarray: The generated color palette. + Raises: + ValueError: If the input type is not supported. """ - - return np.array( - [ - [(i * 100) % 255, (i * 150) % 255, (i * 200) % 255] - for i in range(num_colors) - ], - dtype=np.uint8, - ) + if self.config.input_type == InputType.SINGLE_IMAGE: + self.image_processor.process() + elif self.config.input_type == InputType.SINGLE_VIDEO: + self.video_processor.process() + else: + raise ValueError(f"Unsupported input type: {self.config.input_type}") class DirectoryProcessor: """ - A processor for handling multiple video files in a directory. + Processes multiple video files in a directory. - This class manages the processing of multiple video files within a specified - directory, utilizing the SegmentationProcessor for individual video processing. + This class handles the batch processing of video files found in a specified directory. Attributes: config (Config): Configuration object containing processing parameters. - logger (Logger): Logger instance for tracking processing events. + video_iterator (VideoFileIterator): Iterator for video files in the directory. + logger: Logger instance for this processor. """ def __init__(self, config: Config): """ - Initialize the DirectoryProcessor. + Initializes the DirectoryProcessor with the given configuration. Args: - config (Config): Configuration object containing processing parameters. + config (Config): Configuration object for the processor. """ self.config = config + self.video_iterator = VideoFileIterator(config.input) self.logger = logger.bind( processor_type=self.__class__.__name__, input_type=self.config.input_type.value, @@ -965,24 +605,19 @@ def __init__(self, config: Config): def process(self) -> None: """ - Process all video files in the specified directory. + Processes all video files in the specified directory. - This method identifies all video files in the input directory, - processes each video using SegmentationProcessor, and handles any errors - that occur during processing. + This method iterates through all video files, processing each one + according to the configuration. Raises: - InputError: If no video files are found in the specified directory. + InputError: If no video files are found in the directory. """ - # Lazy import of tqdm_context - from .utils import tqdm_context - self.logger.debug( "Starting directory processing", input_path=str(self.config.input) ) - video_files = self.get_video_files() - if not video_files: + if not self.video_iterator.video_files: self.logger.error("No video files found") raise InputError(f"No video files found in directory: {self.config.input}") @@ -991,8 +626,12 @@ def process(self) -> None: f"Output directory set: {str(output_dir)}", output_dir=str(output_dir) ) - with tqdm_context(total=len(video_files), desc="Processing videos") as pbar: - for video_file in video_files: + with tqdm_context( + total=len(self.video_iterator.video_files), + desc="Processing videos", + disable=self.config.disable_tqdm, + ) as pbar: + for video_file in self.video_iterator: if video_file.name in self.config.ignore_files: self.logger.info( f"Ignoring video file: {str(video_file.name)}", @@ -1001,7 +640,7 @@ def process(self) -> None: pbar.update(1) continue try: - self.process_single_video(video_file, output_dir) + self._process_single_video(video_file, output_dir) except Exception as e: self.logger.error( "Error processing video", @@ -1016,20 +655,9 @@ def process(self) -> None: "Finished processing all videos", input_directory=str(self.config.input) ) - def get_video_files(self) -> List[Path]: - """ - Get a list of video files in the input directory. - - Returns: - List[Path]: A list of paths to video files. - """ - video_files = get_video_files(self.config.input) - self.logger.info(f"Found {len(video_files)} video files in {self.config.input}") - return video_files - - def process_single_video(self, video_file: Path, output_dir: Path) -> None: + def _process_single_video(self, video_file: Path, output_dir: Path) -> None: """ - Process a single video file. + Processes a single video file. Args: video_file (Path): Path to the video file to process. @@ -1038,7 +666,9 @@ def process_single_video(self, video_file: Path, output_dir: Path) -> None: Raises: ProcessingError: If an error occurs during video processing. """ - video_config = self.create_video_config(video_file, output_dir) + logger.debug("Creating video config...", video_file=str(video_file)) + video_config = self._create_video_config(video_file, output_dir) + logger.debug("Video config created", video_config=video_config) try: processor = SegmentationProcessor(video_config) @@ -1049,16 +679,16 @@ def process_single_video(self, video_file: Path, output_dir: Path) -> None: ) raise ProcessingError(f"Error processing video {video_file}: {str(e)}") - def create_video_config(self, video_file: Path, output_dir: Path) -> Config: + def _create_video_config(self, video_file: Path, output_dir: Path) -> Config: """ - Create a configuration object for processing a single video. + Creates a configuration object for processing a single video. Args: video_file (Path): Path to the video file. output_dir (Path): Directory to save the processing results. Returns: - Config: A configuration object for the video processing. + Config: Configuration object for the video processor. """ return Config( input=video_file, @@ -1066,11 +696,13 @@ def create_video_config(self, video_file: Path, output_dir: Path) -> Config: output_prefix=None, model=self.config.model, frame_step=self.config.frame_step, + batch_size=self.config.batch_size, save_raw_segmentation=self.config.save_raw_segmentation, save_colored_segmentation=self.config.save_colored_segmentation, save_overlay=self.config.save_overlay, visualization=self.config.visualization, force_reprocess=self.config.force_reprocess, + disable_tqdm=self.config.disable_tqdm, ) @@ -1078,14 +710,13 @@ def create_processor( config: Config, ) -> Union[SegmentationProcessor, DirectoryProcessor]: """ - Create and return the appropriate processor based on the input type. + Creates and returns the appropriate processor based on the input type. Args: config (Config): Configuration object containing processing parameters. Returns: - Union[SegmentationProcessor, DirectoryProcessor]: An instance of either - SegmentationProcessor or DirectoryProcessor, depending on the input type. + Union[SegmentationProcessor, DirectoryProcessor]: The appropriate processor instance. """ if config.input_type == InputType.DIRECTORY: return DirectoryProcessor(config) diff --git a/src/cityseg/segmentation_analyzer.py b/src/cityseg/segmentation_analyzer.py new file mode 100644 index 0000000..46e3a2b --- /dev/null +++ b/src/cityseg/segmentation_analyzer.py @@ -0,0 +1,148 @@ +""" +This module provides a class for analyzing segmentation results. + +It includes methods to analyze segmentation maps, compute pixel counts and percentages +for each category, and generate statistics for the analysis results. + +Classes: + SegmentationAnalyzer: A class for analyzing segmentation results. +""" + +import csv +from pathlib import Path +from typing import Any, Dict + +import h5py +import numpy as np +import pandas as pd +from loguru import logger + +from cityseg.utils import get_segmentation_data_batch + + +class SegmentationAnalyzer: + """ + A class for analyzing segmentation results. + + This class provides methods to analyze segmentation maps, compute pixel counts and + percentages for each category, and generate statistics for the analysis results. + + Methods: + analyze_segmentation_map: Analyzes a segmentation map to compute pixel counts and percentages. + analyze_results: Analyzes segmentation data and saves counts and percentages to CSV files. + generate_category_stats: Generates statistics for category counts or percentages. + """ + + @staticmethod + def analyze_segmentation_map( + seg_map: np.ndarray, num_categories: int + ) -> Dict[int, tuple[int, float]]: + """ + Analyzes a segmentation map to compute pixel counts and percentages for each category. + + Args: + seg_map (np.ndarray): The segmentation map to analyze. + num_categories (int): The total number of categories in the segmentation. + + Returns: + Dict[int, Tuple[int, float]]: A dictionary where keys are category IDs and values + are tuples of (pixel count, percentage) for each category. + """ + unique, counts = np.unique(seg_map, return_counts=True) + total_pixels = seg_map.size + category_analysis = {i: (0, 0.0) for i in range(num_categories)} + + for category_id, pixel_count in zip(unique, counts): + percentage = (pixel_count / total_pixels) * 100 + category_analysis[int(category_id)] = (int(pixel_count), float(percentage)) + + return category_analysis + + @staticmethod + def analyze_results( + segmentation_data: h5py.Dataset, metadata: Dict[str, Any], output_path: Path + ) -> None: + """ + Analyzes segmentation data and saves counts and percentages to CSV files. + + This method processes the segmentation data in chunks, computes the analysis for each + frame, and writes the results to separate CSV files for counts and percentages. + + Args: + segmentation_data (h5py.Dataset): The segmentation data to analyze. + metadata (Dict[str, Any]): Metadata containing label IDs and frame step. + output_path (Path): The path where the output CSV files will be saved. + """ + counts_file = output_path.with_name(f"{output_path.stem}_category_counts.csv") + percentages_file = output_path.with_name( + f"{output_path.stem}_category_percentages.csv" + ) + + id2label = metadata["label_ids"] + headers = ["Frame"] + [id2label[i] for i in sorted(id2label.keys())] + + chunk_size = 100 # Adjust based on memory constraints + + with open(counts_file, "w", newline="") as cf, open( + percentages_file, "w", newline="" + ) as pf: + counts_writer = csv.writer(cf) + percentages_writer = csv.writer(pf) + counts_writer.writerow(headers) + percentages_writer.writerow(headers) + + for chunk_start in range(0, len(segmentation_data), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(segmentation_data)) + seg_chunk = get_segmentation_data_batch( + segmentation_data, chunk_start, chunk_end + ) + + for frame_idx, seg_map in enumerate(seg_chunk, start=chunk_start): + analysis = SegmentationAnalyzer.analyze_segmentation_map( + seg_map, len(id2label) + ) + frame_number = frame_idx * metadata["frame_step"] + + counts_row = [frame_number] + [ + analysis[i][0] for i in sorted(analysis.keys()) + ] + percentages_row = [frame_number] + [ + analysis[i][1] for i in sorted(analysis.keys()) + ] + + counts_writer.writerow(counts_row) + percentages_writer.writerow(percentages_row) + + logger.info(f"Category counts saved to {counts_file}") + logger.info(f"Category percentages saved to {percentages_file}") + + SegmentationAnalyzer.generate_category_stats( + counts_file, output_path.with_name(f"{output_path.stem}_counts_stats.csv") + ) + SegmentationAnalyzer.generate_category_stats( + percentages_file, + output_path.with_name(f"{output_path.stem}_percentages_stats.csv"), + ) + + @staticmethod + def generate_category_stats(input_file: Path, output_file: Path) -> None: + """ + Generates statistics for category counts or percentages. + + This method reads the input CSV file, computes statistics (mean, median, std, min, max) + for each category, and saves the results to the specified output file. + + Args: + input_file (Path): Path to the input CSV file containing category data. + output_file (Path): Path to save the generated statistics. + """ + try: + df = pd.read_csv(input_file) + category_columns = df.columns[1:] + stats = df[category_columns].agg(["mean", "median", "std", "min", "max"]) + stats = stats.transpose() + stats.to_csv(output_file) + logger.info(f"Category statistics saved to {output_file}") + except Exception as e: + logger.error(f"Error generating category stats: {str(e)}") + raise diff --git a/src/cityseg/utils.py b/src/cityseg/utils.py index 806c0fd..fa627cd 100644 --- a/src/cityseg/utils.py +++ b/src/cityseg/utils.py @@ -8,184 +8,29 @@ import sys from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Iterator -import cv2 +import h5py import numpy as np -import pandas as pd from loguru import logger -from tqdm.asyncio import tqdm +from tqdm.auto import tqdm -def analyze_segmentation_map( - seg_map: np.ndarray, num_categories: int -) -> Dict[int, Tuple[int, float]]: +def get_segmentation_data_batch( + segmentation_data: h5py.Dataset, start: int, end: int +) -> np.ndarray: """ - Analyze a segmentation map to compute pixel counts and percentages for each category. + Get a batch of segmentation data from the HDF5 file. Args: - seg_map (np.ndarray): The segmentation map to analyze. - num_categories (int): The total number of categories in the segmentation. + segmentation_data: + start (int): Start index of the batch. + end (int): End index of the batch. Returns: - Dict[int, Tuple[int, float]]: A dictionary where keys are category IDs and values - are tuples of (pixel count, percentage) for each category. + np.ndarray: A batch of segmentation data. """ - unique, counts = np.unique(seg_map, return_counts=True) - total_pixels = seg_map.size - category_analysis = {i: (0, 0.0) for i in range(num_categories)} - - for category_id, pixel_count in zip(unique, counts): - percentage = (pixel_count / total_pixels) * 100 - category_analysis[int(category_id)] = (int(pixel_count), float(percentage)) - - return category_analysis - - -def generate_category_stats(csv_file_path: Path) -> pd.DataFrame: - """ - Generate statistical summaries for each category from a CSV file. - - Args: - csv_file_path (Path): Path to the CSV file containing category data. - - Returns: - pd.DataFrame: A DataFrame containing statistical summaries for each category. - """ - df = pd.read_csv(csv_file_path) - category_columns = df.columns[1:] - - stats = [] - for category in category_columns: - category_data = df[category] - stats.append( - { - "Category": category, - "Mean": category_data.mean(), - "Median": category_data.median(), - "Std Dev": category_data.std(), - "Min": category_data.min(), - "Max": category_data.max(), - } - ) - - return pd.DataFrame(stats) - - -def initialize_csv_files( - output_prefix: Path, category_names: Dict[int, str] -) -> Tuple[Path, Path]: - """ - Initialize CSV files for storing category counts and percentages. - - Args: - output_prefix (Path): The prefix for output file names. - category_names (Dict[int, str]): A dictionary mapping category IDs to names. - - Returns: - Tuple[Path, Path]: Paths to the created counts and percentages CSV files. - """ - counts_file = output_prefix.with_name(f"{output_prefix.stem}_category_counts.csv") - percentages_file = output_prefix.with_name( - f"{output_prefix.stem}_category_percentages.csv" - ) - - headers = ["Frame"] + [category_names[i] for i in sorted(category_names.keys())] - - pd.DataFrame(columns=headers).to_csv(counts_file, index=False) - pd.DataFrame(columns=headers).to_csv(percentages_file, index=False) - - return counts_file, percentages_file - - -def append_to_csv_files( - counts_file: Path, - percentages_file: Path, - frame_count: int, - analysis: Dict[int, Tuple[int, float]], -) -> None: - """ - Append analysis results to the counts and percentages CSV files. - - Args: - counts_file (Path): Path to the counts CSV file. - percentages_file (Path): Path to the percentages CSV file. - frame_count (int): The current frame count. - analysis (Dict[int, Tuple[int, float]]): Analysis results for the current frame. - """ - counts_row = [frame_count] + [analysis[i][0] for i in sorted(analysis.keys())] - percentages_row = [frame_count] + [analysis[i][1] for i in sorted(analysis.keys())] - - pd.DataFrame([counts_row]).to_csv(counts_file, mode="a", header=False, index=False) - pd.DataFrame([percentages_row]).to_csv( - percentages_file, mode="a", header=False, index=False - ) - - -def get_video_files(directory: Path) -> List[Path]: - """ - Get a list of video files in the specified directory. - - Args: - directory (Path): The directory to search for video files. - - Returns: - List[Path]: A list of paths to video files found in the directory. - """ - video_extensions = [".mp4", ".avi", ".mov"] - video_files = [] - for ext in video_extensions: - video_files.extend(directory.glob(f"*{ext}")) - return video_files - - -def save_segmentation_map( - seg_map: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None -) -> None: - """ - Save a segmentation map as a NumPy array file. - - Args: - seg_map (np.ndarray): The segmentation map to save. - output_prefix (Path): The prefix for the output file name. - frame_count (Optional[int]): The frame count, if applicable. - """ - filename = f"{output_prefix.stem}_segmap{'_' + str(frame_count) if frame_count is not None else ''}.npy" - np.save(output_prefix.with_name(filename), seg_map) - - -def save_colored_segmentation( - colored_seg: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None -) -> None: - """ - Save a colored segmentation map as an image file. - - Args: - colored_seg (np.ndarray): The colored segmentation map to save. - output_prefix (Path): The prefix for the output file name. - frame_count (Optional[int]): The frame count, if applicable. - """ - filename = f"{output_prefix.stem}_colored{'_' + str(frame_count) if frame_count is not None else ''}.png" - cv2.imwrite( - str(output_prefix.with_name(filename)), - cv2.cvtColor(colored_seg, cv2.COLOR_RGB2BGR), - ) - - -def save_overlay( - overlay: np.ndarray, output_prefix: Path, frame_count: Optional[int] = None -) -> None: - """ - Save an overlay image as a file. - - Args: - overlay (np.ndarray): The overlay image to save. - output_prefix (Path): The prefix for the output file name. - frame_count (Optional[int]): The frame count, if applicable. - """ - filename = f"{output_prefix.stem}_overlay{'_' + str(frame_count) if frame_count is not None else ''}.png" - cv2.imwrite(str(output_prefix.with_name(filename)), overlay) + return segmentation_data[start:end] def setup_logging(log_level: str, verbose: bool = False) -> None: diff --git a/src/cityseg/video_file_iterator.py b/src/cityseg/video_file_iterator.py new file mode 100644 index 0000000..a3300eb --- /dev/null +++ b/src/cityseg/video_file_iterator.py @@ -0,0 +1,66 @@ +""" +This module provides an iterator class for iterating over video files in a specified directory. + +It retrieves and stores video files from the given input path and provides an iterator +interface to access these files. + +Classes: + VideoFileIterator: An iterator class for iterating over video files in a specified directory. +""" + +from pathlib import Path +from typing import Iterator, List + +from loguru import logger + + +class VideoFileIterator: + """ + An iterator class for iterating over video files in a specified directory. + + This class retrieves and stores video files from the given input path and + provides an iterator interface to access these files. + + Attributes: + input_path (Path): The path to the directory containing video files. + video_files (List[Path]): A list of video file paths found in the input directory. + """ + + def __init__(self, input_path: Path): + """ + Initializes the VideoFileIterator with the specified input path. + + Args: + input_path (Path): The path to the directory containing video files. + """ + self.input_path = input_path + self.video_files = self._get_video_files() + + def _get_video_files(self) -> List[Path]: + """ + Retrieves a list of video files from the input directory. + + This method uses the utility function `get_video_files` to find all video files + in the specified input path and logs the number of files found. + + Returns: + List[Path]: A list of paths to the video files found in the input directory. + """ + video_extensions = [".mp4", ".avi", ".mov"] + video_files = [ + f for f in self.input_path.glob("*") if f.suffix.lower() in video_extensions + ] + logger.info(f"Found {len(video_files)} video files in {self.input_path}") + return list(video_files) + + def __iter__(self) -> Iterator[Path]: + """ + Returns an iterator over the video files. + + This method allows the VideoFileIterator to be used in a for-loop or any context + that requires an iterable. + + Returns: + Iterator[Path]: An iterator over the video file paths. + """ + return iter(self.video_files) diff --git a/src/cityseg/visualization_handler.py b/src/cityseg/visualization_handler.py new file mode 100644 index 0000000..6da0af0 --- /dev/null +++ b/src/cityseg/visualization_handler.py @@ -0,0 +1,102 @@ +""" +This module provides a class for visualizing segmentation results using color palettes. + +It includes methods to visualize segmentation maps with color palettes and options for +displaying colored or blended results. + +Classes: + VisualizationHandler: A class for visualizing segmentation results using color palettes. +""" + +from typing import List, Optional, Union + +import numpy as np +from loguru import logger + + +class VisualizationHandler: + """ + A class for visualizing segmentation results using color palettes. + + This class provides methods to visualize segmentation maps with color palettes + and options for displaying colored or blended results. + + Methods: + visualize_segmentation: Visualizes segmentation results with color palettes. + _generate_palette: Generates a color palette for visualization. + """ + + @staticmethod + def visualize_segmentation( + images: Union[np.ndarray, List[np.ndarray]], + seg_maps: Union[np.ndarray, List[np.ndarray]], + palette: Optional[np.ndarray] = None, + colored_only: bool = False, + ) -> Union[np.ndarray, List[np.ndarray]]: + """ + Visualizes segmentation results using color palettes. + + This method takes input images and their corresponding segmentation maps, + applies the specified color palette, and returns the visualized results. + + Args: + images (Union[np.ndarray, List[np.ndarray]]): Input images or a list of images. + seg_maps (Union[np.ndarray, List[np.ndarray]]): Segmentation maps or a list of maps. + palette (Optional[np.ndarray]): Color palette for visualization. If None, a default palette is generated. + colored_only (bool): Flag to indicate if only colored results are desired (True) or blended with the original images (False). + + Returns: + Union[np.ndarray, List[np.ndarray]]: Visualized segmentation results, either as a single array or a list of arrays. + """ + logger.debug( + f"Visualizing segmentation for {len(images) if isinstance(images, list) else 1} images" + ) + if palette is None: + palette = VisualizationHandler._generate_palette(256) + if isinstance(palette, list): + palette = np.array(palette, dtype=np.uint8) + + if isinstance(images, np.ndarray) and images.ndim == 3: + images = [images] + seg_maps = [seg_maps] + + results = [] + for image, seg_map in zip(images, seg_maps): + color_seg = palette[seg_map] + + if colored_only: + results.append(color_seg) + else: + img = image * 0.5 + color_seg * 0.5 + results.append(img.astype(np.uint8)) + + return results[0] if len(results) == 1 else results + + @staticmethod + def _generate_palette(num_colors: int) -> np.ndarray: + """ + Generates a color palette for visualization. + + This method creates a color palette with a specified number of colors, + which can be used to visualize segmentation results. + + Args: + num_colors (int): Number of colors to generate in the palette. + + Returns: + np.ndarray: Color palette array for visualization, with shape (num_colors, 3). + """ + from .palettes import ADE20K_PALETTE + + if num_colors < len(ADE20K_PALETTE): + logger.debug(f"Using ADE20K palette with {num_colors} colors") + return np.array(ADE20K_PALETTE[:num_colors], dtype=np.uint8) + else: + logger.debug(f"Generating custom palette for {num_colors} colors") + return np.array( + [ + [(i * 100) % 255, (i * 150) % 255, (i * 200) % 255] + for i in range(num_colors) + ], + dtype=np.uint8, + ) diff --git a/tests/test_file_handler.py b/tests/test_file_handler.py new file mode 100644 index 0000000..1af1c29 --- /dev/null +++ b/tests/test_file_handler.py @@ -0,0 +1,105 @@ +import json +from unittest.mock import MagicMock, patch + +import h5py +import numpy as np +import pytest + +from cityseg.config import Config +from cityseg.file_handler import FileHandler + + +@pytest.fixture +def temp_hdf_file(tmp_path): + file_path = tmp_path / "test.hdf5" + yield file_path + if file_path.exists(): + file_path.unlink() + + +@pytest.fixture +def temp_video_file(tmp_path): + file_path = tmp_path / "test.mp4" + file_path.touch() + yield file_path + if file_path.exists(): + file_path.unlink() + + +def test_saves_hdf_file_correctly(temp_hdf_file): + segmentation_data = np.random.rand(10, 10) + metadata = {"frame_step": 1, "palette": np.array([1, 2, 3])} + FileHandler.save_hdf_file(temp_hdf_file, segmentation_data, metadata) + with h5py.File(temp_hdf_file, "r") as f: + assert "segmentation" in f + assert "metadata" in f + assert np.array_equal(f["segmentation"], segmentation_data) + loaded_metadata = json.loads(f["metadata"][()]) + assert loaded_metadata["frame_step"] == 1 + assert loaded_metadata["palette"] == [1, 2, 3] + + +def test_loads_hdf_file_correctly(temp_hdf_file): + segmentation_data = np.random.rand(10, 10) + metadata = {"frame_step": 1, "palette": [1, 2, 3]} + with h5py.File(temp_hdf_file, "w") as f: + f.create_dataset("segmentation", data=segmentation_data) + f.create_dataset("metadata", data=json.dumps(metadata)) + hdf_file, loaded_metadata = FileHandler.load_hdf_file(temp_hdf_file) + assert np.array_equal(hdf_file["segmentation"], segmentation_data) + assert loaded_metadata["frame_step"] == 1 + assert np.array_equal(loaded_metadata["palette"], np.array([1, 2, 3])) + + +def test_verifies_hdf_file_correctly(temp_hdf_file): + segmentation_data = np.random.rand(10, 10) + metadata = {"frame_step": 1} + with h5py.File(temp_hdf_file, "w") as f: + f.create_dataset("segmentation", data=segmentation_data) + f.create_dataset("metadata", data=json.dumps(metadata)) + mock_config = MagicMock(spec=Config) + mock_config.frame_step = 1 + assert FileHandler.verify_hdf_file(temp_hdf_file, mock_config) is True + + +def test_fails_verification_for_invalid_hdf_file(temp_hdf_file): + segmentation_data = np.random.rand(10, 10) + metadata = {"frame_step": 2} + with h5py.File(temp_hdf_file, "w") as f: + f.create_dataset("segmentation", data=segmentation_data) + f.create_dataset("metadata", data=json.dumps(metadata)) + mock_config = MagicMock(spec=Config) + mock_config.frame_step = 1 + assert FileHandler.verify_hdf_file(temp_hdf_file, mock_config) is False + + +def test_verifies_video_file_correctly(temp_video_file): + with patch("cv2.VideoCapture") as mock_capture: + mock_capture.return_value.isOpened.return_value = True + mock_capture.return_value.read.side_effect = [ + (True, np.zeros((10, 10, 3))), + (True, np.zeros((10, 10, 3))), + ] + assert FileHandler.verify_video_file(temp_video_file) is True + + +def test_fails_verification_for_invalid_video_file(temp_video_file): + with patch("cv2.VideoCapture") as mock_capture: + mock_capture.return_value.isOpened.return_value = False + assert FileHandler.verify_video_file(temp_video_file) is False + + +def test_verifies_analysis_files_correctly(tmp_path): + counts_file = tmp_path / "counts.txt" + percentages_file = tmp_path / "percentages.txt" + counts_file.write_text("data") + percentages_file.write_text("data") + assert FileHandler.verify_analysis_files(counts_file, percentages_file) is True + + +def test_fails_verification_for_empty_analysis_files(tmp_path): + counts_file = tmp_path / "counts.txt" + percentages_file = tmp_path / "percentages.txt" + counts_file.touch() + percentages_file.touch() + assert FileHandler.verify_analysis_files(counts_file, percentages_file) is False diff --git a/tests/test_processing_plan.py b/tests/test_processing_plan.py new file mode 100644 index 0000000..b1bc16c --- /dev/null +++ b/tests/test_processing_plan.py @@ -0,0 +1,101 @@ +from unittest.mock import MagicMock + +import pytest + +from cityseg.config import Config +from cityseg.processing_plan import ProcessingPlan + + +@pytest.fixture +def mock_config(): + config = MagicMock(spec=Config) + config.force_reprocess = False + config.save_colored_segmentation = True + config.save_overlay = True + config.analyze_results = True + config.get_output_path.return_value = MagicMock() + return config + + +def test_force_reprocess_enabled(mock_config): + """ + Given force reprocessing is enabled + When creating a processing plan + Then all processing steps should be executed + """ + mock_config.force_reprocess = True + processing_plan = ProcessingPlan(mock_config) + + expected_plan = { + "process_video": True, + "generate_hdf": True, + "generate_colored_video": True, + "generate_overlay_video": True, + "analyze_results": True, + } + + assert processing_plan.plan == expected_plan + + +def test_existing_outputs_invalid(mock_config): + """ + Given force reprocessing is disabled + And existing outputs are invalid + When creating a processing plan + Then all processing steps should be executed + """ + mock_config.force_reprocess = False + + # Mock the _check_existing_outputs method before creating the ProcessingPlan object + ProcessingPlan._check_existing_outputs = MagicMock( + return_value={ + "hdf_file_valid": False, + "colored_video_valid": False, + "overlay_video_valid": False, + "analysis_files_valid": False, + } + ) + + processing_plan = ProcessingPlan(mock_config) + + expected_plan = { + "process_video": True, + "generate_hdf": True, + "generate_colored_video": True, + "generate_overlay_video": True, + "analyze_results": True, + } + + assert processing_plan.plan == expected_plan + + +def test_existing_outputs_valid(mock_config): + """ + Given force reprocessing is disabled + And existing outputs are valid + When creating a processing plan + Then no processing steps should be executed + """ + mock_config.force_reprocess = False + + # Mock the _check_existing_outputs method before creating the ProcessingPlan object + ProcessingPlan._check_existing_outputs = MagicMock( + return_value={ + "hdf_file_valid": True, + "colored_video_valid": True, + "overlay_video_valid": True, + "analysis_files_valid": True, + } + ) + + processing_plan = ProcessingPlan(mock_config) + + expected_plan = { + "process_video": False, + "generate_hdf": False, + "generate_colored_video": False, + "generate_overlay_video": False, + "analyze_results": False, + } + + assert processing_plan.plan == expected_plan diff --git a/tests/test_segmentation_analyzer.py b/tests/test_segmentation_analyzer.py new file mode 100644 index 0000000..f1e4bd0 --- /dev/null +++ b/tests/test_segmentation_analyzer.py @@ -0,0 +1,172 @@ +from typing import Any, Dict +from unittest.mock import mock_open, patch + +import numpy as np +import pandas as pd +import pytest +from cityseg.segmentation_analyzer import SegmentationAnalyzer + + +class TestSegmentationAnalyzer: + """ + Test suite for the SegmentationAnalyzer class. + + This class contains tests to verify the correct behavior of the SegmentationAnalyzer + methods, including segmentation map analysis, results analysis, and statistics generation. + """ + + @pytest.fixture + def sample_segmentation_map(self) -> np.ndarray: + """ + Fixture that provides a sample segmentation map for testing. + + Returns: + np.ndarray: A 2D numpy array representing a sample segmentation map. + """ + return np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]]) + + @pytest.fixture + def sample_metadata(self) -> Dict[str, Any]: + """ + Fixture that provides sample metadata for testing. + + Returns: + Dict[str, Any]: A dictionary containing sample metadata. + """ + return { + "label_ids": {0: "background", 1: "foreground", 2: "edge"}, + "frame_step": 5, + } + + def test_analyze_segmentation_map(self, sample_segmentation_map): + """ + Test the analyze_segmentation_map method. + + This test verifies that the method correctly computes pixel counts + and percentages for each category in the segmentation map. + + Args: + sample_segmentation_map (np.ndarray): The sample segmentation map fixture. + """ + result = SegmentationAnalyzer.analyze_segmentation_map( + sample_segmentation_map, 3 + ) + expected = {0: (3, 100 / 3), 1: (3, 100 / 3), 2: (3, 100 / 3)} + + assert len(result) == len(expected), "Number of categories does not match" + for category in expected: + assert category in result, f"Category {category} is missing from the result" + assert ( + result[category][0] == expected[category][0] + ), f"Pixel count for category {category} does not match" + assert np.isclose( + result[category][1], expected[category][1], rtol=1e-9 + ), f"Percentage for category {category} is not close enough" + + @patch("cityseg.segmentation_analyzer.get_segmentation_data_batch") + @patch("cityseg.segmentation_analyzer.open", new_callable=mock_open) + @patch("cityseg.segmentation_analyzer.csv.writer") + @patch("cityseg.segmentation_analyzer.logger") + @patch("cityseg.segmentation_analyzer.pd.read_csv") + @patch("cityseg.segmentation_analyzer.SegmentationAnalyzer.generate_category_stats") + def test_analyze_results( + self, + mock_generate_stats, + mock_read_csv, + mock_logger, + mock_csv_writer, + mock_open, + mock_get_data, + sample_metadata, + tmp_path, + ): + """ + Test the analyze_results method. + + This test verifies that the method correctly processes segmentation data, + writes results to CSV files, and generates statistics. + + Args: + mock_generate_stats: Mocked generate_category_stats method. + mock_read_csv: Mocked pandas read_csv function. + mock_logger: Mocked logger object. + mock_csv_writer: Mocked CSV writer object. + mock_open: Mocked open function. + mock_get_data: Mocked get_segmentation_data_batch function. + sample_metadata (Dict[str, Any]): The sample metadata fixture. + tmp_path (Path): Pytest fixture for a temporary directory path. + """ + mock_segmentation_data = mock_get_data.return_value + mock_segmentation_data.__len__.return_value = 10 + mock_get_data.return_value = np.array([[[0, 1], [1, 2]]] * 10) + + # Mock the DataFrame that would be read from the CSV + mock_df = pd.DataFrame( + { + "Frame": [0, 5], + "background": [50, 60], + "foreground": [30, 25], + "edge": [20, 15], + } + ) + mock_read_csv.return_value = mock_df + + output_path = tmp_path / "test_output.h5" + SegmentationAnalyzer.analyze_results( + mock_segmentation_data, sample_metadata, output_path + ) + + assert mock_open.call_count == 2, "Should open two files for writing" + assert mock_csv_writer.call_count == 2, "Should create two CSV writers" + assert ( + mock_generate_stats.call_count == 2 + ), "Should call generate_category_stats twice" + + @patch("cityseg.segmentation_analyzer.pd.read_csv") + @patch("cityseg.segmentation_analyzer.logger") + def test_generate_category_stats(self, mock_logger, mock_read_csv, tmp_path): + """ + Test the generate_category_stats method. + + This test verifies that the method correctly reads input data, + computes statistics, and saves the results. + + Args: + mock_logger: Mocked logger object. + mock_read_csv: Mocked pandas read_csv function. + tmp_path (Path): Pytest fixture for a temporary directory path. + """ + mock_df = pd.DataFrame( + { + "Frame": [0, 5, 10], + "background": [50, 60, 70], + "foreground": [30, 25, 20], + "edge": [20, 15, 10], + } + ) + mock_read_csv.return_value = mock_df + + input_file = tmp_path / "input.csv" + output_file = tmp_path / "output.csv" + + SegmentationAnalyzer.generate_category_stats(input_file, output_file) + + mock_read_csv.assert_called_once_with(input_file) + assert output_file.exists(), "Output file should be created" + mock_logger.info.assert_called_once() + + def test_generate_category_stats_error_handling(self, tmp_path): + """ + Test error handling in the generate_category_stats method. + + This test verifies that the method properly handles and logs errors + when processing fails. + + Args: + tmp_path (Path): Pytest fixture for a temporary directory path. + """ + non_existent_file = tmp_path / "non_existent.csv" + output_file = tmp_path / "output.csv" + + with pytest.raises(Exception): + SegmentationAnalyzer.generate_category_stats(non_existent_file, output_file) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6506187 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,317 @@ +import os +import sys +import tempfile +from typing import Generator +from unittest.mock import patch + +import h5py +import numpy as np +import pytest +from cityseg.utils import get_segmentation_data_batch, setup_logging, tqdm_context +from tqdm.auto import tqdm + + +@pytest.fixture +def temp_hdf5_file() -> Generator[str, None, None]: + """ + Fixture to create a temporary HDF5 file for testing. + + Yields: + str: Path to the temporary HDF5 file. + """ + with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp: + tmp_path = tmp.name + yield tmp_path + os.unlink(tmp_path) + + +class TestGetSegmentationDataBatch: + """Tests for the get_segmentation_data_batch function.""" + + def test_retrieves_correct_batch_of_segmentation_data( + self, temp_hdf5_file: str + ) -> None: + """ + Test that the function retrieves the correct batch of segmentation data. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand(100, 100) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("segmentation", data=data) + batch = get_segmentation_data_batch(dset, 10, 20) + assert batch.shape == (10, 100) + assert np.array_equal(batch, data[10:20]) + + def test_handles_empty_segmentation_data_batch(self, temp_hdf5_file: str) -> None: + """ + Test that the function correctly handles an empty batch request. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand(100, 100) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("segmentation", data=data) + batch = get_segmentation_data_batch(dset, 0, 0) + assert batch.shape == (0, 100) + + def test_handles_out_of_bounds_segmentation_data_batch( + self, temp_hdf5_file: str + ) -> None: + """ + Test that the function correctly handles out-of-bounds batch requests. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand(100, 100) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("segmentation", data=data) + batch = get_segmentation_data_batch(dset, 90, 110) + assert batch.shape == (10, 100) + assert np.array_equal(batch, data[90:100]) + + def test_with_different_data_types(self, temp_hdf5_file: str) -> None: + """ + Test the function with different data types (int and float). + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + int_data = np.random.randint(0, 100, (100, 100)) + float_data = np.random.rand(100, 100) + + with h5py.File(temp_hdf5_file, "w") as f: + int_dset = f.create_dataset("int_segmentation", data=int_data) + float_dset = f.create_dataset("float_segmentation", data=float_data) + + int_batch = get_segmentation_data_batch(int_dset, 10, 20) + float_batch = get_segmentation_data_batch(float_dset, 10, 20) + + assert int_batch.dtype == np.int64 + assert float_batch.dtype == np.float64 + assert np.array_equal(int_batch, int_data[10:20]) + assert np.array_equal(float_batch, float_data[10:20]) + + def test_with_multi_dimensional_data(self, temp_hdf5_file: str) -> None: + """ + Test the function with multi-dimensional data. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand( + 100, 50, 50, 3 + ) # 4D data (frames, height, width, channels) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("multi_dim_segmentation", data=data) + batch = get_segmentation_data_batch(dset, 10, 20) + assert batch.shape == (10, 50, 50, 3) + assert np.array_equal(batch, data[10:20]) + + def test_single_element_batch(self, temp_hdf5_file: str) -> None: + """ + Test the function with a single-element batch. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand(100, 100) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("segmentation", data=data) + batch = get_segmentation_data_batch(dset, 10, 11) + assert batch.shape == (1, 100) + assert np.array_equal(batch, data[10:11]) + + +class TestTqdmContext: + """Tests for the tqdm_context function.""" + + def test_handles_empty_progress_bar(self) -> None: + """Test that the function correctly handles an empty progress bar.""" + with tqdm_context(total=0) as progress_bar: + assert isinstance(progress_bar, tqdm) + assert progress_bar.total == 0 + + def test_handles_non_empty_progress_bar(self) -> None: + """Test that the function correctly handles a non-empty progress bar.""" + with tqdm_context(total=100) as progress_bar: + assert isinstance(progress_bar, tqdm) + assert progress_bar.total == 100 + + def test_handles_progress_bar_with_updates(self) -> None: + """Test that the function correctly handles progress bar updates.""" + with tqdm_context(total=100) as progress_bar: + progress_bar.update(10) + assert progress_bar.n == 10 + progress_bar.update(20) + assert progress_bar.n == 30 + + def test_handles_progress_bar_with_exception(self) -> None: + """Test that the function correctly handles exceptions within the context.""" + with pytest.raises(ValueError, match="Test exception"): + with tqdm_context(total=100): + raise ValueError("Test exception") + + def test_with_large_total_value(self) -> None: + """Test that the function handles a very large total value.""" + large_total = 10**10 + with tqdm_context(total=large_total) as progress_bar: + assert progress_bar.total == large_total + progress_bar.update(large_total // 2) + assert progress_bar.n == large_total // 2 + + +class TestSetupLogging: + """ + Test suite for the setup_logging function in utils.py. + + This class contains tests to verify the correct behavior of the logging setup, + including console and file logging configurations, different log levels, + and formatting options. + """ + + @pytest.fixture + def mock_logger(self): + """ + Fixture that provides a mock logger for testing. + + Returns: + MockLogger: A mock object that simulates the behavior of the loguru logger. + """ + + class MockLogger: + def __init__(self): + self.handlers = [] + + def remove(self, handler_id=None): + print("MockLogger: remove() called") + + def add(self, sink, **kwargs): + print(f"MockLogger: add() called with sink={sink}, kwargs={kwargs}") + self.handlers.append((sink, kwargs)) + + def info(self, message): + print(f"MockLogger: info() called with message: {message}") + + return MockLogger() + + @patch("cityseg.utils.logger") + def test_basic_setup(self, mock_logger): + """ + Test the basic setup of logging with default parameters. + + Args: + mock_logger: The mocked logger object. + """ + setup_logging("INFO") + assert mock_logger.remove.called, "logger.remove() should be called" + assert ( + len(mock_logger.add.call_args_list) == 2 + ), "Should add console and file handlers" + + # Verify console handler + console_call = mock_logger.add.call_args_list[0] + assert console_call[0][0] == sys.stderr, "Console handler should use sys.stderr" + assert console_call[1]["level"] == "INFO", "Console level should be INFO" + + # Verify file handler + file_call = mock_logger.add.call_args_list[1] + assert ( + file_call[0][0] == "segmentation.log" + ), "File handler should use 'segmentation.log'" + assert file_call[1]["level"] == "INFO", "File level should be INFO" + + @patch("cityseg.utils.logger") + def test_verbose_mode(self, mock_logger): + """ + Test the logging setup in verbose mode. + + Args: + mock_logger: The mocked logger object. + """ + setup_logging("INFO", verbose=True) + console_call = mock_logger.add.call_args_list[0] + assert ( + console_call[1]["level"] == "DEBUG" + ), "Console level should be DEBUG in verbose mode" + + @patch("cityseg.utils.logger") + def test_different_log_levels(self, mock_logger): + """ + Test the logging setup with different log levels. + + Args: + mock_logger: The mocked logger object. + """ + setup_logging("DEBUG") + console_call = mock_logger.add.call_args_list[0] + file_call = mock_logger.add.call_args_list[1] + assert console_call[1]["level"] == "DEBUG", "Console level should be DEBUG" + assert ( + file_call[1]["level"] == "DEBUG" + ), "File level should be DEBUG (minimum of input and INFO)" + + setup_logging("WARNING") + console_call = mock_logger.add.call_args_list[2] + file_call = mock_logger.add.call_args_list[3] + assert console_call[1]["level"] == "WARNING", "Console level should be WARNING" + assert ( + file_call[1]["level"] == "INFO" + ), "File level should be INFO (minimum of input and INFO)" + + @patch("cityseg.utils.logger") + def test_file_logging_config(self, mock_logger): + """ + Test the file logging configuration. + + Args: + mock_logger: The mocked logger object. + """ + setup_logging("INFO") + file_call = mock_logger.add.call_args_list[1] + assert file_call[1]["rotation"] == "100 MB", "File should rotate at 100 MB" + assert ( + file_call[1]["retention"] == "1 week" + ), "File should be retained for 1 week" + assert file_call[1]["serialize"] is True, "File logging should be serialized" + + @patch("cityseg.utils.logger") + def test_console_logging_format(self, mock_logger): + """ + Test the console logging format. + + Args: + mock_logger: The mocked logger object. + """ + setup_logging("INFO") + console_call = mock_logger.add.call_args_list[0] + format_string = console_call[1]["format"] + assert "{time:YYYY-MM-DD HH:mm:ss}" in format_string + assert "{level: <8}" in format_string + assert ( + "{name}:{function}:{line}" + in format_string + ) + + +def test_integration_segmentation_with_tqdm(temp_hdf5_file: str) -> None: + """ + Integration test for using get_segmentation_data_batch within a tqdm_context. + + Args: + temp_hdf5_file (str): Path to temporary HDF5 file. + """ + data = np.random.rand(100, 100) + with h5py.File(temp_hdf5_file, "w") as f: + dset = f.create_dataset("segmentation", data=data) + + with tqdm_context(total=len(data), desc="Processing") as pbar: + for i in range(0, len(data), 10): + batch = get_segmentation_data_batch(dset, i, min(i + 10, len(data))) + assert batch.shape[0] <= 10 + pbar.update(len(batch)) + + assert pbar.n == len(data) diff --git a/tests/test_video_file_iterator.py b/tests/test_video_file_iterator.py new file mode 100644 index 0000000..48e9a23 --- /dev/null +++ b/tests/test_video_file_iterator.py @@ -0,0 +1,56 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cityseg.video_file_iterator import VideoFileIterator + + +@pytest.fixture +def mock_video_files(): + # Create mock Path objects for video files + return [Path("video1.mp4"), Path("video2.mp4")] + + +@pytest.fixture +def mock_directory(): + # Create a mock Path object for the directory + return Path("/mock/directory") + + +def test_video_file_iterator_initializes_with_video_files( + mock_video_files, mock_directory +): + with patch.object(Path, "glob", return_value=mock_video_files): + iterator = VideoFileIterator(mock_directory) + assert iterator.video_files == mock_video_files + + +def test_video_file_iterator_yields_video_files(mock_video_files, mock_directory): + with patch.object(Path, "glob", return_value=mock_video_files): + iterator = VideoFileIterator(mock_directory) + video_files_list = list(iterator) + assert video_files_list == mock_video_files + + +def test_video_file_iterator_handles_empty_directory(mock_directory): + with patch.object(Path, "glob", return_value=[]): + iterator = VideoFileIterator(mock_directory) + assert iterator.video_files == [] + assert list(iterator) == [] + + +def test_video_file_iterator_handles_non_mp4_files(mock_directory): + non_video_files = [Path("file1.txt"), Path("file2.doc")] + with patch.object(Path, "glob", return_value=non_video_files): + iterator = VideoFileIterator(mock_directory) + assert iterator.video_files == [] + assert list(iterator) == [] + + +def test_video_file_iterator_handles_mixed_files(mock_directory): + mixed_files = [Path("video1.mp4"), Path("file1.txt"), Path("video2.avi")] + with patch.object(Path, "glob", return_value=mixed_files): + iterator = VideoFileIterator(mock_directory) + assert iterator.video_files == [Path("video1.mp4"), Path("video2.avi")] + assert list(iterator) == [Path("video1.mp4"), Path("video2.avi")] diff --git a/tests/test_visualization_handler.py b/tests/test_visualization_handler.py new file mode 100644 index 0000000..16ebdcd --- /dev/null +++ b/tests/test_visualization_handler.py @@ -0,0 +1,73 @@ +import numpy as np + +from cityseg.visualization_handler import VisualizationHandler + + +def test_visualize_single_image_with_default_palette(): + image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8) + result = VisualizationHandler.visualize_segmentation(image, seg_map) + assert result.shape == image.shape + + +def test_visualize_single_image_with_custom_palette(): + image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8) + palette = np.random.randint(0, 255, (256, 3), dtype=np.uint8) + result = VisualizationHandler.visualize_segmentation(image, seg_map, palette) + assert result.shape == image.shape + + +def test_visualize_single_image_colored_only(): + image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + seg_map = np.random.randint(0, 256, (100, 100), dtype=np.uint8) + result = VisualizationHandler.visualize_segmentation( + image, seg_map, colored_only=True + ) + assert result.shape == image.shape + + +def test_visualize_multiple_images_with_default_palette(): + images = [ + np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3) + ] + seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)] + results = VisualizationHandler.visualize_segmentation(images, seg_maps) + assert len(results) == 3 + for result, image in zip(results, images): + assert result.shape == image.shape + + +def test_visualize_multiple_images_with_custom_palette(): + images = [ + np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3) + ] + seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)] + palette = np.random.randint(0, 255, (256, 3), dtype=np.uint8) + results = VisualizationHandler.visualize_segmentation(images, seg_maps, palette) + assert len(results) == 3 + for result, image in zip(results, images): + assert result.shape == image.shape + + +def test_visualize_multiple_images_colored_only(): + images = [ + np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) for _ in range(3) + ] + seg_maps = [np.random.randint(0, 256, (100, 100), dtype=np.uint8) for _ in range(3)] + results = VisualizationHandler.visualize_segmentation( + images, seg_maps, colored_only=True + ) + assert len(results) == 3 + for result, image in zip(results, images): + assert result.shape == image.shape + + +def test_generate_palette_with_less_colors_than_default(): + palette = VisualizationHandler._generate_palette(10) + assert palette.shape == (10, 3) + + +def test_generate_palette_with_more_colors_than_default(): + palette = VisualizationHandler._generate_palette(300) + assert palette.shape == (300, 3)