diff --git a/pyodi/apps/ground_truth.py b/pyodi/apps/ground_truth.py index e20d28d..bfaf1d7 100644 --- a/pyodi/apps/ground_truth.py +++ b/pyodi/apps/ground_truth.py @@ -48,15 +48,12 @@ from pathlib import Path from typing import Optional, Tuple -from loguru import logger - from pyodi.core.boxes import add_centroids from pyodi.core.utils import coco_ground_truth_to_df from pyodi.plots.boxes import get_centroids_heatmap, plot_heatmap from pyodi.plots.common import plot_scatter_with_histograms -@logger.catch def ground_truth( ground_truth_file: str, show: bool = True, @@ -87,7 +84,7 @@ def ground_truth( df_images, x="img_width", y="img_height", - title=f"{Path(ground_truth_file).stem}: Image Shapes", + title="Image_Shapes", show=show, output=output, output_size=output_size, @@ -108,7 +105,7 @@ def ground_truth( df_annotations, x="absolute_width", y="absolute_height", - title=f"{Path(ground_truth_file).stem}: Bounding Box Shapes", + title="Bounding_Box_Shapes", show=show, output=output, output_size=output_size, @@ -120,7 +117,7 @@ def ground_truth( plot_heatmap( get_centroids_heatmap(df_annotations), - title=f"{Path(ground_truth_file).stem}: Bounding Box Centers", + title="Bounding_Box_Centers", show=show, output=output, output_size=output_size, diff --git a/setup.cfg b/setup.cfg index b2608d6..7255e9d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,9 +17,9 @@ install_requires = numba==0.52.0 pandas==1.2.1 pillow==8.1.0 - plotly==4.14.3 + plotly==5.2.2 pycocotools==2.0.2 - kaleido==v0.1.0 + kaleido==v0.2.1 scikit-learn==0.24.1 fire==0.4.0 diff --git a/tests/apps/test_ground_truth.py b/tests/apps/test_ground_truth.py new file mode 100644 index 0000000..04fe2f8 --- /dev/null +++ b/tests/apps/test_ground_truth.py @@ -0,0 +1,27 @@ +import json +from pathlib import Path + +from pyodi.apps.ground_truth import ground_truth + + +def test_ground_truth_saves_output_to_files(tmpdir): + output = tmpdir.mkdir("results") + + categories = [{"id": 1, "name": "drone"}] + images = [{"id": 0, "file_name": "image.jpg", "height": 10, "width": 10}] + annotations = [ + {"image_id": 0, "category_id": 1, "id": 0, "bbox": [0, 0, 5, 5], "area": 25} + ] + coco_data = dict( + images=images, + annotations=annotations, + categories=categories, + info={}, + licenses={}, + ) + with open(tmpdir / "data.json", "w") as f: + json.dump(coco_data, f) + + ground_truth(tmpdir / "data.json", show=False, output=output) + + assert len(list(Path(output / "data").iterdir())) == 3