Skip to content

Commit

Permalink
Generalise the code
Browse files Browse the repository at this point in the history
  • Loading branch information
olejandro committed Dec 25, 2024
1 parent ae4f274 commit 6c405cb
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2507,19 +2507,26 @@ def apply_transform_tables(
how="inner",
)

# Create dictionaries of processes/commodities indexed by module name
process_by_module = (
# Create a dictionary of processes/commodities indexed by module name
obj_by_module = dict()
obj_by_module["process"] = (
model.processes.groupby("module_name")["process"].agg(set).to_dict()
)
commodity_by_module = (
obj_by_module["commodity"] = (
model.commodities.groupby("module_name")["commodity"].agg(set).to_dict()
)
# Create a dictionary of processes/commodities available in addtion to those declared in a module
obj_suppl = dict()
obj_suppl["process"] = set()
obj_suppl["commodity"] = (
obj_by_module["commodity"]
.get("BASE", set())
.union(obj_by_module["commodity"].get("SYSSETTINGS", set()))
)
# Create sets attributes that require a process/commodity index
attr_with_prc = {
attr.times_name for attr in config.times_xl_maps if "process" in attr.xl_cols
}
attr_with_com = {
attr.times_name for attr in config.times_xl_maps if "commodity" in attr.xl_cols
attr_with_obj = {
obj: {attr.times_name for attr in config.times_xl_maps if obj in attr.xl_cols}
for obj in ["process", "commodity"]
}

if Tag.tfm_comgrp in tables:
Expand Down Expand Up @@ -2720,25 +2727,21 @@ def apply_transform_tables(
if generated_records:
module_data = pd.concat(generated_records, ignore_index=True)
module_type = module_data["module_type"].iloc[0]
if "process" in module_data.columns:
module_data = module_data.explode("process", ignore_index=True)
if module_type in {"base", "subres"}:
drop = ~module_data["process"].isin(
process_by_module.get(data_module, set())
) & module_data["attribute"].isin(attr_with_prc)
if any(drop):
module_data = module_data[~drop]
if "commodity" in module_data.columns:
module_data = module_data.explode("commodity", ignore_index=True)
if module_type in {"base", "subres"}:
valid_commodities = commodity_by_module.get(
data_module, set()
).union(commodity_by_module.get("BASE", set()))
drop = ~module_data["commodity"].isin(
valid_commodities
) & module_data["attribute"].isin(attr_with_com)
if any(drop):
module_data = module_data[~drop]
# Explode process and commodity columns and remove invalid rows
for obj in ["process", "commodity"]:
if obj in module_data.columns:
module_data = module_data.explode(obj, ignore_index=True)
if module_type in {"base", "subres"}:
valid_objs = (
obj_by_module[obj]
.get(data_module, set())
.union(obj_suppl[obj])
)
drop = ~module_data[obj].isin(valid_objs) & module_data[
"attribute"
].isin(attr_with_obj[obj])
if any(drop):
module_data = module_data[~drop]
tables[Tag.fi_t] = pd.concat(
[tables[Tag.fi_t], module_data], ignore_index=True
)
Expand Down

0 comments on commit 6c405cb

Please sign in to comment.