Skip to content

Commit

Permalink
link FITS_rec instances to hdu extensions on save (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Jul 10, 2023
1 parent 065ec6e commit 9281ec1
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Other
TARGCAT and TARGDESC, which record the target category and description
as given by the user in the APT. [#179]

Bug Fixes
---------

- Link FITS_rec instances to created HDU on save to avoid data duplication. [#178]


1.7.0 (2023-06-29)
==================
Expand Down
33 changes: 23 additions & 10 deletions src/stdatamodels/fits_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def _fits_array_writer(fits_context, validator, _, instance, schema):
if instance is None:
return

instance_id = id(instance)

instance = np.asanyarray(instance)

if not len(instance.shape):
Expand All @@ -297,6 +299,10 @@ def _fits_array_writer(fits_context, validator, _, instance, schema):
index=index, hdu_type=hdu_type)

hdu.data = instance
if instance_id in fits_context.extension_array_links:
if fits_context.extension_array_links[instance_id]() is not hdu:
raise ValueError("Linking one array to multiple hdus is not supported")
fits_context.extension_array_links[instance_id] = weakref.ref(hdu)
hdu.ver = index + 1


Expand Down Expand Up @@ -331,6 +337,7 @@ def __init__(self, hdulist):
self.hdulist = weakref.ref(hdulist)
self.comment_stack = []
self.sequence_index = None
self.extension_array_links = {}


def _get_validators(hdulist):
Expand All @@ -350,7 +357,7 @@ def _get_validators(hdulist):
'type': partial(_fits_type, fits_context),
})

return validators
return validators, fits_context


def _save_from_schema(hdulist, tree, schema):
Expand All @@ -370,24 +377,30 @@ def datetime_callback(node, json_id):
else:
kwargs = {}

validator = asdf_schema.get_validator(
schema, None, _get_validators(hdulist), **kwargs)
validators, context = _get_validators(hdulist)
validator = asdf_schema.get_validator(schema, None, validators, **kwargs)

# This actually kicks off the saving
validator.validate(tree, _schema=schema)

# Replace arrays in the tree that are identical to HDU arrays
# with ndarray-1.0.0 tagged objects with special source values
# that represent links to the surrounding FITS file.
def ndarray_callback(node, json_id):
if (isinstance(node, (np.ndarray, NDArrayType))):
# Now link extensions to items in the tree

def callback(node, json_id):
if id(node) in context.extension_array_links:
hdu = context.extension_array_links[id(node)]()
return _create_tagged_dict_for_fits_array(hdu, hdulist.index(hdu))
elif isinstance(node, (np.ndarray, NDArrayType)):
# in addition to links generated during validation
# replace arrays in the tree that are identical to HDU arrays
# with ndarray-1.0.0 tagged objects with special source values
# that represent links to the surrounding FITS file.
# This is important for general ASDF-in-FITS support
for hdu_index, hdu in enumerate(hdulist):
if hdu.data is not None and node is hdu.data:
return _create_tagged_dict_for_fits_array(hdu, hdu_index)

return node

tree = treeutil.walk_and_modify(tree, ndarray_callback)
tree = treeutil.walk_and_modify(tree, callback)

return tree

Expand Down
48 changes: 48 additions & 0 deletions tests/test_fits.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from astropy.io import fits
import numpy as np
Expand Down Expand Up @@ -627,3 +629,49 @@ def test_resave_duplication_bug(tmp_path):

with fits.open(fn1) as ff1, fits.open(fn2) as ff2:
assert ff1['ASDF'].size == ff2['ASDF'].size


def test_table_linking(tmp_path):
file_path = tmp_path / "test.fits"

schema = {
'title': 'Test data model',
'$schema': 'http://stsci.edu/schemas/fits-schema/fits-schema',
'type': 'object',
'properties': {
'meta': {
'type': 'object',
'properties': {}
},
'test_table': {
'title': 'Test table',
'fits_hdu': 'TESTTABL',
'datatype': [
{'name': 'A_COL', 'datatype': 'int8'},
{'name': 'B_COL', 'datatype': 'int8'}
]
}
}
}

with DataModel(schema=schema) as dm:
test_array = np.array([(1, 2), (3, 4)], dtype=[('A_COL', 'i1'), ('B_COL', 'i1')])

# assigning to the model will convert the array to a FITS_rec
dm.test_table = test_array
assert isinstance(dm.test_table, fits.FITS_rec)

# save the model (with the table)
dm.save(file_path)

# open the model and confirm that the table was linked to an hdu
with fits.open(file_path) as ff:
# read the bytes for the embedded ASDF content
asdf_bytes = ff['ASDF'].data.tobytes()

# get only the bytes for the tree (not blocks) by splitting
# on the yaml end document marker '...'
# on the first block magic sequence
tree_string = asdf_bytes.split(b'...')[0].decode('ascii')
unlinked_arrays = re.findall(r'source:\s+[^f]', tree_string)
assert not len(unlinked_arrays), unlinked_arrays

0 comments on commit 9281ec1

Please sign in to comment.