From 8c67ff57ba89da5d3087433d2f889b358e4d05b6 Mon Sep 17 00:00:00 2001 From: Andy Casey Date: Tue, 10 Oct 2023 22:23:29 -0600 Subject: [PATCH] notebook for targeting flags creation --- .../20231011_sdss5_targeting_creation.ipynb | 187 ++++++++++++++++++ python/sdss_semaphore/__init__.py | 19 +- setup.cfg | 2 +- 3 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 notebooks/20231011_sdss5_targeting_creation.ipynb diff --git a/notebooks/20231011_sdss5_targeting_creation.ipynb b/notebooks/20231011_sdss5_targeting_creation.ipynb new file mode 100644 index 0000000..3d9931a --- /dev/null +++ b/notebooks/20231011_sdss5_targeting_creation.ipynb @@ -0,0 +1,187 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating a `TargetingFlags` column for flat files\n", + "\n", + "In this notebook we will query the SDSS database and create a FITS file with two columns: `SDSS_ID`, and `SDSS5_TARGET_FLAGS`.\n", + "The `SDSS5_TARGET_FLAGS` will be an array of integers, created using `semaphore`. \n", + "\n", + "Andy Casey (andrew.casey@monash.edu)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can install `semaphore` with:\n", + "\n", + "```pip install sdss-semaphore==0.2.3```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[1;33m[WARNING]: \u001b[0m\u001b[0;39mcatalog_to_allstar_dr17_synspec_rev1: cannot find related table 'allstar_dr17_synspec_rev1'\u001b[0m \u001b[0;36m(SdssdbUserWarning)\u001b[0m\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "from sdssdb.peewee.sdss5db import database\n", + "\n", + "database.set_profile(\"operations\")\n", + "\n", + "from sdssdb.peewee.sdss5db.targetdb import Assignment, CartonToTarget, Target\n", + "from sdssdb.peewee.sdss5db.catalogdb import CatalogdbModel\n", + "\n", + "from sdss_semaphore.targeting import TargetingFlags\n", + "\n", + "# At the time of writing, this peewee model did not exist in my version of sdssdb.\n", + "# In future you can probably just import directly: `from sdssdb.peewee.sdss5db.catalogdb import SDSS_ID_Flat`\n", + "\n", + "class SDSS_ID_Flat(CatalogdbModel):\n", + " class Meta:\n", + " table_name = \"sdss_id_flat\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we knew how many sources there would be then we could just create one `TargetingFlags` object for N sources, and then use the source identifier (e.g., SDSS_ID) to look up which index we need to set each bit for. But that lookup can be expensive for many objects, and we want to avoid pre-computing how many sources there are.\n", + "\n", + "Instead we will create a dictionary of `TargetingFlags` objects, keyed by the source identifier, and merge them together into one `TargetingFlags` object at the end.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Assignments to `TargetingFlags`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "10000it [00:00, 18325.14it/s] \n" + ] + } + ], + "source": [ + "\n", + "# This is a YUGE query. Let's limit to 10,000 for testing purposes\n", + "q = (\n", + " SDSS_ID_Flat\n", + " .select(\n", + " SDSS_ID_Flat.sdss_id,\n", + " CartonToTarget.carton_pk,\n", + " )\n", + " .join(Target, on=(SDSS_ID_Flat.catalogid == Target.catalogid))\n", + " .join(CartonToTarget, on=(Target.pk == CartonToTarget.target_pk))\n", + " .join(Assignment, on=(Assignment.carton_to_target_pk == CartonToTarget.pk))\n", + " .tuples()\n", + " .limit(10_000)\n", + " .iterator()\n", + ")\n", + "\n", + "manual_counts = {}\n", + "\n", + "flags_dict = {}\n", + "for sdss_id, carton_pk in tqdm(q, total=1): # total=1 prevents tqdm from executing the count() query\n", + " try:\n", + " flags_dict[sdss_id]\n", + " except KeyError:\n", + " flags_dict[sdss_id] = TargetingFlags()\n", + " \n", + " flags_dict[sdss_id].set_bit_by_carton_pk(0, carton_pk) # 0 since this is the only object\n", + " manual_counts.setdefault(carton_pk, set())\n", + " manual_counts[carton_pk].add(sdss_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Now we will create two columns:\n", + "# - one for all our source identifiers\n", + "# - one for all our targeting flags\n", + "\n", + "sdss_ids = list(flags_dict.keys())\n", + "flags = TargetingFlags(list(flags_dict.values()))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# A sanity check.\n", + "for carton_pk, count in flags.count_by_attribute(\"carton_pk\", skip_empty=True).items():\n", + " assert count == len(manual_counts[carton_pk])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Now let's write out our columns to a fits file.\n", + "from astropy.io import fits\n", + "\n", + "N, F = flags.array.shape\n", + "hdul = fits.HDUList([\n", + " fits.PrimaryHDU(),\n", + " fits.BinTableHDU.from_columns([\n", + " fits.Column(name=\"SDSS_ID\", array=sdss_ids, format=\"K\"),\n", + " fits.Column(name=\"SDSS5_TARGET_FLAGS\", array=flags.array, format=f\"{F}B\", dim=f\"({F})\")\n", + " ])\n", + "])\n", + "hdul.writeto(\"output.fits\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "astra_base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/sdss_semaphore/__init__.py b/python/sdss_semaphore/__init__.py index 5d0a5a7..78ba31c 100644 --- a/python/sdss_semaphore/__init__.py +++ b/python/sdss_semaphore/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.2" +__version__ = "0.2.3" import numpy as np import warnings @@ -9,7 +9,7 @@ class BaseFlags: """A base class for communicating with flags.""" - def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], Iterable[bytearray]]] = None) -> None: + def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], Iterable[bytearray], Iterable['BaseFlags']]] = None) -> None: if array is None: # Assume single object flag. self.array = np.zeros((1, 0), dtype=self.dtype) @@ -22,6 +22,18 @@ def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], It self.array = np.zeros((N, F), dtype=self.dtype) for i, item in enumerate(array): self.array[i, :len(item)] = np.frombuffer(item, dtype=self.dtype) + elif len(array) > 0 and isinstance(array[0], BaseFlags): + # need to pad the array to the maximum size + N, F = (0, 0) + for item in array: + n, f = item.array.shape + F = max(F, f) + N += n + if N != len(array): + raise ValueError("All items must have only one row when building from an iterable of BaseFlags") + self.array = np.zeros((N, F), dtype=self.dtype) + for i, item in enumerate(array): + self.array[i, :item.array.shape[1]] = item.array else: self.array = np.atleast_2d(array).astype(self.dtype) return None @@ -85,9 +97,8 @@ def _count(self, bits_per_attribute, skip_empty: bool = False) -> dict: Skip flags with no items assigned to them. """ counts = {} - flags = self.as_boolean_array() # (N, F) shape for attribute, bits in bits_per_attribute.items(): - count = np.sum(np.any(flags[:, bits], axis=1)) + count = np.sum(self.are_any_bits_set(bits)) if count > 0 or not skip_empty: counts[attribute] = count return counts diff --git a/setup.cfg b/setup.cfg index 9b26817..7236b4b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = sdss_semaphore -version = 0.2.2 +version = 0.2.3 author = Andy Casey author_email = andrew.casey@monash.edu description = Fun with SDSS flags