Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SGD regressor #374

Merged
merged 8 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"NeuralNetRegressor",
"RandomForestRegressor",
"XGBRegressor",
"SGDRegressor",
bcm-at-zama marked this conversation as resolved.
Show resolved Hide resolved
]
for model_name in REGRESSORS_NAMES:
try:
Expand Down Expand Up @@ -341,6 +342,7 @@ def should_test_config_in_fhe(
"TweedieRegressor",
"PoissonRegressor",
"GammaRegressor",
"SGDRegressor",
}:
return True

Expand Down
1 change: 1 addition & 0 deletions docs/built-in-models/linear.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Concrete ML provides several of the most popular linear models for `regression`
| [Lasso](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-lasso) | [Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html#sklearn.linear_model.Lasso) |
| [Ridge](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-ridge) | [Ridge](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge) |
| [ElasticNet](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-elasticnet) | [ElasticNet](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html#sklearn.linear_model.ElasticNet) |
| [SGDRegressor](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-sgdregressor) | [SGDRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html) |

Using these models in FHE is extremely similar to what can be done with scikit-learn's [API](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model), making it easy for data scientists who have used this framework to get started with Concrete ML.

Expand Down
4 changes: 3 additions & 1 deletion docs/developer-guide/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
- [`concrete.ml.sklearn.base`](./concrete.ml.sklearn.base.md#module-concretemlsklearnbase): Base classes for all estimators.
- [`concrete.ml.sklearn.glm`](./concrete.ml.sklearn.glm.md#module-concretemlsklearnglm): Implement sklearn's Generalized Linear Models (GLM).
- [`concrete.ml.sklearn.linear_model`](./concrete.ml.sklearn.linear_model.md#module-concretemlsklearnlinear_model): Implement sklearn linear model.
- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn linear model.
- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn neighbors model.
- [`concrete.ml.sklearn.qnn`](./concrete.ml.sklearn.qnn.md#module-concretemlsklearnqnn): Scikit-learn interface for fully-connected quantized neural networks.
- [`concrete.ml.sklearn.qnn_module`](./concrete.ml.sklearn.qnn_module.md#module-concretemlsklearnqnn_module): Sparse Quantized Neural Network torch module.
- [`concrete.ml.sklearn.rf`](./concrete.ml.sklearn.rf.md#module-concretemlsklearnrf): Implement RandomForest models.
Expand Down Expand Up @@ -182,6 +182,7 @@
- [`base.SklearnLinearClassifierMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearclassifiermixin): A Mixin class for sklearn linear classifiers with FHE.
- [`base.SklearnLinearModelMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearmodelmixin): A Mixin class for sklearn linear models with FHE.
- [`base.SklearnLinearRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearregressormixin): A Mixin class for sklearn linear regressors with FHE.
- [`base.SklearnSGDRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnsgdregressormixin): A Mixin class for sklearn SGD regressors with FHE.
- [`glm.GammaRegressor`](./concrete.ml.sklearn.glm.md#class-gammaregressor): A Gamma regression model with FHE.
- [`glm.PoissonRegressor`](./concrete.ml.sklearn.glm.md#class-poissonregressor): A Poisson regression model with FHE.
- [`glm.TweedieRegressor`](./concrete.ml.sklearn.glm.md#class-tweedieregressor): A Tweedie regression model with FHE.
Expand All @@ -190,6 +191,7 @@
- [`linear_model.LinearRegression`](./concrete.ml.sklearn.linear_model.md#class-linearregression): A linear regression model with FHE.
- [`linear_model.LogisticRegression`](./concrete.ml.sklearn.linear_model.md#class-logisticregression): A logistic regression model with FHE.
- [`linear_model.Ridge`](./concrete.ml.sklearn.linear_model.md#class-ridge): A Ridge regression model with FHE.
- [`linear_model.SGDRegressor`](./concrete.ml.sklearn.linear_model.md#class-sgdregressor): An FHE linear regression model fitted with stochastic gradient descent.
- [`neighbors.KNeighborsClassifier`](./concrete.ml.sklearn.neighbors.md#class-kneighborsclassifier): A k-nearest neighbors classifier model with FHE.
- [`qnn.NeuralNetClassifier`](./concrete.ml.sklearn.qnn.md#class-neuralnetclassifier): A Fully-Connected Neural Network classifier with FHE.
- [`qnn.NeuralNetRegressor`](./concrete.ml.sklearn.qnn.md#class-neuralnetregressor): A Fully-Connected Neural Network regressor with FHE.
Expand Down
42 changes: 21 additions & 21 deletions docs/developer-guide/api/concrete.ml.common.utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Utils that can be re-used by other pieces of code in the module.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L94"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L95"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `replace_invalid_arg_name_chars`

Expand All @@ -39,7 +39,7 @@ This does not check that the starting character of arg_name is valid.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L113"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L114"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `generate_proxy_function`

Expand All @@ -65,7 +65,7 @@ This returns a runtime compiled function with the sanitized argument names passe

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L154"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L155"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_onnx_opset_version`

Expand All @@ -85,7 +85,7 @@ Return the ONNX opset_version.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L169"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L170"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `manage_parameters_for_pbs_errors`

Expand Down Expand Up @@ -122,7 +122,7 @@ Note that global_p_error is currently set to 0 in the FHE simulation mode.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L214"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L215"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `check_there_is_no_p_error_options_in_configuration`

Expand All @@ -140,7 +140,7 @@ It would be dangerous, since we set them in direct arguments in our calls to Con

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L235"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L236"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_model_class`

Expand All @@ -159,7 +159,7 @@ The model's class.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L257"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L258"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_model_class_in_a_list`

Expand All @@ -179,7 +179,7 @@ If the model's class is in the list or not.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L271"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L272"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_model_name`

Expand All @@ -198,7 +198,7 @@ the model's name.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L284"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L285"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_classifier_or_partial_classifier`

Expand All @@ -218,7 +218,7 @@ Indicate if the model class represents a classifier.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L296"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L297"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_regressor_or_partial_regressor`

Expand All @@ -238,7 +238,7 @@ Indicate if the model class represents a regressor.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L308"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L309"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_dataframe`

Expand All @@ -260,7 +260,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L324"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L325"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_series`

Expand All @@ -282,7 +282,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L340"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L341"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_type`

Expand All @@ -302,7 +302,7 @@ Indicate if the input container is a Pandas DataFrame or Series.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L435"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L436"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `check_dtype_and_cast`

Expand Down Expand Up @@ -334,7 +334,7 @@ If values types don't match with any supported type or the expected dtype, raise

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L487"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L488"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `compute_bits_precision`

Expand All @@ -354,7 +354,7 @@ Compute the number of bits required to represent x.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L499"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L500"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_brevitas_model`

Expand All @@ -374,7 +374,7 @@ Check if a model is a Brevitas type.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L517"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L518"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `to_tuple`

Expand All @@ -394,7 +394,7 @@ Make the input a tuple if it is not already the case.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L533"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L534"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_integers`

Expand All @@ -414,7 +414,7 @@ Indicate if all unpacked values are of a supported integer dtype.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L546"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L547"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_floats`

Expand All @@ -434,7 +434,7 @@ Indicate if all unpacked values are of a supported float dtype.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L559"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L560"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_of_dtype`

Expand All @@ -455,7 +455,7 @@ Indicate if all unpacked values are of the specified dtype(s).

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L51"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L52"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>class</kbd> `FheMode`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Get the post-processing parameters.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L703"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L698"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `bitwidth_and_range_report`

Expand Down
Loading
Loading