Skip to content

Commit

Permalink
notebook for targeting flags creation
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Oct 11, 2023
1 parent 1b17910 commit 8c67ff5
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 5 deletions.
187 changes: 187 additions & 0 deletions notebooks/20231011_sdss5_targeting_creation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating a `TargetingFlags` column for flat files\n",
"\n",
"In this notebook we will query the SDSS database and create a FITS file with two columns: `SDSS_ID`, and `SDSS5_TARGET_FLAGS`.\n",
"The `SDSS5_TARGET_FLAGS` will be an array of integers, created using `semaphore`. \n",
"\n",
"Andy Casey ([email protected])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can install `semaphore` with:\n",
"\n",
"```pip install sdss-semaphore==0.2.3```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[1;33m[WARNING]: \u001b[0m\u001b[0;39mcatalog_to_allstar_dr17_synspec_rev1: cannot find related table 'allstar_dr17_synspec_rev1'\u001b[0m \u001b[0;36m(SdssdbUserWarning)\u001b[0m\n"
]
}
],
"source": [
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"from sdssdb.peewee.sdss5db import database\n",
"\n",
"database.set_profile(\"operations\")\n",
"\n",
"from sdssdb.peewee.sdss5db.targetdb import Assignment, CartonToTarget, Target\n",
"from sdssdb.peewee.sdss5db.catalogdb import CatalogdbModel\n",
"\n",
"from sdss_semaphore.targeting import TargetingFlags\n",
"\n",
"# At the time of writing, this peewee model did not exist in my version of sdssdb.\n",
"# In future you can probably just import directly: `from sdssdb.peewee.sdss5db.catalogdb import SDSS_ID_Flat`\n",
"\n",
"class SDSS_ID_Flat(CatalogdbModel):\n",
" class Meta:\n",
" table_name = \"sdss_id_flat\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we knew how many sources there would be then we could just create one `TargetingFlags` object for N sources, and then use the source identifier (e.g., SDSS_ID) to look up which index we need to set each bit for. But that lookup can be expensive for many objects, and we want to avoid pre-computing how many sources there are.\n",
"\n",
"Instead we will create a dictionary of `TargetingFlags` objects, keyed by the source identifier, and merge them together into one `TargetingFlags` object at the end.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Assignments to `TargetingFlags`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"10000it [00:00, 18325.14it/s] \n"
]
}
],
"source": [
"\n",
"# This is a YUGE query. Let's limit to 10,000 for testing purposes\n",
"q = (\n",
" SDSS_ID_Flat\n",
" .select(\n",
" SDSS_ID_Flat.sdss_id,\n",
" CartonToTarget.carton_pk,\n",
" )\n",
" .join(Target, on=(SDSS_ID_Flat.catalogid == Target.catalogid))\n",
" .join(CartonToTarget, on=(Target.pk == CartonToTarget.target_pk))\n",
" .join(Assignment, on=(Assignment.carton_to_target_pk == CartonToTarget.pk))\n",
" .tuples()\n",
" .limit(10_000)\n",
" .iterator()\n",
")\n",
"\n",
"manual_counts = {}\n",
"\n",
"flags_dict = {}\n",
"for sdss_id, carton_pk in tqdm(q, total=1): # total=1 prevents tqdm from executing the count() query\n",
" try:\n",
" flags_dict[sdss_id]\n",
" except KeyError:\n",
" flags_dict[sdss_id] = TargetingFlags()\n",
" \n",
" flags_dict[sdss_id].set_bit_by_carton_pk(0, carton_pk) # 0 since this is the only object\n",
" manual_counts.setdefault(carton_pk, set())\n",
" manual_counts[carton_pk].add(sdss_id)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Now we will create two columns:\n",
"# - one for all our source identifiers\n",
"# - one for all our targeting flags\n",
"\n",
"sdss_ids = list(flags_dict.keys())\n",
"flags = TargetingFlags(list(flags_dict.values()))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# A sanity check.\n",
"for carton_pk, count in flags.count_by_attribute(\"carton_pk\", skip_empty=True).items():\n",
" assert count == len(manual_counts[carton_pk])\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Now let's write out our columns to a fits file.\n",
"from astropy.io import fits\n",
"\n",
"N, F = flags.array.shape\n",
"hdul = fits.HDUList([\n",
" fits.PrimaryHDU(),\n",
" fits.BinTableHDU.from_columns([\n",
" fits.Column(name=\"SDSS_ID\", array=sdss_ids, format=\"K\"),\n",
" fits.Column(name=\"SDSS5_TARGET_FLAGS\", array=flags.array, format=f\"{F}B\", dim=f\"({F})\")\n",
" ])\n",
"])\n",
"hdul.writeto(\"output.fits\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "astra_base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
19 changes: 15 additions & 4 deletions python/sdss_semaphore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.2"
__version__ = "0.2.3"

import numpy as np
import warnings
Expand All @@ -9,7 +9,7 @@ class BaseFlags:

"""A base class for communicating with flags."""

def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], Iterable[bytearray]]] = None) -> None:
def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], Iterable[bytearray], Iterable['BaseFlags']]] = None) -> None:
if array is None:
# Assume single object flag.
self.array = np.zeros((1, 0), dtype=self.dtype)
Expand All @@ -22,6 +22,18 @@ def __init__(self, array: Optional[Union[np.ndarray, Iterable[Iterable[int]], It
self.array = np.zeros((N, F), dtype=self.dtype)
for i, item in enumerate(array):
self.array[i, :len(item)] = np.frombuffer(item, dtype=self.dtype)
elif len(array) > 0 and isinstance(array[0], BaseFlags):
# need to pad the array to the maximum size
N, F = (0, 0)
for item in array:
n, f = item.array.shape
F = max(F, f)
N += n
if N != len(array):
raise ValueError("All items must have only one row when building from an iterable of BaseFlags")
self.array = np.zeros((N, F), dtype=self.dtype)
for i, item in enumerate(array):
self.array[i, :item.array.shape[1]] = item.array
else:
self.array = np.atleast_2d(array).astype(self.dtype)
return None
Expand Down Expand Up @@ -85,9 +97,8 @@ def _count(self, bits_per_attribute, skip_empty: bool = False) -> dict:
Skip flags with no items assigned to them.
"""
counts = {}
flags = self.as_boolean_array() # (N, F) shape
for attribute, bits in bits_per_attribute.items():
count = np.sum(np.any(flags[:, bits], axis=1))
count = np.sum(self.are_any_bits_set(bits))
if count > 0 or not skip_empty:
counts[attribute] = count
return counts
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = sdss_semaphore
version = 0.2.2
version = 0.2.3
author = Andy Casey
author_email = [email protected]
description = Fun with SDSS flags
Expand Down

0 comments on commit 8c67ff5

Please sign in to comment.