diff --git a/benchmarks/common.py b/benchmarks/common.py
index cba039347..3c20a9ccd 100644
--- a/benchmarks/common.py
+++ b/benchmarks/common.py
@@ -74,6 +74,7 @@
"NeuralNetRegressor",
"RandomForestRegressor",
"XGBRegressor",
+ "SGDRegressor",
]
for model_name in REGRESSORS_NAMES:
try:
@@ -341,6 +342,7 @@ def should_test_config_in_fhe(
"TweedieRegressor",
"PoissonRegressor",
"GammaRegressor",
+ "SGDRegressor",
}:
return True
diff --git a/docs/built-in-models/linear.md b/docs/built-in-models/linear.md
index 227b2d43c..5c9c6360a 100644
--- a/docs/built-in-models/linear.md
+++ b/docs/built-in-models/linear.md
@@ -14,6 +14,7 @@ Concrete ML provides several of the most popular linear models for `regression`
| [Lasso](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-lasso) | [Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html#sklearn.linear_model.Lasso) |
| [Ridge](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-ridge) | [Ridge](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge) |
| [ElasticNet](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-elasticnet) | [ElasticNet](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html#sklearn.linear_model.ElasticNet) |
+| [SGDRegressor](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-sgdregressor) | [SGDRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html) |
Using these models in FHE is extremely similar to what can be done with scikit-learn's [API](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model), making it easy for data scientists who have used this framework to get started with Concrete ML.
diff --git a/docs/developer-guide/api/README.md b/docs/developer-guide/api/README.md
index 0c94b01da..a25a0e776 100644
--- a/docs/developer-guide/api/README.md
+++ b/docs/developer-guide/api/README.md
@@ -42,7 +42,7 @@
- [`concrete.ml.sklearn.base`](./concrete.ml.sklearn.base.md#module-concretemlsklearnbase): Base classes for all estimators.
- [`concrete.ml.sklearn.glm`](./concrete.ml.sklearn.glm.md#module-concretemlsklearnglm): Implement sklearn's Generalized Linear Models (GLM).
- [`concrete.ml.sklearn.linear_model`](./concrete.ml.sklearn.linear_model.md#module-concretemlsklearnlinear_model): Implement sklearn linear model.
-- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn linear model.
+- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn neighbors model.
- [`concrete.ml.sklearn.qnn`](./concrete.ml.sklearn.qnn.md#module-concretemlsklearnqnn): Scikit-learn interface for fully-connected quantized neural networks.
- [`concrete.ml.sklearn.qnn_module`](./concrete.ml.sklearn.qnn_module.md#module-concretemlsklearnqnn_module): Sparse Quantized Neural Network torch module.
- [`concrete.ml.sklearn.rf`](./concrete.ml.sklearn.rf.md#module-concretemlsklearnrf): Implement RandomForest models.
@@ -182,6 +182,7 @@
- [`base.SklearnLinearClassifierMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearclassifiermixin): A Mixin class for sklearn linear classifiers with FHE.
- [`base.SklearnLinearModelMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearmodelmixin): A Mixin class for sklearn linear models with FHE.
- [`base.SklearnLinearRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearregressormixin): A Mixin class for sklearn linear regressors with FHE.
+- [`base.SklearnSGDRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnsgdregressormixin): A Mixin class for sklearn SGD regressors with FHE.
- [`glm.GammaRegressor`](./concrete.ml.sklearn.glm.md#class-gammaregressor): A Gamma regression model with FHE.
- [`glm.PoissonRegressor`](./concrete.ml.sklearn.glm.md#class-poissonregressor): A Poisson regression model with FHE.
- [`glm.TweedieRegressor`](./concrete.ml.sklearn.glm.md#class-tweedieregressor): A Tweedie regression model with FHE.
@@ -190,6 +191,7 @@
- [`linear_model.LinearRegression`](./concrete.ml.sklearn.linear_model.md#class-linearregression): A linear regression model with FHE.
- [`linear_model.LogisticRegression`](./concrete.ml.sklearn.linear_model.md#class-logisticregression): A logistic regression model with FHE.
- [`linear_model.Ridge`](./concrete.ml.sklearn.linear_model.md#class-ridge): A Ridge regression model with FHE.
+- [`linear_model.SGDRegressor`](./concrete.ml.sklearn.linear_model.md#class-sgdregressor): An FHE linear regression model fitted with stochastic gradient descent.
- [`neighbors.KNeighborsClassifier`](./concrete.ml.sklearn.neighbors.md#class-kneighborsclassifier): A k-nearest neighbors classifier model with FHE.
- [`qnn.NeuralNetClassifier`](./concrete.ml.sklearn.qnn.md#class-neuralnetclassifier): A Fully-Connected Neural Network classifier with FHE.
- [`qnn.NeuralNetRegressor`](./concrete.ml.sklearn.qnn.md#class-neuralnetregressor): A Fully-Connected Neural Network regressor with FHE.
diff --git a/docs/developer-guide/api/concrete.ml.common.utils.md b/docs/developer-guide/api/concrete.ml.common.utils.md
index add00f13a..8ee74fa63 100644
--- a/docs/developer-guide/api/concrete.ml.common.utils.md
+++ b/docs/developer-guide/api/concrete.ml.common.utils.md
@@ -17,7 +17,7 @@ Utils that can be re-used by other pieces of code in the module.
______________________________________________________________________
-
+
## function `replace_invalid_arg_name_chars`
@@ -39,7 +39,7 @@ This does not check that the starting character of arg_name is valid.
______________________________________________________________________
-
+
## function `generate_proxy_function`
@@ -65,7 +65,7 @@ This returns a runtime compiled function with the sanitized argument names passe
______________________________________________________________________
-
+
## function `get_onnx_opset_version`
@@ -85,7 +85,7 @@ Return the ONNX opset_version.
______________________________________________________________________
-
+
## function `manage_parameters_for_pbs_errors`
@@ -122,7 +122,7 @@ Note that global_p_error is currently set to 0 in the FHE simulation mode.
______________________________________________________________________
-
+
## function `check_there_is_no_p_error_options_in_configuration`
@@ -140,7 +140,7 @@ It would be dangerous, since we set them in direct arguments in our calls to Con
______________________________________________________________________
-
+
## function `get_model_class`
@@ -159,7 +159,7 @@ The model's class.
______________________________________________________________________
-
+
## function `is_model_class_in_a_list`
@@ -179,7 +179,7 @@ If the model's class is in the list or not.
______________________________________________________________________
-
+
## function `get_model_name`
@@ -198,7 +198,7 @@ the model's name.
______________________________________________________________________
-
+
## function `is_classifier_or_partial_classifier`
@@ -218,7 +218,7 @@ Indicate if the model class represents a classifier.
______________________________________________________________________
-
+
## function `is_regressor_or_partial_regressor`
@@ -238,7 +238,7 @@ Indicate if the model class represents a regressor.
______________________________________________________________________
-
+
## function `is_pandas_dataframe`
@@ -260,7 +260,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t
______________________________________________________________________
-
+
## function `is_pandas_series`
@@ -282,7 +282,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t
______________________________________________________________________
-
+
## function `is_pandas_type`
@@ -302,7 +302,7 @@ Indicate if the input container is a Pandas DataFrame or Series.
______________________________________________________________________
-
+
## function `check_dtype_and_cast`
@@ -334,7 +334,7 @@ If values types don't match with any supported type or the expected dtype, raise
______________________________________________________________________
-
+
## function `compute_bits_precision`
@@ -354,7 +354,7 @@ Compute the number of bits required to represent x.
______________________________________________________________________
-
+
## function `is_brevitas_model`
@@ -374,7 +374,7 @@ Check if a model is a Brevitas type.
______________________________________________________________________
-
+
## function `to_tuple`
@@ -394,7 +394,7 @@ Make the input a tuple if it is not already the case.
______________________________________________________________________
-
+
## function `all_values_are_integers`
@@ -414,7 +414,7 @@ Indicate if all unpacked values are of a supported integer dtype.
______________________________________________________________________
-
+
## function `all_values_are_floats`
@@ -434,7 +434,7 @@ Indicate if all unpacked values are of a supported float dtype.
______________________________________________________________________
-
+
## function `all_values_are_of_dtype`
@@ -455,7 +455,7 @@ Indicate if all unpacked values are of the specified dtype(s).
______________________________________________________________________
-
+
## class `FheMode`
diff --git a/docs/developer-guide/api/concrete.ml.quantization.quantized_module.md b/docs/developer-guide/api/concrete.ml.quantization.quantized_module.md
index a12d24c9a..74920a642 100644
--- a/docs/developer-guide/api/concrete.ml.quantization.quantized_module.md
+++ b/docs/developer-guide/api/concrete.ml.quantization.quantized_module.md
@@ -67,7 +67,7 @@ Get the post-processing parameters.
______________________________________________________________________
-
+
### method `bitwidth_and_range_report`
diff --git a/docs/developer-guide/api/concrete.ml.sklearn.base.md b/docs/developer-guide/api/concrete.ml.sklearn.base.md
index a18bd0dfb..a131edf9d 100644
--- a/docs/developer-guide/api/concrete.ml.sklearn.base.md
+++ b/docs/developer-guide/api/concrete.ml.sklearn.base.md
@@ -14,7 +14,7 @@ Base classes for all estimators.
______________________________________________________________________
-
+
## class `BaseEstimator`
@@ -26,7 +26,7 @@ This class does not inherit from sklearn.base.BaseEstimator as it creates some c
- `_is_a_public_cml_model` (bool): Private attribute indicating if the class is a public model (as opposed to base or mixin classes).
-
+
### method `__init__`
@@ -84,7 +84,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -100,7 +100,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -116,7 +116,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -150,7 +150,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -172,7 +172,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
### method `dump`
@@ -188,7 +188,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -204,7 +204,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -220,7 +220,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -243,7 +243,7 @@ The fitted estimator.
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -270,7 +270,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -292,7 +292,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -312,7 +312,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -336,7 +336,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -360,7 +360,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -382,7 +382,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
## class `BaseClassifier`
@@ -390,7 +390,7 @@ Base class for linear and tree-based classifiers in Concrete ML.
This class inherits from BaseEstimator and modifies some of its methods in order to align them with classifier behaviors. This notably include applying a sigmoid/softmax post-processing to the predicted values as well as handling a mapping of classes in case they are not ordered.
-
+
### method `__init__`
@@ -472,7 +472,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -488,7 +488,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -504,7 +504,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -538,7 +538,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -560,7 +560,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
### method `dump`
@@ -576,7 +576,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -592,7 +592,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -608,7 +608,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -618,7 +618,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -645,7 +645,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -667,7 +667,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -687,7 +687,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -697,7 +697,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -710,7 +710,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -734,7 +734,7 @@ Predict class probabilities.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -756,13 +756,13 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
## class `QuantizedTorchEstimatorMixin`
Mixin that provides quantization for a torch module and follows the Estimator API.
-
+
### method `__init__`
@@ -838,7 +838,7 @@ Get the output quantizers.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -854,7 +854,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -870,7 +870,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -888,7 +888,7 @@ compile(
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -898,7 +898,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -914,7 +914,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -930,7 +930,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -946,7 +946,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -971,7 +971,7 @@ The fitted estimator.
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1002,7 +1002,7 @@ The Concrete ML and equivalent skorch fitted estimators.
______________________________________________________________________
-
+
### method `get_params`
@@ -1024,7 +1024,7 @@ This method is overloaded in order to make sure that auto-computed parameters ar
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -1034,7 +1034,7 @@ get_sklearn_params(deep: 'bool' = True) → Dict
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1054,7 +1054,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1064,7 +1064,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1088,7 +1088,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `prune`
@@ -1116,7 +1116,7 @@ A new pruned copy of the Neural Network model.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1126,7 +1126,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeEstimatorMixin`
@@ -1134,7 +1134,7 @@ Mixin class for tree-based estimators.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -1194,7 +1194,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1210,7 +1210,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1226,7 +1226,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1236,7 +1236,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1246,7 +1246,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1262,7 +1262,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1278,7 +1278,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1294,7 +1294,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1304,7 +1304,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1331,7 +1331,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -1353,7 +1353,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1373,7 +1373,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1383,7 +1383,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1396,7 +1396,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1406,7 +1406,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeRegressorMixin`
@@ -1414,7 +1414,7 @@ Mixin class for tree-based regressors.
This class is used to create a tree-based regressor class that inherits from sklearn.base.RegressorMixin, which essentially gives access to scikit-learn's `score` method for regressors.
-
+
### method `__init__`
@@ -1474,7 +1474,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1490,7 +1490,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1506,7 +1506,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1516,7 +1516,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1526,7 +1526,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1542,7 +1542,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1558,7 +1558,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1574,7 +1574,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1584,7 +1584,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1611,7 +1611,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -1633,7 +1633,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1653,7 +1653,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1663,7 +1663,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1676,7 +1676,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1686,7 +1686,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeClassifierMixin`
@@ -1696,7 +1696,7 @@ This class is used to create a tree-based classifier class that inherits from sk
Additionally, this class adjusts some of the tree-based base class's methods in order to make them compliant with classification workflows.
-
+
### method `__init__`
@@ -1780,7 +1780,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1796,7 +1796,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1812,7 +1812,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1822,7 +1822,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1832,7 +1832,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1848,7 +1848,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1864,7 +1864,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1880,7 +1880,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1890,7 +1890,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1917,7 +1917,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -1939,7 +1939,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1959,7 +1959,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1969,7 +1969,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1982,7 +1982,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -2006,7 +2006,7 @@ Predict class probabilities.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2016,7 +2016,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnLinearModelMixin`
@@ -2024,7 +2024,7 @@ A Mixin class for sklearn linear models with FHE.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -2086,7 +2086,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2102,7 +2102,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2118,7 +2118,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2152,7 +2152,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2162,7 +2162,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2178,7 +2178,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2194,7 +2194,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2210,7 +2210,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2220,7 +2220,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -2247,7 +2247,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -2274,7 +2274,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -2296,7 +2296,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -2316,7 +2316,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -2340,7 +2340,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -2364,7 +2364,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2374,7 +2374,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnLinearRegressorMixin`
@@ -2382,7 +2382,7 @@ A Mixin class for sklearn linear regressors with FHE.
This class is used to create a linear regressor class that inherits from sklearn.base.RegressorMixin, which essentially gives access to scikit-learn's `score` method for regressors.
-
+
### method `__init__`
@@ -2444,7 +2444,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2460,7 +2460,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2476,7 +2476,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2510,7 +2510,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2520,7 +2520,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2536,7 +2536,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2552,7 +2552,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2568,7 +2568,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2578,7 +2578,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -2605,7 +2605,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -2632,7 +2632,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -2654,7 +2654,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -2674,7 +2674,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -2698,7 +2698,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -2722,7 +2722,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2732,7 +2732,365 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
+
+## class `SklearnSGDRegressorMixin`
+
+A Mixin class for sklearn SGD regressors with FHE.
+
+This class is used to create a SGD regressor class what can be exported to ONNX using Hummingbird.
+
+
+
+### method `__init__`
+
+```python
+__init__(n_bits: 'Union[int, Dict[str, int]]' = 8)
+```
+
+Initialize the FHE linear model.
+
+**Args:**
+
+- `n_bits` (int, Dict\[str, int\]): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+______________________________________________________________________
+
+#### property fhe_circuit
+
+Get the FHE circuit.
+
+The FHE circuit combines computational graph, mlir, client and server into a single object. More information available in Concrete documentation (https://docs.zama.ai/concrete/getting-started/terminology_and_structure) Is None if the model is not fitted.
+
+**Returns:**
+
+- `Circuit`: The FHE circuit.
+
+______________________________________________________________________
+
+#### property is_compiled
+
+Indicate if the model is compiled.
+
+**Returns:**
+
+- `bool`: If the model is compiled.
+
+______________________________________________________________________
+
+#### property is_fitted
+
+Indicate if the model is fitted.
+
+**Returns:**
+
+- `bool`: If the model is fitted.
+
+______________________________________________________________________
+
+#### property onnx_model
+
+Get the ONNX model.
+
+Is None if the model is not fitted.
+
+**Returns:**
+
+- `onnx.ModelProto`: The ONNX model.
+
+______________________________________________________________________
+
+
+
+### method `check_model_is_compiled`
+
+```python
+check_model_is_compiled() → None
+```
+
+Check if the model is compiled.
+
+**Raises:**
+
+- `AttributeError`: If the model is not compiled.
+
+______________________________________________________________________
+
+
+
+### method `check_model_is_fitted`
+
+```python
+check_model_is_fitted() → None
+```
+
+Check if the model is fitted.
+
+**Raises:**
+
+- `AttributeError`: If the model is not fitted.
+
+______________________________________________________________________
+
+
+
+### method `compile`
+
+```python
+compile(
+ X: 'Data',
+ configuration: 'Optional[Configuration]' = None,
+ artifacts: 'Optional[DebugArtifacts]' = None,
+ show_mlir: 'bool' = False,
+ p_error: 'Optional[float]' = None,
+ global_p_error: 'Optional[float]' = None,
+ verbose: 'bool' = False
+) → Circuit
+```
+
+Compile the model.
+
+**Args:**
+
+- `X` (Data): A representative set of input values used for building cryptographic parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `configuration` (Optional\[Configuration\]): Options to use for compilation. Default to None.
+- `artifacts` (Optional\[DebugArtifacts\]): Artifacts information about the compilation process to store for debugging. Default to None.
+- `show_mlir` (bool): Indicate if the MLIR graph should be printed during compilation. Default to False.
+- `p_error` (Optional\[float\]): Probability of error of a single PBS. A p_error value cannot be given if a global_p_error value is already set. Default to None, which sets this error to a default value.
+- `global_p_error` (Optional\[float\]): Probability of error of the full circuit. A global_p_error value cannot be given if a p_error value is already set. This feature is not supported during the FHE simulation mode, meaning the probability is currently set to 0. Default to None, which sets this error to a default value.
+- `verbose` (bool): Indicate if compilation information should be printed during compilation. Default to False.
+
+**Returns:**
+
+- `Circuit`: The compiled Circuit.
+
+______________________________________________________________________
+
+
+
+### method `dequantize_output`
+
+```python
+dequantize_output(q_y_preds: 'ndarray') → ndarray
+```
+
+______________________________________________________________________
+
+
+
+### method `dump`
+
+```python
+dump(file: 'TextIO') → None
+```
+
+Dump itself to a file.
+
+**Args:**
+
+- `file` (TextIO): The file to dump the serialized object into.
+
+______________________________________________________________________
+
+
+
+### method `dump_dict`
+
+```python
+dump_dict() → Dict[str, Any]
+```
+
+Dump the object as a dict.
+
+**Returns:**
+
+- `Dict[str, Any]`: Dict of serialized objects.
+
+______________________________________________________________________
+
+
+
+### method `dumps`
+
+```python
+dumps() → str
+```
+
+Dump itself to a string.
+
+**Returns:**
+
+- `metadata` (str): String of the serialized object.
+
+______________________________________________________________________
+
+
+
+### method `fit`
+
+```python
+fit(X: 'Data', y: 'Target', **fit_parameters)
+```
+
+______________________________________________________________________
+
+
+
+### method `fit_benchmark`
+
+```python
+fit_benchmark(
+ X: 'Data',
+ y: 'Target',
+ random_state: 'Optional[int]' = None,
+ **fit_parameters
+)
+```
+
+Fit both the Concrete ML and its equivalent float estimators.
+
+**Args:**
+
+- `X` (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
+- `y` (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas Series or List.
+- `random_state` (Optional\[int\]): The random state to use when fitting. Defaults to None.
+- `**fit_parameters`: Keyword arguments to pass to the float estimator's fit method.
+
+**Returns:**
+The Concrete ML and float equivalent fitted estimators.
+
+______________________________________________________________________
+
+
+
+### classmethod `from_sklearn_model`
+
+```python
+from_sklearn_model(
+ sklearn_model: 'BaseEstimator',
+ X: 'Data',
+ n_bits: 'Union[int, Dict[str, int]]' = 8
+)
+```
+
+Build a FHE-compliant model using a fitted scikit-learn model.
+
+**Args:**
+
+- `sklearn_model` (sklearn.base.BaseEstimator): The fitted scikit-learn model to convert.
+- `X` (Data): A representative set of input values used for computing quantization parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `n_bits` (int, Dict\[str, int\]): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+**Returns:**
+The FHE-compliant fitted model.
+
+______________________________________________________________________
+
+
+
+### method `get_sklearn_params`
+
+```python
+get_sklearn_params(deep: 'bool' = True) → dict
+```
+
+Get parameters for this estimator.
+
+This method is used to instantiate a scikit-learn model using the Concrete ML model's parameters. It does not override scikit-learn's existing `get_params` method in order to not break its implementation of `set_params`.
+
+**Args:**
+
+- `deep` (bool): If True, will return the parameters for this estimator and contained subobjects that are estimators. Default to True.
+
+**Returns:**
+
+- `params` (dict): Parameter names mapped to their values.
+
+______________________________________________________________________
+
+
+
+### classmethod `load_dict`
+
+```python
+load_dict(metadata: 'Dict[str, Any]') → BaseEstimator
+```
+
+Load itself from a dict.
+
+**Args:**
+
+- `metadata` (Dict\[str, Any\]): Dict of serialized objects.
+
+**Returns:**
+
+- `BaseEstimator`: The loaded object.
+
+______________________________________________________________________
+
+
+
+### method `post_processing`
+
+```python
+post_processing(y_preds: 'ndarray') → ndarray
+```
+
+Apply post-processing to the de-quantized predictions.
+
+This post-processing step can include operations such as applying the sigmoid or softmax function for classifiers, or summing an ensemble's outputs. These steps are done in the clear because of current technical constraints. They most likely will be integrated in the FHE computations in the future.
+
+For some simple models such a linear regression, there is no post-processing step but the method is kept to make the API consistent for the client-server API. Other models might need to use attributes stored in `post_processing_params`.
+
+**Args:**
+
+- `y_preds` (numpy.ndarray): The de-quantized predictions to post-process.
+
+**Returns:**
+
+- `numpy.ndarray`: The post-processed predictions.
+
+______________________________________________________________________
+
+
+
+### method `predict`
+
+```python
+predict(
+ X: 'Data',
+ fhe: 'Union[FheMode, str]' =
+) → ndarray
+```
+
+Predict values for X, in FHE or in the clear.
+
+**Args:**
+
+- `X` (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame or List.
+- `fhe` (Union\[FheMode, str\]): The mode to use for prediction. Can be FheMode.DISABLE for Concrete ML Python inference, FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution. Can also be the string representation of any of these values. Default to FheMode.DISABLE.
+
+**Returns:**
+
+- `np.ndarray`: The predicted values for X.
+
+______________________________________________________________________
+
+
+
+### method `quantize_input`
+
+```python
+quantize_input(X: 'ndarray') → ndarray
+```
+
+______________________________________________________________________
+
+
## class `SklearnLinearClassifierMixin`
@@ -2742,7 +3100,7 @@ This class is used to create a linear classifier class that inherits from sklear
Additionally, this class adjusts some of the tree-based base class's methods in order to make them compliant with classification workflows.
-
+
### method `__init__`
@@ -2828,7 +3186,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2844,7 +3202,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2860,7 +3218,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2894,7 +3252,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `decision_function`
@@ -2918,7 +3276,7 @@ Predict confidence scores.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2928,7 +3286,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2944,7 +3302,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2960,7 +3318,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2976,7 +3334,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2986,7 +3344,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3013,7 +3371,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -3040,7 +3398,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3062,7 +3420,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -3082,7 +3440,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3092,7 +3450,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -3105,7 +3463,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -3118,7 +3476,7 @@ predict_proba(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -3128,7 +3486,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnKNeighborsMixin`
@@ -3136,7 +3494,7 @@ A Mixin class for sklearn KNeighbors models with FHE.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -3148,7 +3506,7 @@ Initialize the FHE knn model.
**Args:**
-- `n_bits` (int): Number of bits to quantize the model. IThe value will be used for quantizing inputs and X_fit. Default to 3.
+- `n_bits` (int): Number of bits to quantize the model. The value will be used for quantizing inputs and X_fit. Default to 3.
______________________________________________________________________
@@ -3196,7 +3554,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -3212,7 +3570,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -3228,17 +3586,41 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
```python
-compile(*args, **kwargs) → Circuit
+compile(
+ X: 'Data',
+ configuration: 'Optional[Configuration]' = None,
+ artifacts: 'Optional[DebugArtifacts]' = None,
+ show_mlir: 'bool' = False,
+ p_error: 'Optional[float]' = None,
+ global_p_error: 'Optional[float]' = None,
+ verbose: 'bool' = False
+) → Circuit
```
+Compile the model.
+
+**Args:**
+
+- `X` (Data): A representative set of input values used for building cryptographic parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `configuration` (Optional\[Configuration\]): Options to use for compilation. Default to None.
+- `artifacts` (Optional\[DebugArtifacts\]): Artifacts information about the compilation process to store for debugging. Default to None.
+- `show_mlir` (bool): Indicate if the MLIR graph should be printed during compilation. Default to False.
+- `p_error` (Optional\[float\]): Probability of error of a single PBS. A p_error value cannot be given if a global_p_error value is already set. Default to None, which sets this error to a default value.
+- `global_p_error` (Optional\[float\]): Probability of error of the full circuit. A global_p_error value cannot be given if a p_error value is already set. This feature is not supported during the FHE simulation mode, meaning the probability is currently set to 0. Default to None, which sets this error to a default value.
+- `verbose` (bool): Indicate if compilation information should be printed during compilation. Default to False.
+
+**Returns:**
+
+- `Circuit`: The compiled Circuit.
+
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -3248,7 +3630,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -3264,7 +3646,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -3280,7 +3662,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -3296,7 +3678,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -3306,7 +3688,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3333,7 +3715,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3355,7 +3737,31 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
+
+### method `get_topk_labels`
+
+```python
+get_topk_labels(
+ X: 'Data',
+ fhe: 'Union[FheMode, str]' =
+) → ndarray
+```
+
+Return the K-nearest labels of each point.
+
+**Args:**
+
+- `X` (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame or List.
+- `fhe` (Union\[FheMode, str\]): The mode to use for prediction. Can be FheMode.DISABLE for Concrete ML Python inference, FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution. Can also be the string representation of any of these values. Default to FheMode.DISABLE.
+
+**Returns:**
+
+- `numpy.ndarray`: The K-Nearest labels for each point.
+
+______________________________________________________________________
+
+
### classmethod `load_dict`
@@ -3375,7 +3781,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `majority_vote`
@@ -3395,7 +3801,7 @@ Determine the most common class among nearest neighborsfor each query.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3403,21 +3809,21 @@ ______________________________________________________________________
post_processing(y_preds: 'ndarray') → ndarray
```
-Perform the majority.
+Provide the majority vote among the topk labels of each point.
For KNN, the de-quantization step is not required. Because \_inference returns the label of the k-nearest neighbors.
**Args:**
-- `y_preds` (numpy.ndarray): The topk nearest labels
+- `y_preds` (numpy.ndarray): The topk nearest labels for each point.
**Returns:**
-- `numpy.ndarray`: The majority vote.
+- `numpy.ndarray`: The majority vote for each point.
______________________________________________________________________
-
+
### method `predict`
@@ -3430,7 +3836,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -3440,7 +3846,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnKNeighborsClassifierMixin`
@@ -3448,7 +3854,7 @@ A Mixin class for sklearn KNeighbors classifiers with FHE.
This class is used to create a KNeighbors classifier class that inherits from SklearnKNeighborsMixin and sklearn.base.ClassifierMixin. By inheriting from sklearn.base.ClassifierMixin, it allows this class to be recognized as a classifier."
-
+
### method `__init__`
@@ -3460,7 +3866,7 @@ Initialize the FHE knn model.
**Args:**
-- `n_bits` (int): Number of bits to quantize the model. IThe value will be used for quantizing inputs and X_fit. Default to 3.
+- `n_bits` (int): Number of bits to quantize the model. The value will be used for quantizing inputs and X_fit. Default to 3.
______________________________________________________________________
@@ -3508,7 +3914,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -3524,7 +3930,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -3540,17 +3946,41 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
```python
-compile(*args, **kwargs) → Circuit
+compile(
+ X: 'Data',
+ configuration: 'Optional[Configuration]' = None,
+ artifacts: 'Optional[DebugArtifacts]' = None,
+ show_mlir: 'bool' = False,
+ p_error: 'Optional[float]' = None,
+ global_p_error: 'Optional[float]' = None,
+ verbose: 'bool' = False
+) → Circuit
```
+Compile the model.
+
+**Args:**
+
+- `X` (Data): A representative set of input values used for building cryptographic parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `configuration` (Optional\[Configuration\]): Options to use for compilation. Default to None.
+- `artifacts` (Optional\[DebugArtifacts\]): Artifacts information about the compilation process to store for debugging. Default to None.
+- `show_mlir` (bool): Indicate if the MLIR graph should be printed during compilation. Default to False.
+- `p_error` (Optional\[float\]): Probability of error of a single PBS. A p_error value cannot be given if a global_p_error value is already set. Default to None, which sets this error to a default value.
+- `global_p_error` (Optional\[float\]): Probability of error of the full circuit. A global_p_error value cannot be given if a p_error value is already set. This feature is not supported during the FHE simulation mode, meaning the probability is currently set to 0. Default to None, which sets this error to a default value.
+- `verbose` (bool): Indicate if compilation information should be printed during compilation. Default to False.
+
+**Returns:**
+
+- `Circuit`: The compiled Circuit.
+
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -3560,7 +3990,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -3576,7 +4006,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -3592,7 +4022,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -3608,7 +4038,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -3618,7 +4048,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3645,7 +4075,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3667,7 +4097,31 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
+
+### method `get_topk_labels`
+
+```python
+get_topk_labels(
+ X: 'Data',
+ fhe: 'Union[FheMode, str]' =
+) → ndarray
+```
+
+Return the K-nearest labels of each point.
+
+**Args:**
+
+- `X` (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame or List.
+- `fhe` (Union\[FheMode, str\]): The mode to use for prediction. Can be FheMode.DISABLE for Concrete ML Python inference, FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution. Can also be the string representation of any of these values. Default to FheMode.DISABLE.
+
+**Returns:**
+
+- `numpy.ndarray`: The K-Nearest labels for each point.
+
+______________________________________________________________________
+
+
### classmethod `load_dict`
@@ -3687,7 +4141,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `majority_vote`
@@ -3707,7 +4161,7 @@ Determine the most common class among nearest neighborsfor each query.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3715,21 +4169,21 @@ ______________________________________________________________________
post_processing(y_preds: 'ndarray') → ndarray
```
-Perform the majority.
+Provide the majority vote among the topk labels of each point.
For KNN, the de-quantization step is not required. Because \_inference returns the label of the k-nearest neighbors.
**Args:**
-- `y_preds` (numpy.ndarray): The topk nearest labels
+- `y_preds` (numpy.ndarray): The topk nearest labels for each point.
**Returns:**
-- `numpy.ndarray`: The majority vote.
+- `numpy.ndarray`: The majority vote for each point.
______________________________________________________________________
-
+
### method `predict`
@@ -3742,7 +4196,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
diff --git a/docs/developer-guide/api/concrete.ml.sklearn.linear_model.md b/docs/developer-guide/api/concrete.ml.sklearn.linear_model.md
index 31c935dd0..463b6393c 100644
--- a/docs/developer-guide/api/concrete.ml.sklearn.linear_model.md
+++ b/docs/developer-guide/api/concrete.ml.sklearn.linear_model.md
@@ -8,7 +8,7 @@ Implement sklearn linear model.
______________________________________________________________________
-
+
## class `LinearRegression`
@@ -22,7 +22,7 @@ A linear regression model with FHE.
For more details on LinearRegression please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
-
+
### method `__init__`
@@ -83,7 +83,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -93,7 +93,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -103,7 +103,116 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
+
+## class `SGDRegressor`
+
+An FHE linear regression model fitted with stochastic gradient descent.
+
+**Parameters:**
+
+- `n_bits` (int, Dict\[str, int\]): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+For more details on SGDRegressor please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html
+
+
+
+### method `__init__`
+
+```python
+__init__(
+ n_bits=8,
+ loss='squared_error',
+ penalty='l2',
+ alpha=0.0001,
+ l1_ratio=0.15,
+ fit_intercept=True,
+ max_iter=1000,
+ tol=0.001,
+ shuffle=True,
+ verbose=0,
+ epsilon=0.1,
+ random_state=None,
+ learning_rate='invscaling',
+ eta0=0.01,
+ power_t=0.25,
+ early_stopping=False,
+ validation_fraction=0.1,
+ n_iter_no_change=5,
+ warm_start=False,
+ average=False
+)
+```
+
+______________________________________________________________________
+
+#### property fhe_circuit
+
+Get the FHE circuit.
+
+The FHE circuit combines computational graph, mlir, client and server into a single object. More information available in Concrete documentation (https://docs.zama.ai/concrete/getting-started/terminology_and_structure) Is None if the model is not fitted.
+
+**Returns:**
+
+- `Circuit`: The FHE circuit.
+
+______________________________________________________________________
+
+#### property is_compiled
+
+Indicate if the model is compiled.
+
+**Returns:**
+
+- `bool`: If the model is compiled.
+
+______________________________________________________________________
+
+#### property is_fitted
+
+Indicate if the model is fitted.
+
+**Returns:**
+
+- `bool`: If the model is fitted.
+
+______________________________________________________________________
+
+#### property onnx_model
+
+Get the ONNX model.
+
+Is None if the model is not fitted.
+
+**Returns:**
+
+- `onnx.ModelProto`: The ONNX model.
+
+______________________________________________________________________
+
+
+
+### method `dump_dict`
+
+```python
+dump_dict() → Dict[str, Any]
+```
+
+______________________________________________________________________
+
+
+
+### classmethod `load_dict`
+
+```python
+load_dict(metadata: Dict)
+```
+
+______________________________________________________________________
+
+
## class `ElasticNet`
@@ -117,7 +226,7 @@ An ElasticNet regression model with FHE.
For more details on ElasticNet please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html
-
+
### method `__init__`
@@ -185,7 +294,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -195,7 +304,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -205,7 +314,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `Lasso`
@@ -219,7 +328,7 @@ A Lasso regression model with FHE.
For more details on Lasso please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html
-
+
### method `__init__`
@@ -286,7 +395,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -296,7 +405,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -306,7 +415,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `Ridge`
@@ -320,7 +429,7 @@ A Ridge regression model with FHE.
For more details on Ridge please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html
-
+
### method `__init__`
@@ -385,7 +494,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -395,7 +504,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -405,7 +514,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `LogisticRegression`
@@ -419,7 +528,7 @@ A logistic regression model with FHE.
For more details on LogisticRegression please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
-
+
### method `__init__`
@@ -514,7 +623,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -524,7 +633,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
diff --git a/docs/developer-guide/api/concrete.ml.sklearn.neighbors.md b/docs/developer-guide/api/concrete.ml.sklearn.neighbors.md
index 339627a32..bb147fc13 100644
--- a/docs/developer-guide/api/concrete.ml.sklearn.neighbors.md
+++ b/docs/developer-guide/api/concrete.ml.sklearn.neighbors.md
@@ -4,7 +4,7 @@
# module `concrete.ml.sklearn.neighbors`
-Implement sklearn linear model.
+Implement sklearn neighbors model.
______________________________________________________________________
@@ -84,7 +84,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -94,7 +94,27 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
+
+### method `kneighbors`
+
+```python
+kneighbors(X: Union[ndarray, Tensor, ForwardRef('DataFrame'), List]) → ndarray
+```
+
+Return the knearest distances and their respective indices for each query point.
+
+**Args:**
+
+- `X` (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame or List.
+
+**Raises:**
+
+- `NotImplementedError`: The method is not implemented for now.
+
+______________________________________________________________________
+
+
### classmethod `load_dict`
@@ -104,7 +124,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
### method `predict_proba`
diff --git a/src/concrete/ml/sklearn/__init__.py b/src/concrete/ml/sklearn/__init__.py
index 39160ffea..1855fffa1 100644
--- a/src/concrete/ml/sklearn/__init__.py
+++ b/src/concrete/ml/sklearn/__init__.py
@@ -15,7 +15,14 @@
_TREE_MODELS,
)
from .glm import GammaRegressor, PoissonRegressor, TweedieRegressor
-from .linear_model import ElasticNet, Lasso, LinearRegression, LogisticRegression, Ridge
+from .linear_model import (
+ ElasticNet,
+ Lasso,
+ LinearRegression,
+ LogisticRegression,
+ Ridge,
+ SGDRegressor,
+)
from .neighbors import KNeighborsClassifier
from .qnn import NeuralNetClassifier, NeuralNetRegressor
from .rf import RandomForestClassifier, RandomForestRegressor
diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py
index 84769b88a..707646b56 100644
--- a/src/concrete/ml/sklearn/base.py
+++ b/src/concrete/ml/sklearn/base.py
@@ -27,6 +27,7 @@
from concrete.fhe.compilation.configuration import Configuration
from concrete.fhe.dtypes.integer import Integer
from sklearn.base import clone
+from sklearn.linear_model import LinearRegression
from sklearn.utils.validation import check_is_fitted
from ..common.check_inputs import check_array_and_assert, check_X_y_and_assert_multi_output
@@ -1679,6 +1680,38 @@ class SklearnLinearRegressorMixin(SklearnLinearModelMixin, sklearn.base.Regresso
"""
+class SklearnSGDRegressorMixin(SklearnLinearRegressorMixin):
+ """A Mixin class for sklearn SGD regressors with FHE.
+
+ This class is used to create a SGD regressor class what can be exported
+ to ONNX using Hummingbird.
+ """
+
+ # Remove once Hummingbird supports SGDRegressor
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4100
+ def _set_onnx_model(self, test_input: numpy.ndarray) -> None:
+ """Retrieve the model's ONNX graph using Hummingbird conversion.
+
+ Args:
+ test_input (numpy.ndarray): An input data used to trace the model execution.
+ """
+ # Check that the underlying sklearn model has been set and fit
+ assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message()
+
+ model_for_onnx = LinearRegression()
+ model_for_onnx.coef_ = self.sklearn_model.coef_
+ model_for_onnx.intercept_ = self.sklearn_model.intercept_
+
+ self.onnx_model_ = hb_convert(
+ model_for_onnx,
+ backend="onnx",
+ test_input=test_input,
+ extra_config={"onnx_target_opset": OPSET_VERSION_FOR_ONNX_EXPORT},
+ ).model
+
+ self._clean_graph()
+
+
class SklearnLinearClassifierMixin(
BaseClassifier, SklearnLinearModelMixin, sklearn.base.ClassifierMixin, ABC
):
diff --git a/src/concrete/ml/sklearn/linear_model.py b/src/concrete/ml/sklearn/linear_model.py
index 2c3cf85b7..96c08f2a9 100644
--- a/src/concrete/ml/sklearn/linear_model.py
+++ b/src/concrete/ml/sklearn/linear_model.py
@@ -3,7 +3,11 @@
import sklearn.linear_model
-from .base import SklearnLinearClassifierMixin, SklearnLinearRegressorMixin
+from .base import (
+ SklearnLinearClassifierMixin,
+ SklearnLinearRegressorMixin,
+ SklearnSGDRegressorMixin,
+)
# pylint: disable=invalid-name,too-many-instance-attributes
@@ -99,6 +103,176 @@ def load_dict(cls, metadata: Dict):
return obj
+class SGDRegressor(SklearnSGDRegressorMixin):
+ """An FHE linear regression model fitted with stochastic gradient descent.
+
+ Parameters:
+ n_bits (int, Dict[str, int]): Number of bits to quantize the model. If an int is passed
+ for n_bits, the value will be used for quantizing inputs and weights. If a dict is
+ passed, then it should contain "op_inputs" and "op_weights" as keys with
+ corresponding number of quantization bits so that:
+ - op_inputs : number of bits to quantize the input values
+ - op_weights: number of bits to quantize the learned parameters
+ Default to 8.
+
+ For more details on SGDRegressor please refer to the scikit-learn documentation:
+ https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html
+ """
+
+ sklearn_model_class = sklearn.linear_model.SGDRegressor
+
+ _is_a_public_cml_model = True
+
+ _args = [
+ "penalty",
+ "alpha",
+ "l1_ratio",
+ "fit_intercept",
+ "max_iter",
+ "tol",
+ "shuffle",
+ "verbose",
+ "epsilon",
+ "random_state",
+ "learning_rate",
+ "eta0",
+ "power_t",
+ "early_stopping",
+ "validation_fraction",
+ "n_iter_no_change",
+ "warm_start",
+ "average",
+ ]
+
+ def __init__(
+ self,
+ n_bits=8,
+ loss="squared_error",
+ *,
+ penalty="l2",
+ alpha=0.0001,
+ l1_ratio=0.15,
+ fit_intercept=True,
+ max_iter=1000,
+ tol=1e-3,
+ shuffle=True,
+ verbose=0,
+ epsilon=0.1,
+ random_state=None,
+ learning_rate="invscaling",
+ eta0=0.01,
+ power_t=0.25,
+ early_stopping=False,
+ validation_fraction=0.1,
+ n_iter_no_change=5,
+ warm_start=False,
+ average=False,
+ ):
+
+ super().__init__(n_bits=n_bits)
+
+ self.loss = loss
+ self.penalty = penalty
+ self.alpha = alpha
+ self.l1_ratio = l1_ratio
+ self.fit_intercept = fit_intercept
+ self.max_iter = max_iter
+ self.tol = tol
+ self.shuffle = shuffle
+ self.verbose = verbose
+ self.epsilon = epsilon
+ self.random_state = random_state
+ self.learning_rate = learning_rate
+ self.eta0 = eta0
+ self.power_t = power_t
+ self.early_stopping = early_stopping
+ self.validation_fraction = validation_fraction
+ self.n_iter_no_change = n_iter_no_change
+ self.warm_start = warm_start
+ self.average = average
+
+ def dump_dict(self) -> Dict[str, Any]:
+ assert self._weight_quantizer is not None, self._is_not_fitted_error_message()
+
+ metadata: Dict[str, Any] = {}
+
+ # Concrete-ML
+ metadata["n_bits"] = self.n_bits
+ metadata["sklearn_model"] = self.sklearn_model
+ metadata["_is_fitted"] = self._is_fitted
+ metadata["_is_compiled"] = self._is_compiled
+ metadata["input_quantizers"] = self.input_quantizers
+ metadata["_weight_quantizer"] = self._weight_quantizer
+ metadata["output_quantizers"] = self.output_quantizers
+ metadata["onnx_model_"] = self.onnx_model_
+ metadata["_q_weights"] = self._q_weights
+ metadata["_q_bias"] = self._q_bias
+ metadata["post_processing_params"] = self.post_processing_params
+
+ # Scikit-Learn
+ metadata["loss"] = self.loss
+ metadata["penalty"] = self.penalty
+ metadata["alpha"] = self.alpha
+ metadata["l1_ratio"] = self.l1_ratio
+ metadata["fit_intercept"] = self.fit_intercept
+ metadata["max_iter"] = self.max_iter
+ metadata["tol"] = self.tol
+ metadata["shuffle"] = self.shuffle
+ metadata["verbose"] = self.verbose
+ metadata["epsilon"] = self.epsilon
+ metadata["random_state"] = self.random_state
+ metadata["learning_rate"] = self.learning_rate
+ metadata["eta0"] = self.eta0
+ metadata["power_t"] = self.power_t
+ metadata["early_stopping"] = self.early_stopping
+ metadata["validation_fraction"] = self.validation_fraction
+ metadata["n_iter_no_change"] = self.n_iter_no_change
+ metadata["warm_start"] = self.warm_start
+ metadata["average"] = self.average
+
+ return metadata
+
+ @classmethod
+ def load_dict(cls, metadata: Dict):
+
+ # Instantiate the model
+ obj = cls(n_bits=metadata["n_bits"])
+
+ # Concrete-ML
+ obj.sklearn_model = metadata["sklearn_model"]
+ obj._is_fitted = metadata["_is_fitted"]
+ obj._is_compiled = metadata["_is_compiled"]
+ obj.input_quantizers = metadata["input_quantizers"]
+ obj.output_quantizers = metadata["output_quantizers"]
+ obj._weight_quantizer = metadata["_weight_quantizer"]
+ obj.onnx_model_ = metadata["onnx_model_"]
+ obj._q_weights = metadata["_q_weights"]
+ obj._q_bias = metadata["_q_bias"]
+ obj.post_processing_params = metadata["post_processing_params"]
+
+ obj.loss = metadata["loss"]
+ obj.penalty = metadata["penalty"]
+ obj.alpha = metadata["alpha"]
+ obj.l1_ratio = metadata["l1_ratio"]
+ obj.fit_intercept = metadata["fit_intercept"]
+ obj.max_iter = metadata["max_iter"]
+ obj.tol = metadata["tol"]
+ obj.shuffle = metadata["shuffle"]
+ obj.verbose = metadata["verbose"]
+ obj.epsilon = metadata["epsilon"]
+ obj.random_state = metadata["random_state"]
+ obj.learning_rate = metadata["learning_rate"]
+ obj.eta0 = metadata["eta0"]
+ obj.power_t = metadata["power_t"]
+ obj.early_stopping = metadata["early_stopping"]
+ obj.validation_fraction = metadata["validation_fraction"]
+ obj.n_iter_no_change = metadata["n_iter_no_change"]
+ obj.warm_start = metadata["warm_start"]
+ obj.average = metadata["average"]
+
+ return obj
+
+
class ElasticNet(SklearnLinearRegressorMixin):
"""An ElasticNet regression model with FHE.
diff --git a/tests/common/test_skearn_model_lists.py b/tests/common/test_skearn_model_lists.py
index baf6abce9..22001a35f 100644
--- a/tests/common/test_skearn_model_lists.py
+++ b/tests/common/test_skearn_model_lists.py
@@ -14,6 +14,7 @@
LinearRegression,
LogisticRegression,
Ridge,
+ SGDRegressor,
)
from concrete.ml.sklearn.neighbors import KNeighborsClassifier
from concrete.ml.sklearn.qnn import NeuralNetClassifier, NeuralNetRegressor
@@ -79,6 +80,7 @@ def test_get_sklearn_models():
LogisticRegression,
PoissonRegressor,
Ridge,
+ SGDRegressor,
TweedieRegressor,
]
@@ -101,5 +103,5 @@ def test_get_sklearn_models():
def test_models_and_datasets():
"""Check that the tested model's configuration lists remain fixed."""
- assert len(MODELS_AND_DATASETS) == 29
- assert len(UNIQUE_MODELS_AND_DATASETS) == 19
+ assert len(MODELS_AND_DATASETS) == 30
+ assert len(UNIQUE_MODELS_AND_DATASETS) == 20
diff --git a/tests/sklearn/test_common.py b/tests/sklearn/test_common.py
index 27356a682..6073774b8 100644
--- a/tests/sklearn/test_common.py
+++ b/tests/sklearn/test_common.py
@@ -26,7 +26,7 @@ def test_sklearn_args():
)
test_counter += 1
- assert test_counter == 19
+ assert test_counter == 20
@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
diff --git a/tests/sklearn/test_dump_onnx.py b/tests/sklearn/test_dump_onnx.py
index 3f72bfea6..454518d5b 100644
--- a/tests/sklearn/test_dump_onnx.py
+++ b/tests/sklearn/test_dump_onnx.py
@@ -397,6 +397,15 @@ def test_dump(
) {
%variable = Gemm[alpha = 1, beta = 1](%input_0, %_operators.0.coefficients, %_operators.0.intercepts)
return %variable
+}""",
+ "SGDRegressor": """graph torch_jit (
+ %input_0[DOUBLE, symx10]
+) initializers (
+ %_operators.0.coefficients[FLOAT, 10x1]
+ %_operators.0.intercepts[FLOAT, 1]
+) {
+ %variable = Gemm[alpha = 1, beta = 1](%input_0, %_operators.0.coefficients, %_operators.0.intercepts)
+ return %variable
}""",
"Lasso": """graph torch_jit (
%input_0[DOUBLE, symx10]
diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py
index 5936f14a5..298f9e1a2 100644
--- a/tests/sklearn/test_sklearn_models.py
+++ b/tests/sklearn/test_sklearn_models.py
@@ -194,6 +194,7 @@ def check_correctness_with_sklearn(
"XGBClassifier": 0.7,
"RandomForestClassifier": 0.8,
"KNeighborsClassifier": 0.9,
+ "SGDRegressor": 0.9,
}
model_name = get_model_name(model_class)