From a3555a08eba1cff388032d025c888b8289f280ff Mon Sep 17 00:00:00 2001 From: EnricoTrizio Date: Wed, 13 Nov 2024 15:00:45 +0100 Subject: [PATCH] Added metadata to DictDataset --- mlcolvar/data/dataset.py | 44 ++++++++++++++++------------------ mlcolvar/data/graph/dataset.py | 6 ++--- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/mlcolvar/data/dataset.py b/mlcolvar/data/dataset.py index 73ef6ed..48913b4 100644 --- a/mlcolvar/data/dataset.py +++ b/mlcolvar/data/dataset.py @@ -16,7 +16,7 @@ class DictDataset(Dataset): 'weights' : np.asarray([0.5,1.5,1.5,0.5]) } """ - def __init__(self, dictionary: dict = None, feature_names=None, **kwargs): + def __init__(self, dictionary: dict=None, feature_names = None, metadata: dict = None, **kwargs): """Create a Dataset from a dictionary or from a list of kwargs. Parameters @@ -32,7 +32,12 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs): raise TypeError( f"DictDataset requires a dictionary , not {type(dictionary)}." ) - + + if (metadata is not None) and (not isinstance(metadata, dict)): + raise TypeError( + f"DictDataset metadata requires a dictionary , not {type(metadata)}." + ) + # Add kwargs to dict if dictionary is None: dictionary = {} @@ -43,7 +48,7 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs): # convert to torch.Tensors for key, val in dictionary.items(): if not isinstance(val, torch.Tensor): - if key in ["data_list", "z_table", "cutoff"]: + if key =="data_list": dictionary[key] = val else: dictionary[key] = torch.Tensor(val) @@ -54,10 +59,13 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs): # save feature names self.feature_names = feature_names + # save metadata + self.metadata = metadata + # check that all elements of dict have same length - it = iter(dictionary.items()) - self.length = len(next(it)[1]) - if not all([len(v)==self.length for l,v in it if l not in ["cutoff", "z_table"]]): + it = iter(dictionary.values()) + self.length = len(next(it)) + if not all([len(l)==self.length for l in it]): raise ValueError("not all arrays in dictionary have same length!") def __getitem__(self, index): @@ -67,28 +75,16 @@ def __getitem__(self, index): else: slice_dict = {} for key, val in self._dictionary.items(): - if key in ["z_table", "cutoff"]: - slice_dict[key] = val - else: - slice_dict[key] = val[index] + slice_dict[key] = val[index] return slice_dict def __setitem__(self, index, value): if isinstance(index, str): # check lengths if len(value) != len(self): - if index not in ["z_table", "cutoff"]: - raise ValueError( - f"length of value ({len(value)}) != length of dataset ({len(self)})." - ) - elif index=="z_table" and len(value) != len(self._dictionary[index]): - raise ValueError( - f"length of value ({len(value)}) != length of original dataset['{index}'] ({self._dictionary[index]})." - ) - elif index=="cutoff" and not isinstance(value, float): - raise ValueError( - f" 'cutoff' must be type float, found {type(value)} " - ) + raise ValueError( + f"length of value ({len(value)}) != length of dataset ({len(self)})." + ) self._dictionary[index] = value else: raise NotImplementedError( @@ -115,12 +111,14 @@ def get_stats(self): def __repr__(self) -> str: string = "DictDataset(" for key, val in self._dictionary.items(): - if key in ["data_list", "z_table"]: + if key=="data_list": string += f' "{key}": {len(val)},' elif key in ["cutoff"]: string += f' "{key}": {val},' else: string += f' "{key}": {list(val.shape)},' + for key, val in self.metadata.items(): + string += f' "{key}": {val},' string = string[:-1] + " )" return string diff --git a/mlcolvar/data/graph/dataset.py b/mlcolvar/data/graph/dataset.py index cf49c98..23b272e 100644 --- a/mlcolvar/data/graph/dataset.py +++ b/mlcolvar/data/graph/dataset.py @@ -258,9 +258,9 @@ def create_dataset_from_configurations( data_list[i].cell = cell_list[i] # dataset = GraphDataSet(data_list, z_table.zs, cutoff) - dataset = DictDataset({'data_list' : data_list, - 'z_table' : z_table.zs, - 'cutoff' : cutoff}) + dataset = DictDataset(dictionary={'data_list' : data_list}, + metadata={'z_table' : z_table.zs, + 'cutoff' : cutoff}) return dataset