diff --git a/atom3d/datasets/datasets.py b/atom3d/datasets/datasets.py index 9854a89..0e3ba82 100644 --- a/atom3d/datasets/datasets.py +++ b/atom3d/datasets/datasets.py @@ -438,12 +438,12 @@ def make_lmdb_dataset(dataset, output_lmdb, id_to_idx = {} i = 0 for x in tqdm.tqdm(dataset, total=num_examples): + if filter_fn is not None and filter_fn(x): + continue # Add an entry that stores the original types of all entries x['types'] = {key: str(type(val)) for key, val in x.items()} # ... including itself x['types']['types'] = str(type(x['types'])) - if filter_fn is not None and filter_fn(x): - continue buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6) as f: f.write(serialize(x, serialization_format)) @@ -562,13 +562,15 @@ def combine_datasets(dataset_list, output_lmdb, filter_fn=None, serialization_fo txn.put(b'id_to_idx', serialize(id_to_idx, serialization_format)) -def download_dataset(name, out_path): +def download_dataset(name, out_path, split=None): """Download an ATOM3D dataset in LMDB format. Available datasets are SMP, PIP, RES, MSP, LBA, LEP, PSR, RSR. Please see `FAQ `_ or `atom3d.ai `_ for more details on each dataset. :param name: Three-letter code for dataset (not case-sensitive). :type name: str :param out_path: Path to directory in which to save downloaded dataset. :type out_path: str + :param split: name of split data to download in LMDB format. Defaults to None, in which case raw (unsplit) dataset is downloaded. Please use :func:`download_split_indices` to get pre-computed split indices for raw datasets. + :type split: str """ def _hook(t): @@ -596,21 +598,65 @@ def update_to(b=1, bsize=1, tsize=None): name = name.lower() if name == 'smp': - link = '13MT_f86so0fm6TOtzhW2Qy9ubVQo6UiU' - elif name == 'pip': - link = '1D4gMdJEz-6hzSc7_QQ2CF1K-anR4mO8T' + if split is None: + link = '1Qj67Y3cmTZoo9NCnjjI9CIs1KXnRY36O' + elif split == 'random': + link = '1MtqouZsV_7nEb3CZMoaJejAHrS1_bFMN' + else: + logger.warning(f'specified split {split} not available. Possible values are "random".') + return + elif name == 'ppi': + if split is None: + link = '1QYAXy71s9oStaSBnaVIL0i62jNSpiGQB' + elif split == 'DIPS': + link = '1ddUdYTr5aqXJv0Ncz1TWloqiLCLPLO_K' + else: + logger.warning(f'specified split {split} not available. Possible values are "DIPS".') + return elif name == 'res': link = '1XgZ19YYwloHxEtZUk78PLVzHipFkqIm5' elif name == 'msp': - link = '15rojYF-UjNnqoD8BnNpFtoxVZu64Y7FL' + if split is None: + link = '1ACkgojNUKo_ck34F3VEvsjHtlqIs2ecx' + elif split == 'sequence-identity-30': + link = '1f2GUGRIxR82l5eb8r8OFX7QkST4zbuZ3' + else: + logger.warning(f'specified split {split} not available. Possible values are "sequence-identity-30".') + return elif name == 'lba': - link = '1CGCRj3IwbT0HNSHIqQ46-o2n1CmGOnwK' + if split is None: + link = '16U5imKQ9bZr2GQPbmOE6FlcKeXuUrETa' + elif split == 'sequence-identity-30': + link = '1WQERC8h3t2DSkKkg12xpoOaYfB9dzeCB' + elif split == 'sequence-identity-60': + link = '1pGOe_V-JL6Mn_qxXjFwpRTFxYODhBZMR' + else: + logger.warning(f'specified split {split} not available. Possible values are "sequence-identity-30", "sequence-identity-60".') + return elif name == 'lep': - link = '15A85q2h6C1WFKjVttv6sInFNnB5z7Ha7' + if split is None: + link = '1V0r_VutAKKfwHYdx_nPUZi84a_Ud-rri' + elif split == 'protein': + link = '1tCHqotbAqcHmgtEeVidSld3pOyY2MUIh' + else: + logger.warning(f'specified split {split} not available. Possible values are "protein".') + return elif name == 'psr': - link = '1rvxf9JKTq0OvU3QLkxNYomfyXg5sd2CO' + if split is None: + link = '1ahFkfqijbLSO9kelRrp6i8TcufXuSoVa' + elif split == 'year': + link = '1nmiqJLRZMTnbADzkcEUjfO7H9WwT13Ns' + else: + logger.warning(f'specified split {split} not available. Possible values are "year".') + return elif name == 'rsr': - link = '1rlQ8BmyamMud2TZkcFGy_raz9iI1-KMm' + if split is None: + link = '16sbDowF_IyAkJAZ_UyUAirbIug2Oi4EU' + elif split == 'year': + link = '1yI03LSslNrOaculxoM0sUbSC1z0hKnt6' + else: + logger.warning(f'specified split {split} not available. Possible values are "year".') + return else: print('Invalid dataset name specified. Possible values are {SMP, PIP, RES, MSP, LBA, LEP, PSR, RSR}') @@ -619,7 +665,9 @@ def update_to(b=1, bsize=1, tsize=None): # urllib.request.urlretrieve(link, filename=f_out, # reporthook=_hook(t), data=None) - cmd = f"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={link}' -O- | sed -En 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p' | tr -d \"n\")&id={link}\" -O {name}.tar.gz" + if not os.path.exists(out_path): + os.makedirs(out_path) + cmd = f"wget --load-cookies /tmp/cookies.txt \"https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id={link}' -O- | sed -En 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p' | tr -d \"n\")&id={link}\" -O {out_path}/{name}.tar.gz" subprocess.call(cmd, shell=True) - cmd2 = f"tar xzvf {name}.tar.gz" + cmd2 = f"tar xzvf {out_path}/{name}.tar.gz" subprocess.call(cmd2, shell=True) diff --git a/atom3d/datasets/scripts/combine_lmdb.py b/atom3d/datasets/scripts/combine_lmdb.py index 6cfe21b..c25ccda 100644 --- a/atom3d/datasets/scripts/combine_lmdb.py +++ b/atom3d/datasets/scripts/combine_lmdb.py @@ -16,17 +16,22 @@ @click.command() @click.argument('lmdb_list', nargs=-1) @click.argument('output_lmdb', type=click.Path(exists=False)) -def main(lmdb_list, output_lmdb): +@click.option('--append', '-a', is_flag=True) +def main(lmdb_list, output_lmdb, append): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) env = lmdb.open(str(output_lmdb), map_size=int(1e12)) + max_i = 0 + if append: + for key, value in env.cursor(): + max_i = max(max_i, key) with env.begin(write=True) as txn: id_to_idx = {} - i = 0 - for db_idx, db in enumerate(lmdb_list): + i = max_i + 1 + for db_idx, db in enumerate(lmdb_list[16:]): logger.info(f'on database {db_idx + 1} of {len(lmdb_list)}') dataset = LMDBDataset(db) @@ -55,4 +60,4 @@ def main(lmdb_list, output_lmdb): txn.put(b'id_to_idx', serialize(id_to_idx, serialization_format)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/atom3d/models/mlp.py b/atom3d/models/mlp.py index 42050b2..3246331 100644 --- a/atom3d/models/mlp.py +++ b/atom3d/models/mlp.py @@ -3,7 +3,7 @@ import torch.nn.functional as F class MLP(nn.Module): - """A basic feed-forward neural network (multi-layer perceptron), with tunable hidden layer number and dimension. + """A basic feed-forward neural network (MLP), with tunable hidden layer number and dimension. The number of layers is assumed to be equal to :math:`len(hidden\_dims) + 2`, including the input and output layers. Dropout can optionally be specified and is applied after every layer (except output).