From c89afe299ea80b52b7b8bd8747b22280d198ba6a Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 17 Oct 2024 12:58:04 -0400 Subject: [PATCH 1/5] Copy over skymatch from romancal and combine with jwst version --- src/stcal/skymatch/__init__.py | 0 src/stcal/skymatch/region.py | 425 +++++++++++ src/stcal/skymatch/skyimage.py | 1020 +++++++++++++++++++++++++++ src/stcal/skymatch/skymatch.py | 539 ++++++++++++++ src/stcal/skymatch/skystatistics.py | 136 ++++ 5 files changed, 2120 insertions(+) create mode 100644 src/stcal/skymatch/__init__.py create mode 100644 src/stcal/skymatch/region.py create mode 100644 src/stcal/skymatch/skyimage.py create mode 100644 src/stcal/skymatch/skymatch.py create mode 100644 src/stcal/skymatch/skystatistics.py diff --git a/src/stcal/skymatch/__init__.py b/src/stcal/skymatch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/stcal/skymatch/region.py b/src/stcal/skymatch/region.py new file mode 100644 index 000000000..252e8454d --- /dev/null +++ b/src/stcal/skymatch/region.py @@ -0,0 +1,425 @@ +""" +Polygon filling algorithm. + +NOTE: Algorithm description can be found, e.g., here: + http://www.cs.rit.edu/~icss571/filling/how_to.html + http://www.cs.uic.edu/~jbell/CourseNotes/ComputerGraphics/PolygonFilling.html +""" + +from collections import OrderedDict + +import numpy as np + +__all__ = ["Region", "Edge", "Polygon"] + + +class ValidationError(Exception): + def __init__(self, message): + self._message = message + + def __str__(self): + return self._message + + +class Region: + """ + Base class for regions. + + Parameters + ------------- + rid : int or string + region ID + coordinate_system : astropy.wcs.CoordinateSystem instance or a string + in the context of WCS this would be an instance of wcs.CoordinateSysem + """ + + def __init__(self, rid, coordinate_system): + self._coordinate_system = coordinate_system + self._rid = rid + + def __contains__(self, xy: tuple[float]) -> bool: + """ + Determines if a pixel is within a region. + + Parameters + ---------- + xy : tuple[float] + x , y values of a pixel + + Returns + ------- + True or False + + Subclasses must define this method. + """ + raise NotImplementedError("__contains__") + + def scan(self, mask): + """ + Sets mask values to region id for all pixels within the region. + Subclasses must define this method. + + Parameters + ---------- + mask : ndarray + a byte array with the shape of the observation to be used as a mask + + Returns + ------- + mask : array where the value of the elements is the region ID or 0 (for + pixels which are not included in any region). + """ + raise NotImplementedError("scan") + + +class Polygon(Region): + """ + Represents a 2D polygon region with multiple vertices + + Parameters + ---------- + rid : string + polygon id + vertices : list of (x,y) tuples or lists + The list is ordered in such a way that when traversed in a + counterclockwise direction, the enclosed area is the polygon. + The last vertex must coincide with the first vertex, minimum + 4 vertices are needed to define a triangle + coord_system : string + coordinate system + + """ + + def __init__(self, rid, vertices, coord_system="Cartesian"): + assert len(vertices) >= 4, ( + "Expected vertices to be " "a list of minimum 4 tuples (x,y)" + ) + super().__init__(rid, coord_system) + + # self._shiftx & self._shifty are introduced to shift the bottom-left + # corner of the polygon's bounding box to (0,0) as a (hopefully + # temporary) workaround to a limitation of the original code that the + # polygon must be completely contained in the image. It seems that the + # code works fine if we make sure that the bottom-left corner of the + # polygon's bounding box has non-negative coordinates. + self._shiftx = 0 + self._shifty = 0 + for vertex in vertices: + x, y = vertex + if x < self._shiftx: + self._shiftx = x + if y < self._shifty: + self._shifty = y + v = [(i - self._shiftx, j - self._shifty) for i, j in vertices] + + # convert to integer coordinates: + self._vertices = np.asarray(list(map(_round_vertex, v))) + self._shiftx = int(round(self._shiftx)) + self._shifty = int(round(self._shifty)) + + self._bbox = self._get_bounding_box() + self._scan_line_range = list( + range(self._bbox[1], self._bbox[3] + self._bbox[1] + 1) + ) + # constructs a Global Edge Table (GET) in bbox coordinates + self._GET = self._construct_ordered_GET() + + def _get_bounding_box(self): + x = self._vertices[:, 0].min() + y = self._vertices[:, 1].min() + w = self._vertices[:, 0].max() - x + h = self._vertices[:, 1].max() - y + return x, y, w, h + + def _construct_ordered_GET(self): + """ + Construct a Global Edge Table (GET) + + The GET is an OrderedDict. Keys are scan line numbers, + ordered from bbox.ymin to bbox.ymax, where bbox is the + bounding box of the polygon. + Values are lists of edges for which edge.ymin==scan_line_number. + + Returns + ------- + GET: OrderedDict + {scan_line: [edge1, edge2]} + """ + # edges is a list of Edge objects which define a polygon + # with these vertices + edges = self.get_edges() + GET = OrderedDict.fromkeys(self._scan_line_range) + ymin = np.asarray([e._ymin for e in edges]) + for i in self._scan_line_range: + ymin_ind = (ymin == i).nonzero()[0] + # a hack for incomplete filling .any() fails if 0 is in ymin_ind + # if ymin_ind.any(): + (yminindlen,) = ymin_ind.shape + if yminindlen: + GET[i] = [edges[ymin_ind[0]]] + for j in ymin_ind[1:]: + GET[i].append(edges[j]) + return GET + + def get_edges(self): + """ + Create a list of Edge objects from vertices + """ + edges = [] + for i in range(1, len(self._vertices)): + name = "E" + str(i - 1) + edges.append( + Edge(name=name, start=self._vertices[i - 1], stop=self._vertices[i]) + ) + return edges + + def scan(self, data): + """ + This is the main function which scans the polygon and creates the mask + + Parameters + ---------- + data : array + the mask array + it has all zeros initially, elements within a region are set to + the region's ID + + Algorithm: + - Set the Global Edge Table (GET) + - Set y to be the smallest y coordinate that has an entry in GET + - Initialize the Active Edge Table (AET) to be empty + - For each scan line: + 1. Add edges from GET to AET for which ymin==y + 2. Remove edges from AET fro which ymax==y + 3. Compute the intersection of the current scan line with all edges in the AET + 4. Sort on X of intersection point + 5. Set elements between pairs of X in the AET to the Edge's ID + + """ + # TODO: 1.This algorithm does not mark pixels in the top + # row and left most column. Pad the initial pixel description on + # top and left with 1 px to prevent this. 2. Currently it uses + # intersection of the scan line with edges. If this is too slow it + # should use the 1/m increment (replace 3 above) (or the increment + # should be removed from the GET entry). + + # see comments in the __init__ function for the reason of introducing + # polygon shifts (self._shiftx & self._shifty). Here we need to shift + # it back. + + (ny, nx) = data.shape + + y = np.min(list(self._GET.keys())) + + AET = [] + scline = self._scan_line_range[-1] + + while y <= scline: + if y < scline: + AET = self.update_AET(y, AET) + + if self._bbox[2] <= 0: + y += 1 + continue + + scan_line = Edge( + "scan_line", + start=[self._bbox[0], y], + stop=[self._bbox[0] + self._bbox[2], y], + ) + x = [ + int(np.ceil(e.compute_AET_entry(scan_line)[1])) + for e in AET + if e is not None + ] + xnew = np.sort(x) + ysh = y + self._shifty + + if ysh < 0 or ysh >= ny: + y += 1 + continue + + for i, j in zip(xnew[::2], xnew[1::2]): + xstart = max(0, i + self._shiftx) + xend = min(j + self._shiftx, nx - 1) + data[ysh][xstart : xend + 1] = self._rid + + y += 1 + + return data + + def update_AET(self, y, AET): + """ + Update the Active Edge Table (AET) + + Add edges from GET to AET for which ymin of the edge is + equal to the y of the scan line. + Remove edges from AET for which ymax of the edge is + equal to y of the scan line. + + """ + edge_cont = self._GET[y] + if edge_cont is not None: + for edge in edge_cont: + if edge._start[1] != edge._stop[1] and edge._ymin == y: + AET.append(edge) + for edge in AET[::-1]: + if edge is not None: + if edge._ymax == y: + AET.remove(edge) + return AET + + def __contains__(self, px): + """even-odd algorithm or something else better should be used""" + # minx = self._vertices[:,0].min() + # maxx = self._vertices[:,0].max() + # miny = self._vertices[:,1].min() + # maxy = self._vertices[:,1].max() + return ( + px[0] >= self._bbox[0] + and px[0] <= self._bbox[0] + self._bbox[2] + and px[1] >= self._bbox[1] + and px[1] <= self._bbox[1] + self._bbox[3] + ) + + +class Edge: + """ + Edge representation + + An edge has "start" and "stop" (x,y) vertices and an entry in the + GET table of a polygon. The GET entry is a list of these values: + + [ymax, x_at_ymin, delta_x/delta_y] + + """ + + def __init__(self, name=None, start=None, stop=None, next=None): # noqa: A002 + self._start = None + if start is not None: + self._start = np.asarray(start) + self._name = name + self._stop = stop + if stop is not None: + self._stop = np.asarray(stop) + self._next = next + + if self._stop is not None and self._start is not None: + if self._start[1] < self._stop[1]: + self._ymin = self._start[1] + self._yminx = self._start[0] + else: + self._ymin = self._stop[1] + self._yminx = self._stop[0] + self._ymax = max(self._start[1], self._stop[1]) + self._xmin = min(self._start[0], self._stop[0]) + self._xmax = max(self._start[0], self._stop[1]) + else: + self._ymin = None + self._yminx = None + self._ymax = None + self._xmin = None + self._xmax = None + self.GET_entry = self.compute_GET_entry() + + @property + def ymin(self): + return self._ymin + + @property + def start(self): + return self._start + + @property + def stop(self): + return self._stop + + @property + def ymax(self): + return self._ymax + + def compute_GET_entry(self): + """ + Compute the entry in the Global Edge Table + + [ymax, x@ymin, 1/m] + + """ + if self._start is None: + entry = None + else: + earr = np.asarray([self._start, self._stop]) + if np.diff(earr[:, 1]).item() == 0: + return None + else: + entry = [ + self._ymax, + self._yminx, + (np.diff(earr[:, 0]) / np.diff(earr[:, 1])).item(), + None, + ] + return entry + + def compute_AET_entry(self, edge): + """ + Compute the entry for an edge in the current Active Edge Table + + [ymax, x_intersect, 1/m] + note: currently 1/m is not used + """ + x = self.intersection(edge)[0] + return [self._ymax, x, self.GET_entry[2]] + + def __repr__(self): + fmt = "" + if self._name is not None: + fmt += self._name + next_edge = self.next + while next_edge is not None: + fmt += "-->" + fmt += next_edge._name + next_edge = next_edge.next + return fmt + + @property + def next(self): # noqa: A003 + return self._next + + @next.setter + def next(self, edge): # noqa: A003 + if self._name is None: + self._name = edge._name + self._stop = edge._stop + self._start = edge._start + self._next = edge.next + else: + self._next = edge + + def intersection(self, edge): + u = self._stop - self._start + v = edge._stop - edge._start + w = self._start - edge._start + + # Find the determinant of the matrix formed by the vectors u and v + # Note: Originally this was computed using a numpy "2D" cross product, + # however, this functionality has been deprecated and slated for + # removal. + D = np.linalg.det([u, v]) + + if np.allclose(D, 0, rtol=0, atol=1e2 * np.finfo(float).eps): + return np.array(self._start) + + # See note above + return np.linalg.det([v, w]) / D * u + self._start + + def is_parallel(self, edge): + u = self._stop - self._start + v = edge._stop - edge._start + return np.allclose( + np.linalg.det([u, v]), 0, rtol=0, atol=1e2 * np.finfo(float).eps + ) + + +def _round_vertex(v): + x, y = v + return int(round(x)), int(round(y)) diff --git a/src/stcal/skymatch/skyimage.py b/src/stcal/skymatch/skyimage.py new file mode 100644 index 000000000..35df56ae3 --- /dev/null +++ b/src/stcal/skymatch/skyimage.py @@ -0,0 +1,1020 @@ +""" +The ``skyimage`` module contains algorithms that are used by +``skymatch`` to manage all of the information for footprints (image outlines) +on the sky as well as perform useful operations on these outlines such as +computing intersections and statistics in the overlap regions. +""" + +import abc +import tempfile + +import numpy as np +from spherical_geometry.polygon import SphericalPolygon + +from . import region +from .skystatistics import SkyStats + +__all__ = [ + "SkyImage", + "SkyGroup", + "DataAccessor", + "NDArrayInMemoryAccessor", + "NDArrayMappedAccessor", +] + + +class DataAccessor(abc.ABC): + """Base class for all data accessors. Provides a common interface to + access data. + """ + + @abc.abstractmethod + def get_data(self): + pass + + @abc.abstractmethod + def set_data(self, data): + """Sets data. + + Parameters + ---------- + data : numpy.ndarray + Data array to be set. + + """ + pass + + @abc.abstractmethod + def get_data_shape(self): + pass + + +class NDArrayInMemoryAccessor(DataAccessor): + """Acessor for in-memory `numpy.ndarray` data.""" + + def __init__(self, data): + super().__init__() + self._data = data + + def get_data(self): + return self._data + + def set_data(self, data): + self._data = data + + def get_data_shape(self): + return np.shape(self._data) + + +class NDArrayMappedAccessor(DataAccessor): + """Data accessor for arrays stored in temporary files.""" + + def __init__( + self, data, tmpfile=None, prefix="tmp_skymatch_", suffix=".npy", tmpdir="" + ): + super().__init__() + if tmpfile is None: + self._close = True + self._tmp = tempfile.NamedTemporaryFile( + prefix=prefix, suffix=suffix, dir=tmpdir + ) + if not self._tmp: + raise RuntimeError("Unable to create temporary file.") + else: + # temp file managed by the caller + self._close = False + self._tmp = tmpfile + + self.set_data(data) + + def get_data(self): + self._tmp.seek(0) + return np.load(self._tmp) + + def set_data(self, data): + data = np.asanyarray(data) + self._data_shape = data.shape + self._tmp.seek(0) + np.save(self._tmp, data) + + def __del__(self): + if self._close: + self._tmp.close() + + def get_data_shape(self): + return self._data_shape + + +class SkyImage: + """ + Container that holds information about properties of a *single* + image such as: + + * image data; + * WCS of the chip image; + * bounding spherical polygon; + * id; + * pixel area; + * sky background value; + * sky statistics parameters; + * mask associated image data indicating "good" (1) data. + + """ + + def __init__( + self, + image, + wcs_fwd, + wcs_inv, + pix_area=1.0, + convf=1.0, + mask=None, + id=None, # noqa: A002 + skystat=None, + stepsize=None, + meta=None, + reduce_memory_usage=True, + ): + """Initializes the SkyImage object. + + Parameters + ---------- + image : numpy.ndarray, NDArrayDataAccessor + A 2D array of image data or a `NDArrayDataAccessor`. + + wcs_fwd : function + "forward" pixel-to-world transformation function. + + wcs_inv : function + "inverse" world-to-pixel transformation function. + + pix_area : float, optional + Average pixel's sky area. + + convf : float, optional + Conversion factor that when multiplied to `image` data converts + the data to "uniform" (across multiple images) surface + brightness units. + + .. note:: + + The functionality to support this conversion is not yet + implemented and at this moment `convf` is ignored. + + mask : numpy.ndarray, NDArrayDataAccessor + A 2D array or `NDArrayDataAccessor` of a 2D array that indicates + which pixels in the input `image` should be used for sky + computations (``1``) and which pixels should **not** be used + for sky computations (``0``). + + id : anything + The value of this parameter is simply stored within the `SkyImage` + object. While it can be of any type, it is preferable that `id` be + of a type with nice string representation. + + skystat : callable, None, optional + A callable object that takes a either a 2D image (2D + `numpy.ndarray`) or a list of pixel values (an Nx1 array) and + returns a tuple of two values: some statistics (e.g., mean, + median, etc.) and number of pixels/values from the input image + used in computing that statistics. + + When `skystat` is not set, `SkyImage` will use + :py:class:`~stcal.skymatch.skystatistics.SkyStats` object + to perform sky statistics on image data. + + stepsize : int, None, optional + Spacing between vertices of the image's bounding polygon. Default + value of `None` creates bounding polygons with four vertices + corresponding to the corners of the image. + + meta : dict, None, optional + A dictionary of various items to be stored within the `SkyImage` + object. + + reduce_memory_usage : bool, optional + Indicates whether to attempt to minimize memory usage by attaching + input ``image`` and/or ``mask`` `numpy.ndarray` arrays to + file-mapped accessor. This has no effect when input parameters + ``image`` and/or ``mask`` are already of `NDArrayDataAccessor` + objects. + + """ + self._image = None + self._mask = None + self._image_shape = None + self._mask_shape = None + self._reduce_memory_usage = reduce_memory_usage + + self.image = image + + self.convf = convf + self.meta = meta + self._id = id + self._pix_area = pix_area + + # WCS + self.wcs_fwd = wcs_fwd + self.wcs_inv = wcs_inv + + # initial sky value: + self._sky = 0.0 + self._sky_is_valid = False + + self.mask = mask + + # create spherical polygon bounding the image + if image is None or wcs_fwd is None or wcs_inv is None: + self._radec = [(np.array([]), np.array([]))] + self._polygon = SphericalPolygon([]) + self._poly_area = 0.0 + + else: + self.calc_bounding_polygon(stepsize) + + # set sky statistics function (NOTE: it must return statistics and + # the number of pixels used after clipping) + if skystat is None: + self.set_builtin_skystat() + else: + self.skystat = skystat + + @property + def mask(self): + """Set or get `SkyImage`'s ``mask`` data array or `None`.""" + if self._mask is None: + return None + else: + return self._mask.get_data() + + @mask.setter + def mask(self, mask): + if mask is None: + self._mask = None + self._mask_shape = None + + elif isinstance(mask, DataAccessor): + if self._image is None: + raise ValueError("'mask' must be None when 'image' is None") + + self._mask = mask + self._mask_shape = mask.get_data_shape() + + # check that mask has the same shape as image: + if self._mask_shape != self.image_shape: + raise ValueError("'mask' must have the same shape as 'image'.") + + else: + if self._image is None: + raise ValueError("'mask' must be None when 'image' is None") + + mask = np.asanyarray(mask, dtype=bool) + self._mask_shape = mask.shape + + # check that mask has the same shape as image: + if self._mask_shape != self.image_shape: + raise ValueError("'mask' must have the same shape as 'image'.") + + if self._mask is None: + if self._reduce_memory_usage: + self._mask = NDArrayMappedAccessor( + mask, prefix="tmp_skymatch_mask_" + ) + else: + self._mask = NDArrayInMemoryAccessor(mask) + else: + self._mask.set_data(mask) + + @property + def image(self): + """Set or get `SkyImage`'s ``image`` data array.""" + if self._image is None: + return None + else: + return self._image.get_data() + + @image.setter + def image(self, image): + if image is None: + self._image = None + self._image_shape = None + self.mask = None + + if isinstance(image, DataAccessor): + self._image = image + self._image_shape = image.get_data_shape() + + else: + image = np.asanyarray(image) + self._image_shape = image.shape + if self._image is None: + if self._reduce_memory_usage: + self._image = NDArrayMappedAccessor( + image, prefix="tmp_skymatch_image_" + ) + else: + self._image = NDArrayInMemoryAccessor(image) + else: + self._image.set_data(image) + + @property + def image_shape(self): + """Get `SkyImage`'s ``image`` data shape.""" + if self._image_shape is None and self._image is not None: + self._image_shape = self._image.get_data_shape() + return self._image_shape + + @property + def id(self): # noqa: A003 + """Set or get `SkyImage`'s `id`. + + While `id` can be of any type, it is preferable that `id` be + of a type with nice string representation. + + """ + return self._id + + @id.setter + def id(self, value): # noqa: A003 + self._id = value + + @property + def pix_area(self): + """Set or get mean pixel area.""" + return self._pix_area + + @pix_area.setter + def pix_area(self, pix_area): + self._pix_area = pix_area + + @property + def poly_area(self): + """Get bounding polygon area in srad units.""" + return self._poly_area + + @property + def sky(self): + """Sky background value. See `calc_sky` for more details.""" + return self._sky + + @sky.setter + def sky(self, sky): + self._sky = sky + + @property + def is_sky_valid(self): + """ + Indicates whether sky value was successfully computed. + Must be set externally. + """ + return self._sky_is_valid + + @is_sky_valid.setter + def is_sky_valid(self, valid): + self._sky_is_valid = valid + + @property + def radec(self): + """ + Get RA and DEC of the vertices of the bounding polygon as a + `~numpy.ndarray` of shape (N, 2) where N is the number of vertices + 1. + """ + return self._radec + + @property + def polygon(self): + """Get image's bounding polygon.""" + return self._polygon + + def intersection(self, skyimage): + """ + Compute intersection of this `SkyImage` object and another + `SkyImage`, `SkyGroup`, or + :py:class:`~spherical_geometry.polygon.SphericalPolygon` + object. + + Parameters + ---------- + skyimage : SkyImage, SkyGroup, SphericalPolygon + Another object that should be intersected with this `SkyImage`. + + Returns + ------- + polygon : SphericalPolygon + A :py:class:`~spherical_geometry.polygon.SphericalPolygon` that is + the intersection of this `SkyImage` and `skyimage`. + + """ + if isinstance(skyimage, (SkyImage, SkyGroup)): + other = skyimage.polygon + else: + other = skyimage + + pts1 = np.sort(list(self._polygon.points)[0], axis=0) + pts2 = np.sort(list(other.points)[0], axis=0) + if np.allclose(pts1, pts2, rtol=0, atol=5e-9): + intersect_poly = self._polygon.copy() + else: + intersect_poly = self._polygon.intersection(other) + return intersect_poly + + def calc_bounding_polygon(self, stepsize=None): + """Compute image's bounding polygon. + + Parameters + ---------- + stepsize : int, None, optional + Indicates the maximum separation between two adjacent vertices + of the bounding polygon along each side of the image. Corners + of the image are included automatically. If `stepsize` is `None`, + bounding polygon will contain only vertices of the image. + + """ + ny, nx = self.image_shape + + if stepsize is None: + nintx = 2 + ninty = 2 + else: + nintx = max(2, int(np.ceil((nx + 1.0) / stepsize))) + ninty = max(2, int(np.ceil((ny + 1.0) / stepsize))) + + xs = np.linspace(-0.5, nx - 0.5, nintx, dtype=float) + ys = np.linspace(-0.5, ny - 0.5, ninty, dtype=float)[1:-1] + nptx = xs.size + npty = ys.size + + npts = 2 * (nptx + npty) + + borderx = np.empty((npts + 1,), dtype=float) + bordery = np.empty((npts + 1,), dtype=float) + + # "bottom" points: + borderx[:nptx] = xs + bordery[:nptx] = -0.5 + # "right" + sl = np.s_[nptx : nptx + npty] + borderx[sl] = nx - 0.5 + bordery[sl] = ys + # "top" + sl = np.s_[nptx + npty : 2 * nptx + npty] + borderx[sl] = xs[::-1] + bordery[sl] = ny - 0.5 + # "left" + sl = np.s_[2 * nptx + npty : -1] + borderx[sl] = -0.5 + bordery[sl] = ys[::-1] + + # close polygon: + borderx[-1] = borderx[0] + bordery[-1] = bordery[0] + + ra, dec = self.wcs_fwd(borderx, bordery, with_bounding_box=False) + # TODO: for strange reasons, occasionally ra[0] != ra[-1] and/or + # dec[0] != dec[-1] (even though we close the polygon in the + # previous two lines). Then SphericalPolygon fails because + # points are not closed. Therefore we force it to be closed: + ra[-1] = ra[0] + dec[-1] = dec[0] + + self._radec = [(ra, dec)] + self._polygon = SphericalPolygon.from_radec(ra, dec) + self._poly_area = np.fabs(self._polygon.area()) + + @property + def skystat(self): + """Stores/retrieves a callable object that takes a either a 2D image + (2D `numpy.ndarray`) or a list of pixel values (an Nx1 array) and + returns a tuple of two values: some statistics + (e.g., mean, median, etc.) and number of pixels/values from the input + image used in computing that statistics. + + When `skystat` is not set, `SkyImage` will use + :py:class:`~stcal.skymatch.skystatistics.SkyStats` object + to perform sky statistics on image data. + + """ + return self._skystat + + @skystat.setter + def skystat(self, skystat): + self._skystat = skystat + + def set_builtin_skystat( + self, + skystat="median", + lower=None, + upper=None, + nclip=5, + lsigma=4.0, + usigma=4.0, + binwidth=0.1, + ): + """ + Replace already set `skystat` with a "built-in" version of a + statistics callable object used to measure sky background. + + See :py:class:`~stcal.skymatch.skystatistics.SkyStats` for the + parameter description. + + """ + self._skystat = SkyStats( + skystat=skystat, + lower=lower, + upper=upper, + nclip=nclip, + lsig=lsigma, + usig=usigma, + binwidth=binwidth, + ) + + def calc_sky(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + if overlap is None: + if self._mask is None: + data = self.image + else: + data = self.image[self._mask.get_data()] + + polyarea = self.poly_area + + else: + fill_mask = np.zeros(self.image_shape, dtype=bool) + + if isinstance(overlap, SkyImage): + intersection = self.intersection(overlap) + polyarea = np.fabs(intersection.area()) + radec = list(intersection.to_radec()) + + elif isinstance(overlap, SkyGroup): + radec = [] + polyarea = 0.0 + for im in overlap: + intersection = self.intersection(im) + polyarea1 = np.fabs(intersection.area()) + if polyarea1 == 0.0: + continue + polyarea += polyarea1 + radec += list(intersection.to_radec()) + + elif isinstance(overlap, SphericalPolygon): + radec = [] + polyarea = 0.0 + for p in overlap._polygons: + intersection = self.intersection(SphericalPolygon([p])) + polyarea1 = np.fabs(intersection.area()) + if polyarea1 == 0.0: + continue + polyarea += polyarea1 + radec += list(intersection.to_radec()) + + else: # assume a list of (ra, dec) tuples: + radec = [] + polyarea = 0.0 + for r, d in overlap: + poly = SphericalPolygon.from_radec(r, d) + polyarea1 = np.fabs(poly.area()) + if polyarea1 == 0.0 or len(r) < 4: + continue + polyarea += polyarea1 + radec.append(self.intersection(poly).to_radec()) + + if polyarea == 0.0: + return None, 0, 0.0 + + for ra, dec in radec: + if len(ra) < 4: + continue + + # set pixels in 'fill_mask' that are inside a polygon to True: + x, y = self.wcs_inv(ra, dec) + poly_vert = list(zip(*[x, y])) + + polygon = region.Polygon(True, poly_vert) + fill_mask = polygon.scan(fill_mask) + + if self._mask is not None: + fill_mask &= self._mask.get_data() + + data = self.image[fill_mask] + + if data.size < 1: + return None, 0, 0.0 + + # Calculate sky + try: + skyval, npix = self._skystat(data) + except ValueError: + return None, 0, 0.0 + + if not np.isfinite(skyval): + return None, 0, 0.0 + + if delta: + skyval -= self._sky + + return skyval, npix, polyarea + + def _calc_sky_orig(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + + if overlap is None: + if self._mask is None: + data = self.image + else: + data = self.image[self._mask.get_data()] + + polyarea = self.poly_area + + else: + fill_mask = np.zeros(self.image_shape, dtype=bool) + + if isinstance(overlap, (SkyImage, SkyGroup, SphericalPolygon)): + intersection = self.intersection(overlap) + polyarea = np.fabs(intersection.area()) + radec = intersection.to_radec() + + else: # assume a list of (ra, dec) tuples: + radec = [] + polyarea = 0.0 + for r, d in overlap: + poly = SphericalPolygon.from_radec(r, d) + polyarea1 = np.fabs(poly.area()) + if polyarea1 == 0.0 or len(r) < 4: + continue + polyarea += polyarea1 + radec.append(self.intersection(poly).to_radec()) + + if polyarea == 0.0: + return None, 0, 0.0 + + for ra, dec in radec: + if len(ra) < 4: + continue + + # set pixels in 'fill_mask' that are inside a polygon to True: + x, y = self.wcs_inv(ra, dec) + poly_vert = list(zip(*[x, y])) + + polygon = region.Polygon(True, poly_vert) + fill_mask = polygon.scan(fill_mask) + + if self._mask is not None: + fill_mask &= self._mask.get_data() + + data = self.image[fill_mask] + + if data.size < 1: + return None, 0, 0.0 + + # Calculate sky + try: + skyval, npix = self._skystat(data) + + except ValueError: + return None, 0, 0.0 + + if delta: + skyval -= self._sky + + return skyval, npix, polyarea + + def copy(self): + """ + Return a shallow copy of the `SkyImage` object. + """ + si = SkyImage( + image=None, + wcs_fwd=self.wcs_fwd, + wcs_inv=self.wcs_inv, + pix_area=self.pix_area, + convf=self.convf, + mask=None, + id=self.id, + stepsize=None, + meta=self.meta, + ) + + si._image = self._image + si._mask = self._mask + si._image_shape = self._image_shape + si._mask_shape = self._mask_shape + si._reduce_memory_usage = self._reduce_memory_usage + + si._radec = self._radec + si._polygon = self._polygon + si._poly_area = self._poly_area + si.sky = self.sky + return si + + +class SkyGroup: + """ + Holds multiple :py:class:`SkyImage` objects whose sky background values + must be adjusted together. + + `SkyGroup` provides methods for obtaining bounding polygon of the group + of :py:class:`SkyImage` objects and to compute sky value of the group. + + """ + + def __init__(self, images, id=None, sky=0.0): # noqa: A002 + if isinstance(images, SkyImage): + self._images = [images] + + elif hasattr(images, "__iter__"): + self._images = [] + for im in images: + if not isinstance(im, SkyImage): + raise TypeError( + "Each element of the 'images' parameter " + "must be an 'SkyImage' object." + ) + self._images.append(im) + + else: + raise TypeError( + "Parameter 'images' must be either a single " + "'SkyImage' object or a list of 'SkyImage' objects" + ) + + self._id = id + self._update_bounding_polygon() + self._sky = sky + for im in self._images: + im.sky += sky + + @property + def id(self): # noqa: A003 + """Set or get `SkyImage`'s `id`. + + While `id` can be of any type, it is preferable that `id` be + of a type with nice string representation. + + """ + return self._id + + @id.setter + def id(self, value): # noqa: A003 + self._id = value + + @property + def sky(self): + """Sky background value. See `calc_sky` for more details.""" + return self._sky + + @sky.setter + def sky(self, sky): + delta_sky = sky - self._sky + self._sky = sky + for im in self._images: + im.sky += delta_sky + + @property + def radec(self): + """ + Get RA and DEC of the vertices of the bounding polygon as a + `~numpy.ndarray` of shape (N, 2) where N is the number of vertices + 1. + + """ + return self._radec + + @property + def polygon(self): + """Get image's bounding polygon.""" + return self._polygon + + def intersection(self, skyimage): + """ + Compute intersection of this `SkyImage` object and another + `SkyImage`, `SkyGroup`, or + :py:class:`~spherical_geometry.polygon.SphericalPolygon` + object. + + Parameters + ---------- + skyimage : SkyImage, SkyGroup, SphericalPolygon + Another object that should be intersected with this `SkyImage`. + + Returns + ------- + intersect_poly : SphericalPolygon + A :py:class:`~spherical_geometry.polygon.SphericalPolygon` that is + the intersection of this `SkyImage` and `skyimage`. + + """ + if isinstance(skyimage, (SkyImage, SkyGroup)): + other = skyimage.polygon + else: + other = skyimage + + pts1 = np.sort(list(self._polygon.points)[0], axis=0) + pts2 = np.sort(list(other.points)[0], axis=0) + if np.allclose(pts1, pts2, rtol=0, atol=1e-8): + intersect_poly = self._polygon.copy() + else: + intersect_poly = self._polygon.intersection(other) + return intersect_poly + + def _update_bounding_polygon(self): + polygons = [im.polygon for im in self._images] + if len(polygons) == 0: + self._polygon = SphericalPolygon([]) + self._radec = [] + else: + self._polygon = SphericalPolygon.multi_union(polygons) + self._radec = list(self._polygon.to_radec()) + + def __len__(self): + return len(self._images) + + def __getitem__(self, idx): + return self._images[idx] + + def __setitem__(self, idx, value): + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images[idx] = value + self._update_bounding_polygon() + + def __delitem__(self, idx): + del self._images[idx] + if len(self._images) == 0: + self._sky = 0.0 + self._id = None + self._update_bounding_polygon() + + def __iter__(self): + yield from self._images + + def insert(self, idx, value): + """Inserts a `SkyImage` into the group.""" + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images.insert(idx, value) + self._update_bounding_polygon() + + def append(self, value): + """Appends a `SkyImage` to the group.""" + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images.append(value) + self._update_bounding_polygon() + + def calc_sky(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + + if len(self._images) == 0: + return None, 0, 0.0 + + wght = 0 + area = 0.0 + + if overlap is None: + # compute minimum sky across all images in the group: + wsky = None + + for image in self._images: + # make sure all images have the same background: + image.background = self._sky + + sky, npix, imarea = image.calc_sky(overlap=None, delta=delta) + + if sky is None: + continue + + if wsky is None or wsky > sky: + wsky = sky + wght = npix + area = imarea + + return wsky, wght, area + + # compute weighted sky in various overlaps: + wsky = 0.0 + + for image in self._images: + # make sure all images have the same background: + image.background = self._sky + + sky, npix, area1 = image.calc_sky(overlap=overlap, delta=delta) + + area += area1 + + if sky is not None and npix > 0: + pix_area = npix * image.pix_area + wsky += sky * pix_area + wght += pix_area + + if wght == 0.0 or area == 0.0: + return None, wght, area + else: + return wsky / wght, wght, area diff --git a/src/stcal/skymatch/skymatch.py b/src/stcal/skymatch/skymatch.py new file mode 100644 index 000000000..d428fdaa0 --- /dev/null +++ b/src/stcal/skymatch/skymatch.py @@ -0,0 +1,539 @@ +""" +A module that provides functions for matching sky in overlapping images. +""" + +import logging +from datetime import datetime + +import numpy as np + +from .skyimage import SkyGroup, SkyImage + +__all__ = ["match"] + +__local_debug__ = True + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def match(images, skymethod="global+match", match_down=True, subtract=False): + """ + A function to compute and/or "equalize" sky background in input images. + + .. note:: + Sky matching ("equalization") is possible only for **overlapping** + images. + + Parameters + ---------- + images : list of SkyImage or SkyGroup + A list of :py:class:`~stcal.skymatch.skyimage.SkyImage` or + :py:class:`~stcal.skymatch.skyimage.SkyGroup` objects. + + skymethod : {'local', 'global+match', 'global', 'match'}, optional + Select the algorithm for sky computation: + + * **'local'** : compute sky background values of each input image or + group of images (members of the same "exposure"). A single sky value + is computed for each group of images. + + .. note:: + This setting is recommended when regions of overlap between images + are dominated by "pure" sky (as opposed to extended, diffuse + sources). + + * **'global'** : compute a common sky value for all input images and + groups of images. With this setting `local` will compute + sky values for each input image/group, find the minimum sky value, + and then it will set (and/or subtract) the sky value of each input image + to this minimum value. This method *may* be + useful when the input images have been already matched. + + * **'match'** : compute differences in sky values between images + and/or groups in (pair-wise) common sky regions. In this case + the computed sky values will be relative (delta) to the sky computed + in one of the input images whose sky value will be set to + (reported to be) 0. This setting will "equalize" sky values between + the images in large mosaics. However, this method is not recommended + when used in conjunction with + `astrodrizzle + `_ + because it computes relative sky values while `astrodrizzle` needs + "absolute" sky values for median image generation and CR rejection. + + * **'global+match'** : first use the **'match'** method to + equalize sky values between images and then find a minimum + "global" sky value amongst all input images. + + .. note:: + This is the *recommended* setting for images + containing diffuse sources (e.g., galaxies, nebulae) + covering significant parts of the image. + + match_down : bool, optional + Specifies whether the sky *differences* should be subtracted from + images with higher sky values (`match_down` = `True`) to match the + image with the lowest sky or sky differences should be added to the + images with lower sky values to match the sky of the image with the + highest sky value (`match_down` = `False`). + + .. note:: + This setting applies *only* when the `skymethod` parameter is + either `'match'` or `'global+match'`. + + subtract : bool (Default = False) + Subtract computed sky value from image data. + + + Raises + ------ + + TypeError + The `images` argument must be a Python list of + :py:class:`~stcal.skymatch.skyimage.SkyImage` and/or + :py:class:`~stcal.skymatch.skyimage.SkyGroup` objects. + + + Notes + ----- + + :py:func:`match` provides new algorithms for sky value computations + and enhances previously available algorithms used by, e.g., + `astrodrizzle + `_. + + Two new methods of sky subtraction have been introduced (compared to the + standard ``'local'``): ``'global'`` and ``'match'``, as well as a + combination of the two -- ``'global+match'``. + + - The ``'global'`` method computes the minimum sky value across *all* + input images and/or groups. That sky value is then considered to be + the background in all input images. + + - The ``'match'`` algorithm is somewhat similar to the traditional sky + subtraction method (`skymethod` = `'local'`) in the sense that it + measures the sky independently in input images (or groups). The major + differences are that, unlike the traditional method, + + #. ``'match'`` algorithm computes *relative* (delta) sky values with + regard to the sky in a reference image chosen from the input list + of images; *and* + + #. Sky statistics are computed only in the part of the image + that intersects other images. + + This makes the ``'match'`` sky computation algorithm particularly useful + for "equalizing" sky values in large mosaics in which one may have + only (at least) pair-wise intersection of images without having + a common intersection region (on the sky) in all images. + + The `'match'` method works in the following way: for each pair + of intersecting images, an equation is written that + requires that average surface brightness in the overlapping part of + the sky be equal in both images. The final system of equations is then + solved for unknown background levels. + + .. warning:: + + The current algorithm is not capable of detecting cases where some subsets + of intersecting images (from the input list of images) do not intersect + at all with other subsets of intersecting images (except for the simple + case when *single* images do not intersect any other images). In these + cases the algorithm will find equalizing sky values for each + intersecting subset of images and/or groups of images. + However since these subsets of images do not intersect each other, + sky will be matched only within each subset and the "inter-subset" + sky mismatch could be significant. + + Users are responsible for detecting such cases and adjusting processing + accordingly. + + - The ``'global+match'`` algorithm combines the ``'match'`` and + ``'global'`` methods in order to overcome the limitation of the + ``'match'`` method described in the note above: it uses the ``'global'`` + algorithm to find a baseline sky value common to all input images + and the ``'match'`` algorithm to "equalize" sky values in the mosaic. + Thus, the sky value of the "reference" image will be equal to the + baseline sky value (instead of 0 in ``'match'`` algorithm alone). + + **Remarks:** + * :py:func:`match` works directly on *geometrically distorted* + flat-fielded images thus avoiding the need to perform distortion + correction on the input images. + + Initially, the footprint of a chip in an image is approximated by a + 2D planar rectangle representing the borders of chip's distorted + image. After applying distortion model to this rectangle and + projecting it onto the celestial sphere, it is approximated by + spherical polygons. Footprints of exposures and mosaics are + computed as unions of such spherical polygons while overlaps + of image pairs are found by intersecting these spherical polygons. + + **Limitations and Discussions:** + Primary reason for introducing "sky match" algorithm was to try to + equalize the sky in large mosaics in which computation of the + "absolute" sky is difficult due to the presence of large diffuse + sources in the image. As discussed above, :py:func:`match` + accomplishes this by comparing "sky values" in a pair of images in the + overlap region (that is common to both images). Quite obviously the + quality of sky "matching" will depend on how well these "sky values" + can be estimated. We use quotation marks around *sky values* because + for some image "true" background may not be present at all and the + measured sky may be the surface brightness of large galaxy, nebula, etc. + + In the discussion below we will refer to parameter names in + :py:class:`~stcal.skymatch.skystatistics.SkyStats` and these + parameter names may differ from the parameters of the actual `skystat` + object passed to initializer of the + :py:class:`~stcal.skymatch.skyimage.SkyImage`. + + Here is a brief list of possible limitations/factors that can affect + the outcome of the matching (sky subtraction in general) algorithm: + + * Since sky subtraction is performed on *flat-fielded* but + *not distortion corrected* images, it is important to keep in mind + that flat-fielding is performed to obtain uniform surface brightness + and not flux. This distinction is important for images that have + not been distortion corrected. As a consequence, it is advisable that + point-like sources be masked through the user-supplied mask files. + Values different from zero in user-supplied masks indicate "good" data + pixels. Alternatively, one can use `upper` parameter to limit the use + of bright objects in sky computations. + + * Normally, distorted flat-fielded images contain cosmic rays. This + algorithm does not perform CR cleaning. A possible way of minimizing + the effect of the cosmic rays on sky computations is to use + clipping (`nclip` > 0) and/or set `upper` parameter to a value + larger than most of the sky background (or extended source) but + lower than the values of most CR pixels. + + * In general, clipping is a good way of eliminating "bad" pixels: + pixels affected by CR, hot/dead pixels, etc. However, for + images with complicated backgrounds (extended galaxies, nebulae, + etc.), affected by CR and noise, clipping process may mask different + pixels in different images. If variations in the background are + too strong, clipping may converge to different sky values in + different images even when factoring in the "true" difference + in the sky background between the two images. + + * In general images can have different "true" background values + (we could measure it if images were not affected by large diffuse + sources). However, arguments such as `lower` and `upper` will + apply to all images regardless of the intrinsic differences + in sky levels. + + """ + function_name = match.__name__ + + # Time it + runtime_begin = datetime.now() + + log.info(" ") + log.info(f"***** {__name__:s}.{function_name:s}() started on {runtime_begin}") + log.info(" ") + + # check sky method: + skymethod = skymethod.lower() + if skymethod not in ["local", "global", "match", "global+match"]: + raise ValueError( + "Unsupported 'skymethod'. Valid values are: " + "'local', 'global', 'match', or 'global+match'" + ) + do_match = "match" in skymethod + do_global = "global" in skymethod + show_old = subtract + + log.info(f"Sky computation method: '{skymethod}'") + if do_match: + log.info("Sky matching direction: {:s}".format("DOWN" if match_down else "UP")) + + log.info( + "Sky subtraction from image data: {:s}".format("ON" if subtract else "OFF") + ) + + # check that input file name is a list of either SkyImage or SkyGroup: + nimages = 0 + for img in images: + if isinstance(img, SkyImage): + nimages += 1 + elif isinstance(img, SkyGroup): + nimages += len(img) + else: + raise TypeError( + "Each element of the 'images' must be either a " + "'SkyImage' or a 'SkyGroup'" + ) + + if nimages == 0: + raise ValueError("Argument 'images' must contain at least one image") + + log.debug( + "Total number of images to be sky-subtracted and/or matched: {:d}".format( + nimages + ) + ) + + # Print conversion factors + log.debug(" ") + log.debug("---- Image data conversion factors:") + + for img in images: + img_type = "Image" if isinstance(img, SkyImage) else "Group" + + if img_type == "Group": + log.debug(f" * Group ID={img.id}. Conversion factors:") + for im in img: + log.debug( + " - Image ID={}. Conversion factor = {:G}".format( + im.id, im.convf + ) + ) + else: + log.debug(f" * Image ID={img.id}. Conversion factor = {img.convf:G}") + + # 1. Method: "match" (or "global+match"). + # Find sky "deltas" that will match sky across all + # (intersecting) images. + if do_match: + log.info(" ") + log.info("---- Computing differences in sky values in " "overlapping regions.") + + # find "optimum" sky changes: + sky_deltas = _find_optimum_sky_deltas(images, apply_sky=not subtract) + sky_good = np.isfinite(sky_deltas) + + if np.any(sky_good): + # match sky "Up" or "Down": + if match_down: + refsky = np.amin(sky_deltas[sky_good]) + else: + refsky = np.amax(sky_deltas[sky_good]) + sky_deltas[sky_good] -= refsky + + # convert to Python list and replace numpy.nan with None + sky_deltas = [skd if np.isfinite(skd) else None for skd in sky_deltas] + + _apply_sky(images, sky_deltas, False, subtract, show_old) + show_old = True + + # 2. Method: "local". Compute the minimum sky background + # value in each sky group/image. + # This is an improved (use of masks) replacement + # for the classical 'subtract' used by astrodrizzle. + # + # NOTE: incompatible with "match"-containing + # 'skymethod' modes. + # + # 3. Method: "global". Compute the minimum sky background + # value *across* *all* sky line members. + if do_global or not do_match: + log.info(" ") + if do_global: + minsky = None + log.info( + '---- Computing "global" sky - smallest sky value ' + "across *all* input images." + ) + else: + log.info("---- Sky values computed per image and/or image " "groups.") + + sky_deltas = [] + for img in images: + sky = img.calc_sky(delta=not subtract)[0] + sky_deltas.append(sky) + if do_global and (minsky is None or sky < minsky): + minsky = sky + + if do_global: + log.info(" ") + if minsky is None: + log.warning(' Unable to compute "global" sky value') + sky_deltas = len(sky_deltas) * [minsky] + log.info( + ' "Global" sky value correction: {} ' "[not converted]".format(minsky) + ) + + if do_match: + log.info(" ") + log.info("---- Final (match+global) sky for:") + + _apply_sky(images, sky_deltas, do_global, subtract, show_old) + + # log running time: + runtime_end = datetime.now() + log.info(" ") + log.info(f"***** {__name__:s}.{function_name:s}() ended on {runtime_end}") + log.info( + "***** {:s}.{:s}() TOTAL RUN TIME: {}".format( + __name__, function_name, runtime_end - runtime_begin + ) + ) + log.info(" ") + + +def _apply_sky(images, sky_deltas, do_global, do_skysub, show_old): + for img, sky in zip(images, sky_deltas): + is_group = not isinstance(img, SkyImage) + + if do_global: + if sky is None: + valid = img[0].is_sky_valid if is_group else img.is_sky_valid + sky = 0.0 + else: + valid = True + + else: + valid = sky is not None + if not valid: + log.warning( + " * {:s} ID={}: Unable to compute sky value".format( + "Group" if is_group else "Image", img.id + ) + ) + sky = 0.0 + + if is_group: + # apply sky change: + old_img_sky = [im.sky for im in img] + if do_skysub: + for im in img: + im._image.set_data(im._image.get_data() - sky) + img.sky += sky + new_img_sky = [im.sky for im in img] + + # log sky values: + log.info( + " * Group ID={}. Sky background of " + "component images:".format(img.id) + ) + + for im, old_sky, new_sky in zip(img, old_img_sky, new_img_sky): + c = 1.0 / im.convf + if show_old: + log.info( + " - Image ID={}. Sky background: {:G} " + "(old={:G}, delta={:G})".format( + im.id, c * new_sky, c * old_sky, c * sky + ) + ) + else: + log.info( + " - Image ID={}. Sky background: {:G}".format( + im.id, c * new_sky + ) + ) + + im.is_sky_valid = valid + + else: + # apply sky change: + old_sky = img.sky if img.sky is not None else 0 + if do_skysub: + img._image.set_data(img._image.get_data() - sky) + if img.sky is None: + img.sky = 0 + + img.sky += sky + new_sky = img.sky + + # log sky values: + c = 1.0 / img.convf + if show_old: + log.info( + " * Image ID={}. Sky background: {:G} " + "(old={:G}, delta={:G})".format( + img.id, c * new_sky, c * old_sky, c * sky + ) + ) + else: + log.info( + " * Image ID={}. Sky background: {:G}".format( + img.id, c * new_sky + ) + ) + + img.is_sky_valid = valid + + +def _overlap_matrix(images, apply_sky=True): + # TODO: to improve performance, the nested loops could be parallelized + # since _calc_sky() here can be called independently from previous steps. + ns = len(images) + A = np.zeros((ns, ns), dtype=float) + W = np.zeros((ns, ns), dtype=float) + for i in range(ns): + for j in range(i + 1, ns): + s1, w1, area1 = images[i].calc_sky(overlap=images[j], delta=apply_sky) + + s2, w2, area2 = images[j].calc_sky(overlap=images[i], delta=apply_sky) + if area1 == 0.0 or area2 == 0.0 or s1 is None or s2 is None: + continue + + A[j, i] = s1 + W[j, i] = w1 + A[i, j] = s2 + W[i, j] = w2 + + return A, W + + +def _find_optimum_sky_deltas(images, apply_sky=True): + ns = len(images) + A, W = _overlap_matrix(images, apply_sky=apply_sky) + + def is_valid(i, j): + return W[i, j] > 0 and W[j, i] > 0 + + # We need to know how many "non-trivial" (at least for now... - we will + # compute rank later) equations can be built so that we know the + # shape of the arrays that need to be created... + # NOTE: for now use only pairs that *both* have weights > 0 (but a + # different scenario when only one image has a valid weight can be + # considered): + neq = 0 + for i in range(ns): + for j in range(i + 1, ns): + if is_valid(i, j): + neq += 1 + + # average weights: + Wm = 0.5 * (W + W.T) + + # create arrays for coefficients and free terms: + K = np.zeros((neq, ns), dtype=float) + F = np.zeros(neq, dtype=float) + invalid = (ns) * [True] + + # now process intersections between the rest of the images: + ieq = 0 + for i in range(0, ns): + for j in range(i + 1, ns): + if is_valid(i, j): + K[ieq, i] = Wm[i, j] + K[ieq, j] = -Wm[i, j] + F[ieq] = Wm[i, j] * (A[j, i] - A[i, j]) + invalid[i] = False + invalid[j] = False + ieq += 1 + + try: + rank = np.linalg.matrix_rank(K, 1.0e-12) + except np.linalg.LinAlgError: + log.warning("Unable to compute sky: No valid data in common " "image areas") + deltas = np.full(ns, np.nan, dtype=float) + return deltas + + if rank < ns - 1: + log.warning(f"There are more unknown sky values ({ns}) to be solved for") + log.warning( + "than there are independent equations available " + "(matrix rank={}).".format(rank) + ) + log.warning("Sky matching (delta) values will be computed only for") + log.warning("a subset (or more independent subsets) of input images.") + invK = np.linalg.pinv(K, rcond=1.0e-12) + + deltas = np.dot(invK, F) + deltas[np.asarray(invalid, dtype=bool)] = np.nan + return deltas diff --git a/src/stcal/skymatch/skystatistics.py b/src/stcal/skymatch/skystatistics.py new file mode 100644 index 000000000..b93a99ed7 --- /dev/null +++ b/src/stcal/skymatch/skystatistics.py @@ -0,0 +1,136 @@ +""" +The `skystatistics` module provides statistics computation class used by +:py:func:`~stcal.skymatch.skymatch.match` +and :py:class:`~stcal.skymatch.skyimage.SkyImage`. +""" + +from copy import deepcopy + +# THIRD PARTY +from stsci.imagestats import ImageStats + +__all__ = ["SkyStats"] + + +class SkyStats: + """ + This is a superclass build on top of + :py:class:`stsci.imagestats.ImageStats`. Compared to + :py:class:`stsci.imagestats.ImageStats`, `SkyStats` has + "persistent settings" in the sense that object's parameters need to be + set once and these settings will be applied to all subsequent + computations on different data. + """ + + def __init__( + self, + skystat="mean", + lower=None, + upper=None, + nclip=5, + lsig=4.0, + usig=4.0, + binwidth=0.1, + **kwargs, + ): + """Initializes the SkyStats object. + + Parameters + ----------- + skystat : {'mode', 'median', 'mode', 'midpt'}, optional + Sets the statistics that will be returned by `~SkyStats.calc_sky`. + The following statistics are supported: 'mean', 'mode', 'midpt', + and 'median'. First three statistics have the same meaning as in + `stsdas.toolbox.imgtools.gstatistics `_ + while 'median' will compute the median of the distribution. + + lower : float, None, optional + Lower limit of usable pixel values for computing the sky. + This value should be specified in the units of the input image(s). + + upper : float, None, optional + Upper limit of usable pixel values for computing the sky. + This value should be specified in the units of the input image(s). + + nclip : int, optional + A non-negative number of clipping iterations to use when computing + the sky value. + + lsig : float, optional + Lower clipping limit, in sigma, used when computing the sky value. + + usig : float, optional + Upper clipping limit, in sigma, used when computing the sky value. + + binwidth : float, optional + Bin width, in sigma, used to sample the distribution of pixel + brightness values in order to compute the sky background + statistics. + + kwargs : dict + A dictionary of optional arguments to be passed to `ImageStats`. + + """ + self.npix = None + self.skyval = None + + self._fields = f"npix,{skystat}" + + self._kwargs = deepcopy(kwargs) + if "fields" in self._kwargs: + del self._kwargs["fields"] + if "image" in self._kwargs: + del self._kwargs["image"] + self._kwargs["lower"] = lower + self._kwargs["upper"] = upper + self._kwargs["nclip"] = nclip + self._kwargs["lsig"] = lsig + self._kwargs["usig"] = usig + self._kwargs["binwidth"] = binwidth + + self._skystat = { + "mean": self._extract_mean, + "mode": self._extract_mode, + "median": self._extract_median, + "midpt": self._extract_midpt, + }[skystat] + + def _extract_mean(self, imstat): + return imstat.mean + + def _extract_median(self, imstat): + return imstat.median + + def _extract_mode(self, imstat): + return imstat.mode + + def _extract_midpt(self, imstat): + return imstat.midpt + + def calc_sky(self, data): + """Computes statistics on data. + + Parameters + ----------- + data : numpy.ndarray + A numpy array of values for which the statistics needs to be computed. + + Returns + -------- + statistics : tuple + A tuple of two values: (`skyvalue`, `npix`), where `skyvalue` + is the statistics specified by the `skystat` parameter during + the initialization of the `SkyStats` object and `npix` is the + number of pixels used in computing the statistics reported + in `skyvalue`. + + """ + imstat = ImageStats(image=data, fields=self._fields, **(self._kwargs)) + self.skyval = self._skystat(imstat) # dict or scalar + + self.npix = imstat.npix + return self.skyval, self.npix + + def __call__(self, data): + return self.calc_sky(data) From 2a84121201356ade1d153341cf836003a43d8f74 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 17 Oct 2024 13:42:22 -0400 Subject: [PATCH 2/5] Update changes --- changes/310.general.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/310.general.rst diff --git a/changes/310.general.rst b/changes/310.general.rst new file mode 100644 index 000000000..e20c43583 --- /dev/null +++ b/changes/310.general.rst @@ -0,0 +1 @@ +Move common parts of skymatch shared by both jwst and romancal into stcal. From c36de6e940a94904eba396b9bec03465a3743180 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 19 Dec 2024 14:25:31 -0500 Subject: [PATCH 3/5] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5faca45e0..80d8127c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "gwcs >=0.22.0", "tweakwcs >=0.8.8", "requests >=2.22", + "spherical-geometry>=1.2.22" ] dynamic = [ "version", From b4a3745e004fabb9047709752f3a059b4081c29c Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 19 Dec 2024 14:44:26 -0500 Subject: [PATCH 4/5] Fix check types --- src/stcal/skymatch/skyimage.py | 4 ++-- src/stcal/skymatch/skystatistics.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stcal/skymatch/skyimage.py b/src/stcal/skymatch/skyimage.py index 35df56ae3..96fd53f9e 100644 --- a/src/stcal/skymatch/skyimage.py +++ b/src/stcal/skymatch/skyimage.py @@ -9,10 +9,10 @@ import tempfile import numpy as np -from spherical_geometry.polygon import SphericalPolygon +from spherical_geometry.polygon import SphericalPolygon # type: ignore[import-untyped] from . import region -from .skystatistics import SkyStats +from .skystatistics import SkyStats # type: ignore[import-untyped] __all__ = [ "SkyImage", diff --git a/src/stcal/skymatch/skystatistics.py b/src/stcal/skymatch/skystatistics.py index b93a99ed7..ffecd43be 100644 --- a/src/stcal/skymatch/skystatistics.py +++ b/src/stcal/skymatch/skystatistics.py @@ -7,7 +7,7 @@ from copy import deepcopy # THIRD PARTY -from stsci.imagestats import ImageStats +from stsci.imagestats import ImageStats # type: ignore[import-untyped] __all__ = ["SkyStats"] From 6202d388f86315c32044bb160bcc06347f6b493e Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 19 Dec 2024 14:47:20 -0500 Subject: [PATCH 5/5] Move to using gwcs for region --- src/stcal/skymatch/region.py | 425 --------------------------------- src/stcal/skymatch/skyimage.py | 2 +- 2 files changed, 1 insertion(+), 426 deletions(-) delete mode 100644 src/stcal/skymatch/region.py diff --git a/src/stcal/skymatch/region.py b/src/stcal/skymatch/region.py deleted file mode 100644 index 252e8454d..000000000 --- a/src/stcal/skymatch/region.py +++ /dev/null @@ -1,425 +0,0 @@ -""" -Polygon filling algorithm. - -NOTE: Algorithm description can be found, e.g., here: - http://www.cs.rit.edu/~icss571/filling/how_to.html - http://www.cs.uic.edu/~jbell/CourseNotes/ComputerGraphics/PolygonFilling.html -""" - -from collections import OrderedDict - -import numpy as np - -__all__ = ["Region", "Edge", "Polygon"] - - -class ValidationError(Exception): - def __init__(self, message): - self._message = message - - def __str__(self): - return self._message - - -class Region: - """ - Base class for regions. - - Parameters - ------------- - rid : int or string - region ID - coordinate_system : astropy.wcs.CoordinateSystem instance or a string - in the context of WCS this would be an instance of wcs.CoordinateSysem - """ - - def __init__(self, rid, coordinate_system): - self._coordinate_system = coordinate_system - self._rid = rid - - def __contains__(self, xy: tuple[float]) -> bool: - """ - Determines if a pixel is within a region. - - Parameters - ---------- - xy : tuple[float] - x , y values of a pixel - - Returns - ------- - True or False - - Subclasses must define this method. - """ - raise NotImplementedError("__contains__") - - def scan(self, mask): - """ - Sets mask values to region id for all pixels within the region. - Subclasses must define this method. - - Parameters - ---------- - mask : ndarray - a byte array with the shape of the observation to be used as a mask - - Returns - ------- - mask : array where the value of the elements is the region ID or 0 (for - pixels which are not included in any region). - """ - raise NotImplementedError("scan") - - -class Polygon(Region): - """ - Represents a 2D polygon region with multiple vertices - - Parameters - ---------- - rid : string - polygon id - vertices : list of (x,y) tuples or lists - The list is ordered in such a way that when traversed in a - counterclockwise direction, the enclosed area is the polygon. - The last vertex must coincide with the first vertex, minimum - 4 vertices are needed to define a triangle - coord_system : string - coordinate system - - """ - - def __init__(self, rid, vertices, coord_system="Cartesian"): - assert len(vertices) >= 4, ( - "Expected vertices to be " "a list of minimum 4 tuples (x,y)" - ) - super().__init__(rid, coord_system) - - # self._shiftx & self._shifty are introduced to shift the bottom-left - # corner of the polygon's bounding box to (0,0) as a (hopefully - # temporary) workaround to a limitation of the original code that the - # polygon must be completely contained in the image. It seems that the - # code works fine if we make sure that the bottom-left corner of the - # polygon's bounding box has non-negative coordinates. - self._shiftx = 0 - self._shifty = 0 - for vertex in vertices: - x, y = vertex - if x < self._shiftx: - self._shiftx = x - if y < self._shifty: - self._shifty = y - v = [(i - self._shiftx, j - self._shifty) for i, j in vertices] - - # convert to integer coordinates: - self._vertices = np.asarray(list(map(_round_vertex, v))) - self._shiftx = int(round(self._shiftx)) - self._shifty = int(round(self._shifty)) - - self._bbox = self._get_bounding_box() - self._scan_line_range = list( - range(self._bbox[1], self._bbox[3] + self._bbox[1] + 1) - ) - # constructs a Global Edge Table (GET) in bbox coordinates - self._GET = self._construct_ordered_GET() - - def _get_bounding_box(self): - x = self._vertices[:, 0].min() - y = self._vertices[:, 1].min() - w = self._vertices[:, 0].max() - x - h = self._vertices[:, 1].max() - y - return x, y, w, h - - def _construct_ordered_GET(self): - """ - Construct a Global Edge Table (GET) - - The GET is an OrderedDict. Keys are scan line numbers, - ordered from bbox.ymin to bbox.ymax, where bbox is the - bounding box of the polygon. - Values are lists of edges for which edge.ymin==scan_line_number. - - Returns - ------- - GET: OrderedDict - {scan_line: [edge1, edge2]} - """ - # edges is a list of Edge objects which define a polygon - # with these vertices - edges = self.get_edges() - GET = OrderedDict.fromkeys(self._scan_line_range) - ymin = np.asarray([e._ymin for e in edges]) - for i in self._scan_line_range: - ymin_ind = (ymin == i).nonzero()[0] - # a hack for incomplete filling .any() fails if 0 is in ymin_ind - # if ymin_ind.any(): - (yminindlen,) = ymin_ind.shape - if yminindlen: - GET[i] = [edges[ymin_ind[0]]] - for j in ymin_ind[1:]: - GET[i].append(edges[j]) - return GET - - def get_edges(self): - """ - Create a list of Edge objects from vertices - """ - edges = [] - for i in range(1, len(self._vertices)): - name = "E" + str(i - 1) - edges.append( - Edge(name=name, start=self._vertices[i - 1], stop=self._vertices[i]) - ) - return edges - - def scan(self, data): - """ - This is the main function which scans the polygon and creates the mask - - Parameters - ---------- - data : array - the mask array - it has all zeros initially, elements within a region are set to - the region's ID - - Algorithm: - - Set the Global Edge Table (GET) - - Set y to be the smallest y coordinate that has an entry in GET - - Initialize the Active Edge Table (AET) to be empty - - For each scan line: - 1. Add edges from GET to AET for which ymin==y - 2. Remove edges from AET fro which ymax==y - 3. Compute the intersection of the current scan line with all edges in the AET - 4. Sort on X of intersection point - 5. Set elements between pairs of X in the AET to the Edge's ID - - """ - # TODO: 1.This algorithm does not mark pixels in the top - # row and left most column. Pad the initial pixel description on - # top and left with 1 px to prevent this. 2. Currently it uses - # intersection of the scan line with edges. If this is too slow it - # should use the 1/m increment (replace 3 above) (or the increment - # should be removed from the GET entry). - - # see comments in the __init__ function for the reason of introducing - # polygon shifts (self._shiftx & self._shifty). Here we need to shift - # it back. - - (ny, nx) = data.shape - - y = np.min(list(self._GET.keys())) - - AET = [] - scline = self._scan_line_range[-1] - - while y <= scline: - if y < scline: - AET = self.update_AET(y, AET) - - if self._bbox[2] <= 0: - y += 1 - continue - - scan_line = Edge( - "scan_line", - start=[self._bbox[0], y], - stop=[self._bbox[0] + self._bbox[2], y], - ) - x = [ - int(np.ceil(e.compute_AET_entry(scan_line)[1])) - for e in AET - if e is not None - ] - xnew = np.sort(x) - ysh = y + self._shifty - - if ysh < 0 or ysh >= ny: - y += 1 - continue - - for i, j in zip(xnew[::2], xnew[1::2]): - xstart = max(0, i + self._shiftx) - xend = min(j + self._shiftx, nx - 1) - data[ysh][xstart : xend + 1] = self._rid - - y += 1 - - return data - - def update_AET(self, y, AET): - """ - Update the Active Edge Table (AET) - - Add edges from GET to AET for which ymin of the edge is - equal to the y of the scan line. - Remove edges from AET for which ymax of the edge is - equal to y of the scan line. - - """ - edge_cont = self._GET[y] - if edge_cont is not None: - for edge in edge_cont: - if edge._start[1] != edge._stop[1] and edge._ymin == y: - AET.append(edge) - for edge in AET[::-1]: - if edge is not None: - if edge._ymax == y: - AET.remove(edge) - return AET - - def __contains__(self, px): - """even-odd algorithm or something else better should be used""" - # minx = self._vertices[:,0].min() - # maxx = self._vertices[:,0].max() - # miny = self._vertices[:,1].min() - # maxy = self._vertices[:,1].max() - return ( - px[0] >= self._bbox[0] - and px[0] <= self._bbox[0] + self._bbox[2] - and px[1] >= self._bbox[1] - and px[1] <= self._bbox[1] + self._bbox[3] - ) - - -class Edge: - """ - Edge representation - - An edge has "start" and "stop" (x,y) vertices and an entry in the - GET table of a polygon. The GET entry is a list of these values: - - [ymax, x_at_ymin, delta_x/delta_y] - - """ - - def __init__(self, name=None, start=None, stop=None, next=None): # noqa: A002 - self._start = None - if start is not None: - self._start = np.asarray(start) - self._name = name - self._stop = stop - if stop is not None: - self._stop = np.asarray(stop) - self._next = next - - if self._stop is not None and self._start is not None: - if self._start[1] < self._stop[1]: - self._ymin = self._start[1] - self._yminx = self._start[0] - else: - self._ymin = self._stop[1] - self._yminx = self._stop[0] - self._ymax = max(self._start[1], self._stop[1]) - self._xmin = min(self._start[0], self._stop[0]) - self._xmax = max(self._start[0], self._stop[1]) - else: - self._ymin = None - self._yminx = None - self._ymax = None - self._xmin = None - self._xmax = None - self.GET_entry = self.compute_GET_entry() - - @property - def ymin(self): - return self._ymin - - @property - def start(self): - return self._start - - @property - def stop(self): - return self._stop - - @property - def ymax(self): - return self._ymax - - def compute_GET_entry(self): - """ - Compute the entry in the Global Edge Table - - [ymax, x@ymin, 1/m] - - """ - if self._start is None: - entry = None - else: - earr = np.asarray([self._start, self._stop]) - if np.diff(earr[:, 1]).item() == 0: - return None - else: - entry = [ - self._ymax, - self._yminx, - (np.diff(earr[:, 0]) / np.diff(earr[:, 1])).item(), - None, - ] - return entry - - def compute_AET_entry(self, edge): - """ - Compute the entry for an edge in the current Active Edge Table - - [ymax, x_intersect, 1/m] - note: currently 1/m is not used - """ - x = self.intersection(edge)[0] - return [self._ymax, x, self.GET_entry[2]] - - def __repr__(self): - fmt = "" - if self._name is not None: - fmt += self._name - next_edge = self.next - while next_edge is not None: - fmt += "-->" - fmt += next_edge._name - next_edge = next_edge.next - return fmt - - @property - def next(self): # noqa: A003 - return self._next - - @next.setter - def next(self, edge): # noqa: A003 - if self._name is None: - self._name = edge._name - self._stop = edge._stop - self._start = edge._start - self._next = edge.next - else: - self._next = edge - - def intersection(self, edge): - u = self._stop - self._start - v = edge._stop - edge._start - w = self._start - edge._start - - # Find the determinant of the matrix formed by the vectors u and v - # Note: Originally this was computed using a numpy "2D" cross product, - # however, this functionality has been deprecated and slated for - # removal. - D = np.linalg.det([u, v]) - - if np.allclose(D, 0, rtol=0, atol=1e2 * np.finfo(float).eps): - return np.array(self._start) - - # See note above - return np.linalg.det([v, w]) / D * u + self._start - - def is_parallel(self, edge): - u = self._stop - self._start - v = edge._stop - edge._start - return np.allclose( - np.linalg.det([u, v]), 0, rtol=0, atol=1e2 * np.finfo(float).eps - ) - - -def _round_vertex(v): - x, y = v - return int(round(x)), int(round(y)) diff --git a/src/stcal/skymatch/skyimage.py b/src/stcal/skymatch/skyimage.py index 96fd53f9e..155d0e385 100644 --- a/src/stcal/skymatch/skyimage.py +++ b/src/stcal/skymatch/skyimage.py @@ -9,9 +9,9 @@ import tempfile import numpy as np +from gwcs import region from spherical_geometry.polygon import SphericalPolygon # type: ignore[import-untyped] -from . import region from .skystatistics import SkyStats # type: ignore[import-untyped] __all__ = [