Skip to content

Commit

Permalink
HiPRGen-BonDNet lmdb worker level
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenbin Xu committed Dec 8, 2023
1 parent 309599c commit 8adf0b4
Show file tree
Hide file tree
Showing 6 changed files with 637 additions and 79 deletions.
189 changes: 147 additions & 42 deletions HiPRGen/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,50 @@ def __init__(self, config, transform=None):
self.config = config
self.path = Path(self.config["src"])

# Get metadata in case
# self.metadata_path = self.path.parent / "metadata.npz"
self.env = self.connect_db(self.path)

# If "length" encoded as ascii is present, use that
# If there are additional properties, there must be length.
length_entry = self.env.begin().get("length".encode("ascii"))
if length_entry is not None:
num_entries = pickle.loads(length_entry)
if not self.path.is_file():
db_paths = sorted(self.path.glob("*.lmdb"))
assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'"
#self.metadata_path = self.path / "metadata.npz"

self._keys = []
self.envs = []
for db_path in db_paths:
cur_env = self.connect_db(db_path)
self.envs.append(cur_env)

# If "length" encoded as ascii is present, use that
length_entry = cur_env.begin().get("length".encode("ascii"))
if length_entry is not None:
num_entries = pickle.loads(length_entry)
else:
# Get the number of stores data from the number of entries in the LMDB
num_entries = cur_env.stat()["entries"]

# Append the keys (0->num_entries) as a list
self._keys.append(list(range(num_entries)))

keylens = [len(k) for k in self._keys]
self._keylen_cumulative = np.cumsum(keylens).tolist()
self.num_samples = sum(keylens)


else:
# Get the number of stores data from the number of entries
# in the LMDB
num_entries = self.env.stat()["entries"]

self._keys = list(range(num_entries))
self.num_samples = num_entries
# Get metadata in case
# self.metadata_path = self.path.parent / "metadata.npz"
self.env = self.connect_db(self.path)

# If "length" encoded as ascii is present, use that
# If there are additional properties, there must be length.
length_entry = self.env.begin().get("length".encode("ascii"))
if length_entry is not None:
num_entries = pickle.loads(length_entry)
else:
# Get the number of stores data from the number of entries
# in the LMDB
num_entries = self.env.stat()["entries"]

self._keys = list(range(num_entries))
self.num_samples = num_entries

# Get portion of total dataset
self.sharded = False
Expand All @@ -71,15 +99,34 @@ def __getitem__(self, idx):
# if sharding, remap idx to appropriate idx of the sharded set
if self.sharded:
idx = self.available_indices[idx]

if not self.path.is_file():
# Figure out which db this should be indexed from.
db_idx = bisect.bisect(self._keylen_cumulative, idx)
# Extract index of element within that db.
el_idx = idx
if db_idx != 0:
el_idx = idx - self._keylen_cumulative[db_idx - 1]
assert el_idx >= 0

# Return features.
datapoint_pickled = (
self.envs[db_idx]
.begin()
.get(f"{self._keys[db_idx][el_idx]}".encode("ascii"))
)
data_object = pickle.loads(datapoint_pickled)
#data_object.id = f"{db_idx}_{el_idx}"

else:
#!CHECK, _keys should be less then total numbers of keys as there are more properties.
datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii"))

#!CHECK, _keys should be less then total numbers of keys as there are more properties.
datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii"))

data_object = pickle.loads(datapoint_pickled)
data_object = pickle.loads(datapoint_pickled)

# TODO
if self.transform is not None:
data_object = self.transform(data_object)
# TODO
if self.transform is not None:
data_object = self.transform(data_object)

return data_object

Expand Down Expand Up @@ -109,56 +156,114 @@ def get_metadata(self, num_samples=100):
class LmdbMoleculeDataset(LmdbBaseDataset):
def __init__(self, config, transform=None):
super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform)

if not self.path.is_file():
self.env_ = self.envs[0]
raise("Not Implemented Yet")

else:
self.env_ = self.env
@property
def charges(self):
charges = self.env.begin().get("charges".encode("ascii"))
charges = self.env_.begin().get("charges".encode("ascii"))
return pickle.loads(charges)

@property
def ring_sizes(self):
ring_sizes = self.env.begin().get("ring_sizes".encode("ascii"))
ring_sizes = self.env_.begin().get("ring_sizes".encode("ascii"))
return pickle.loads(ring_sizes)

@property
def elements(self):
elements = self.env.begin().get("elements".encode("ascii"))
elements = self.env_.begin().get("elements".encode("ascii"))
return pickle.loads(elements)

@property
def feature_info(self):
feature_info = self.env.begin().get("feature_info".encode("ascii"))
feature_info = self.env_.begin().get("feature_info".encode("ascii"))
return pickle.loads(feature_info)


class LmdbReactionDataset(LmdbBaseDataset):
def __init__(self, config, transform=None):
super(LmdbReactionDataset, self).__init__(config=config, transform=transform)

if not self.path.is_file():
self.env_ = self.envs[0]
#get keys
for i in range(1, len(self.envs)):
for key in ["feature_size", "dtype", "feature_name"]: #, "mean", "std"]:
assert self.envs[i].begin().get(key.encode("ascii")) == self.envs[0].begin().get(key.encode("ascii"))
#! mean and std are not equal across different dataset at this time.
#get mean and std
mean_list = [pickle.loads(self.envs[i].begin().get("mean".encode("ascii"))) for i in range(0, len(self.envs))]
std_list = [pickle.loads(self.envs[i].begin().get("std".encode("ascii"))) for i in range(0, len(self.envs))]
count_list = [pickle.loads(self.envs[i].begin().get("length".encode("ascii"))) for i in range(0, len(self.envs))]
self._mean, self._std = combined_mean_std(mean_list, std_list, count_list)

else:
self.env_ = self.env
self._mean = pickle.loads(self.env_.begin().get("mean".encode("ascii")))
self._std = pickle.loads(self.env_.begin().get("std".encode("ascii")))

@property
def dtype(self):
dtype = self.env.begin().get("dtype".encode("ascii"))
dtype = self.env_.begin().get("dtype".encode("ascii"))
return pickle.loads(dtype)

@property
def feature_size(self):
feature_size = self.env.begin().get("feature_size".encode("ascii"))
feature_size = self.env_.begin().get("feature_size".encode("ascii"))
return pickle.loads(feature_size)

@property
def feature_name(self):
feature_name = self.env.begin().get("feature_name".encode("ascii"))
feature_name = self.env_.begin().get("feature_name".encode("ascii"))
return pickle.loads(feature_name)

@property
def mean(self):
mean = self.env.begin().get("mean".encode("ascii"))
return pickle.loads(mean)

return self._mean

@property
def std(self):
std = self.env.begin().get("std".encode("ascii"))
return pickle.loads(std)
#std = self.env_.begin().get("std".encode("ascii"))
return self._std

# @property
# def mean(self):
# mean = self.env_.begin().get("mean".encode("ascii"))
# return pickle.loads(mean)

# @property
# def std(self):
# std = self.env_.begin().get("std".encode("ascii"))
# return pickle.loads(std)


def combined_mean_std(mean_list, std_list, count_list):
"""
Calculate the combined mean and standard deviation of multiple datasets.
:param mean_list: List of means of the datasets.
:param std_list: List of standard deviations of the datasets.
:param count_list: List of number of data points in each dataset.
:return: Combined mean and standard deviation.
"""
# Calculate total number of data points
total_count = sum(count_list)

# Calculate combined mean
combined_mean = sum(mean * count for mean, count in zip(mean_list, count_list)) / total_count

# Calculate combined variance
combined_variance = sum(
((std ** 2) * (count - 1) + count * (mean - combined_mean) ** 2 for mean, std, count in zip(mean_list, std_list, count_list))
) / (total_count - len(mean_list))

# Calculate combined standard deviation
combined_std = (combined_variance ** 0.5)

return combined_mean, combined_std



Expand Down Expand Up @@ -442,10 +547,10 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path):
map_async=True,
)

pbar = tqdm(
total=len(new_samples),
desc=f"Adding new samples into LMDBs",
)
# pbar = tqdm(
# total=len(new_samples),
# desc=f"Adding new samples into LMDBs",
# )

#write indexed samples
idx = current_length
Expand All @@ -456,7 +561,7 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path):
pickle.dumps(sample, protocol=-1),
)
idx += 1
pbar.update(1)
#pbar.update(1)
txn.commit()

#write properties
Expand Down
Loading

0 comments on commit 8adf0b4

Please sign in to comment.