Skip to content

Commit

Permalink
add paralell option
Browse files Browse the repository at this point in the history
  • Loading branch information
imedan committed Sep 6, 2024
1 parent 3695bfb commit 9ac1e19
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions bin/create_designid_status_replace_designs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np
import glob
import datetime
from tqdm import trange
from tqdm import trange, tqdm
from multiprocessing import Pool
from functools import partial

from astropy.io import fits
from astropy.table import Table
Expand All @@ -16,25 +18,32 @@
from sdssdb.peewee.sdss5db import targetdb


def get_designid_status(file, field_id):
def get_designid_status(file, field_id, Ncores):
"""
get the designid_status for a manual design
"""
from mugatu.fpsdesign import FPSDesign
from mugatu.designmode import find_designid_status
from mugatu.designs_to_targetdb import assignment_hash

def designid_status(design_file, obsTime, exp, field_id):
def create_des_object(design_file, obsTime, exp):
des = FPSDesign(design_pk=-1,
obsTime=obsTime,
design_file=design_file,
manual_design=True,
exp=exp)
des.build_design_manual()
return des

def designid_status(design_file, obsTime, exp, fexp, field_id, des_objs=None):
if des_objs is None:
des = create_des_object(design_file, obsTime, exp)
else:
des = des_objs[fexp]

assign_hash = assignment_hash(des.design['catalogID'][des.design['robotID'] != -1],
des.design['holeID'][des.design['robotID'] != -1])
designid_status = find_designid_status(field_id, exp, assign_hash=assign_hash)
designid_status = find_designid_status(field_id, fexp, assign_hash=assign_hash)
return designid_status

with fits.open(file) as hdu:
Expand All @@ -56,10 +65,20 @@ def designid_status(design_file, obsTime, exp, field_id):
status = np.zeros(n_exp, dtype='S20')
if n_exp == 1:
exp = 0
designid[exp], status[exp] = designid_status(file, obsTime, exp, field_id)
designid[exp], status[exp] = designid_status(file, obsTime, exp, exp, field_id)
else:
if Ncores > 1:
with Pool(processes=Ncores) as pool:
des_objs = tqdm(pool.imap(partial(create_des_object, design_file=file,
obsTime=obsTime),
range(1, n_exp + 1)),
total=len(n_exp))
des_objs = [r for r in des_objs]
else:
des_objs = None
for exp in trange(n_exp):
designid[exp], status[exp] = designid_status(file, obsTime, exp + 1, field_id)
designid[exp], status[exp] = designid_status(file, obsTime, exp + 1, exp,
field_id, des_objs=des_objs)
return designid, status


Expand All @@ -76,11 +95,15 @@ def designid_status(design_file, obsTime, exp, field_id):
parser.add_argument('-f', '--fieldids', dest='fieldids', nargs='+',
help='field_ids to replace)',
type=int, required=True)
parser.add_argument('-n', '--Ncores', dest='Ncores',
type=int, help='number of cores to use. If Ncores=1, then not run in parallal.',
default=1, nargs='?')

args = parser.parse_args()
loc = args.loc
plan = args.plan
fieldids = args.fieldids
Ncores = args.Ncores

if loc == 'local':
targetdb.database.connect_from_parameters(user='sdss',
Expand Down Expand Up @@ -122,7 +145,7 @@ def designid_status(design_file, obsTime, exp, field_id):
targetdb.Field.field_id >= 100000)
if len(same_field) > 0:
field_id = same_field[0].field_id
designid, status = get_designid_status(file, field_id)
designid, status = get_designid_status(file, field_id, Ncores)
else:
field_id = -1
designid = np.zeros(n_exp, dtype='>i4') - 1
Expand Down

0 comments on commit 9ac1e19

Please sign in to comment.