From 63cf77a7b894f1ca9ce0ad58f5b49fcba42b9773 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sat, 30 Nov 2024 15:45:55 -0500 Subject: [PATCH 01/10] Update scipy version constraints in pyproject.toml --- poetry.lock | 142 +++++++++++-------------------------------------- pyproject.toml | 5 +- 2 files changed, 34 insertions(+), 113 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5f8f9338..a37b3bec 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alabaster" @@ -1106,73 +1106,6 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] -[[package]] -name = "pandas" -version = "2.0.3" -description = "Powerful data structures for data analysis, time series, and statistics" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, - {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, - {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, - {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, - {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, - {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, - {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, - {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, - {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, - {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, - {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, - {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, - {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, - {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, - {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, - {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, - {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, - {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, - {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, - {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, - {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, - {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, - {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, - {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, - {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, -] - -[package.dependencies] -numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, -] -python-dateutil = ">=2.8.2" -pytz = ">=2020.1" -tzdata = ">=2022.1" - -[package.extras] -all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] -aws = ["s3fs (>=2021.08.0)"] -clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] -compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] -computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] -feather = ["pyarrow (>=7.0.0)"] -fss = ["fsspec (>=2021.07.0)"] -gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] -hdf5 = ["tables (>=3.6.1)"] -html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] -mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"] -parquet = ["pyarrow (>=7.0.0)"] -performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"] -plot = ["matplotlib (>=3.6.1)"] -postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] -spss = ["pyreadstat (>=1.1.2)"] -sql-other = ["SQLAlchemy (>=1.4.16)"] -test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.6.3)"] - [[package]] name = "pandocfilters" version = "1.5.1" @@ -1791,45 +1724,45 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "scipy" -version = "1.11.4" +version = "1.13.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.9" files = [ - {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, - {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9999c008ccf00e8fbcce1236f85ade5c569d13144f77a1946bef8863e8f6eb4"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:933baf588daa8dc9a92c20a0be32f56d43faf3d1a60ab11b3f08c356430f6e56"}, - {file = "scipy-1.11.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8fce70f39076a5aa62e92e69a7f62349f9574d8405c0a5de6ed3ef72de07f446"}, - {file = "scipy-1.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:6550466fbeec7453d7465e74d4f4b19f905642c89a7525571ee91dd7adabb5a3"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f313b39a7e94f296025e3cffc2c567618174c0b1dde173960cf23808f9fae4be"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1b7c3dca977f30a739e0409fb001056484661cb2541a01aba0bb0029f7b68db8"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00150c5eae7b610c32589dda259eacc7c4f1665aedf25d921907f4d08a951b1c"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:530f9ad26440e85766509dbf78edcfe13ffd0ab7fec2560ee5c36ff74d6269ff"}, - {file = "scipy-1.11.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5e347b14fe01003d3b78e196e84bd3f48ffe4c8a7b8a1afbcb8f5505cb710993"}, - {file = "scipy-1.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:acf8ed278cc03f5aff035e69cb511741e0418681d25fbbb86ca65429c4f4d9cd"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:028eccd22e654b3ea01ee63705681ee79933652b2d8f873e7949898dda6d11b6"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c6ff6ef9cc27f9b3db93a6f8b38f97387e6e0591600369a297a50a8e96e835d"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b030c6674b9230d37c5c60ab456e2cf12f6784596d15ce8da9365e70896effc4"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad669df80528aeca5f557712102538f4f37e503f0c5b9541655016dd0932ca79"}, - {file = "scipy-1.11.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7fff2e23ab2cc81ff452a9444c215c28e6305f396b2ba88343a567feec9660"}, - {file = "scipy-1.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:36750b7733d960d7994888f0d148d31ea3017ac15eef664194b4ef68d36a4a97"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e619aba2df228a9b34718efb023966da781e89dd3d21637b27f2e54db0410d7"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f3cd9e7b3c2c1ec26364856f9fbe78695fe631150f94cd1c22228456404cf1ec"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d10e45a6c50211fe256da61a11c34927c68f277e03138777bdebedd933712fea"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937"}, - {file = "scipy-1.11.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6df1468153a31cf55ed5ed39647279beb9cfb5d3f84369453b49e4b8502394fd"}, - {file = "scipy-1.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee410e6de8f88fd5cf6eadd73c135020bfbbbdfcd0f6162c36a7638a1ea8cc65"}, - {file = "scipy-1.11.4.tar.gz", hash = "sha256:90a2b78e7f5733b9de748f589f09225013685f9b218275257f8a8168ededaeaa"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, ] [package.dependencies] -numpy = ">=1.21.6,<1.28.0" +numpy = ">=1.22.4,<2.3" [package.extras] -dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "six" @@ -2130,17 +2063,6 @@ files = [ {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] -[[package]] -name = "tzdata" -version = "2024.1" -description = "Provider of IANA time zone data" -optional = false -python-versions = ">=2" -files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, -] - [[package]] name = "urllib3" version = "2.2.0" @@ -2198,4 +2120,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1, <3.13" -content-hash = "f908f8c1e11f3e2d1ce347673013bc3dd77eef943341e17c3106dd8fc31d6633" +content-hash = "16086795ef82b4a6043e37364fac505084b653bd51b657d11afe1ac76237b49a" diff --git a/pyproject.toml b/pyproject.toml index 2fe8e8a0..e76846af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,12 +14,11 @@ numpy = [ ] progressbar2 = "^4.2.0" scipy = [ - { version = ">=1.10.1,<1.11", python = "<3.9,>=3.8" }, - { version = ">=1.11.1,<1.12", python = "<3.13,>=3.9" } + { version = ">=1.10.1,<1.13", python = "<3.9,>=3.8" }, + { version = ">=1.13.0,<1.14", python = "<3.13,>=3.9" } ] [tool.poetry.group.dev.dependencies] -pandas = ">=1.6" pytest = "^7.2.2" flake8 = "^6.0.0" codecov = "^2.1.12" From 050feeb8319fbaa4d530432b65e8a266064a8a5e Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sat, 30 Nov 2024 17:44:24 -0500 Subject: [PATCH 02/10] Update Python and SciPy version constraints, add .vscode settings to .gitignore, and implement tests for GAMRegressor and GAMClassifier --- .gitignore | 1 + poetry.lock | 181 ++++++++++++-------------------- pygam/sklearn_api.py | 216 ++++++++++++++++++++++++++++++++++++++ pyproject.toml | 10 +- tests/test_sklearn_api.py | 125 ++++++++++++++++++++++ 5 files changed, 412 insertions(+), 121 deletions(-) create mode 100644 pygam/sklearn_api.py create mode 100644 tests/test_sklearn_api.py diff --git a/.gitignore b/.gitignore index a30b5eb4..c4ff995e 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ _build/ # PyCharm ######### .idea/ +.vscode/settings.json diff --git a/poetry.lock b/poetry.lock index a37b3bec..a724bec8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -70,9 +70,6 @@ files = [ {file = "Babel-2.14.0.tar.gz", hash = "sha256:6919867db036398ba21eb5c7a0f6b28ab8cbc3ae7a73a44ebe34ae74a4e7d363"}, ] -[package.dependencies] -pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} - [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] @@ -585,24 +582,6 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] -[[package]] -name = "importlib-resources" -version = "6.1.1" -description = "Read resources from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, - {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, -] - -[package.dependencies] -zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -689,6 +668,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "jsonschema" version = "4.21.1" @@ -702,9 +692,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} jsonschema-specifications = ">=2023.03.6" -pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""} referencing = ">=0.28.4" rpds-py = ">=0.7.1" @@ -724,7 +712,6 @@ files = [ ] [package.dependencies] -importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} referencing = ">=0.31.0" [[package]] @@ -1013,43 +1000,6 @@ nbformat = "*" sphinx = ">=1.8" traitlets = ">=5" -[[package]] -name = "numpy" -version = "1.24.4" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, - {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, - {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, - {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, - {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, - {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, - {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, - {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, - {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, - {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, - {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, -] - [[package]] name = "numpy" version = "1.26.4" @@ -1168,17 +1118,6 @@ files = [ {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, ] -[[package]] -name = "pkgutil-resolve-name" -version = "1.3.10" -description = "Resolve a name to an object." -optional = false -python-versions = ">=3.6" -files = [ - {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, - {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, -] - [[package]] name = "platformdirs" version = "4.2.0" @@ -1401,17 +1340,6 @@ docs = ["mock", "python-utils", "sphinx"] loguru = ["loguru"] tests = ["flake8", "loguru", "pytest", "pytest-asyncio", "pytest-cov", "pytest-mypy", "sphinx", "types-setuptools"] -[[package]] -name = "pytz" -version = "2024.1" -description = "World timezone definitions, modern and historical" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, - {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, -] - [[package]] name = "pywin32" version = "306" @@ -1685,42 +1613,54 @@ files = [ ] [[package]] -name = "scipy" -version = "1.10.1" -description = "Fundamental algorithms for scientific computing in Python" +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" optional = false -python-versions = "<3.12,>=3.8" -files = [ - {file = "scipy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019"}, - {file = "scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1553b5dcddd64ba9a0d95355e63fe6c3fc303a8fd77c7bc91e77d61363f7433f"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c0ff64b06b10e35215abce517252b375e580a6125fd5fdf6421b98efbefb2d2"}, - {file = "scipy-1.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:fae8a7b898c42dffe3f7361c40d5952b6bf32d10c4569098d276b4c547905ee1"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f1564ea217e82c1bbe75ddf7285ba0709ecd503f048cb1236ae9995f64217bd"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d925fa1c81b772882aa55bcc10bf88324dadb66ff85d548c71515f6689c6dac5"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d"}, - {file = "scipy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f"}, - {file = "scipy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601"}, - {file = "scipy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea"}, - {file = "scipy-1.10.1.tar.gz", hash = "sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5"}, +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, ] [package.dependencies] -numpy = ">=1.19.5,<1.27.0" +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" [package.extras] -dev = ["click", "doit (>=0.36.0)", "flake8", "mypy", "pycodestyle", "pydevtool", "rich-click", "typing_extensions"] -doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] [[package]] name = "scipy" @@ -1988,6 +1928,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.2.1" @@ -2119,5 +2070,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" -python-versions = ">=3.8.1, <3.13" -content-hash = "16086795ef82b4a6043e37364fac505084b653bd51b657d11afe1ac76237b49a" +python-versions = ">=3.9, <=3.13" +content-hash = "84ba875e24ed4d7ac984dabe842dbf2e10056a16d645ae7131bf449b92a14d3f" diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py new file mode 100644 index 00000000..b29ce38c --- /dev/null +++ b/pygam/sklearn_api.py @@ -0,0 +1,216 @@ +""" +sklearn_api.py + +This module provides scikit-learn compatible classes for Generalized Additive Models (GAM) regressors and classifiers. +It integrates pygam's GAM capabilities with scikit-learn's estimator interface, enabling seamless use in machine learning pipelines. +""" + +from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin +from pygam import GAM +from pygam.terms import te, TermList # Import te for interactions +import numpy as np + +class GAMRegressor(BaseEstimator, RegressorMixin): + """ + GAMRegressor + + A scikit-learn compatible regressor using Generalized Additive Models (GAM). + + Parameters + ---------- + distribution : str, default='normal' + The distribution of the response variable. + link : str, default='identity' + The link function. + terms : 'auto' or TermList, default='auto' + The terms to include in the model. + interactions : list of tuples, optional + Interaction terms to include in the model. + callbacks : list, default=['deviance', 'diffs'] + List of callbacks to monitor during training. + fit_intercept : bool, default=True + Whether to fit an intercept. + max_iter : int, default=100 + Maximum number of iterations. + tol : float, default=1e-4 + Tolerance for stopping criteria. + verbose : bool, default=False + Verbosity mode. + **gam_params : + Additional parameters for the GAM model. + + Attributes + ---------- + model_ : GAM + The underlying pygam GAM model fitted to the data. + """ + + def __init__( + self, + distribution='normal', + link='identity', + terms='auto', + interactions=None, # Added interactions parameter + callbacks=['deviance', 'diffs'], + fit_intercept=True, + max_iter=100, + tol=1e-4, + verbose=False, + **gam_params + ): + self.distribution = distribution + self.link = link + self.terms = TermList() if terms == 'auto' else terms + self.interactions = interactions # Store interactions + self.callbacks = callbacks + self.fit_intercept = fit_intercept + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.gam_params = gam_params + + # ...existing code... + if self.interactions: + # Ensure terms is a list before appending + if isinstance(self.terms, str): + self.terms = [self.terms] + elif not isinstance(self.terms, list): + self.terms = list(self.terms) + # Convert interaction tuples to te instances + for interaction in self.interactions: + self.terms.append(te(*interaction)) + + if not isinstance(self.terms, TermList): + self.terms = TermList(*self.terms) + + self.model_ = GAM( + distribution=self.distribution, + link=self.link, + terms=self.terms, + callbacks=self.callbacks, + fit_intercept=self.fit_intercept, + max_iter=self.max_iter, + tol=self.tol, + verbose=self.verbose, + **self.gam_params + ) + + def fit(self, X, y): + self.model_.fit(X, y) + return self + + def predict(self, X): + return self.model_.predict(X) + + def score(self, X, y): + return float(self.model_.statistics_.get('pseudo R-squared', 0)) + +class GAMClassifier(BaseEstimator, ClassifierMixin): + """ + GAMClassifier + + A scikit-learn compatible classifier using Generalized Additive Models (GAM). + + Parameters + ---------- + distribution : str, default='binomial' + The distribution of the response variable. + link : str, default='logit' + The link function. + terms : 'auto' or TermList, default='auto' + The terms to include in the model. + interactions : list of tuples, optional + Interaction terms to include in the model. + callbacks : list, default=['deviance', 'diffs', 'accuracy'] + List of callbacks to monitor during training. + fit_intercept : bool, default=True + Whether to fit an intercept. + max_iter : int, default=100 + Maximum number of iterations. + tol : float, default=1e-4 + Tolerance for stopping criteria. + verbose : bool, default=False + Verbosity mode. + **gam_params : + Additional parameters for the GAM model. + + Attributes + ---------- + model_ : GAM + The underlying pygam GAM model fitted to the data. + classes_ : array-like + Unique class labels. + """ + + def __init__( + self, + distribution='binomial', + link='logit', + terms='auto', + interactions=None, # Added interactions parameter + callbacks=['deviance', 'diffs', 'accuracy'], + fit_intercept=True, + max_iter=100, + tol=1e-4, + verbose=False, + **gam_params + ): + self.distribution = distribution + self.link = link + self.terms = TermList() if terms == 'auto' else terms + self.interactions = interactions # Store interactions + self.callbacks = callbacks + self.fit_intercept = fit_intercept + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.gam_params = gam_params + + # ...existing code... + if self.interactions: + # Ensure terms is a list before appending + if isinstance(self.terms, str): + self.terms = [self.terms] + elif not isinstance(self.terms, list): + self.terms = list(self.terms) + # Convert interaction tuples to te instances + for interaction in self.interactions: + self.terms.append(te(*interaction)) + + if not isinstance(self.terms, TermList): + self.terms = TermList(*self.terms) + + self.model_ = GAM( + distribution=self.distribution, + link=self.link, + terms=self.terms, + callbacks=self.callbacks, + fit_intercept=self.fit_intercept, + max_iter=self.max_iter, + tol=self.tol, + verbose=self.verbose, + **self.gam_params + ) + + def fit(self, X, y): + self.model_.fit(X, y) + self.classes_ = np.unique(y) + return self + + def predict(self, X): + proba = self.model_.predict(X) + if len(self.classes_) == 2: + return (proba >= 0.5).astype(int) + else: + return self.classes_[np.argmax(proba, axis=1)] + + def predict_proba(self, X): + proba = self.model_.predict(X) + if len(self.classes_) == 2: + return np.vstack([1 - proba, proba]).T + else: + return proba # Assume GAM model returns probabilities for each class + + def score(self, X, y): + from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e76846af..f55672ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,16 +7,14 @@ license = "Apache-2.0" readme = "README.md" [tool.poetry.dependencies] -python = ">=3.8.1, <3.13" +python = ">=3.9, <=3.13" numpy = [ { version = ">=1.24.2,<1.25", python = "<3.9,>=3.8" }, { version = ">=1.25", python = "<3.13,>=3.9" }, ] progressbar2 = "^4.2.0" -scipy = [ - { version = ">=1.10.1,<1.13", python = "<3.9,>=3.8" }, - { version = ">=1.13.0,<1.14", python = "<3.13,>=3.9" } -] +scipy = ">=1.11.4" +scikit-learn = "^1.5.2" [tool.poetry.group.dev.dependencies] pytest = "^7.2.2" @@ -57,4 +55,4 @@ build-backend = "poetry_dynamic_versioning.backend" [tool.poetry-dynamic-versioning] enable = true vcs = "git" -style = "semver" +style = "semver" \ No newline at end of file diff --git a/tests/test_sklearn_api.py b/tests/test_sklearn_api.py new file mode 100644 index 00000000..7c03fc16 --- /dev/null +++ b/tests/test_sklearn_api.py @@ -0,0 +1,125 @@ +import pytest +import numpy as np +from sklearn.datasets import make_regression, make_classification +from sklearn.model_selection import train_test_split +from sklearn.metrics import r2_score, accuracy_score +from pygam.sklearn_api import GAMRegressor, GAMClassifier + +@pytest.fixture +def regression_data(): + X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42) + return train_test_split(X, y, test_size=0.2, random_state=42) + +@pytest.fixture +def classification_data(): + X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=42) + return train_test_split(X, y, test_size=0.2, random_state=42) + +def test_gam_regressor_fit_predict(regression_data): + X_train, X_test, y_train, y_test = regression_data + reg = GAMRegressor() + reg.fit(X_train, y_train) + predictions = reg.predict(X_test) + assert predictions.shape == y_test.shape + assert r2_score(y_test, predictions) >= 0 # Basic sanity check + +def test_gam_regressor_score(regression_data): + X_train, X_test, y_train, y_test = regression_data + reg = GAMRegressor() + reg.fit(X_train, y_train) + score = reg.score(X_test, y_test) + assert isinstance(score, float) + assert score >= 0 # R-squared should be non-negative + +def test_gam_classifier_fit_predict(classification_data): + X_train, X_test, y_train, y_test = classification_data + clf = GAMClassifier() + clf.fit(X_train, y_train) + predictions = clf.predict(X_test) + assert predictions.shape == y_test.shape + assert set(predictions).issubset({0, 1}) # Binary classification + +def test_gam_classifier_predict_proba(classification_data): + X_train, X_test, y_train, y_test = classification_data + clf = GAMClassifier() + clf.fit(X_train, y_train) + proba = clf.predict_proba(X_test) + assert proba.shape == (X_test.shape[0], 2) + assert np.allclose(proba.sum(axis=1), 1) + +def test_gam_classifier_score(classification_data): + X_train, X_test, y_train, y_test = classification_data + clf = GAMClassifier() + clf.fit(X_train, y_train) + score = clf.score(X_test, y_test) + assert isinstance(score, float) + assert 0 <= score <= 1 # Accuracy between 0 and 1 + +def test_gam_regressor_with_custom_params(regression_data): + X_train, X_test, y_train, y_test = regression_data + reg = GAMRegressor(distribution='normal', link='identity', max_iter=200, tol=1e-5) + reg.fit(X_train, y_train) + predictions = reg.predict(X_test) + assert r2_score(y_test, predictions) >= 0 + +def test_gam_classifier_with_custom_params(classification_data): + X_train, X_test, y_train, y_test = classification_data + clf = GAMClassifier(distribution='binomial', link='logit', max_iter=200, tol=1e-5) + clf.fit(X_train, y_train) + predictions = clf.predict(X_test) + proba = clf.predict_proba(X_test) + assert accuracy_score(y_test, predictions) >= 0 + assert proba.shape == (X_test.shape[0], 2) + +def test_gam_regressor_with_callbacks(regression_data): + X_train, X_test, y_train, y_test = regression_data + reg = GAMRegressor(callbacks=['deviance', 'diffs']) + reg.fit(X_train, y_train) + assert 'deviance' in reg.model_.logs_ + assert 'diffs' in reg.model_.logs_ + +def test_gam_classifier_with_callbacks(classification_data): + X_train, X_test, y_train, y_test = classification_data + clf = GAMClassifier(callbacks=['deviance', 'diffs', 'accuracy']) + clf.fit(X_train, y_train) + assert 'deviance' in clf.model_.logs_ + assert 'diffs' in clf.model_.logs_ + assert 'accuracy' in clf.model_.logs_ + +def test_gam_regressor_gamma(): + X = np.random.rand(100, 2) + y = np.random.gamma(shape=2.0, scale=1.0, size=100) + model = GAMRegressor(distribution='gamma') + model.fit(X, y) + predictions = model.predict(X) + assert predictions.shape == y.shape + +def test_gam_regressor_poisson(): + X = np.random.rand(100, 2) + y = np.random.poisson(lam=3.0, size=100) + model = GAMRegressor(distribution='poisson') + model.fit(X, y) + predictions = model.predict(X) + assert predictions.shape == y.shape + +def test_gam_regressor_with_interactions(regression_data): + X_train, X_test, y_train, y_test = regression_data + interactions = [(0, 1), (2, 3)] # Specify feature indices for interactions + reg = GAMRegressor(interactions=interactions) + reg.fit(X_train, y_train) + predictions = reg.predict(X_test) + assert predictions.shape == y_test.shape + assert r2_score(y_test, predictions) >= 0 # Basic sanity check + +def test_gam_classifier_with_interactions(classification_data): + X_train, X_test, y_train, y_test = classification_data + interactions = [(0, 1), (2, 3)] # Specify feature indices for interactions + clf = GAMClassifier(interactions=interactions) + clf.fit(X_train, y_train) + predictions = clf.predict(X_test) + proba = clf.predict_proba(X_test) + assert predictions.shape == y_test.shape + assert set(predictions).issubset({0, 1}) # Binary classification + assert proba.shape == (X_test.shape[0], 2) + assert np.allclose(proba.sum(axis=1), 1) + assert accuracy_score(y_test, predictions) >= 0 From 9873e8a521b5109f1ad957fd0b2d2805f50ef8a8 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 09:08:23 -0500 Subject: [PATCH 03/10] better than before --- pygam/sklearn_api.py | 72 ++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py index b29ce38c..f9bc6075 100644 --- a/pygam/sklearn_api.py +++ b/pygam/sklearn_api.py @@ -7,8 +7,10 @@ from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin from pygam import GAM -from pygam.terms import te, TermList # Import te for interactions +from pygam.terms import te, TermList, Term # Import te for interactions import numpy as np +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split class GAMRegressor(BaseEstimator, RegressorMixin): """ @@ -49,8 +51,8 @@ def __init__( self, distribution='normal', link='identity', - terms='auto', - interactions=None, # Added interactions parameter + terms='auto', # Will be handled by GAM class + interactions=None, callbacks=['deviance', 'diffs'], fit_intercept=True, max_iter=100, @@ -60,8 +62,8 @@ def __init__( ): self.distribution = distribution self.link = link - self.terms = TermList() if terms == 'auto' else terms - self.interactions = interactions # Store interactions + self.terms = terms # Simply pass through to GAM + self.interactions = interactions self.callbacks = callbacks self.fit_intercept = fit_intercept self.max_iter = max_iter @@ -69,24 +71,24 @@ def __init__( self.verbose = verbose self.gam_params = gam_params - # ...existing code... + # Handle interactions if specified if self.interactions: - # Ensure terms is a list before appending - if isinstance(self.terms, str): + if isinstance(self.terms, str) and self.terms == 'auto': + self.terms = [] # Convert 'auto' to empty list to append to + elif isinstance(self.terms, str): self.terms = [self.terms] elif not isinstance(self.terms, list): self.terms = list(self.terms) - # Convert interaction tuples to te instances + + # Add interaction terms for interaction in self.interactions: self.terms.append(te(*interaction)) - if not isinstance(self.terms, TermList): - self.terms = TermList(*self.terms) - + # Initialize the GAM model self.model_ = GAM( distribution=self.distribution, link=self.link, - terms=self.terms, + terms=self.terms, # Pass terms directly, let GAM handle 'auto' callbacks=self.callbacks, fit_intercept=self.fit_intercept, max_iter=self.max_iter, @@ -146,8 +148,8 @@ def __init__( self, distribution='binomial', link='logit', - terms='auto', - interactions=None, # Added interactions parameter + terms='auto', # Will be handled by GAM class + interactions=None, callbacks=['deviance', 'diffs', 'accuracy'], fit_intercept=True, max_iter=100, @@ -157,8 +159,8 @@ def __init__( ): self.distribution = distribution self.link = link - self.terms = TermList() if terms == 'auto' else terms - self.interactions = interactions # Store interactions + self.terms = terms # Simply pass through to GAM + self.interactions = interactions self.callbacks = callbacks self.fit_intercept = fit_intercept self.max_iter = max_iter @@ -166,24 +168,24 @@ def __init__( self.verbose = verbose self.gam_params = gam_params - # ...existing code... + # Handle interactions if specified if self.interactions: - # Ensure terms is a list before appending - if isinstance(self.terms, str): + if isinstance(self.terms, str) and self.terms == 'auto': + self.terms = [] # Convert 'auto' to empty list to append to + elif isinstance(self.terms, str): self.terms = [self.terms] elif not isinstance(self.terms, list): self.terms = list(self.terms) - # Convert interaction tuples to te instances + + # Add interaction terms for interaction in self.interactions: self.terms.append(te(*interaction)) - if not isinstance(self.terms, TermList): - self.terms = TermList(*self.terms) - + # Initialize the GAM model self.model_ = GAM( distribution=self.distribution, link=self.link, - terms=self.terms, + terms=self.terms, # Pass terms directly, let GAM handle 'auto' callbacks=self.callbacks, fit_intercept=self.fit_intercept, max_iter=self.max_iter, @@ -213,4 +215,22 @@ def predict_proba(self, X): def score(self, X, y): from sklearn.metrics import accuracy_score - return accuracy_score(y, self.predict(X)) \ No newline at end of file + return accuracy_score(y, self.predict(X)) + + +if __name__ == '__main__': + + # Generate synthetic data + X, y = make_regression(n_samples=100, n_features=3, noise=0.1) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) + + # Initialize GAMRegressor with 'auto' terms + model = GAMRegressor(terms='auto', verbose=True) + model.fit(X_train, y_train) + + # Inspect the generated terms + print(model.model_.terms) + + # Predict and evaluate + y_pred = model.predict(X_test) + print(f"Test RMSE: {model.rmse(X_test, y_test):.4f}") \ No newline at end of file From 31352c232734580370fb1f4e9893f5c378b62712 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 09:12:21 -0500 Subject: [PATCH 04/10] Fix term list bug in test cases --- pygam/sklearn_api.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py index f9bc6075..5aa1a2aa 100644 --- a/pygam/sklearn_api.py +++ b/pygam/sklearn_api.py @@ -7,11 +7,12 @@ from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin from pygam import GAM -from pygam.terms import te, TermList, Term # Import te for interactions +from pygam.terms import te, TermList, Term # Import te for interactions import numpy as np from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split + class GAMRegressor(BaseEstimator, RegressorMixin): """ GAMRegressor @@ -40,7 +41,7 @@ class GAMRegressor(BaseEstimator, RegressorMixin): Verbosity mode. **gam_params : Additional parameters for the GAM model. - + Attributes ---------- model_ : GAM @@ -58,7 +59,7 @@ def __init__( max_iter=100, tol=1e-4, verbose=False, - **gam_params + **gam_params, ): self.distribution = distribution self.link = link @@ -70,7 +71,7 @@ def __init__( self.tol = tol self.verbose = verbose self.gam_params = gam_params - + # Handle interactions if specified if self.interactions: if isinstance(self.terms, str) and self.terms == 'auto': @@ -79,11 +80,15 @@ def __init__( self.terms = [self.terms] elif not isinstance(self.terms, list): self.terms = list(self.terms) - + # Add interaction terms for interaction in self.interactions: self.terms.append(te(*interaction)) - + + # Convert terms to TermList if necessary + if isinstance(self.terms, list): + self.terms = TermList(*self.terms) + # Initialize the GAM model self.model_ = GAM( distribution=self.distribution, @@ -94,7 +99,7 @@ def __init__( max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, - **self.gam_params + **self.gam_params, ) def fit(self, X, y): @@ -107,6 +112,7 @@ def predict(self, X): def score(self, X, y): return float(self.model_.statistics_.get('pseudo R-squared', 0)) + class GAMClassifier(BaseEstimator, ClassifierMixin): """ GAMClassifier @@ -135,7 +141,7 @@ class GAMClassifier(BaseEstimator, ClassifierMixin): Verbosity mode. **gam_params : Additional parameters for the GAM model. - + Attributes ---------- model_ : GAM @@ -155,7 +161,7 @@ def __init__( max_iter=100, tol=1e-4, verbose=False, - **gam_params + **gam_params, ): self.distribution = distribution self.link = link @@ -167,7 +173,7 @@ def __init__( self.tol = tol self.verbose = verbose self.gam_params = gam_params - + # Handle interactions if specified if self.interactions: if isinstance(self.terms, str) and self.terms == 'auto': @@ -176,11 +182,15 @@ def __init__( self.terms = [self.terms] elif not isinstance(self.terms, list): self.terms = list(self.terms) - + # Add interaction terms for interaction in self.interactions: self.terms.append(te(*interaction)) - + + # Convert terms to TermList if necessary + if isinstance(self.terms, list): + self.terms = TermList(*self.terms) + # Initialize the GAM model self.model_ = GAM( distribution=self.distribution, @@ -191,7 +201,7 @@ def __init__( max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, - **self.gam_params + **self.gam_params, ) def fit(self, X, y): @@ -215,8 +225,9 @@ def predict_proba(self, X): def score(self, X, y): from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) - + if __name__ == '__main__': @@ -233,4 +244,4 @@ def score(self, X, y): # Predict and evaluate y_pred = model.predict(X_test) - print(f"Test RMSE: {model.rmse(X_test, y_test):.4f}") \ No newline at end of file + print(f"Test RMSE: {model.rmse(X_test, y_test):.4f}") From a856b6a5279b5644d4d5c906d71cc7f037a99da8 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 09:14:52 -0500 Subject: [PATCH 05/10] Refactor tests for GAMRegressor and GAMClassifier: simplify assertions and restore test_sklearn_api.py --- pygam/tests/test_penalties.py | 14 +++++++------- {tests => pygam/tests}/test_sklearn_api.py | 0 2 files changed, 7 insertions(+), 7 deletions(-) rename {tests => pygam/tests}/test_sklearn_api.py (100%) diff --git a/pygam/tests/test_penalties.py b/pygam/tests/test_penalties.py index bf6dd68a..43b3ee24 100644 --- a/pygam/tests/test_penalties.py +++ b/pygam/tests/test_penalties.py @@ -23,13 +23,13 @@ def test_single_spline_penalty(): monotonic_ and convexity_ should be 0. """ coef = np.array(1.0) - assert np.alltrue(derivative(1, coef).A == 0.0) - assert np.alltrue(l2(1, coef).A == 1.0) - assert np.alltrue(monotonic_inc(1, coef).A == 0.0) - assert np.alltrue(monotonic_dec(1, coef).A == 0.0) - assert np.alltrue(convex(1, coef).A == 0.0) - assert np.alltrue(concave(1, coef).A == 0.0) - assert np.alltrue(none(1, coef).A == 0.0) + assert np.all(derivative(1, coef).A == 0.0) + assert np.all(l2(1, coef).A == 1.0) + assert np.all(monotonic_inc(1, coef).A == 0.0) + assert np.all(monotonic_dec(1, coef).A == 0.0) + assert np.all(convex(1, coef).A == 0.0) + assert np.all(concave(1, coef).A == 0.0) + assert np.all(none(1, coef).A == 0.0) def test_wrap_penalty(): diff --git a/tests/test_sklearn_api.py b/pygam/tests/test_sklearn_api.py similarity index 100% rename from tests/test_sklearn_api.py rename to pygam/tests/test_sklearn_api.py From f772acc5aad671ce999d9f2a86141c49143ff871 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 09:22:39 -0500 Subject: [PATCH 06/10] Reorganize imports in sklearn_api.py for clarity and consistency --- pygam/sklearn_api.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py index 5aa1a2aa..4b041861 100644 --- a/pygam/sklearn_api.py +++ b/pygam/sklearn_api.py @@ -5,12 +5,18 @@ It integrates pygam's GAM capabilities with scikit-learn's estimator interface, enabling seamless use in machine learning pipelines. """ -from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin -from pygam import GAM -from pygam.terms import te, TermList, Term # Import te for interactions +# Standard library imports import numpy as np + +# Third-party imports +from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score + +# Local application imports +from pygam import GAM +from pygam.terms import te, TermList, Term # Import te for interactions class GAMRegressor(BaseEstimator, RegressorMixin): @@ -224,8 +230,6 @@ def predict_proba(self, X): return proba # Assume GAM model returns probabilities for each class def score(self, X, y): - from sklearn.metrics import accuracy_score - return accuracy_score(y, self.predict(X)) From c9a672bbbc1b32ed29c28f05c814fc324a955d8d Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 09:33:08 -0500 Subject: [PATCH 07/10] Add pandas and pytz dependencies to poetry.lock and update pyproject.toml --- poetry.lock | 61 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index a724bec8..07eadd60 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1056,6 +1056,54 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] +[[package]] +name = "pandas" +version = "1.5.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23"}, + {file = "pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6"}, + {file = "pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf"}, + {file = "pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51"}, + {file = "pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5"}, + {file = "pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a"}, + {file = "pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9"}, + {file = "pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, +] +python-dateutil = ">=2.8.1" +pytz = ">=2020.1" + +[package.extras] +test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] + [[package]] name = "pandocfilters" version = "1.5.1" @@ -1340,6 +1388,17 @@ docs = ["mock", "python-utils", "sphinx"] loguru = ["loguru"] tests = ["flake8", "loguru", "pytest", "pytest-asyncio", "pytest-cov", "pytest-mypy", "sphinx", "types-setuptools"] +[[package]] +name = "pytz" +version = "2024.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, +] + [[package]] name = "pywin32" version = "306" @@ -2071,4 +2130,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9, <=3.13" -content-hash = "84ba875e24ed4d7ac984dabe842dbf2e10056a16d645ae7131bf449b92a14d3f" +content-hash = "324627eba8faf274ca5bee70303e1854f31029b86a8fc61aaa9b27df4b87a2ed" diff --git a/pyproject.toml b/pyproject.toml index f55672ae..4170a419 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ numpy = [ progressbar2 = "^4.2.0" scipy = ">=1.11.4" scikit-learn = "^1.5.2" +pandas = "^1.3.3" [tool.poetry.group.dev.dependencies] pytest = "^7.2.2" From b2623eda041087333ea79484a16a0dc863100a1a Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 11:03:45 -0500 Subject: [PATCH 08/10] Add .env to .gitignore to prevent environment variable files from being tracked --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c4ff995e..ab63179a 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ _build/ ######### .idea/ .vscode/settings.json +.env From 0afb8f024625c0884a793a99c11820371846d108 Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 18:13:32 -0500 Subject: [PATCH 09/10] Update .gitignore and pyproject.toml; enhance GAMRegressor and GAMClassifier to support categorical features --- .gitignore | 1 + poetry.lock | 119 ++++++++++++++++++++++++++++++------------- pygam/sklearn_api.py | 117 +++++++++++++++++++++++++----------------- pyproject.toml | 2 +- 4 files changed, 157 insertions(+), 82 deletions(-) diff --git a/.gitignore b/.gitignore index ab63179a..63479341 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ _build/ .idea/ .vscode/settings.json .env +.vscode/launch.json diff --git a/poetry.lock b/poetry.lock index 07eadd60..e177aab3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1058,51 +1058,89 @@ files = [ [[package]] name = "pandas" -version = "1.5.3" +version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406"}, - {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572"}, - {file = "pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996"}, - {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354"}, - {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23"}, - {file = "pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc"}, - {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae"}, - {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6"}, - {file = "pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792"}, - {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7"}, - {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf"}, - {file = "pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51"}, - {file = "pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a"}, - {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0"}, - {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5"}, - {file = "pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a"}, - {file = "pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9"}, - {file = "pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, ] [package.dependencies] numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] -python-dateutil = ">=2.8.1" +python-dateutil = ">=2.8.2" pytz = ">=2020.1" +tzdata = ">=2022.7" [package.extras] -test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] [[package]] name = "pandocfilters" @@ -2073,6 +2111,17 @@ files = [ {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] +[[package]] +name = "tzdata" +version = "2024.2" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, +] + [[package]] name = "urllib3" version = "2.2.0" @@ -2130,4 +2179,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9, <=3.13" -content-hash = "324627eba8faf274ca5bee70303e1854f31029b86a8fc61aaa9b27df4b87a2ed" +content-hash = "6374b0bba22da8a0e240ae5f655897a449c79134d4e4ef416ec5a1f270385d2f" diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py index 4b041861..fda4882b 100644 --- a/pygam/sklearn_api.py +++ b/pygam/sklearn_api.py @@ -17,6 +17,22 @@ # Local application imports from pygam import GAM from pygam.terms import te, TermList, Term # Import te for interactions +from pygam.terms import s, f, l, intercept # Import s, f, l for splines + +def create_default_terms(X, categorical_features=None): + """Generate default terms for each feature in X, handling categoricals.""" + n_features = X.shape[1] + terms = [] + if categorical_features is None: + categorical_features = [] + for i in range(n_features): + if i in categorical_features: + terms.append(f(i)) + elif not np.issubdtype(X[:, i].dtype, np.number): + terms.append(f(i)) + else: + terms.append(s(i)) + return terms class GAMRegressor(BaseEstimator, RegressorMixin): @@ -31,9 +47,10 @@ class GAMRegressor(BaseEstimator, RegressorMixin): The distribution of the response variable. link : str, default='identity' The link function. - terms : 'auto' or TermList, default='auto' - The terms to include in the model. - interactions : list of tuples, optional + terms : 'auto', None, or list of Term objects, default='auto' + The terms to include in the model. If 'auto', terms are automatically inferred based on X. + If None, no terms are used. If a list of Term objects, they are used as specified. + interactions : None or list of Term objects, optional Interaction terms to include in the model. callbacks : list, default=['deviance', 'diffs'] List of callbacks to monitor during training. @@ -45,6 +62,8 @@ class GAMRegressor(BaseEstimator, RegressorMixin): Tolerance for stopping criteria. verbose : bool, default=False Verbosity mode. + categorical_features : list, optional + List of indices of categorical features. **gam_params : Additional parameters for the GAM model. @@ -58,48 +77,50 @@ def __init__( self, distribution='normal', link='identity', - terms='auto', # Will be handled by GAM class + terms='auto', interactions=None, callbacks=['deviance', 'diffs'], fit_intercept=True, max_iter=100, tol=1e-4, verbose=False, + categorical_features=None, **gam_params, ): self.distribution = distribution self.link = link - self.terms = terms # Simply pass through to GAM + self.terms = terms self.interactions = interactions self.callbacks = callbacks self.fit_intercept = fit_intercept self.max_iter = max_iter self.tol = tol self.verbose = verbose + self.categorical_features = categorical_features self.gam_params = gam_params - # Handle interactions if specified - if self.interactions: - if isinstance(self.terms, str) and self.terms == 'auto': - self.terms = [] # Convert 'auto' to empty list to append to - elif isinstance(self.terms, str): - self.terms = [self.terms] - elif not isinstance(self.terms, list): - self.terms = list(self.terms) + def fit(self, X, y): + if self.terms == 'auto': + self.terms_ = create_default_terms(X, self.categorical_features) + elif self.terms is None: + self.terms_ = [] + else: + self.terms_ = self.terms - # Add interaction terms - for interaction in self.interactions: - self.terms.append(te(*interaction)) + if self.interactions is not None: + self.interactions_ = [te(*interaction) if isinstance(interaction, tuple) else interaction for interaction in self.interactions] + else: + self.interactions_ = [] - # Convert terms to TermList if necessary - if isinstance(self.terms, list): - self.terms = TermList(*self.terms) + # Combine terms and interactions + terms = self.terms_ + self.interactions_ + terms = TermList(*terms) - # Initialize the GAM model + # Create the GAM model with the specified terms self.model_ = GAM( distribution=self.distribution, link=self.link, - terms=self.terms, # Pass terms directly, let GAM handle 'auto' + terms=terms, callbacks=self.callbacks, fit_intercept=self.fit_intercept, max_iter=self.max_iter, @@ -107,8 +128,6 @@ def __init__( verbose=self.verbose, **self.gam_params, ) - - def fit(self, X, y): self.model_.fit(X, y) return self @@ -131,9 +150,10 @@ class GAMClassifier(BaseEstimator, ClassifierMixin): The distribution of the response variable. link : str, default='logit' The link function. - terms : 'auto' or TermList, default='auto' - The terms to include in the model. - interactions : list of tuples, optional + terms : 'auto', None, or list of Term objects, default='auto' + The terms to include in the model. If 'auto', terms are automatically inferred based on X. + If None, no terms are used. If a list of Term objects, they are used as specified. + interactions : None or list of Term objects, optional Interaction terms to include in the model. callbacks : list, default=['deviance', 'diffs', 'accuracy'] List of callbacks to monitor during training. @@ -145,6 +165,8 @@ class GAMClassifier(BaseEstimator, ClassifierMixin): Tolerance for stopping criteria. verbose : bool, default=False Verbosity mode. + categorical_features : list, optional + List of indices of categorical features. **gam_params : Additional parameters for the GAM model. @@ -160,48 +182,53 @@ def __init__( self, distribution='binomial', link='logit', - terms='auto', # Will be handled by GAM class + terms='auto', interactions=None, callbacks=['deviance', 'diffs', 'accuracy'], fit_intercept=True, max_iter=100, tol=1e-4, verbose=False, + categorical_features=None, **gam_params, ): self.distribution = distribution self.link = link - self.terms = terms # Simply pass through to GAM + self.terms = terms self.interactions = interactions self.callbacks = callbacks self.fit_intercept = fit_intercept self.max_iter = max_iter self.tol = tol self.verbose = verbose + self.categorical_features = categorical_features self.gam_params = gam_params - # Handle interactions if specified - if self.interactions: - if isinstance(self.terms, str) and self.terms == 'auto': - self.terms = [] # Convert 'auto' to empty list to append to - elif isinstance(self.terms, str): - self.terms = [self.terms] - elif not isinstance(self.terms, list): - self.terms = list(self.terms) + def fit(self, X, y): + if self.terms == 'auto': + self.terms_ = create_default_terms(X, self.categorical_features) + elif self.terms is None: + self.terms_ = [] + else: + self.terms_ = self.terms - # Add interaction terms - for interaction in self.interactions: - self.terms.append(te(*interaction)) + if self.interactions is not None: + self.interactions_ = [ + te(*interaction) if isinstance(interaction, tuple) else interaction + for interaction in self.interactions + ] + else: + self.interactions_ = [] - # Convert terms to TermList if necessary - if isinstance(self.terms, list): - self.terms = TermList(*self.terms) + # Combine terms and interactions + terms = self.terms_ + self.interactions_ + terms = TermList(*terms) - # Initialize the GAM model + # Create the GAM model with the specified terms self.model_ = GAM( distribution=self.distribution, link=self.link, - terms=self.terms, # Pass terms directly, let GAM handle 'auto' + terms=terms, callbacks=self.callbacks, fit_intercept=self.fit_intercept, max_iter=self.max_iter, @@ -209,8 +236,6 @@ def __init__( verbose=self.verbose, **self.gam_params, ) - - def fit(self, X, y): self.model_.fit(X, y) self.classes_ = np.unique(y) return self diff --git a/pyproject.toml b/pyproject.toml index 4170a419..2cf5ff35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ numpy = [ progressbar2 = "^4.2.0" scipy = ">=1.11.4" scikit-learn = "^1.5.2" -pandas = "^1.3.3" +pandas = ">=1.4.0" # Updated line [tool.poetry.group.dev.dependencies] pytest = "^7.2.2" From 7069f5a4b1a8307f9a2b1e077cdba121b1bd050b Mon Sep 17 00:00:00 2001 From: Nicholas Corona Date: Sun, 1 Dec 2024 18:17:52 -0500 Subject: [PATCH 10/10] Refactor GAMRegressor and GAMClassifier: improve code readability and remove main execution block --- pygam/sklearn_api.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/pygam/sklearn_api.py b/pygam/sklearn_api.py index fda4882b..f218ff11 100644 --- a/pygam/sklearn_api.py +++ b/pygam/sklearn_api.py @@ -19,6 +19,7 @@ from pygam.terms import te, TermList, Term # Import te for interactions from pygam.terms import s, f, l, intercept # Import s, f, l for splines + def create_default_terms(X, categorical_features=None): """Generate default terms for each feature in X, handling categoricals.""" n_features = X.shape[1] @@ -108,7 +109,10 @@ def fit(self, X, y): self.terms_ = self.terms if self.interactions is not None: - self.interactions_ = [te(*interaction) if isinstance(interaction, tuple) else interaction for interaction in self.interactions] + self.interactions_ = [ + te(*interaction) if isinstance(interaction, tuple) else interaction + for interaction in self.interactions + ] else: self.interactions_ = [] @@ -256,21 +260,3 @@ def predict_proba(self, X): def score(self, X, y): return accuracy_score(y, self.predict(X)) - - -if __name__ == '__main__': - - # Generate synthetic data - X, y = make_regression(n_samples=100, n_features=3, noise=0.1) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) - - # Initialize GAMRegressor with 'auto' terms - model = GAMRegressor(terms='auto', verbose=True) - model.fit(X_train, y_train) - - # Inspect the generated terms - print(model.model_.terms) - - # Predict and evaluate - y_pred = model.predict(X_test) - print(f"Test RMSE: {model.rmse(X_test, y_test):.4f}")