Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ruff linter to the pre-commit check #214

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ docs/api/
*.log
/profile.*
xl2times/.cache/
*.log.zip
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 22.8.0
hooks:
- id: black

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.304
hooks:
- id: pyright

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix, --exit-non-zero-on-fix ]
37 changes: 36 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dev = [
"tabulate",
"pytest",
"pytest-cov",
"poethepoet"
"poethepoet",
"ruff"
]

[project.urls]
Expand All @@ -61,3 +62,37 @@ benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help
benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" }
test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" }


# Config for various pre-commit checks are below
# Ruff linting rules - see https://github.com/charliermarsh/ruff and https://beta.ruff.rs/docs/rules/
[tool.ruff]
target-version = "py311"
line-length = 88

# Option 1: use basic rules only.
lint.select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"UP", # pyupgrade
"N", # pep8 naming
"I", # isort
"TID", # tidy imports
"UP", # pyupgrade
"NPY", # numpy style
"PL", # pylint
# "PD", # pandas style # TODO enable later
# "C90", # code complexity # TODO enable later
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these last two will require a reasonable number of (simple) changes, so I've left them off for now so this diff isn't too polluted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea, thanks. Looking forward to the next linting PR. :)

]

# Add specific rule codes/groups here to ignore them, or add a '#noqa' comment to the line of code to skip all checks.
lint.ignore = [
"PLR", # complexity rules
"PD901", "PD011", # pandas 'df''
"E501", # line too long, handled by black
]

# Ruff rule-specific options:
[tool.ruff.mccabe]
max-complexity = 12 # increase max function 'complexity'
8 changes: 4 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import pandas as pd

from xl2times import transforms, utils, datatypes
from xl2times import datatypes, transforms, utils
from xl2times.transforms import (
_process_comm_groups_vectorised,
_count_comm_group_vectorised,
_match_wildcards,
_process_comm_groups_vectorised,
commodity_map,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ This was the original motivation - standard import sorting/formatting

expand_rows,
get_matching_commodities,
get_matching_processes,
_match_wildcards,
process_map,
commodity_map,
)

logger = utils.get_logger()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from xl2times import utils
import pandas as pd

from xl2times import utils


class TestUtils:
def test_explode(self):
Expand Down
31 changes: 15 additions & 16 deletions utils/dd_to_csv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
import sys
from collections import defaultdict
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -13,7 +12,7 @@

def parse_parameter_values_from_file(
path: Path,
) -> Tuple[Dict[str, List], Dict[str, set]]:
) -> tuple[dict[str, list], dict[str, set]]:
SamRWest marked this conversation as resolved.
Show resolved Hide resolved
"""
Parse *.dd to turn it into CSV format
There are parameters and sets, and each has a slightly different format
Expand All @@ -35,11 +34,11 @@ def parse_parameter_values_from_file(

"""

data = list(open(path, "r"))
data = list(open(path))
data = [line.rstrip() for line in data]

param_value_dict: Dict[str, List] = dict()
set_data_dict: Dict[str, set] = dict()
param_value_dict: dict[str, list] = dict()
set_data_dict: dict[str, set] = dict()
index = 0
while index < len(data):
if data[index].startswith("PARAMETER"):
Expand Down Expand Up @@ -124,8 +123,8 @@ def parse_parameter_values_from_file(


def save_data_with_headers(
param_data_dict: Dict[str, Union[pd.DataFrame, List[str]]],
headers_data: Dict[str, List[str]],
param_data_dict: dict[str, pd.DataFrame | list[str]],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What python versions are you planning to support?
This was done with the py11 setting (to match CI), but I'll revert and re-run for py39 if that's the minimum target. (I think the | type hint only works from 3.10 onwards)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't had that discussion actually... How about py12?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me, type hints are a lot cleaner in 3.12 for instance.
I just thought someone mentioned a use-case for running with an older version is all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we use the match-case statement somewhere in transforms.py which is >= 3.11. I feel perhaps we should try to keep the min version as low as possible, just to make it easier for users? Not sure how much of a pain it is to install a specific Python version in Windows (in linux I use pyenv)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it is easy! :-) But I am also fine with sticking to 3.11.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I use pyenv on windows also, works very well.
Less savvy users can just download and run the official installers etc.
Ok, so we'll target >=3.11, cool.

headers_data: dict[str, list[str]],
save_dir: str,
) -> None:
"""
Expand Down Expand Up @@ -157,7 +156,7 @@ def save_data_with_headers(
return


def generate_headers_by_attr() -> Dict[str, List[str]]:
def generate_headers_by_attr() -> dict[str, list[str]]:
with open("xl2times/config/times-info.json") as f:
attributes = json.load(f)

Expand All @@ -173,7 +172,7 @@ def generate_headers_by_attr() -> Dict[str, List[str]]:


def convert_dd_to_tabular(
basedir: str, output_dir: str, headers_by_attr: Dict[str, List[str]]
basedir: str, output_dir: str, headers_by_attr: dict[str, list[str]]
) -> None:
dd_files = [p for p in Path(basedir).rglob("*.dd")]

Expand Down Expand Up @@ -201,15 +200,15 @@ def convert_dd_to_tabular(
os.makedirs(set_path, exist_ok=True)

# Extract headers with key=param_name and value=List[attributes]
lines = list(open("xl2times/config/times_mapping.txt", "r"))
lines = list(open("xl2times/config/times_mapping.txt"))
headers_data = headers_by_attr
# The following will overwrite data obtained from headers_by_attr
# TODO: Remove once migration is done?
for line in lines:
line = line.strip()
if line != "":
param_name = line.split("[")[0]
attributes = line.split("[")[1].split("]")[0].split(",")
ln = line.strip()
if ln != "":
param_name = ln.split("[")[0]
attributes = ln.split("[")[1].split("]")[0].split(",")
headers_data[param_name] = [*attributes]

save_data_with_headers(all_parameters, headers_data, param_path)
Expand Down
19 changes: 12 additions & 7 deletions utils/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from os import path, symlink
from typing import Any, Tuple
from typing import Any

import git
import pandas as pd
import yaml
from dd_to_csv import main
from tabulate import tabulate

from dd_to_csv import main
from xl2times import utils
from xl2times.__main__ import parse_args, run
from xl2times.utils import max_workers

logger = utils.get_logger()


def parse_result(output: str) -> Tuple[float, int, int]:
def parse_result(output: str) -> tuple[float, int, int]:
# find pattern in multiline string
m = re.findall(
r"(\d+\.\d)% of ground truth rows present in output \((\d+)/(\d+)\), (\d+) additional rows",
Expand Down Expand Up @@ -65,6 +65,7 @@ def run_gams_gdxdiff(
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
check=False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
if res.returncode != 0:
logger.info(res.stdout)
Expand Down Expand Up @@ -96,6 +97,7 @@ def run_gams_gdxdiff(
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
check=False,
)
if res.returncode != 0:
logger.info(res.stdout)
Expand All @@ -119,6 +121,7 @@ def run_gams_gdxdiff(
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
check=False,
)
if verbose:
logger.info(res.stdout)
Expand All @@ -138,7 +141,7 @@ def run_benchmark(
out_folder: str = "out",
verbose: bool = False,
debug: bool = False,
) -> Tuple[str, float, str, float, int, int]:
) -> tuple[str, float, str, float, int, int]:
xl_folder = path.join(benchmarks_folder, "xlsx", benchmark["input_folder"])
dd_folder = path.join(benchmarks_folder, "dd", benchmark["dd_folder"])
csv_folder = path.join(benchmarks_folder, "csv", benchmark["name"])
Expand All @@ -160,6 +163,7 @@ def run_benchmark(
stderr=subprocess.STDOUT,
text=True,
shell=True if os.name == "nt" else False,
check=False,
)
if res.returncode != 0:
# Remove partial outputs
Expand Down Expand Up @@ -191,7 +195,7 @@ def run_benchmark(
if "regions" in benchmark:
args.extend(["--regions", benchmark["regions"]])
if "inputs" in benchmark:
args.extend((path.join(xl_folder, b) for b in benchmark["inputs"]))
args.extend(path.join(xl_folder, b) for b in benchmark["inputs"])
else:
args.append(xl_folder)
start = time.time()
Expand All @@ -203,6 +207,7 @@ def run_benchmark(
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
check=False,
)
else:
# If debug option is set, run as a function call to allow stepping with a debugger.
Expand Down Expand Up @@ -295,7 +300,6 @@ def run_all_benchmarks(
for benchmark in benchmarks:
with open(
path.join(benchmarks_folder, "out-main", benchmark["name"], "stdout"),
"r",
) as f:
result = parse_result(f.readlines()[-1])
# Use a fake runtime and GAMS result
Expand Down Expand Up @@ -330,7 +334,8 @@ def run_all_benchmarks(
results_main = list(executor.map(run_a_benchmark, benchmarks))

# Print table with combined results to make comparison easier
trunc = lambda s: s[:10] + "\u2026" if len(s) > 10 else s
trunc = lambda s: s[:10] + "\u2026" if len(s) > 10 else s # noqa

combined_results = [
(
f"{b:<20}",
Expand Down
Loading
Loading