Skip to content

Commit

Permalink
Strip strings in imported databases.
Browse files Browse the repository at this point in the history
Version bump.
  • Loading branch information
romainsacchi authored and romainsacchi committed Nov 27, 2024
1 parent a3cd87f commit 3f3d81b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 34 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def package_files(directory):

setup(
name="unfold",
version="1.2.1",
version="1.2.2",
python_requires=">=3.10",
packages=packages,
author="Romain Sacchi <[email protected]>",
Expand Down
6 changes: 3 additions & 3 deletions unfold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ("Unfold", "Fold")
__version__ = (1, 2, 1)
__all__ = ("Unfold", "Fold", "clear_cache")
__version__ = (1, 2, 2)


from .fold import Fold
from .unfold import Unfold
from .unfold import Unfold, clear_cache
25 changes: 25 additions & 0 deletions unfold/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,28 @@ def change_db_name(data, name):
if exc.get("input") and exc["input"][0] in old_names:
exc["input"] = (name, exc["input"][1])
return data

def clean_fields(database):

for dataset in database:
dataset["name"] = dataset["name"].strip()

if dataset.get("location"):
dataset["location"] = dataset["location"].strip()

if dataset.get("reference product"):
dataset["reference product"] = dataset["reference product"].strip()

dataset["unit"] = dataset["unit"].strip()

if "exchanges" in dataset:
for exchange in dataset["exchanges"]:
exchange["name"] = exchange["name"].strip()
exchange["unit"] = exchange["unit"].strip()
exchange["type"] = exchange["type"].strip()

if exchange["type"] in ["technosphere", "production"]:
exchange["name"] = exchange["name"].strip()
exchange["product"] = exchange["product"].strip()

return database
67 changes: 37 additions & 30 deletions unfold/unfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_outdated_units,
remove_categories_for_technosphere_flows,
remove_missing_fields,
clean_fields
)

try:
Expand Down Expand Up @@ -82,12 +83,12 @@ def get_list_unique_exchanges(databases: list) -> list:
set(
[
(
exchange["name"],
exchange.get("product"),
exchange["name"].strip(),
exchange.get("product").strip() if exchange.get("product") else None,
exchange.get("categories"),
exchange.get("location"),
exchange.get("unit"),
exchange.get("type"),
exchange.get("location").strip() if exchange.get("location") else None,
exchange.get("unit").strip(),
exchange.get("type").strip(),
)
for database in databases
for dataset in database
Expand Down Expand Up @@ -119,7 +120,7 @@ def check_cached_database(name) -> list:

# extract the database, pickle it for next time and return it
print("Cannot find cached database. Will create one now for next time...")
database = extract_brightway2_databases(name)
database = clean_fields(extract_brightway2_databases(name))
pickle.dump(database, open(file_name, "wb"))
return database

Expand Down Expand Up @@ -398,11 +399,11 @@ def store_datasets_metadata(self) -> None:
# store the metadata in a dictionary
self.dict_meta = {
(
dataset["name"],
dataset["reference product"],
dataset["name"].strip(),
dataset["reference product"].strip(),
None,
dataset["location"],
dataset["unit"],
dataset["location"].strip(),
dataset["unit"].strip(),
"production",
): {
key: values
Expand Down Expand Up @@ -575,20 +576,20 @@ def populate_sparse_matrix(self) -> nsp.lil_matrix:
for exc in ds["exchanges"]:
# Source activity
s = (
exc["name"],
exc.get("product"),
exc["name"].strip(),
exc.get("product").strip() if exc.get("product") else None,
exc.get("categories"),
exc.get("location"),
exc["unit"],
exc["type"],
exc.get("location").strip() if exc.get("location") else None,
exc["unit"].strip(),
exc["type"].strip(),
)
# Destination activity
c = (
ds["name"],
ds.get("reference product"),
ds["name"].strip(),
ds.get("reference product").strip() if ds.get("reference product") else None,
ds.get("categories"),
ds.get("location"),
ds["unit"],
ds.get("location").strip() if ds.get("location") else None,
ds["unit"].strip(),
"production",
)
# Add the exchange amount to the corresponding cell in the matrix
Expand Down Expand Up @@ -624,15 +625,16 @@ def write_scaling_factors_in_matrix(
s_name, s_prod, s_loc, s_cat, s_unit, s_type = list(flow_id)[4:]

# Look up the index of the consumer activity in the reversed activities index.
consumer_id = (
c_name,
c_prod,
None,
c_loc,
c_unit,
"production",
)
consumer_idx = self.reversed_acts_indices.get(
(
c_name,
c_prod,
None,
c_loc,
c_unit,
"production",
),
consumer_id,
None,
)

Expand Down Expand Up @@ -683,9 +685,14 @@ def write_scaling_factors_in_matrix(
matrix[supplier_idx, consumer_idx]
)
else:
print(
f"Could not find activity for flow {flow_id} in scenario {scenario_name}."
)
if supplier_idx is None:
print(
f"Could not find supplier {supplier_id} in scenario {scenario_name}."
)
if consumer_idx is None:
print(
f"Could not find consumer {consumer_id} in scenario {scenario_name}."
)

# Return the scaled matrix.
return matrix
Expand Down

0 comments on commit 3f3d81b

Please sign in to comment.