Skip to content

Commit

Permalink
Merge pull request #96 from punch-mission/mhughes-nov13
Browse files Browse the repository at this point in the history
Nov 17 Mega update
  • Loading branch information
jmbhughes authored Nov 18, 2024
2 parents 7e80201 + 22e009a commit 1eee1b6
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 67 deletions.
46 changes: 46 additions & 0 deletions create_distortion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Creates disortion models."""
import os
from datetime import datetime

import numpy as np
from astropy.io import fits
from astropy.time import Time
from astropy.wcs import WCS, DistortionLookupTable

from simpunch.level1 import generate_spacecraft_wcs

CURRENT_DIR = os.path.dirname(__file__)

now = datetime(2024, 1, 1, 1, 1, 1)
now_str = now.strftime("%Y%m%d%H%M%S")

for spacecraft_id in ["1", "2", "3", "4"]:

filename_distortion = (
os.path.join(CURRENT_DIR, "simpunch/data/distortion_NFI.fits")
if spacecraft_id == "4"
else os.path.join(CURRENT_DIR, "simpunch/data/distortion_WFI.fits")
)

spacecraft_wcs = generate_spacecraft_wcs(spacecraft_id, 0, Time.now())

with fits.open(filename_distortion) as hdul:
err_x = hdul[1].data
err_y = hdul[2].data

crpix = err_x.shape[1] / 2 + 0.5, err_x.shape[0] / 2 + 0.5
crval = 1024.5, 1024.5
cdelt = (spacecraft_wcs.wcs.cdelt[1] * err_x.shape[1] / 2048,
spacecraft_wcs.wcs.cdelt[0] * err_x.shape[0] / 2048)

cpdis1 = DistortionLookupTable(
-err_x.astype(np.float32), crpix, crval, cdelt,
)
cpdis2 = DistortionLookupTable(
-err_y.astype(np.float32), crpix, crval, cdelt,
)

w = WCS(naxis=2)
w.cpdis1 = cpdis1
w.cpdis2 = cpdis2
w.to_fits().writeto(f"PUNCH_DD{spacecraft_id}_{now_str}_v1.fits", overwrite=True)
36 changes: 32 additions & 4 deletions create_quartic_fit_coeffs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa
from datetime import datetime

import numpy as np
from astropy.io.fits import CompImageHDU, HDUList, ImageHDU, PrimaryHDU
from astropy.wcs import WCS
from astropy.wcs.docstrings import naxis
Expand All @@ -9,17 +10,44 @@
from punchbowl.data.io import load_ndcube_from_fits
from punchbowl.level1.quartic_fit import create_constant_quartic_coefficients

# backward
wfi_vignetting_model_path = "./build_3_review_files/PUNCH_L1_GM1_20240817174727_v2.fits"
nfi_vignetting_model_path = "./build_3_review_files/PUNCH_L1_GM4_20240819045110_v1.fits"

wfi_vignette = load_ndcube_from_fits(wfi_vignetting_model_path).data[...] + 1E-8
nfi_vignette = load_ndcube_from_fits(nfi_vignetting_model_path).data[...] + 1E-8
wfi_vignette = load_ndcube_from_fits(wfi_vignetting_model_path).data[...]
nfi_vignette = load_ndcube_from_fits(nfi_vignetting_model_path).data[...]

wfi_quartic = create_constant_quartic_coefficients((2048, 2048))
nfi_quartic = create_constant_quartic_coefficients((2048, 2048))

wfi_quartic[-2, :, :] /= wfi_vignette
nfi_quartic[-2, :, :] /= nfi_vignette
wfi_quartic[-2, :, :] = wfi_vignette
nfi_quartic[-2, :, :] = nfi_vignette

meta = NormalizedMetadata.load_template("FQ1", "1")
meta['DATE-OBS'] = str(datetime.now())

wfi_cube = NDCube(data=wfi_quartic, meta=meta, wcs=WCS(naxis=3))
nfi_cube = NDCube(data=nfi_quartic, meta=meta, wcs=WCS(naxis=3))

write_ndcube_to_fits(wfi_cube, "wfi_quartic_backward_coeffs.fits")
write_ndcube_to_fits(nfi_cube, "nfi_quartic_backward_coeffs.fits")

# forward
wfi_vignetting_model_path = "./build_3_review_files/PUNCH_L1_GM1_20240817174727_v2.fits"
nfi_vignetting_model_path = "./build_3_review_files/PUNCH_L1_GM4_20240819045110_v1.fits"

wfi_vignette = load_ndcube_from_fits(wfi_vignetting_model_path).data[...]
nfi_vignette = load_ndcube_from_fits(nfi_vignetting_model_path).data[...]

wfi_quartic = create_constant_quartic_coefficients((2048, 2048))
nfi_quartic = create_constant_quartic_coefficients((2048, 2048))

wfi_quartic[-2, :, :] = 1/wfi_vignette
nfi_quartic[-2, :, :] = 1/nfi_vignette
wfi_quartic[np.isinf(wfi_quartic)] = 0
wfi_quartic[np.isnan(wfi_quartic)] = 0
nfi_quartic[np.isinf(nfi_quartic)] = 0
nfi_quartic[np.isnan(nfi_quartic)] = 0

meta = NormalizedMetadata.load_template("FQ1", "1")
meta['DATE-OBS'] = str(datetime.now())
Expand Down
3 changes: 2 additions & 1 deletion create_synthetic_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
from matplotlib import pyplot as plt
from regularizepsf import (ArrayPSFTransform, simple_functional_psf,
varied_functional_psf)
from regularizepsf.util import calculate_covering
Expand Down Expand Up @@ -53,7 +54,7 @@ def target_psf(row,
@varied_functional_psf(target_psf)
def synthetic_psf(row, col):
return {"tail_angle": -np.arctan2(row - img_size//2, col - img_size//2),
"tail_separation": np.sqrt((row - img_size//2) ** 2 + (col - img_size//2) ** 2)/500 * 2 + 1E-3,
"tail_separation": np.sqrt((row - img_size//2) ** 2 + (col - img_size//2) ** 2)/1200 * 2 + 1E-3,
"core_sigma_x": initial_sigma,
"core_sigma_y": initial_sigma}

Expand Down
15 changes: 9 additions & 6 deletions simpunch/cli.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
"""Command line interface."""

import click
import toml
from prefect import serve

from .flow import generate_flow
from simpunch.flow import generate_flow


@click.group()
#@click.group()
def main():
"""Simulate PUNCH data with simpunch."""

@main.command()
@click.argument("configuration_path", type=click.Path(exists=True))
#@main.command()
#@click.argument("configuration_path", type=click.Path(exists=True))
def generate(configuration_path):
"""Run a single instance of the pipeline."""
configuration = load_configuration(configuration_path)
generate_flow(**configuration)

@main.command()
#@main.command()
def automate():
"""Automate the data generation using Prefect."""
serve(generate_flow.to_deployment(name="simulator-deployment",
Expand All @@ -28,3 +27,7 @@ def automate():
def load_configuration(configuration_path: str) -> dict:
"""Load a configuration file."""
return toml.load(configuration_path)


if __name__ == "__main__":
generate("/home/marcus.hughes/build4/punch190_simpunch_config.toml")
45 changes: 32 additions & 13 deletions simpunch/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,79 @@
import shutil
from datetime import datetime

import numpy as np
from asyncpg.pgproto.pgproto import timedelta
from prefect import flow

from simpunch.level0 import generate_l0_all
from simpunch.level1 import generate_l1_all
from simpunch.level2 import generate_l2_all
from simpunch.level3 import generate_l3_all
from simpunch.level3 import generate_l3_all, generate_l3_all_fixed


@flow(log_prints=True)
def generate_flow(gamera_directory: str,
output_directory: str,
psf_model_path: str,
forward_psf_model_path: str,
backward_psf_model_path: str,
wfi_quartic_backward_model_path: str,
nfi_quartic_backward_model_path: str,
wfi_quartic_model_path: str,
nfi_quartic_model_path: str,
num_repeats: int = 1,
start_time: datetime | None = None,
transient_probability: float = 0.03,
shift_pointing: bool = False,
generate_new: bool = True,
update_database: bool = True) -> None:
"""Generate all the products in the reverse pipeline."""
if start_time is None:
start_time = datetime.now() - timedelta(days=3) # noqa: DTZ005
time_str = start_time.strftime("%Y%m%d%H%M%S")
start_time = datetime.now() # noqa: DTZ005
start_time = datetime(2024, 11, 13, 15, 20, 23)

if generate_new:
time_delta = timedelta(days=0.25)
files_tb = sorted(glob.glob(gamera_directory + "/synthetic_cme/*_TB.fits"))
files_pb = sorted(glob.glob(gamera_directory + "/synthetic_cme/*_PB.fits"))

previous_month = np.linspace(1, -30, int(timedelta(days=30)/time_delta)) * time_delta + start_time
generate_l3_all_fixed(gamera_directory, previous_month, files_pb[0], files_tb[0])

next_month = np.linspace(1, 30, int(timedelta(days=30)/time_delta)) * time_delta + start_time
generate_l3_all_fixed(gamera_directory, next_month, files_pb[-1], files_tb[-1])

generate_l3_all(gamera_directory, start_time, num_repeats=num_repeats)
generate_l2_all(gamera_directory)
generate_l1_all(gamera_directory)
generate_l0_all(gamera_directory,
psf_model_path,
wfi_quartic_model_path,
nfi_quartic_model_path,
backward_psf_model_path,
wfi_quartic_backward_model_path,
nfi_quartic_backward_model_path,
shift_pointing=shift_pointing,
transient_probability=transient_probability)

model_time = start_time - timedelta(days=35)
model_time_str = model_time.strftime("%Y%m%d%H%M%S")

# duplicate the psf model to all required versions
for type_code in ["RM", "RZ", "RP", "RC"]:
for obs_code in ["1", "2", "3", "4"]:
new_name = f"PUNCH_L1_{type_code}{obs_code}_{time_str}_v1.fits"
shutil.copy(psf_model_path, os.path.join(gamera_directory, f"synthetic_l0/{new_name}"))
new_name = f"PUNCH_L1_{type_code}{obs_code}_{model_time_str}_v1.fits"
shutil.copy(forward_psf_model_path, os.path.join(gamera_directory, f"synthetic_l0/{new_name}"))

# duplicate the quartic model
type_code = "FQ"
for obs_code in ["1", "2", "3"]:
new_name = f"PUNCH_L1_{type_code}{obs_code}_{time_str}_v1.fits"
new_name = f"PUNCH_L1_{type_code}{obs_code}_{model_time_str}_v1.fits"
shutil.copy(wfi_quartic_model_path, os.path.join(gamera_directory, f"synthetic_l0/{new_name}"))
obs_code = "4"
new_name = f"PUNCH_L1_{type_code}{obs_code}_{time_str}_v1.fits"
new_name = f"PUNCH_L1_{type_code}{obs_code}_{model_time_str}_v1.fits"
shutil.copy(nfi_quartic_model_path, os.path.join(gamera_directory, f"synthetic_l0/{new_name}"))

if update_database:
from punchpipe import __version__
from punchpipe.controlsegment.db import File
from punchpipe.controlsegment.util import get_database_session
from punchpipe.control.db import File
from punchpipe.control.util import get_database_session
db_session = get_database_session()
for file_path in sorted(glob.glob(os.path.join(gamera_directory, "synthetic_l0/*v[0-9].fits")),
key=lambda s: os.path.basename(s)[13:27]):
Expand Down
66 changes: 39 additions & 27 deletions simpunch/level0.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from punchbowl.level1.initial_uncertainty import compute_noise
from punchbowl.level1.sqrt import encode_sqrt
from regularizepsf import ArrayPSFTransform
from tqdm import tqdm

from simpunch.spike import generate_spike_image
from simpunch.util import update_spacecraft_location, write_array_to_fits
Expand All @@ -26,8 +25,8 @@
def perform_photometric_uncalibration(input_data: NDCube, coefficient_array: np.ndarray) -> NDCube:
"""Undo quartic fit calibration."""
num_coefficients = coefficient_array.shape[0]
new_data = np.sum(
[coefficient_array[i, ...] / np.power(input_data.data, num_coefficients - i - 1)
new_data = np.nansum(
[coefficient_array[i, ...] * np.power(input_data.data, num_coefficients - i - 1)
for i in range(num_coefficients)], axis=0)
input_data.data[...] = new_data[...]
return input_data
Expand Down Expand Up @@ -111,12 +110,16 @@ def starfield_misalignment(input_data: NDCube,
@task
def generate_l0_pmzp(input_file: NDCube,
path_output: str,
psf_model: ArrayPSFTransform,
wfi_quartic_coefficients: np.ndarray,
nfi_quartic_coefficients: np.ndarray,
transient_probability: float=0.03) -> None:
psf_model_path: str, # ArrayPSFTransform,
wfi_quartic_coeffs_path: str, # np.ndarray,
nfi_quartic_coeffs_path: str, # np.ndarray,
transient_probability: float=0.03,
shift_pointing: bool=False) -> None:
"""Generate level 0 polarized synthetic data."""
input_data = load_ndcube_from_fits(input_file)
psf_model = ArrayPSFTransform.load(Path(psf_model_path))
wfi_quartic_coefficients = load_ndcube_from_fits(wfi_quartic_coeffs_path).data
nfi_quartic_coefficients = load_ndcube_from_fits(nfi_quartic_coeffs_path).data

# Define the output data product
product_code = input_data.meta["TYPECODE"].value + input_data.meta["OBSCODE"].value
Expand All @@ -134,7 +137,11 @@ def generate_l0_pmzp(input_file: NDCube,
output_meta[key] = input_data.meta[key].value

input_data = NDCube(data=input_data.data, meta=output_meta, wcs=input_data.wcs)
output_data, original_wcs = starfield_misalignment(input_data)
if shift_pointing:
output_data, original_wcs = starfield_misalignment(input_data)
else:
output_data = input_data
original_wcs = input_data.wcs.copy()
output_data, transient = add_transients(output_data, transient_probability=transient_probability)
output_data = uncorrect_psf(output_data, psf_model)

Expand Down Expand Up @@ -186,12 +193,16 @@ def generate_l0_pmzp(input_file: NDCube,

@task
def generate_l0_cr(input_file: NDCube, path_output: str,
psf_model: ArrayPSFTransform,
wfi_quartic_coefficients: np.ndarray,
nfi_quartic_coefficients: np.ndarray,
transient_probability: float = 0.03) -> None:
psf_model_path: str, # ArrayPSFTransform,
wfi_quartic_coeffs_path: str, # np.ndarray,
nfi_quartic_coeffs_path: str, # np.ndarray,
transient_probability: float = 0.03,
shift_pointing: bool=False) -> None:
"""Generate level 0 clear synthetic data."""
input_data = load_ndcube_from_fits(input_file)
psf_model = ArrayPSFTransform.load(Path(psf_model_path))
wfi_quartic_coefficients = load_ndcube_from_fits(wfi_quartic_coeffs_path).data
nfi_quartic_coefficients = load_ndcube_from_fits(nfi_quartic_coeffs_path).data

# Define the output data product
product_code = input_data.meta["TYPECODE"].value + input_data.meta["OBSCODE"].value
Expand All @@ -209,7 +220,11 @@ def generate_l0_cr(input_file: NDCube, path_output: str,
output_meta[key] = input_data.meta[key].value

input_data = NDCube(data=input_data.data, meta=output_meta, wcs=input_data.wcs)
output_data, original_wcs = starfield_misalignment(input_data)
if shift_pointing:
output_data, original_wcs = starfield_misalignment(input_data)
else:
output_data = input_data
original_wcs = input_data.wcs.copy()
output_data, transient = add_transients(output_data, transient_probability=transient_probability)
output_data = uncorrect_psf(output_data, psf_model)
output_data = add_stray_light(output_data)
Expand Down Expand Up @@ -258,11 +273,12 @@ def generate_l0_cr(input_file: NDCube, path_output: str,
original_wcs.to_header().tofile(path_output + get_base_file_name(output_data) + "_original_wcs.txt")

@flow(log_prints=True,
task_runner=DaskTaskRunner(cluster_kwargs={"n_workers": 4, "threads_per_worker": 2},
task_runner=DaskTaskRunner(cluster_kwargs={"n_workers": 32, "threads_per_worker": 2},
))
def generate_l0_all(datadir: str, psf_model_path: str,
wfi_quartic_coeffs_path: str, nfi_quartic_coeffs_path: str,
transient_probability: float = 0.03) -> None:
transient_probability: float = 0.03,
shift_pointing: bool = False) -> None:
"""Generate all level 0 synthetic data."""
print(f"Running from {datadir}")
outdir = os.path.join(datadir, "synthetic_l0/")
Expand All @@ -272,19 +288,15 @@ def generate_l0_all(datadir: str, psf_model_path: str,
# Parse list of level 1 model data
files_l1 = glob.glob(datadir + "/synthetic_l1/*L1_P*_v1.fits")
files_cr = glob.glob(datadir + "/synthetic_l1/*CR*_v1.fits")
print(f"Generating based on {len(files_l1)} files.")
print(f"Generating based on {len(files_l1)+len(files_cr)} files.")
files_l1.sort()

psf_model = ArrayPSFTransform.load(Path(psf_model_path))
wfi_quartic_coeffs = load_ndcube_from_fits(wfi_quartic_coeffs_path).data
nfi_quartic_coeffs = load_ndcube_from_fits(nfi_quartic_coeffs_path).data
files_cr.sort()

futures = []
for file_l1 in tqdm(files_l1, total=len(files_l1)):
futures.append(generate_l0_pmzp.submit(file_l1, outdir, psf_model, # noqa: PERF401
wfi_quartic_coeffs, nfi_quartic_coeffs, transient_probability))

for file_cr in tqdm(files_cr, total=len(files_cr)):
futures.append(generate_l0_cr.submit(file_cr, outdir, psf_model, # noqa: PERF401
wfi_quartic_coeffs, nfi_quartic_coeffs, transient_probability))
futures.extend(generate_l0_pmzp.map(files_l1, outdir, psf_model_path,
wfi_quartic_coeffs_path, nfi_quartic_coeffs_path,
transient_probability, shift_pointing))
futures.extend(generate_l0_cr.map(files_cr, outdir, psf_model_path,
wfi_quartic_coeffs_path, nfi_quartic_coeffs_path,
transient_probability, shift_pointing))
wait(futures)
Loading

0 comments on commit 1eee1b6

Please sign in to comment.