Skip to content

Commit

Permalink
Merge pull request #53 from choderalab/sample_weight
Browse files Browse the repository at this point in the history
Sample weight
  • Loading branch information
yuanqing-wang authored Oct 27, 2020
2 parents 54a50fd + 3e1664d commit 1c46403
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 2 deletions.
42 changes: 40 additions & 2 deletions espaloma/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def __getitem__(self, idx):

return graph

elif isinstance(idx, slice): # implement slicing
elif isinstance(idx, slice):
# implement slicing
if self.transforms is None:
# return a Dataset object rather than list
return self.__class__(graphs=self.graphs[idx])
else:
graphs = []
for graph in self.graphs:
for graph in self.graphs[idx]:

# nested transforms
for transform in self.transforms:
Expand All @@ -76,6 +77,25 @@ def __getitem__(self, idx):

return self.__class__(graphs=graphs)

elif isinstance(idx, list):
# implement slicing
if self.transforms is None:
# return a Dataset object rather than list
return self.__class__(
graphs=[self.graphs[_idx] for _idx in idx]
)
else:
graphs = []
for _idx in idx:
graph = self[_idx]
# nested transforms
for transform in self.transforms:
graph = transform(graph)
graphs.append(graph)

return self.__class__(graphs=graphs)


def __iter__(self):
if self.transforms is None:
return iter(self.graphs)
Expand Down Expand Up @@ -147,6 +167,24 @@ def split(self, partition):

return ds

def subsample(self, ratio):
""" Subsample the dataset according to some ratio.
Parameters
----------
ratio : float
Ratio between the size of the subsampled dataset and the
original dataset.
"""
n_data = len(self)
idxs = list(range(n_data))
import random
_idxs = random.choices(idxs, k=int(n_data*ratio))
print(_idxs)
print(self[_idxs])
return self[_idxs]

def save(self, path):
""" Save dataset to path.
Expand Down
5 changes: 5 additions & 0 deletions espaloma/data/qcarchive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,8 @@ def make_batch_size_consistent(ds, batch_size=32):
)
)
)


def weight_by_snapshots(g, key="weight"):
n_snapshots = g.nodes['n1'].data['xyz'].shape[1]
g.nodes['g'].data[key] = torch.tensor(float(1.0/n_snapshots))[None, :]
4 changes: 4 additions & 0 deletions espaloma/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ def test_no_return(ds):
fn = lambda x: x + 1
ds.apply(fn).apply(fn)
assert all(x == x_ + 2 for (x, x_) in zip(ds, range(5)))

def test_subsample(ds):
_ds = ds.subsample(0.2)
assert len(_ds) == 1
18 changes: 18 additions & 0 deletions espaloma/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ def _std(input, target, metric=metric, weight=weight, dim=dim):

return _std

def weighted(metric, weight, reduction="mean"):
def _weighted(
input, target, metric=metric, weight=weight, reduction=reduction
):
_loss = metric(input, target)
for _ in range(_loss.dims()-1):
weight = weight.unsqueeze(-1)
return getattr(torch, reduction)(weight)
return _weighted

def weighted_with_key(metric, key="weight", reduction="mean"):
def _weighted(input, target, metric=metric, key=key, reduction=reduction):
weight = target.nodes["g"].data[key].flatten()
_loss = metric(input, target)
for _ in range(_loss.dims()-1):
weight = weight.unsqueeze(-1)
return getattr(torch, reduction)(weight)
return _weighted

def bootstrap(metric, n_samples=1000, ci=0.95):
def _bootstrap(input, target, metric=metric, n_samples=n_samples, ci=0.95):
Expand Down

0 comments on commit 1c46403

Please sign in to comment.