Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DLC #1047

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 93 additions & 49 deletions src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,40 +163,23 @@ def _ensure_individuals_in_header(df, individual_name):
return df


def _get_pes_args(
*,
config_file: Path,
h5file: Path,
individual_name: str,
timestamps_available: bool = False,
infer_timestamps: bool = True,
):
config_file = Path(config_file)
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

cfg = _read_config(config_file)

vidname, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer
video = None
def _get_graph_edges(metadata_file_path: Path):
"""
Extracts the part affinity field graph from the metadata pickle file.

df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)
Parameters
----------
metadata_file_path : Path
The path to the metadata pickle file.

# Fetch the corresponding metadata pickle file
Returns
-------
list
The part affinity field graph, which defines the edges between the keypoints in the pose estimation.
"""
paf_graph = []
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
break
if i > 0:
filename = filename[:-i]
metadata_file = Path(filename + "_meta.pickle")

if metadata_file.exists():
with open(metadata_file, "rb") as file:
if metadata_file_path.exists():
with open(metadata_file_path, "rb") as file:
metadata = pickle.load(file)

test_cfg = metadata["data"]["DLC-model-config file"]
Expand All @@ -208,25 +191,64 @@ def _get_pes_args(
else:
warnings.warn("Metadata not found...")

return paf_graph


def _get_video_info_from_config_file(config_file_path: Path, vidname: str):
"""
Get the video information from the project config file.

Parameters
----------
config_file_path : Path
The path to the project config file.
vidname : str
The name of the video.

Returns
-------
tuple
A tuple containing the video file path and the image shape.
"""
config_file_path = Path(config_file_path)
cfg = _read_config(config_file_path)

video = None
for video_path, params in cfg["video_sets"].items():
if vidname in video_path:
video = video_path, params["crop"]
break

# find timestamps only if required:
if timestamps_available:
timestamps = None
else:
if video is None:
timestamps = df.index.tolist() # setting timestamps to dummy TODO: extract timestamps in DLC?
else:
timestamps = _get_movie_timestamps(video[0], infer_timestamps=infer_timestamps)

if video is None:
warnings.warn(f"The video file corresponding to {h5file} could not be found...")
video = "fake_path", "0, 0, 0, 0"
warnings.warn(f"The corresponding video file could not be found...")
video = None, "0, 0, 0, 0"

# The video in the config_file looks like this:
# video_sets:
# /Data/openfield-Pranav-2018-08-20/videos/m1s1.mp4:
# crop: 0, 640, 0, 480

video_file_path, image_shape = video

return video_file_path, image_shape


def _get_pes_args(
*,
h5file: Path,
individual_name: str,
):
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

_, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer

df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)

return scorer, df, video, paf_graph, timestamps, cfg
return scorer, df


def _write_pes_to_nwbfile(
Expand Down Expand Up @@ -332,15 +354,37 @@ def add_subject_to_nwbfile(
nwbfile : pynwb.NWBFile
nwbfile with pes written in the behavior module
"""
timestamps_available = timestamps is not None
scorer, df, video, paf_graph, dlc_timestamps, _ = _get_pes_args(
config_file=config_file,
h5file = Path(h5file)

scorer, df = _get_pes_args(
h5file=h5file,
individual_name=individual_name,
timestamps_available=timestamps_available,
)
if timestamps is None:
timestamps = dlc_timestamps

# Note the video here is a tuple of the video path and the image shape
vidname, scorer = h5file.stem.split("DLC")
video = _get_video_info_from_config_file(config_file_path=config_file, vidname=vidname)

# find timestamps only if required:``
timestamps_available = timestamps is not None
video_file_path = video[0]
if not timestamps_available:
if video_file_path is None:
timestamps = df.index.tolist() # setting timestamps to dummy
else:
timestamps = _get_movie_timestamps(video_file_path, infer_timestamps=True)

# Fetch the corresponding metadata pickle file, we extract the edges graph from here
# TODO: This is the original implementation way to extract the file name but looks very brittle
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
break
if i > 0:
filename = filename[:-i]

metadata_file_path = Path(filename + "_meta.pickle")
paf_graph = _get_graph_edges(metadata_file_path=metadata_file_path)

df_animal = df.xs(individual_name, level="individuals", axis=1)

Expand Down
Loading