diff --git a/bin/create_designid_status_replace_designs.py b/bin/create_designid_status_replace_designs.py index 401301a..61e01a1 100644 --- a/bin/create_designid_status_replace_designs.py +++ b/bin/create_designid_status_replace_designs.py @@ -4,7 +4,9 @@ import numpy as np import glob import datetime -from tqdm import trange +from tqdm import trange, tqdm +from multiprocessing import Pool +from functools import partial from astropy.io import fits from astropy.table import Table @@ -16,7 +18,7 @@ from sdssdb.peewee.sdss5db import targetdb -def get_designid_status(file, field_id): +def get_designid_status(file, field_id, Ncores): """ get the designid_status for a manual design """ @@ -24,17 +26,24 @@ def get_designid_status(file, field_id): from mugatu.designmode import find_designid_status from mugatu.designs_to_targetdb import assignment_hash - def designid_status(design_file, obsTime, exp, field_id): + def create_des_object(design_file, obsTime, exp): des = FPSDesign(design_pk=-1, obsTime=obsTime, design_file=design_file, manual_design=True, exp=exp) des.build_design_manual() + return des + + def designid_status(design_file, obsTime, exp, fexp, field_id, des_objs=None): + if des_objs is None: + des = create_des_object(design_file, obsTime, exp) + else: + des = des_objs[fexp] assign_hash = assignment_hash(des.design['catalogID'][des.design['robotID'] != -1], des.design['holeID'][des.design['robotID'] != -1]) - designid_status = find_designid_status(field_id, exp, assign_hash=assign_hash) + designid_status = find_designid_status(field_id, fexp, assign_hash=assign_hash) return designid_status with fits.open(file) as hdu: @@ -56,10 +65,20 @@ def designid_status(design_file, obsTime, exp, field_id): status = np.zeros(n_exp, dtype='S20') if n_exp == 1: exp = 0 - designid[exp], status[exp] = designid_status(file, obsTime, exp, field_id) + designid[exp], status[exp] = designid_status(file, obsTime, exp, exp, field_id) else: + if Ncores > 1: + with Pool(processes=Ncores) as pool: + des_objs = tqdm(pool.imap(partial(create_des_object, design_file=file, + obsTime=obsTime), + range(1, n_exp + 1)), + total=len(n_exp)) + des_objs = [r for r in des_objs] + else: + des_objs = None for exp in trange(n_exp): - designid[exp], status[exp] = designid_status(file, obsTime, exp + 1, field_id) + designid[exp], status[exp] = designid_status(file, obsTime, exp + 1, exp, + field_id, des_objs=des_objs) return designid, status @@ -76,11 +95,15 @@ def designid_status(design_file, obsTime, exp, field_id): parser.add_argument('-f', '--fieldids', dest='fieldids', nargs='+', help='field_ids to replace)', type=int, required=True) + parser.add_argument('-n', '--Ncores', dest='Ncores', + type=int, help='number of cores to use. If Ncores=1, then not run in parallal.', + default=1, nargs='?') args = parser.parse_args() loc = args.loc plan = args.plan fieldids = args.fieldids + Ncores = args.Ncores if loc == 'local': targetdb.database.connect_from_parameters(user='sdss', @@ -122,7 +145,7 @@ def designid_status(design_file, obsTime, exp, field_id): targetdb.Field.field_id >= 100000) if len(same_field) > 0: field_id = same_field[0].field_id - designid, status = get_designid_status(file, field_id) + designid, status = get_designid_status(file, field_id, Ncores) else: field_id = -1 designid = np.zeros(n_exp, dtype='>i4') - 1