diff --git a/python/target_selection/cartons/base.py b/python/target_selection/cartons/base.py index c4a26508..3a45a0c9 100644 --- a/python/target_selection/cartons/base.py +++ b/python/target_selection/cartons/base.py @@ -255,12 +255,12 @@ def run(self, if self.database.table_exists(self.table_name, schema=self.schema): if overwrite: - log.info(f'Dropping table {path!r}.') + self.log.info(f'Dropping table {path!r}.') self.drop_table() else: raise RuntimeError(f'Temporary table {path!r} already exists.') - log.info('Running query ...') + self.log.info('Running query ...') version_id = self.get_version_id() with Timer() as timer: @@ -304,21 +304,24 @@ def run(self, query_str = cursor.mogrify(query_sql, params).decode() if not self._disable_query_log: - log.debug(color_text(f'CREATE TABLE IF NOT EXISTS {path} AS ' + query_str, - 'darkgrey')) + log_message = f'CREATE TABLE IF NOT EXISTS {path} AS ' + query_str + if self.log.rich_console: + self.log.debug(log_message, extra={"highlighter": None}) + else: + self.log.debug(color_text(log_message, 'darkgrey')) else: - log.debug('Not printing VERY long query.') + self.log.debug('Not printing VERY long query.') with self.database.atomic(): self.setup_transaction() execute_sql(f'CREATE TABLE IF NOT EXISTS {path} AS ' + query_sql, params) - log.info(f'Created table {path!r} in {timer.interval:.3f} s.') + self.log.info(f'Created table {path!r} in {timer.interval:.3f} s.') self.RModel = self.get_model() - log.debug('Adding columns and indexes.') + self.log.debug('Adding columns and indexes.') columns = [ col.name for col in self.database.get_columns(self.table_name, self.schema) @@ -345,18 +348,18 @@ def run(self, execute_sql(f'ANALYZE {path};') n_rows = self.RModel.select().count() - log.debug(f'Table {path!r} contains {n_rows:,} rows.') + self.log.debug(f'Table {path!r} contains {n_rows:,} rows.') - log.debug('Running post-process.') + self.log.debug('Running post-process.') with self.database.atomic(): self.setup_transaction() self.post_process(self.RModel, **post_process_kawrgs) n_selected = self.RModel.select().where(self.RModel.selected >> True).count() - log.debug(f'Selected {n_selected:,} rows after post-processing.') + self.log.debug(f'Selected {n_selected:,} rows after post-processing.') if add_optical_magnitudes: - log.debug('Adding optical magnitude columns.') + self.log.debug('Adding optical magnitude columns.') self.add_optical_magnitudes() self.has_run = True @@ -380,12 +383,10 @@ def add_optical_magnitudes(self): if any([mag in Model._meta.columns for mag in magnitudes]): if not all([mag in Model._meta.columns for mag in magnitudes]): raise TargetSelectionError( - 'Some optical magnitudes are defined in the query ' - 'but not all of them.') + 'Some optical magnitudes are defined in the query but not all of them.') if 'optical_prov' not in Model._meta.columns: raise TargetSelectionError('optical_prov column does not exist.') - warnings.warn('All optical magnitude columns are defined in the query.', - TargetSelectionUserWarning) + self.log.warning('All optical magnitude columns are defined in the query.') return # First create the columns. Also create z to speed things up. We won't @@ -716,7 +717,7 @@ def write_table(self, filename=None, mode='results', write=True): else: filename = f'{self.name}_{self.plan}_targetdb.fits.gz' - log.debug(f'Writing table to {filename}.') + self.log.debug(f'Writing table to {filename}.') if not self.RModel: self.RModel = self.get_model() @@ -793,11 +794,30 @@ def write_table(self, filename=None, mode='results', write=True): return carton_table - def load(self, overwrite=False): - """Loads the output of the intermediate table into targetdb.""" + def load(self, mode='fail', overwrite=False): + """Loads the output of the intermediate table into targetdb. + + Parameters + ---------- + mode : str + The mode to use when loading the targets. If ``'fail'``, raises an + error if the carton already exist. If ``'overwrite'``, overwrites + the targets. If ``'append'``, appends the targets. + overwrite : bool + Equivalent to setting ``mode='overwrite'``. This option is deprecated and + will raise a warning. + + """ + + if overwrite: + mode = 'overwrite' + warnings.warn( + 'The `overwrite` option is deprecated and will be removed in a future version. ' + 'Use `mode="overwrite"` instead.', + TargetSelectionUserWarning) if self.check_targets(): - if overwrite: + if mode == 'overwrite': warnings.warn( f'Carton {self.name!r} with plan {self.plan!r} ' f'already has targets loaded. ' @@ -805,12 +825,16 @@ def load(self, overwrite=False): TargetSelectionUserWarning, ) self.drop_carton() - else: + elif mode == 'append': + pass + elif mode == 'fail': raise TargetSelectionError( f'Found existing targets for ' f'carton {self.name!r} with plan ' f'{self.plan!r}.' ) + else: + raise ValueError(f'Invalid mode {mode!r}. Use "fail", "overwrite", or "append".') if self.RModel is None: RModel = self.get_model() @@ -872,7 +896,7 @@ def _create_carton_metadata(self): version_pk = version.pk if created: - log.info( + self.log.info( f'Created record in targetdb.version for ' f'{self.plan!r} with tag {self.tag!r}.' ) @@ -889,13 +913,13 @@ def _create_carton_metadata(self): mapper, created_pk = tdb.Mapper.get_or_create(label=self.mapper) mapper_pk = mapper.pk if created: - log.debug(f'Created mapper {self.mapper!r}') + self.log.debug(f'Created mapper {self.mapper!r}') if self.category: category, created = tdb.Category.get_or_create(label=self.category) category_pk = category.pk if created: - log.debug(f'Created category {self.category!r}') + self.log.debug(f'Created category {self.category!r}') tdb.Carton.create( carton=self.name, @@ -906,12 +930,12 @@ def _create_carton_metadata(self): run_on=datetime.datetime.now().isoformat().split('T')[0] ).save() - log.debug(f'Created carton {self.name!r}') + self.log.debug(f'Created carton {self.name!r}') def _load_targets(self, RModel): """Load data from the intermediate table tp targetdb.target.""" - log.debug('loading data into targetdb.target.') + self.log.debug('loading data into targetdb.target.') n_inserted = ( # noqa: 841 tdb.Target.insert_from( @@ -947,14 +971,14 @@ def _load_targets(self, RModel): .execute() ) - log.info('Inserted new rows into targetdb.target.') + self.log.info('Inserted new rows into targetdb.target.') return def _load_magnitudes(self, RModel): """Load magnitudes into targetdb.magnitude.""" - log.debug('Loading data into targetdb.magnitude.') + self.log.debug('Loading data into targetdb.magnitude.') Magnitude = tdb.Magnitude @@ -1035,12 +1059,12 @@ def _load_magnitudes(self, RModel): n_inserted = Magnitude.insert_from(select_from, fields).returning().execute() # noqa: 841 - log.info('Inserted new rows into targetdb.magnitude.') + self.log.info('Inserted new rows into targetdb.magnitude.') def _load_carton_to_target(self, RModel): """Populate targetdb.carton_to_target.""" - log.debug('Loading data into targetdb.carton_to_target.') + self.log.debug('Loading data into targetdb.carton_to_target.') version_pk = tdb.Version.get( plan=self.plan, @@ -1228,7 +1252,7 @@ def _load_carton_to_target(self, RModel): .execute() ) - log.info('Inserted rows into targetdb.carton_to_target.') + self.log.info('Inserted rows into targetdb.carton_to_target.') def drop_carton(self): """Drops the entry in ``targetdb.carton``.""" diff --git a/python/target_selection/cartons/too.py b/python/target_selection/cartons/too.py new file mode 100644 index 00000000..84ad974a --- /dev/null +++ b/python/target_selection/cartons/too.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# @Author: José Sánchez-Gallego (gallegoj@uw.edu) +# @Date: 2024-02-16 +# @Filename: too.py +# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) + +from __future__ import annotations + +from sdssdb.peewee.sdss5db.catalogdb import (CatalogToToO_Target, + ToO_Metadata, ToO_Target) +from sdssdb.peewee.sdss5db.targetdb import (Carton, CartonToTarget, + Target, Version) + +from .base import BaseCarton + + +__all__ = ['ToO_Carton'] + + +class ToO_Carton(BaseCarton): + """Target of opportunity carton. + + Selects all the targets in ``catalogdb.too_target`` that don't yet exist in + the carton. + + """ + + name = 'too' + category = 'science' + cadence = 'bright_1x1' + priority = 3000 + program = 'too' + can_offset = True + + def build_query(self, version_id, query_region=None): + + C2TT = CatalogToToO_Target + + too_in_carton = (Target + .select(Target.catalogid) + .join(CartonToTarget) + .join(Carton) + .join(Version) + .where(Version.plan == self.plan, + Carton.carton == self.name)).alias('too_in_carton') + + query = (ToO_Target.select(C2TT.catalogid, + ToO_Target.fiber_type.alias('instrument'), + ToO_Metadata.g_mag.alias('g'), + ToO_Metadata.r_mag.alias('r'), + ToO_Metadata.i_mag.alias('i'), + ToO_Metadata.z_mag.alias('z'), + ToO_Metadata.optical_prov) + .join(C2TT, on=(ToO_Target.too_id == C2TT.target_id)) + .switch(ToO_Target) + .join(ToO_Metadata, on=(ToO_Target.too_id == ToO_Metadata.too_id)) + .where(C2TT.version_id == version_id, + C2TT.best >> True, + C2TT.catalogid.not_in(too_in_carton))) + + return query diff --git a/python/target_selection/config/target_selection.yml b/python/target_selection/config/target_selection.yml index 8082dc93..2a79216b 100644 --- a/python/target_selection/config/target_selection.yml +++ b/python/target_selection/config/target_selection.yml @@ -1,3 +1,18 @@ +'1.1.0': # This is a dedicated plan for ToO + xmatch_plan: 1.0.0 + cartons: + - too + schema: sandbox + magnitudes: + h: [catalog_to_twomass_psc, twomass_psc, twomass_psc.h_m] + j: [catalog_to_twomass_psc, twomass_psc, twomass_psc.j_m] + k: [catalog_to_twomass_psc, twomass_psc, twomass_psc.k_m] + bp: [catalog_to_gaia_dr3_source, gaia_dr3_source, gaia_dr3_source.phot_bp_mean_mag] + rp: [catalog_to_gaia_dr3_source, gaia_dr3_source, gaia_dr3_source.phot_rp_mean_mag] + gaia_g: [catalog_to_gaia_dr3_source, gaia_dr3_source, gaia_dr3_source.phot_g_mean_mag] + database_options: + work_mem: '5GB' + '1.0.51': xmatch_plan: 1.0.0 cartons: diff --git a/python/target_selection/xmatch.py b/python/target_selection/xmatch.py index f7539e9b..eccd3466 100644 --- a/python/target_selection/xmatch.py +++ b/python/target_selection/xmatch.py @@ -6,6 +6,7 @@ # @Filename: xmatch.py # @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) +import copy import hashlib import inspect import os @@ -17,6 +18,7 @@ import networkx import numpy import peewee +import rich.markup import yaml from networkx.algorithms import shortest_path from peewee import SQL, Case, Model, fn @@ -360,10 +362,15 @@ class XMatchPlanner(object): Used in phase 2. Defaults to 1 arcsec. schema : str The schema in which all the tables to cross-match live (multiple - schemas are not supported), and the schema in which the output table + schemas are not supported), and the schema in which the output tables will be created. output_table : str The name of the output table. Defaults to ``catalog``. + temp_schema + The schema where the temporary ``catalog`` table will be initially created. + log + A logger to which to log messages. If not provided the ``target_selection`` + logger is used. log_path : str The path to which to log or `False` to disable file logging. debug : bool or int @@ -389,11 +396,6 @@ class XMatchPlanner(object): join_paths : list When using path_mode=``config_list`` is the list of paths to link tables to output catalog table in phase_1. - phase1_range : list - List to indicate which tables will be ingested in phase_1 with multiple - queries split by pk range instead of using a single query. - Each element of this list is a list with the name of the table the - minimum pk value, the maximum pk value and the bin width. database_options : dict A dictionary of database configuration parameters to be set locally during each phase transaction, temporarily overriding the default @@ -415,12 +417,12 @@ class XMatchPlanner(object): def __init__(self, database, models, plan, run_id, version_id=None, extra_nodes=[], order='hierarchical', key='row_count', epoch=EPOCH, start_node=None, query_radius=None, - schema='catalogdb', output_table='catalog', - log_path='./xmatch_{plan}.log', debug=False, show_sql=False, + schema='catalogdb', temp_schema=TEMP_SCHEMA, output_table='catalog', + log=None, log_path='./xmatch_{plan}.log', debug=False, show_sql=False, sample_region=None, database_options=None, path_mode='full', - join_paths=None, phase1_range=[]): + join_paths=None): - self.log = target_selection.log + self.log = log or target_selection.log self.log.header = '' if log_path: @@ -439,6 +441,7 @@ def __init__(self, database, models, plan, run_id, version_id=None, self.log.sh.setLevel(debug) self.schema = schema + self.temp_schema = temp_schema self.output_table = output_table self.md5 = hashlib.md5(plan.encode()).hexdigest()[0:16] @@ -486,7 +489,6 @@ def __init__(self, database, models, plan, run_id, version_id=None, self._max_cid = self.run_id << RUN_ID_BIT_SHIFT self.path_mode = path_mode - self.phase1_range = phase1_range if path_mode == 'config_list': join_paths_warning = 'join_paths needed for path_mode=config_list' @@ -562,7 +564,7 @@ def read(cls, in_models, plan, config_file=None, **kwargs): attribute. plan : str The cross-matching plan. - config_file : str + config_file : str or dict The path to the configuration file to use. Defaults to ``config/xmatch.yml``. The file must contain a hash with the cross-match plan. @@ -641,7 +643,10 @@ def read(cls, in_models, plan, config_file=None, **kwargs): def _read_config(file_, plan): """Reads the configuration file, recursively.""" - config = yaml.load(open(file_, 'r'), Loader=yaml.SafeLoader) + if isinstance(file_, dict): + config = copy.deepcopy(file_) + else: + config = yaml.load(open(file_, 'r'), Loader=yaml.SafeLoader) assert plan in config, f'plan {plan!r} not found in configuration.' @@ -721,7 +726,7 @@ def _log_db_configuration(self): log_str = f'{parameter} = {values[parameter]}' self.log.debug(log_str) - def update_model_graph(self, silent=False): + def update_model_graph(self): """Updates the model graph using models as nodes and fks as edges.""" self.model_graph = networkx.Graph() @@ -759,7 +764,7 @@ def update_model_graph(self, silent=False): join_weight=join_weight) if model in self.models.values(): - rel_model = self._get_relational_model(model) + rel_model = self.get_relational_model(model, sandboxed=False) rel_model._meta.schema = self.schema rel_model_tname = rel_model._meta.table_name self.model_graph.add_node(rel_model_tname, model=rel_model) @@ -968,7 +973,7 @@ def _prepare_models(self): Catalog._meta.table_name = self.output_table Catalog._meta.set_database(self.database) - TempCatalog._meta.schema = TEMP_SCHEMA + TempCatalog._meta.schema = self.temp_schema TempCatalog._meta.table_name = self._temp_table TempCatalog._meta.set_database(self.database) @@ -1047,16 +1052,20 @@ def _create_models(self, force=False): self._check_version(TempCatalog, force) - self._temp_count = int(get_row_count(self.database, - self._temp_table, - schema=self.schema, - approximate=True)) + try: + self._temp_count = int(get_row_count(self.database, + self._temp_table, + schema=self.schema, + approximate=True)) + except ValueError: + self._temp_count = 0 + else: # Add Q3C index for TempCatalog - TempCatalog.add_index(SQL(f'CREATE INDEX ' + TempCatalog.add_index(SQL(f'CREATE INDEX IF NOT EXISTS ' f'{self._temp_table}_q3c_idx ON ' - f'{TEMP_SCHEMA}.{self._temp_table} ' + f'{self.temp_schema}.{self._temp_table} ' f'(q3c_ang2ipix(ra, dec))')) self.database.create_tables([TempCatalog]) @@ -1066,7 +1075,7 @@ def _create_models(self, force=False): self._temp_count = 0 def run(self, vacuum=False, analyze=False, from_=None, - force=False, load_catalog=True, keep_temp=False): + force=False, dry_run=False, keep_temp=False): """Runs the cross-matching process. Parameters @@ -1082,15 +1091,24 @@ def run(self, vacuum=False, analyze=False, from_=None, Allows to continue even if the temporary table exists or the output table contains records for this version. ``force=True`` is assumed if ``from_`` is defined. - load_catalog : bool - If `True`, loads the temporary table into ``Catalog``. `False` implies - ``keep_temp=True``. + dry_run : bool + If `False`, loads the temporary tables into ``Catalog`` and ``CatalogToXXX``. + `True` implies ``keep_temp=True``; all the cross-matching steps will be run + but the original tables won't be modified. A dry run can only be executed for + a plan with a single catalogue since processing multiple catalogue requires the + final tables to have been updated for successive catalogues. keep_temp : bool Whether to keep the temporary table or to drop it after the cross matching is done. """ + if len(self.process_order) > 1 and dry_run is True: + raise RuntimeError('Cannot dry run with a plan that includes more than one catalogue.') + + if dry_run: + keep_temp = True + if vacuum or analyze: cmd = ' '.join(('VACUUM' if vacuum else '', 'ANALYZE' if analyze else '')).strip() @@ -1124,20 +1142,22 @@ def run(self, vacuum=False, analyze=False, from_=None, with Timer() as timer: p_order = self.process_order - for table_name in p_order: - if (from_ and - p_order.index(table_name) < p_order.index(from_)): + for norder, table_name in enumerate(p_order): + if dry_run and norder > 0: + raise RuntimeError('Cannot dry run more than one catalogue.') + + if (from_ and p_order.index(table_name) < p_order.index(from_)): self.log.warning(f'Skipping table {table_name}.') continue model = self.models[table_name] - self.process_model(model) + self.process_model(model, force=force) - if load_catalog: - self._load_output_table(keep_temp=keep_temp) + if not dry_run: + self.load_output_tables(model, keep_temp=keep_temp) self.log.info(f'Cross-matching completed in {timer.interval:.3f} s.') - def process_model(self, model): + def process_model(self, model, force=False): """Processes a model, loading it into the output table.""" table_name = model._meta.table_name @@ -1150,6 +1170,16 @@ def process_model(self, model): raise TargetSelectionNotImplemented( 'handling of tables with duplicates is not implemented.') + rel_model = self.get_relational_model(model, sandboxed=True, create=False) + rel_model_table_name = rel_model._meta.table_name + if rel_model.table_exists(): + if force is False: + raise RuntimeError( + f'Sandboxed relational table {rel_model_table_name} exists. ' + 'Delete it manually before continuing.') + else: + self.log.warning(f'Sandboxed relational table {rel_model_table_name} exists.') + # Check if there are already records in catalog for this version. if self.process_order.index(model._meta.table_name) == 0 and not self.is_addendum: is_first_model = True @@ -1166,8 +1196,7 @@ def process_model(self, model): self._run_phase_2(model) self._run_phase_3(model) - self.log.info(f'Fully processed {table_name} ' - f'in {timer.elapsed:.0f} s.') + self.log.info(f'Fully processed {table_name} in {timer.elapsed:.0f} s.') self.update_model_graph() @@ -1263,11 +1292,19 @@ def _get_model_fields(self, model): return model_fields - def _get_relational_model(self, model, temp=False, create=False): + def get_output_model(self, temporary=False): + """Returns the temporary or final output model (``catalog``).""" + + if temporary: + return TempCatalog + + return Catalog + + def get_relational_model(self, model, sandboxed=False, temp=False, create=False): """Gets or creates a relational table for a given model. - Returns the relational model and `True` if the table was created or - `False` if it already existed. + When the relational model is ``sandboxed``, the table is created in the + temporary schema and suffixed with the same MD5 used for the run. """ @@ -1298,10 +1335,11 @@ class BaseModel(peewee.Model): distance = peewee.DoubleField(null=True) best = peewee.BooleanField(null=False) plan_id = peewee.TextField(null=True) + added_by_phase = peewee.SmallIntegerField(null=True) class Meta: database = meta.database - schema = meta.schema + schema = self.temp_schema if sandboxed else meta.schema primary_key = False model_prefix = ''.join(x.capitalize() or '_' @@ -1316,10 +1354,10 @@ class Meta: RelationalModel._meta.schema = None return RelationalModel - if meta.xmatch.relational_table is not None: - RelationalModel._meta.table_name = meta.xmatch.relational_table - else: - RelationalModel._meta.table_name = prefix + meta.table_name + table_name = prefix + meta.table_name + if sandboxed: + table_name += f'_{self.md5}' + RelationalModel._meta.table_name = table_name if create and not RelationalModel.table_exists(): RelationalModel.create_table() @@ -1362,7 +1400,9 @@ def _run_phase_1(self, model): """Runs the linking against matched catalogids stage.""" table_name = model._meta.table_name - rel_model = self._get_relational_model(model, create=True) + rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) + + rel_model = self.get_relational_model(model, create=False, sandboxed=False) model_pk = model._meta.primary_key @@ -1397,32 +1437,20 @@ def _run_phase_1(self, model): # a sequential scan. join_rel_model = join_models[-1] - # We can have join catalogues that produce more than - # one result for each target or different targets that - # are joined to the same catalogid. In the future we may - # want to order by something that selects the best - # candidate. - - partition = (fn.first_value(model_pk) - .over(partition_by=[join_rel_model.catalogid], - order_by=[model_pk.asc()])) - best = peewee.Value(partition == model_pk) - query = (self._build_join(join_models) .select(model_pk.alias('target_id'), join_rel_model.catalogid, - best.alias('best')) + peewee.Value(True).alias('best')) .where(join_rel_model.version_id == self.version_id, join_rel_model.best >> True) .where(~fn.EXISTS( - rel_model + rel_model_sb .select(SQL('1')) - .where(rel_model.version_id == self.version_id, - ((rel_model.target_id == model_pk) | - (rel_model.catalogid == - join_rel_model.catalogid))))) - # To avoid breaking the unique constraint in rel tables - .distinct(model_pk, join_rel_model.catalogid)) + .where(rel_model_sb.version_id == self.version_id, + ((rel_model_sb.target_id == model_pk) | + (rel_model_sb.catalogid == join_rel_model.catalogid))))) + # Select only one match per target in the catalogue with are cross-matching. + .distinct(model_pk)) # Deal with duplicates in LS8 if table_name == 'legacy_survey_dr8': @@ -1433,6 +1461,14 @@ def _run_phase_1(self, model): query = query.where(model.survey_primary >> True, fn.coalesce(model.ref_cat, '') != 'T2') + # If the real relational model exists, exclude any matches that already exist there. + if rel_model.table_exists(): + query = query.where( + ~fn.EXISTS(rel_model + .select(SQL('1')) + .where(rel_model.version_id == self.version_id, + rel_model.target_id == model_pk))) + # In query we do not include a Q3C where for the sample region because # TempCatalog for this plan should already be sample region limited. @@ -1440,140 +1476,65 @@ def _run_phase_1(self, model): with self.database.atomic(): - self._setup_transaction(model, phase=1) - inter_name1 = 'phase1_content' - model_pk_class = model_pk.__class__ - - class Intermediate1(peewee.Model): - """Model for the intermediate results of phase_1.""" - target_id = model_pk_class(null=False, index=True) - catalogid = peewee.BigIntegerField(index=True, null=False) - best = peewee.BooleanField(null=False) - distance = peewee.DoubleField(null=True) - version_id = peewee.SmallIntegerField(null=False, index=True) - - class Meta: - database = self.database - schema = TEMP_SCHEMA - table_name = inter_name1 - primary_key = False - - query_str = self._get_sql(query, return_string=True) - - range_info = self.phase1_range - range_tables = [el[0] for el in range_info] - - if table_name not in range_tables: - self.log.debug(f'Selecting linked targets into ' - f'table {self.schema}.{inter_name1} IN PSQL ' - f'with join path {path}{self._get_sql(query)}') - - sql_file = open('phase1_query.sql', 'w') - - if self._options['database_options']: - options = self._options['database_options'].copy() - for param in options: - if param == 'maintenance_work_mem': - continue - value = options[param] - sql_file.write(f'SET {param}={value!r};\n') - - sql_file.write(f'DROP TABLE IF EXISTS {TEMP_SCHEMA}.{inter_name1};') - sql_file.write(f'CREATE TABLE {TEMP_SCHEMA}.{inter_name1} AS (') - sql_file.write(query_str + ');') - sql_file.close() - - os.system('psql -U sdss -d sdss5db -f ' - 'phase1_query.sql -o phase1_output.log') - - # If the config file indicates that we want to ingest phase_1 by PK range - # Then instead of doing a simple query to create the intermediate result table - # We first create the table and then do multiple queries each with a PK range - # each time inserting the result to the intermediate result table + temp_model = self.get_relational_model(model, temp=True, sandboxed=True) + temp_table = temp_model._meta.table_name - else: - range_info_table = range_info[range_tables.index(table_name)] - low_limit = range_info_table[1] - high_limit = range_info_table[2] - bin_range = range_info_table[3] - first_query = query.where(model_pk < low_limit) - first_query_str = self._get_sql(first_query, return_string=True) - self.log.debug(f'Selecting linked targets into ' - f'table {TEMP_SCHEMA}.{inter_name1} IN PSQL BY PK RANGE ' - f'with join path {path} and base (before splitting) query' - f'{self._get_sql(query)}') - - sql_file = open('phase1_query.sql', 'w') - sql_file.write(f'DROP TABLE IF EXISTS {TEMP_SCHEMA}.{inter_name1};') - sql_file.write(f'CREATE TABLE {TEMP_SCHEMA}.{inter_name1} AS (') - sql_file.write(first_query_str + ');') - sql_file.close() - - os.system('psql -U sdss -d sdss5db -f phase1_query.sql ' - '-o phase1_output.log') - - min_id, max_id = low_limit, low_limit + bin_range - - while max_id <= high_limit: - range_query = query.where(model_pk >= min_id, model_pk < max_id) - range_query_str = self._get_sql(range_query) - sql_file = open('phase1_query.sql', 'w') - - if self._options['database_options']: - options = self._options['database_options'].copy() - for param in options: - if param == 'maintenance_work_mem': - continue - value = options[param] - sql_file.write(f'SET {param}={value!r};\n') - - sql_file.write(f'INSERT INTO {TEMP_SCHEMA}.{inter_name1} ') - sql_file.write(range_query_str + ';') - sql_file.close() - os.system('psql -U sdss -d sdss5db -f phase1_query.sql ' - '-o phase1_output.log') - min_id += bin_range - max_id += bin_range + self._setup_transaction(model, phase=1) + self.log.debug(f'Selecting linked targets into temporary ' + f'table {temp_table!r} with join path ' + f'{path}{self._get_sql(query)}') + query.create_table(temp_table, temporary=True) self.log.debug(f'Copying data into relational model ' - f'{rel_model._meta.table_name!r}.') - - fields = [Intermediate1.target_id, Intermediate1.catalogid, - Intermediate1.version_id, Intermediate1.best, - rel_model.plan_id] - - cursor = rel_model.insert_from( - Intermediate1.select(Intermediate1.target_id, - Intermediate1.catalogid, - peewee.Value(self.version_id), - Intermediate1.best, - self.plan if self.is_addendum else None), + f'{rel_model_sb._meta.table_name!r}.') + + fields = [temp_model.target_id, temp_model.catalogid, + temp_model.version_id, temp_model.best, + temp_model.plan_id, temp_model.added_by_phase] + + nids = rel_model_sb.insert_from( + temp_model.select(temp_model.target_id, + temp_model.catalogid, + peewee.Value(self.version_id), + temp_model.best, + peewee.Value(self.plan) if self.is_addendum else None, + peewee.Value(1)), fields).returning().execute() - nids = cursor.rowcount - self.log.debug(f'Linked {nids:,} records in' - f' {timer.interval:.3f} s.') + self.log.debug(f'Linked {nids.rowcount:,} records in {timer.interval:.3f} s.') self._phases_run.add(1) - if nids > 0: - self._analyze(rel_model) + if nids.rowcount > 0: + self._analyze(rel_model_sb) + + def _run_phase_2(self, model, source=TempCatalog): + """Associates existing targets in Catalog with entries in the model. - def _run_phase_2(self, model): - """Associates existing targets in Catalog with entries in the model.""" + Here ``source`` is the catalogue with which we are spatially cross-matching. + Normally this is the temporary catalog table which we are building for this + cross-match run. But when we are doing an addendum, that table is going to be + empty (at least of the first table in the addeundum), so we need to also have + the option of using the real ``catalog`` table as the source. This method will + call itself recursively with ``source=Catalog`` if this is an addendum run. + + """ meta = model._meta xmatch = meta.xmatch table_name = meta.table_name - self.log.info('Phase 2: cross-matching against existing targets.') + self.log.info('Phase 2: cross-matching against existing targets ' + f'({source._meta.table_name}).') if 2 in xmatch.skip_phases: self.log.warning('Skipping due to configuration.') return - rel_model = self._get_relational_model(model, create=True) - rel_table_name = rel_model._meta.table_name + rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) + rel_sb_table_name = rel_model_sb._meta.table_name + + rel_model = self.get_relational_model(model, create=False, sandboxed=False) model_pk = meta.primary_key model_ra = meta.fields[xmatch.ra_column] @@ -1608,10 +1569,7 @@ def _run_phase_2(self, model): max_delta_epoch += .1 # Add .1 yr to be sure it's an upper bound - self.log.debug(f'Maximum epoch delta: ' - f'{max_delta_epoch:.3f} (+ 0.1 year).') - - self.log.debug('Cross-matching model against temporary table.') + self.log.debug(f'Maximum epoch delta: {max_delta_epoch:.3f} (+ 0.1 year).') if use_pm: @@ -1622,28 +1580,29 @@ def _run_phase_2(self, model): q3c_dist = fn.q3c_dist_pm(model_ra, model_dec, model_pmra, model_pmdec, model_is_pmra_cos, model_epoch, - TempCatalog.ra, TempCatalog.dec, + source.ra, source.dec, catalog_epoch) q3c_join = fn.q3c_join_pm(model_ra, model_dec, model_pmra, model_pmdec, model_is_pmra_cos, model_epoch, - TempCatalog.ra, TempCatalog.dec, + source.ra, source.dec, catalog_epoch, max_delta_epoch, query_radius / 3600.) else: q3c_dist = fn.q3c_dist(model_ra, model_dec, - TempCatalog.ra, TempCatalog.dec) + source.ra, source.dec) q3c_join = fn.q3c_join(model_ra, model_dec, - TempCatalog.ra, TempCatalog.dec, + source.ra, source.dec, query_radius / 3600.) # Get the cross-matched catalogid and model target pk (target_id), # and their distance. - xmatched = (TempCatalog - .select(TempCatalog.catalogid, + xmatched = (source + .select(source.catalogid, model_pk.alias('target_id'), - q3c_dist.alias('distance')) + q3c_dist.alias('distance'), + source.version_id) .join(model, peewee.JOIN.CROSS) .where(q3c_join) .where(self._get_sample_where(model_ra, model_dec))) @@ -1682,20 +1641,28 @@ def _run_phase_2(self, model): peewee.Value(self.version_id).alias('version_id'), xmatched.c.distance.alias('distance'), best.alias('best'), - self.plan if self.is_addendum else None)) + self.plan if self.is_addendum else None, + peewee.Value(2).alias('added_by_phase'))) # We only need to care about already linked targets if phase 1 run. if 1 in self._phases_run: in_query = ( - in_query .where( - ~fn.EXISTS( - rel_model .select( - SQL('1')) .where( - (rel_model.version_id == self.version_id) & ( - (rel_model.catalogid == - xmatched.c.catalogid) | ( - (rel_model.target_id == - xmatched.c.target_id))))))) + in_query.where( + xmatched.c.version_id == self.version_id, + ~fn.EXISTS(rel_model_sb.select(SQL('1')) + .where((rel_model_sb.version_id == self.version_id) & + ((rel_model_sb.catalogid == xmatched.c.catalogid) | + ((rel_model_sb.target_id == xmatched.c.target_id)))) + ))) + + if rel_model.table_exists(): + in_query = ( + in_query.where( + ~fn.EXISTS(rel_model.select(SQL('1')) + .where((rel_model.version_id == self.version_id) & + ((rel_model.catalogid == xmatched.c.catalogid) | + ((rel_model.target_id == xmatched.c.target_id)))) + ))) with Timer() as timer: @@ -1710,46 +1677,33 @@ def _run_phase_2(self, model): # 2. Run cross-match and insert data into relational model. - fields = [rel_model.target_id, rel_model.catalogid, - rel_model.version_id, rel_model.distance, - rel_model.best, rel_model.plan_id] + fields = [rel_model_sb.target_id, rel_model_sb.catalogid, + rel_model_sb.version_id, rel_model_sb.distance, + rel_model_sb.best, rel_model_sb.plan_id, rel_model_sb.added_by_phase] - in_query = (rel_model - .insert_from(in_query.with_cte(xmatched), - fields).returning()) + in_query = rel_model_sb.insert_from(in_query.with_cte(xmatched), + fields).returning() - self.log.debug(f'Inserting cross-matched data into ' - f'relational model {rel_table_name!r} in PSQL: ' + self.log.debug(f'Running Q3C query and inserting cross-matched data into ' + f'relational table {rel_sb_table_name!r}: ' f'{self._get_sql(in_query)}') - in_query_str = self._get_sql(in_query, return_string=True) - - sql_file = open('phase2_query.sql', 'w') - - if self._options['database_options']: - options = self._options['database_options'].copy() - for param in options: - if param == 'maintenance_work_mem': - continue - value = options[param] - sql_file.write(f'SET {param}={value!r};\n') + n_catalogid = in_query.execute().rowcount - sql_file.write(in_query_str + ';') - sql_file.close() - - os.system('psql -U sdss -d sdss5db -f phase2_query.sql -o phase2_output.log') - out_file = open('phase2_output.log', 'r') - lines = out_file.readlines() - out_file.close() - n_catalogid = int(lines[-1].split()[-1]) - - self.log.debug(f'Cross-matched {TempCatalog._meta.table_name} with ' + self.log.debug(f'Cross-matched {source._meta.table_name} with ' f'{n_catalogid:,} targets in {table_name}. ' f'Run in {timer.interval:.3f} s.') if n_catalogid > 0: self._phases_run.add(2) - self._analyze(rel_model) + self._analyze(rel_model_sb) + + # For addenda it's not sufficient to cross-match with the temporary table, because that + # does not contain all the cumulated targets from this cross-match version. We need to + # also spatially cross-match with the real catalog table (but only for the targets with + # version_id=). + if self.is_addendum and source != Catalog: + self._run_phase_2(model, source=Catalog) def _run_phase_3(self, model): """Add non-matched targets to Catalog and the relational table.""" @@ -1759,8 +1713,10 @@ def _run_phase_3(self, model): self.log.info('Phase 3: adding non cross-matched targets.') - rel_model = self._get_relational_model(model, create=True) - rel_table_name = rel_model._meta.table_name + rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) + rel_sb_table_name = rel_model_sb._meta.table_name + + rel_model = self.get_relational_model(model, create=False, sandboxed=False) if 3 in xmatch.skip_phases: self.log.warning('Skipping due to configuration.') @@ -1785,10 +1741,19 @@ def _run_phase_3(self, model): unmatched = ( unmatched.where( ~fn.EXISTS( - rel_model .select( - SQL('1')) .where( - rel_model.version_id == self.version_id, - rel_model.target_id == model_pk)))) + rel_model_sb.select(SQL('1')).where( + rel_model_sb.version_id == self.version_id, + rel_model_sb.target_id == model_pk, + rel_model_sb.best >> True)))) + + if rel_model.table_exists(): + print('Im here', rel_model) + unmatched = unmatched.where( + ~fn.EXISTS( + rel_model.select(SQL('1')).where( + rel_model.version_id == self.version_id, + rel_model.target_id == model_pk, + rel_model.best >> True))) if xmatch.has_missing_coordinates: unmatched = unmatched.where(model_ra.is_null(False), @@ -1816,82 +1781,51 @@ def _run_phase_3(self, model): # 1. Run link query and create temporary table with results. self._setup_transaction(model, phase=3) - inter_name3 = 'phase3_content' - model_pk_class = model_pk.__class__ - - class Intermediate3(peewee.Model): - """Model for the intermediate results of phase_3.""" - catalogid = peewee.BigIntegerField(index=True, null=False) - target_id = model_pk_class(null=False, index=True) - version_id = peewee.SmallIntegerField(null=False, index=True) - best = peewee.BooleanField(null=False) - distance = peewee.DoubleField(null=True) - - class Meta: - database = self.database - schema = TEMP_SCHEMA - table_name = inter_name3 - primary_key = False - - unmatched_query_str = self._get_sql(unmatched, return_string=True) + temp_model = self.get_relational_model(model, temp=True, sandboxed=True) + temp_model_name = temp_model._meta.table_name self.log.debug(f'Selecting unique targets ' - f'into table ' - f'{self.schema}.{inter_name3!r}' - f' in PSLQ: {self._get_sql(unmatched)}') - - sql_file = open('phase3_query.sql', 'w') - - if self._options['database_options']: - options = self._options['database_options'].copy() - for param in options: - if param == 'maintenance_work_mem': - continue - value = options[param] - sql_file.write(f'SET {param}={value!r};\n') - - sql_file.write(f'DROP TABLE IF EXISTS {TEMP_SCHEMA}.{inter_name3};') - sql_file.write(f'CREATE TABLE {TEMP_SCHEMA}.{inter_name3} AS (') - sql_file.write(unmatched_query_str + ');') - sql_file.close() + f'into temporary table ' + f'{temp_model_name!r}{self._get_sql(unmatched)}') - os.system('psql -U sdss -d sdss5db -f phase3_query.sql -o phase3_output.log') + unmatched.create_table(temp_model_name, temporary=True) # Analyze the temporary table to gather stats. # self.log.debug('Running ANALYZE on temporary table.') - # self.database.execute_sql(f'ANALYZE "{temp_table}";') + # self.database.execute_sql(f'ANALYZE "{temp_model_name}";') # 2. Copy data from temporary table to relational table. Add # catalogid at this point. - fields = [Intermediate3.catalogid, Intermediate3.target_id, - Intermediate3.version_id, Intermediate3.best, - rel_model.plan_id] - - rel_insert_query = rel_model.insert_from( - Intermediate3.select(Intermediate3.catalogid, - Intermediate3.target_id, - self.version_id, - peewee.SQL('true'), - self.plan if self.is_addendum else None), + fields = [temp_model.catalogid, temp_model.target_id, + temp_model.version_id, temp_model.best, + rel_model_sb.plan_id, rel_model_sb.added_by_phase] + + rel_insert_query = rel_model_sb.insert_from( + temp_model.select(temp_model.catalogid, + temp_model.target_id, + self.version_id, + peewee.SQL('true'), + self.plan if self.is_addendum else None, + peewee.Value(3).alias('added_by_phase')), fields).returning() self.log.debug(f'Copying data into relational model ' - f'{rel_table_name!r}' + f'{rel_sb_table_name!r}' f'{self._get_sql(rel_insert_query)}') cursor = rel_insert_query.execute() n_rows = cursor.rowcount - self.log.debug(f'Insertion into {rel_table_name} completed ' + self.log.debug(f'Insertion into {rel_sb_table_name} completed ' f'with {n_rows:,} rows in ' f'{timer.elapsed:.3f} s.') # 3. Fill out the temporary catalog table with the information # from the unique targets. - temp_table = peewee.Table(inter_name3, schema=TEMP_SCHEMA) + temp_table = peewee.Table(temp_model_name) fields = [TempCatalog.catalogid, TempCatalog.lead, @@ -1937,28 +1871,40 @@ class Meta: if n_rows > 0.5 * self._temp_count: # Cluster if > 50% of rows are new self.log.debug(f'Running CLUSTER on {self._temp_table} ' f'with q3c index.') - self.database.execute_sql(f'CLUSTER {TEMP_SCHEMA}.{self._temp_table} ' + self.database.execute_sql(f'CLUSTER {self.temp_schema}.{self._temp_table} ' f'using {self._temp_table}_q3c_idx;') self.log.debug(f'Running ANALYZE on {self._temp_table}.') - self.database.execute_sql(f'ANALYZE {TEMP_SCHEMA}.{self._temp_table};') + self.database.execute_sql(f'ANALYZE {self.temp_schema}.{self._temp_table};') + + self._analyze(rel_model_sb, catalog=False) + + def load_output_tables(self, model, keep_temp=False): + """Loads the temporary tables into the output tables.""" - self._analyze(rel_model, catalog=False) + self._load_output_table(TempCatalog, Catalog, keep_temp=keep_temp) - def _load_output_table(self, keep_temp=False): + rel_model_sb = self.get_relational_model(model, sandboxed=True, create=False) + rel_model = self.get_relational_model(model, sandboxed=False, create=True) + self._load_output_table(rel_model_sb, rel_model, keep_temp=keep_temp) + + def _load_output_table(self, from_model, to_model, keep_temp=False): """Copies the temporary table to the real output table.""" - self.log.info('Copying temporary table to output table.') + to_table = f'{to_model._meta.schema}.{to_model._meta.table_name}' + from_table = f'{from_model._meta.schema}.{from_model._meta.table_name}' + + self.log.info(f'Copying {from_table} to {to_table}.') with Timer() as timer: with self.database.atomic(): self._setup_transaction() - insert_query = Catalog.insert_from( - TempCatalog.select(), - TempCatalog.select()._returning).returning() + insert_query = to_model.insert_from( + from_model.select(), + from_model.select()._returning).returning() - self.log.debug(f'Running INSERT query into {self.output_table}' + self.log.debug(f'Running INSERT query into {to_table}' f'{self._get_sql(insert_query)}') cursor = insert_query.execute() @@ -1967,15 +1913,14 @@ def _load_output_table(self, keep_temp=False): self.log.debug(f'Inserted {n_rows:,} rows in {timer.elapsed:.3f} s.') if not keep_temp: - self.database.drop_tables([TempCatalog]) - self.log.info(f'Dropped temporary table {self._temp_table}') + self.database.drop_tables([from_model]) + self.log.info(f'Dropped temporary table {from_table}.') - self.log.debug(f'Running VACUUM ANALYZE on {self.output_table}.') - vacuum_table(self.database, f'{self.schema}.{self.output_table}', - vacuum=True, analyze=True) + self.log.debug(f'Running VACUUM ANALYZE on {to_table}.') + vacuum_table(self.database, to_table, vacuum=True, analyze=True) def _get_sql(self, query, return_string=False): - """Returns coulourised SQL text for logging.""" + """Returns colourised SQL text for logging.""" query_str, query_params = query.sql() @@ -1989,6 +1934,8 @@ def _get_sql(self, query, return_string=False): query_str = query_str.replace('None', 'Null') if return_string: return query_str + elif self.log.rich_console: + return f': {rich.markup.escape(query_str)}' elif self._options['show_sql']: return f': {color_text(query_str, "blue")}' else: @@ -2078,8 +2025,7 @@ def _analyze(self, rel_model, vacuum=False, catalog=False): if db_opts: work_mem = db_opts.get('maintenance_work_mem', None) if work_mem: - self.database.execute_sql(f'SET maintenance_work_mem' - f' = {work_mem!r}') + self.database.execute_sql(f'SET maintenance_work_mem = {work_mem!r}') self.log.debug(f'Running ANALYZE on {table_name}.') vacuum_table(self.database, f'{schema}.{table_name}', diff --git a/setup.cfg b/setup.cfg index 568e9279..1c400818 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ install_requires = healpy>=1.13.0 pymangle @ git+https://github.com/esheldon/pymangle.git tables>=3.6.1 - sdsstools>=1.0.0 + sdsstools>=1.6.0 enlighten>=1.4.0 mocpy>=0.8.5 pymoc>=0.5.0