forked from wpb-astro/MCSED
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_mcsed_parallel.py
142 lines (120 loc) · 4.4 KB
/
run_mcsed_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
""" Script for running MCSED in parallel
.. moduleauthor:: Greg Zeimann <[email protected]>
"""
import numpy as np
import os
import sys
from astropy.table import Table, vstack
from multiprocessing import cpu_count, Manager, Process
from distutils.dir_util import mkpath
import run_mcsed_fit
run_mcsed_ind = run_mcsed_fit.main
parse_args = run_mcsed_fit.parse_args
def worker(f, i, chunk, ssp_info, out_q, err_q, kwargs):
''' Simple design to catch exceptions from the given call '''
try:
result = f(argv=chunk, ssp_info=ssp_info)
except Exception as e:
err_q.put(e)
return
# output the result and task ID to output queue
out_q.put((i, result))
def parallel_map(func, argv, args, ncpu, ssp_info, clean=True, **kwargs):
'''
Make multiple calls to run_mcsed_fit's main function for either
test or real data to parallelize the computing effort. Collect the info
at the end.
Inputs
------
func : callable function
This is meant to be run_mcsed_fit's main function
argv : list
Arguments (command line or otherwise) list.
python run_mcsed_fit.py -h
args : class
Built arguments from argv
ncpu : int
Number of parallelized cpus
ssp_info : list
SSP data for spectra, ages, metallicities, etc.
clean : bool
Remove temporary files
'''
if isinstance(ncpu, (int, np.integer)) and ncpu == 1:
return [func(0, argv, **kwargs)]
manager = Manager()
out_q = manager.Queue()
err_q = manager.Queue()
jobs = []
if args.test:
ncpu = min(args.nobjects, ncpu)
x = np.arange(args.nobjects)
v = [len(i) for i in np.array_split(x, ncpu)]
counts = np.full(len(v), 1)
counts[1:] += np.cumsum(v)[:-1]
chunks = [argv + ['--nobjects', '%i' % vi, '--count', '%i' % cnt, '--already_parallel']
for vi, cnt in zip(v, counts)]
else:
mkpath('temp')
ind = argv.index('-f')
data = Table.read(argv[ind+1], format='ascii')
# split up the input file into NCPU files and save as temporary files
ncpu = min(len(data), ncpu)
datachunks = np.array_split(data, ncpu)
for i, chunk in enumerate(datachunks):
T = Table(chunk)
T.write('temp/temp_%i.dat' % i, format='ascii', overwrite=True)
# create a separate argument list for each temporary file
chunks = [argv + ['-f', 'temp/temp_%i.dat' % i, '--already_parallel']
for i, chunk in enumerate(datachunks)]
for i, chunk in enumerate(chunks):
p = Process(target=worker, args=(func, i, chunk, ssp_info, out_q,
err_q, kwargs))
jobs.append(p)
p.start()
# gather the results
for proc in jobs:
proc.join()
if not err_q.empty():
# kill all on any exception from any one worker
raise err_q.get()
# Processes finish in arbitrary order. Process IDs double
# as index in the resultant array.
results = [None] * len(jobs)
while not out_q.empty():
idx, result = out_q.get()
results[idx] = result
# Remove the temporary (divided) input files
if (clean) & (not args.test):
for i, chunk in enumerate(chunks):
if os.path.exists('temp/temp_%i.dat' % i):
os.remove('temp/temp_%i.dat' % i)
try:
os.rmdir('temp/')
except OSError:
pass
return results
def main_parallel(argv=None):
# read command line arguments, if not already calling
# from within run_mcsed_fit.py
if argv == None:
argv = sys.argv
argv.remove('run_mcsed_parallel.py')
argv = argv + ['--parallel']
args = parse_args(argv=argv)
ssp_info = None
NCPU = cpu_count()
ncpu = np.max([1, NCPU - args.reserved_cores])
results = parallel_map(run_mcsed_ind, argv, args, ncpu, ssp_info)
table = vstack([result[0] for result in results])
if args.output_dict['parameters']:
table.write('output/%s' % args.output_filename,
format='ascii.fixed_width_two_line',
formats=results[0][1], overwrite=True)
if args.output_dict['settings']:
filename = open('output/%s.args' % args.output_filename, 'w')
del args.log
filename.write( str( vars(args) ) )
filename.close()
if __name__ == '__main__':
main_parallel()