Skip to content

Commit

Permalink
Merge pull request #367 from alercebroker/refactor/lc-classifier-grou…
Browse files Browse the repository at this point in the history
…p-by-oid

Refactor LC Classification Step: Use oid to group objects
  • Loading branch information
dirodriguezm authored Jan 3, 2024
2 parents 77bcc18 + 12aa92e commit d1e20c1
Show file tree
Hide file tree
Showing 24 changed files with 312 additions and 311 deletions.
1 change: 0 additions & 1 deletion .github/workflows/template_build_with_dagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
with:
ref: main
submodules: ${{ inputs.submodules}}
token: ${{ secrets.GH_TOKEN }}
- name: Install poetry
Expand Down
33 changes: 16 additions & 17 deletions .github/workflows/xmatch_step.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,38 @@ on:
branches:
- main
paths:
- 'xmatch_step/**'
- '!xmatch_step/README.md'
- "xmatch_step/**"
- "!xmatch_step/README.md"

jobs:
xmatch_step_lint:
uses: ./.github/workflows/lint-template.yaml
with:
base-folder: 'xmatch_step'
sources-folder: 'xmatch_step'
base-folder: "xmatch_step"
sources-folder: "xmatch_step"
xmatch_step_unittest:
uses: ./.github/workflows/poetry-tests-template.yaml
with:
base-folder: 'xmatch_step'
python-version: "3.7"
base-folder: "xmatch_step"
python-version: "3.10"
test-folder: "tests/unittest"
secrets:
GH_TOKEN: '${{ secrets.ADMIN_TOKEN }}'
GH_TOKEN: "${{ secrets.ADMIN_TOKEN }}"
xmatch_step_integration:
uses: ./.github/workflows/poetry-tests-template.yaml
with:
base-folder: 'xmatch_step'
python-version: '3.7'
sources-folder: 'xmatch_step'
test-folder: 'tests/integration'
codecov-flags: '' # Do not upload
base-folder: "xmatch_step"
python-version: "3.10"
sources-folder: "xmatch_step"
test-folder: "tests/integration"
codecov-flags: "" # Do not upload
secrets:
GH_TOKEN: '${{ secrets.ADMIN_TOKEN }}'
GH_TOKEN: "${{ secrets.ADMIN_TOKEN }}"

build-xmatch-dagger:

uses: ./.github/workflows/template_build_with_dagger.yaml
with:
stage: staging
extra-args: xmatch_step --dry-run
stage: staging
extra-args: xmatch_step --dry-run
secrets:
GH_TOKEN: ${{ secrets.ADMIN_TOKEN }}
GH_TOKEN: ${{ secrets.ADMIN_TOKEN }}
1 change: 1 addition & 0 deletions lc_classification_step/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ htmlcov/
__SUCCESS__
.env
.env.test
.env.el
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def parse(
parsed = []
features.replace({np.nan: None}, inplace=True)
messages_df = pd.DataFrame(
[{"aid": message.get("aid")} for message in messages]
[{"oid": message.get("oid")} for message in messages]
)

# maybe this won't be enough
probs_copy = model_output.probabilities.copy()
probs_copy.pop("classifier_name")
try:
probs_copy.set_index("aid", inplace=True)
probs_copy.set_index("oid", inplace=True)
except KeyError:
pass
tree_output = {
Expand All @@ -34,43 +34,43 @@ def parse(
"class": probs_copy.idxmax(axis=1),
}

messages_df.drop_duplicates("aid", inplace=True)
messages_df.drop_duplicates("oid", inplace=True)
for _, row in messages_df.iterrows():
aid = row.aid
oid = row.oid
try:
features_aid = features.loc[aid].to_dict()
features_oid = features.loc[oid].to_dict()
except KeyError:
continue

tree_aid = self._get_aid_tree(tree_output, aid)
tree_oid = self._get_oid_tree(tree_output, oid)
write = {
"aid": aid,
"features": features_aid,
"lc_classification": tree_aid,
"oid": oid,
"features": features_oid,
"lc_classification": tree_oid,
}
parsed.append(write)

return KafkaOutput(parsed)

def _get_aid_tree(self, tree, aid):
tree_aid = {}
def _get_oid_tree(self, tree, oid):
tree_oid = {}
for key in tree:
data = tree[key]
if isinstance(data, pd.DataFrame):
try:
data_cpy = data.set_index("aid")
tree_aid[key] = data_cpy.loc[aid].to_dict()
if "classifier_name" in tree_aid[key]:
tree_aid[key].pop("classifier_name")
data_cpy = data.set_index("oid")
tree_oid[key] = data_cpy.loc[oid].to_dict()
if "classifier_name" in tree_oid[key]:
tree_oid[key].pop("classifier_name")
except KeyError as e:
if not data.index.name == "aid":
if not data.index.name == "oid":
raise e
else:
tree_aid[key] = data.loc[aid].to_dict()
if "classifier_name" in tree_aid[key]:
tree_aid[key].pop("classifier_name")
tree_oid[key] = data.loc[oid].to_dict()
if "classifier_name" in tree_oid[key]:
tree_oid[key].pop("classifier_name")
elif isinstance(data, pd.Series):
tree_aid[key] = data.loc[aid]
tree_oid[key] = data.loc[oid]
elif isinstance(data, dict):
tree_aid[key] = self._get_aid_tree(data, aid)
return tree_aid
tree_oid[key] = self._get_oid_tree(data, oid)
return tree_oid
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def parse(
]

if len(new_detection) == 0:
## case when no new detections
# case when no new detections
new_detection = [
det for det in message["detections"] if det["has_stamp"]
]
Expand All @@ -47,9 +47,7 @@ def parse(

new_detection = new_detection[0]

# terrible parche (este codigo es transicional,
# cambiar antes de la semana del congreso)
detection_extra_info[new_detection["aid"]] = {
detection_extra_info[new_detection["oid"]] = {
"candid": new_detection["extra_fields"].get(
"alertId", new_detection["candid"]
),
Expand All @@ -68,14 +66,14 @@ def parse(
predictions = NoClassifiedPostProcessor(
messages, predictions
).get_modified_classifications()
predictions["aid"] = predictions.index
predictions["oid"] = predictions.index
for class_name in self.ClassMapper.get_class_names():
if class_name not in predictions.columns:
predictions[class_name] = 0.0
classifications = predictions.to_dict(orient="records")
output = []
for classification in classifications:
aid = classification.pop("aid")
oid = classification.pop("oid")
if "classifier_name" in classification:
classification.pop("classifier_name")

Expand All @@ -86,14 +84,15 @@ def parse(
}
for predicted_class, prob in classification.items()
]
print(detection_extra_info)
response = {
"alertId": int(detection_extra_info[aid]["candid"]),
"diaSourceId": int(detection_extra_info[aid]["diaSourceId"]),
"alertId": int(detection_extra_info[oid]["candid"]),
"diaSourceId": int(detection_extra_info[oid]["diaSourceId"]),
"elasticcPublishTimestamp": int(
detection_extra_info[aid]["elasticcPublishTimestamp"]
detection_extra_info[oid]["elasticcPublishTimestamp"]
),
"brokerIngestTimestamp": int(
detection_extra_info[aid]["brokerIngestTimestamp"]
detection_extra_info[oid]["brokerIngestTimestamp"]
),
"classifications": output_classification,
"brokerVersion": classifier_version,
Expand Down
46 changes: 23 additions & 23 deletions lc_classification_step/lc_classification/core/parsers/input_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,40 @@ def create_detections_dto(messages: List[dict]) -> pd.DataFrame:
>>> messages = [
{
"detections": [
{"aid": "aid1", "candid": "cand1"},
{"aid": "aid1", "candid": "cand2"},
{"oid": "oid1", "candid": "cand1"},
{"oid": "oid1", "candid": "cand2"},
]
},
{
"detections": [
{"aid": "aid2", "candid": "cand3"},
{"oid": "oid2", "candid": "cand3"},
]
},
]
>>> create_detections_dto(messages)
candid
aid
aid1 cand1
aid2 cand3
oid
oid1 cand1
oid2 cand3
"""
detections = [
pd.DataFrame.from_records(msg["detections"]) for msg in messages
]
detections = pd.concat(detections)
detections.drop_duplicates("candid", inplace=True)
detections = detections.set_index("aid")
detections = detections.set_index("oid")
detections["extra_fields"] = parse_extra_fields(detections)

if detections is not None:
return detections
else:
raise ValueError("Could not set index aid on features dataframe")
raise ValueError("Could not set index oid on features dataframe")


def parse_extra_fields(detections: pd.DataFrame) -> List[dict]:
for ef in detections["extra_fields"]:
for key in ef.copy():
if type(ef[key]) == bytes:
if type(ef[key]) is bytes:
extra_field = pickle.loads(ef[key])
# the loaded pickle is a list of one element
ef[key] = extra_field[0]
Expand All @@ -75,40 +75,40 @@ def parse_extra_fields(detections: pd.DataFrame) -> List[dict]:
def create_features_dto(messages: List[dict]) -> pd.DataFrame:
"""Creates a pandas dataframe with all the features from all messages
The index is the aid and each feature is a column.
The index is the oid and each feature is a column.
Parameters
-------
messages : list
a list of dictionaries with at least aid and features keys.
a list of dictionaries with at least oid and features keys.
Returns
-------
pd.DataFrame
A dataframe where each feature is a column indexed by aid.
Duplicated aid are removed.
A dataframe where each feature is a column indexed by oid.
Duplicated oid are removed.
Examples
--------
>>> messages = [
{
'aid': 'aid1',
'oid': 'oid1',
'features': {'feat1': 1, 'feat2': 2}
},
{
'aid': 'aid1',
'oid': 'oid1',
'features': {'feat1': 2, 'feat2': 3}
},
{
'aid': 'aid2',
'oid': 'oid2',
'features': {'feat1': 4, 'feat2': 5}
}
]
>>> create_features_dto(messages)
feat1 feat2
aid
aid2 4 5
aid1 2 3
oid
oid2 4 5
oid1 2 3
"""
if len(messages) == 0 or "features" not in messages[0]:
return pd.DataFrame()
Expand All @@ -119,15 +119,15 @@ def create_features_dto(messages: List[dict]) -> pd.DataFrame:
entry = {
feat: message["features"][feat] for feat in message["features"]
}
entry["aid"] = message["aid"]
entry["oid"] = message["oid"]
entries.append(entry)
if len(entries) == 0:
return pd.DataFrame()

features = pd.DataFrame.from_records(entries)
features.drop_duplicates("aid", inplace=True, keep="last")
features = features.set_index("aid")
features.drop_duplicates("oid", inplace=True, keep="last")
features = features.set_index("oid")
if features is not None:
return features
else:
raise ValueError("Could not set index aid on features dataframe")
raise ValueError("Could not set index oid on features dataframe")
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ def parse(self, to_parse: OutputDTO, **kwargs) -> KafkaOutput[List[dict]]:
{
'top': Periodic Stochastic Transient
aid
oid
vbKsodtqMI 0.434 0.21 0.356,
'children': {
'Transient': SLSN SNII SNIa SNIbc
aid
oid
vbKsodtqMI 0.082 0.168 0.444 0.306,
'Stochastic': AGN Blazar CV/Nova QSO YSO
aid
oid
vbKsodtqMI 0.032 0.056 0.746 0.01 0.156,
'Periodic': CEP DSCT E LPV Periodic-Other RRL
aid
oid
vbKsodtqMI 0.218 0.082 0.158 0.028 0.12 0.394
}
}
to_parse.probabilities
SLSN SNII SNIa SNIbc ... E LPV Periodic-Other RRL
aid ...
oid ...
vbKsodtqMI 0.029192 0.059808 0.158064 0.108936 ... 0.068572 0.012152 0.05208 0.170996,
}
"""
Expand All @@ -64,9 +64,9 @@ def parse(self, to_parse: OutputDTO, **kwargs) -> KafkaOutput[List[dict]]:
results = [top, probabilities]

results = pd.concat(results)
if not results.index.name == "aid":
if not results.index.name == "oid":
try:
results.set_index("aid", inplace=True)
results.set_index("oid", inplace=True)
except KeyError as e:
if not is_all_strings(results.index.values):
raise e
Expand All @@ -81,7 +81,6 @@ def get_scribe_messages(classifications_by_classifier: pd.DataFrame):
"type": "update_probabilities",
"criteria": {
"_id": idx,
"oid": kwargs["oids"].get(idx, []),
},
"data": {
"classifier_name": row["classifier_name"],
Expand All @@ -94,8 +93,8 @@ def get_scribe_messages(classifications_by_classifier: pd.DataFrame):
commands.append(command)
return classifications_by_classifier

for aid in results.index.unique():
results.loc[[aid], :].groupby(
for oid in results.index.unique():
results.loc[[oid], :].groupby(
"classifier_name", group_keys=False
).apply(get_scribe_messages)

Expand Down
Loading

0 comments on commit d1e20c1

Please sign in to comment.