Skip to content

Commit

Permalink
Added metadata to DictDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Nov 13, 2024
1 parent 53eb43d commit a3555a0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
44 changes: 21 additions & 23 deletions mlcolvar/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DictDataset(Dataset):
'weights' : np.asarray([0.5,1.5,1.5,0.5]) }
"""

def __init__(self, dictionary: dict = None, feature_names=None, **kwargs):
def __init__(self, dictionary: dict=None, feature_names = None, metadata: dict = None, **kwargs):
"""Create a Dataset from a dictionary or from a list of kwargs.
Parameters
Expand All @@ -32,7 +32,12 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs):
raise TypeError(
f"DictDataset requires a dictionary , not {type(dictionary)}."
)


if (metadata is not None) and (not isinstance(metadata, dict)):
raise TypeError(
f"DictDataset metadata requires a dictionary , not {type(metadata)}."
)

# Add kwargs to dict
if dictionary is None:
dictionary = {}
Expand All @@ -43,7 +48,7 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs):
# convert to torch.Tensors
for key, val in dictionary.items():
if not isinstance(val, torch.Tensor):
if key in ["data_list", "z_table", "cutoff"]:
if key =="data_list":
dictionary[key] = val
else:
dictionary[key] = torch.Tensor(val)
Expand All @@ -54,10 +59,13 @@ def __init__(self, dictionary: dict = None, feature_names=None, **kwargs):
# save feature names
self.feature_names = feature_names

# save metadata
self.metadata = metadata

# check that all elements of dict have same length
it = iter(dictionary.items())
self.length = len(next(it)[1])
if not all([len(v)==self.length for l,v in it if l not in ["cutoff", "z_table"]]):
it = iter(dictionary.values())
self.length = len(next(it))
if not all([len(l)==self.length for l in it]):
raise ValueError("not all arrays in dictionary have same length!")

def __getitem__(self, index):
Expand All @@ -67,28 +75,16 @@ def __getitem__(self, index):
else:
slice_dict = {}
for key, val in self._dictionary.items():
if key in ["z_table", "cutoff"]:
slice_dict[key] = val
else:
slice_dict[key] = val[index]
slice_dict[key] = val[index]
return slice_dict

def __setitem__(self, index, value):
if isinstance(index, str):
# check lengths
if len(value) != len(self):
if index not in ["z_table", "cutoff"]:
raise ValueError(
f"length of value ({len(value)}) != length of dataset ({len(self)})."
)
elif index=="z_table" and len(value) != len(self._dictionary[index]):
raise ValueError(
f"length of value ({len(value)}) != length of original dataset['{index}'] ({self._dictionary[index]})."
)
elif index=="cutoff" and not isinstance(value, float):
raise ValueError(
f" 'cutoff' must be type float, found {type(value)} "
)
raise ValueError(
f"length of value ({len(value)}) != length of dataset ({len(self)})."
)
self._dictionary[index] = value
else:
raise NotImplementedError(
Expand All @@ -115,12 +111,14 @@ def get_stats(self):
def __repr__(self) -> str:
string = "DictDataset("
for key, val in self._dictionary.items():
if key in ["data_list", "z_table"]:
if key=="data_list":
string += f' "{key}": {len(val)},'
elif key in ["cutoff"]:
string += f' "{key}": {val},'
else:
string += f' "{key}": {list(val.shape)},'
for key, val in self.metadata.items():
string += f' "{key}": {val},'
string = string[:-1] + " )"
return string

Expand Down
6 changes: 3 additions & 3 deletions mlcolvar/data/graph/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def create_dataset_from_configurations(
data_list[i].cell = cell_list[i]

# dataset = GraphDataSet(data_list, z_table.zs, cutoff)
dataset = DictDataset({'data_list' : data_list,
'z_table' : z_table.zs,
'cutoff' : cutoff})
dataset = DictDataset(dictionary={'data_list' : data_list},
metadata={'z_table' : z_table.zs,
'cutoff' : cutoff})

return dataset

Expand Down

0 comments on commit a3555a0

Please sign in to comment.