forked from JTT94/diffusion_schrodinger_bridge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
40 lines (26 loc) · 963 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os,sys
import argparse
parser = argparse.ArgumentParser(description='Download data.')
parser.add_argument('--data', type=str, help='mnist, celeba')
parser.add_argument('--data_dir', type=str, help='download location')
sys.path.append('..')
from bridge.data.stackedmnist import Stacked_MNIST
from bridge.data.emnist import EMNIST
from bridge.data.celeba import CelebA
# SETTING PARAMETERS
def main():
args = parser.parse_args()
if args.data == 'mnist':
root = os.path.join(args.data_dir, 'mnist')
Stacked_MNIST(root,
load=False,
source_root=root,
train=True,
num_channels = 1,
imageSize=28,
device='cpu')
if args.data == 'celeba':
root = os.path.join(args.data_dir, 'celeba')
CelebA(root, split='train', download=True)
if __name__ == '__main__':
main()