From ebee6e44c15ffcebd54abc182e525a11f6fcdbfb Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Thu, 6 Jun 2024 10:00:47 -0400 Subject: [PATCH] Add support for array aggregation in GroupByAggregate op --- cycquery/ops.py | 2 ++ pyproject.toml | 16 ++++++++-------- tests/cycquery/test_ops.py | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/cycquery/ops.py b/cycquery/ops.py index 48d04f0..1f7c558 100644 --- a/cycquery/ops.py +++ b/cycquery/ops.py @@ -3122,6 +3122,7 @@ class GroupByAggregate(QueryOp): >>> GroupByAggregate("person_id", {"person_id": ("count", "visit_count")})(table) >>> GroupByAggregate("person_id", {"lab_name": "string_agg"}, {"lab_name": ", "})(table) >>> GroupByAggregate("person_id", {"lab_name": ("string_agg", "lab_name_agg"}, {"lab_name": ", "})(table) + >>> GroupByAggregate("person_id", {"lab_name": ("array_agg", "lab_name_array")})(table) """ @@ -3163,6 +3164,7 @@ def __call__(self, table: TableTypes) -> Subquery: "count": func.count, "median": func.percentile_cont(0.5).within_group, "string_agg": func.string_agg, + "array_agg": func.array_agg, } aggfunc_tuples = list(self.aggfuncs.items()) diff --git a/pyproject.toml b/pyproject.toml index 7fb02bd..0e33923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ plugins = ["sqlalchemy.ext.mypy.plugin"] [tool.ruff] include = ["*.py", "pyproject.toml", "*.ipynb"] -select = [ +lint.select = [ "A", # flake8-builtins "B", # flake8-bugbear "COM", # flake8-commas @@ -93,9 +93,9 @@ select = [ "ERA", # eradicate "PL", # pylint ] -fixable = ["A", "B", "COM", "C4", "RET", "SIM", "ICN", "Q", "RSE", "D", "E", "F", "I", "W", "N", "ERA", "PL"] +lint.fixable = ["A", "B", "COM", "C4", "RET", "SIM", "ICN", "Q", "RSE", "D", "E", "F", "I", "W", "N", "ERA", "PL"] line-length = 88 -ignore = [ +lint.ignore = [ "B905", # `zip()` without an explicit `strict=` parameter "E501", # line too long "D203", # 1 blank line required before class docstring @@ -105,19 +105,19 @@ ignore = [ ] # Ignore import violations in all `__init__.py` files. -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401", "F403", "F811"] -[tool.ruff.pep8-naming] +[tool.ruff.lint.pep8-naming] ignore-names = ["X*", "setUp"] -[tool.ruff.isort] +[tool.ruff.lint.isort] lines-after-imports = 2 -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.ruff.pycodestyle] +[tool.ruff.lint.pycodestyle] max-doc-length = 88 [tool.pytest.ini_options] diff --git a/tests/cycquery/test_ops.py b/tests/cycquery/test_ops.py index 958875e..4f19894 100644 --- a/tests/cycquery/test_ops.py +++ b/tests/cycquery/test_ops.py @@ -376,6 +376,12 @@ def test_group_by_aggregate( {"visit_concept_name": ", "}, ), ).run() + measurements_array_agg = measurements_table.ops( + GroupByAggregate( + "person_id", + {"value_as_number": ("array_agg", "value_as_number_array")}, + ), + ).run() measurements_sum = measurements_table.ops( GroupByAggregate( "person_id", @@ -446,6 +452,21 @@ def test_group_by_aggregate( ].item() == 75.7 ) + assert "value_as_number_array" in measurements_array_agg.columns + assert isinstance( + measurements_array_agg[measurements_array_agg["person_id"] == 33][ + "value_as_number_array" + ][0], + str, + ) + assert ( + len( + measurements_array_agg[measurements_array_agg["person_id"] == 33][ + "value_as_number_array" + ][0], + ) + > 0 + ) @pytest.mark.integration_test()