Skip to content

Commit

Permalink
feat: SGDClassifier training in FHE
Browse files Browse the repository at this point in the history
* Add SGDClassifier
* Add SGDClassifier FHE training
  - FHE training still has some limitations, mainly:
    - Binary classification
    - Single target
* Add option to copy inputs in QuantizedReduceSum
  - For model training in FHE it appeared the need to copy the inputs of the
    reduce-sum operation to avoid bit-width propagation. This behavior is not
    always desirable as it adds a PBS that might not be needed for all models.
    So the solution was to add a flag to do so.
* Fix: fix macos grep usage in our Makefile
  - MacOS has a different grep API than GNU.
    This crashes one of our script.
    A simple if-else and change of option name fixes that.

closes zama-ai/concrete-ml-internal#3579
closes zama-ai/concrete-ml-internal#4181
closes zama-ai/concrete-ml-internal#4112
closes zama-ai/concrete-ml-internal#3579
  • Loading branch information
fd0r committed Jan 4, 2024
1 parent e90251a commit 0893718
Show file tree
Hide file tree
Showing 25 changed files with 3,063 additions and 209 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/refresh-one-notebook.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ on:
- LinearRegression \n
- LinearSVR \n
- LogisticRegression \n
- LogisticRegressionTraining \n
- PerrorImpactOnFMNIST \n
- PoissonRegression \n
- QGPT2Evaluate \n
Expand Down Expand Up @@ -72,6 +73,7 @@ env:
LinearRegression: "docs/advanced_examples/LinearRegression.ipynb"
LinearSVR: "docs/advanced_examples/LinearSVR.ipynb"
LogisticRegression: "docs/advanced_examples/LogisticRegression.ipynb"
LogisticRegressionTraining: "docs/advanced_examples/LogisticRegressionTraining.ipynb"
PerrorImpactOnFMNIST: "use_case_examples/cifar/cifar_brevitas_finetuning/PerrorImpactOnFMNIST.ipynb"
PoissonRegression: "docs/advanced_examples/PoissonRegression.ipynb"
QGPT2Evaluate: "use_case_examples/llm/QGPT2Evaluate.ipynb"
Expand Down
454 changes: 454 additions & 0 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/developer-guide/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
- [`base.SklearnLinearClassifierMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearclassifiermixin): A Mixin class for sklearn linear classifiers with FHE.
- [`base.SklearnLinearModelMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearmodelmixin): A Mixin class for sklearn linear models with FHE.
- [`base.SklearnLinearRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearregressormixin): A Mixin class for sklearn linear regressors with FHE.
- [`base.SklearnSGDClassifierMixin`](./concrete.ml.sklearn.base.md#class-sklearnsgdclassifiermixin): A Mixin class for sklearn SGD classifiers with FHE.
- [`base.SklearnSGDRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnsgdregressormixin): A Mixin class for sklearn SGD regressors with FHE.
- [`glm.GammaRegressor`](./concrete.ml.sklearn.glm.md#class-gammaregressor): A Gamma regression model with FHE.
- [`glm.PoissonRegressor`](./concrete.ml.sklearn.glm.md#class-poissonregressor): A Poisson regression model with FHE.
Expand All @@ -197,6 +198,7 @@
- [`linear_model.LinearRegression`](./concrete.ml.sklearn.linear_model.md#class-linearregression): A linear regression model with FHE.
- [`linear_model.LogisticRegression`](./concrete.ml.sklearn.linear_model.md#class-logisticregression): A logistic regression model with FHE.
- [`linear_model.Ridge`](./concrete.ml.sklearn.linear_model.md#class-ridge): A Ridge regression model with FHE.
- [`linear_model.SGDClassifier`](./concrete.ml.sklearn.linear_model.md#class-sgdclassifier): An FHE linear classifier model fitted with stochastic gradient descent.
- [`linear_model.SGDRegressor`](./concrete.ml.sklearn.linear_model.md#class-sgdregressor): An FHE linear regression model fitted with stochastic gradient descent.
- [`neighbors.KNeighborsClassifier`](./concrete.ml.sklearn.neighbors.md#class-kneighborsclassifier): A k-nearest neighbors classifier model with FHE.
- [`qnn.NeuralNetClassifier`](./concrete.ml.sklearn.qnn.md#class-neuralnetclassifier): A Fully-Connected Neural Network classifier with FHE.
Expand Down
10 changes: 5 additions & 5 deletions docs/developer-guide/api/concrete.ml.onnx.convert.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ONNX conversion related code.

______________________________________________________________________

<a href="../../../src/concrete/ml/onnx/convert.py#L25"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/onnx/convert.py#L26"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `fuse_matmul_bias_to_gemm`

Expand All @@ -33,7 +33,7 @@ Fuse sequence of matmul -> add into a gemm node.

______________________________________________________________________

<a href="../../../src/concrete/ml/onnx/convert.py#L116"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/onnx/convert.py#L117"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_equivalent_numpy_forward_from_torch`

Expand All @@ -59,7 +59,7 @@ Get the numpy equivalent forward of the provided torch Module.

______________________________________________________________________

<a href="../../../src/concrete/ml/onnx/convert.py#L165"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/onnx/convert.py#L168"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `preprocess_onnx_model`

Expand All @@ -84,7 +84,7 @@ Get the numpy equivalent forward of the provided ONNX model.

______________________________________________________________________

<a href="../../../src/concrete/ml/onnx/convert.py#L227"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/onnx/convert.py#L230"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_equivalent_numpy_forward_from_onnx`

Expand All @@ -108,7 +108,7 @@ Get the numpy equivalent forward of the provided ONNX model.

______________________________________________________________________

<a href="../../../src/concrete/ml/onnx/convert.py#L252"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/onnx/convert.py#L255"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_equivalent_numpy_forward_from_onnx_tree`

Expand Down
2 changes: 1 addition & 1 deletion docs/developer-guide/api/concrete.ml.quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Modules for quantization.

- **quantizers**
- **base_quantized_op**
- **quantized_module**
- **quantized_ops**
- **quantized_module**
- **quantized_module_passes**
- **post_training**
- **qat_quantizers**
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ QuantizedModule API.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L81"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L83"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>class</kbd> `QuantizedModule`

Inference for a quantized model.

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L91"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L93"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `__init__`

Expand Down Expand Up @@ -67,7 +67,7 @@ Get the post-processing parameters.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L725"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L739"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `bitwidth_and_range_report`

Expand All @@ -83,7 +83,7 @@ Report the ranges and bit-widths for layers that mix encrypted integer values.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L193"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L207"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `check_model_is_compiled`

Expand All @@ -99,7 +99,7 @@ Check if the quantized module is compiled.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L596"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L610"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `compile`

Expand Down Expand Up @@ -139,7 +139,7 @@ Compile the module's forward function.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L552"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L566"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `dequantize_output`

Expand All @@ -159,7 +159,7 @@ Take the last layer q_out and use its de-quant function.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L176"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L190"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `dump`

Expand All @@ -175,7 +175,7 @@ Dump itself to a file.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L123"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L137"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `dump_dict`

Expand All @@ -191,7 +191,7 @@ Dump itself to a dict.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L168"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L182"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `dumps`

Expand All @@ -207,7 +207,7 @@ Dump itself to a string.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L284"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L298"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `forward`

Expand Down Expand Up @@ -235,7 +235,7 @@ This method executes the forward pass in the clear, with simulation or in FHE. I

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L142"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L156"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `load_dict`

Expand All @@ -255,7 +255,7 @@ Load itself from a string.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L224"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L238"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `post_processing`

Expand All @@ -277,7 +277,7 @@ For quantized modules, there is no post-processing step but the method is kept t

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L527"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L541"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `quantize_input`

Expand All @@ -297,7 +297,7 @@ Take the inputs in fp32 and quantize it using the learned quantization parameter

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L358"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L372"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `quantized_forward`

Expand All @@ -321,7 +321,7 @@ Forward function for the FHE circuit.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L578"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L592"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `set_inputs_quantization_parameters`

Expand All @@ -334,3 +334,17 @@ Set the quantization parameters for the module's inputs.
**Args:**

- <b>`*input_q_params (UniformQuantizer)`</b>: The quantizer(s) for the module.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L126"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `set_reduce_sum_copy`

```python
set_reduce_sum_copy()
```

Set reduce sum to copy or not the inputs.

Due to bit-width propagation in the compilation we might or not want to copy the inputs with a PBS to avoid it.
Loading

0 comments on commit 0893718

Please sign in to comment.