Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GAMRegressor and GAMClassifier classes to sklearn_api.py #364

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

nickcorona
Copy link

@nickcorona nickcorona commented Dec 1, 2024

Description

This pull request introduces two new classes, GAMRegressor and GAMClassifier, to the sklearn_api.py file. These classes provide scikit-learn compatible models using Generalized Additive Models (GAM) from the pygam library. The integration allows seamless use of GAM within scikit-learn's estimator interface, enabling these models to be used in standard machine learning pipelines.

Changes

  • Added GAMRegressor class to support regression tasks using GAM.
  • Added GAMClassifier class to support classification tasks using GAM.
  • Both classes support various configurations, including distribution, link function, terms, interactions, callbacks, fit intercept, maximum iterations, tolerance, verbosity, and additional GAM parameters.
  • Included example usage of GAMRegressor in the __main__ block with synthetic data generation, model fitting, and evaluation.

Example Usage

if __name__ == '__main__':
    # Generate synthetic data
    X, y = make_regression(n_samples=100, n_features=3, noise=0.1)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

    # Initialize GAMRegressor with 'auto' terms
    model = GAMRegressor(terms='auto', verbose=True)
    model.fit(X_train, y_train)

    # Inspect the generated terms
    print(model.model_.terms)

    # Predict and evaluate
    y_pred = model.predict(X_test)
    print(f"Test RMSE: {model.rmse(X_test, y_test):.4f}")

Copy link

codecov bot commented Dec 1, 2024

Codecov Report

Attention: Patch coverage is 92.12963% with 17 lines in your changes missing coverage. Please review.

Project coverage is 94.81%. Comparing base (a6c14e4) to head (0afb8f0).

Files with missing lines Patch % Lines
pygam/sklearn_api.py 82.82% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #364      +/-   ##
==========================================
- Coverage   95.00%   94.81%   -0.20%     
==========================================
  Files          22       24       +2     
  Lines        3202     3411     +209     
==========================================
+ Hits         3042     3234     +192     
- Misses        160      177      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant