Skip to content

Commit

Permalink
feat(frontend): use shapes in tfhers-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Dec 18, 2024
1 parent b64e6f7 commit b20e5ae
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion frontends/concrete-python/examples/tfhers-ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaint
We need to dequantize integer outputs using a pre-built quantizer for our ML model

```sh
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --config ./output_quantizer.json
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json
```

## Compute error
Expand Down
32 changes: 30 additions & 2 deletions frontends/concrete-python/tests/tfhers-utils/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ fn main() {
.value_delimiter(',')
.num_args(1..),
)
.arg(
Arg::new("shape")
.short('s')
.long("shape")
.help("shape of values")
.action(ArgAction::Set)
.required(false)
.value_delimiter(',')
.num_args(0..),
)
.arg(
Arg::new("output")
.long("output")
Expand Down Expand Up @@ -476,6 +486,16 @@ fn main() {
.value_delimiter(',')
.num_args(1..),
)
.arg(
Arg::new("shape")
.short('s')
.long("shape")
.help("shape of values")
.action(ArgAction::Set)
.required(false)
.value_delimiter(',')
.num_args(0..),
)
.arg(
Arg::new("output")
.long("output")
Expand Down Expand Up @@ -574,13 +594,17 @@ fn main() {
.get_many::<String>("value")
.unwrap()
.collect();
let shapes: Vec<usize> = match quantize_matches.get_many::<String>("shape") {
Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(),
None => vec![value_str.len()],
};
let config_path = quantize_matches.get_one::<String>("config").unwrap();
let output_path = quantize_matches.get_one::<String>("output");

let quantizer = Quantizer::from_json_file(config_path).unwrap();
let value: Vec<f64> = value_str.iter().map(|v| v.parse().unwrap()).collect();
let quantized_array = quantizer.quantize(
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(),
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(),
);
let quantized_values: Vec<&i64> = quantized_array.iter().collect();
let results_str: Vec<String> = quantized_values.iter().map(|v| v.to_string()).collect();
Expand All @@ -597,13 +621,17 @@ fn main() {
.get_many::<String>("value")
.unwrap()
.collect();
let shapes: Vec<usize> = match dequantize_matches.get_many::<String>("shape") {
Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(),
None => vec![value_str.len()],
};
let config_path = dequantize_matches.get_one::<String>("config").unwrap();
let output_path = dequantize_matches.get_one::<String>("output");

let quantizer = Quantizer::from_json_file(config_path).unwrap();
let value: Vec<i64> = value_str.iter().map(|v| v.parse().unwrap()).collect();
let dequantized_array = quantizer.dequantize(
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(),
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(),
);
let dequantized_values: Vec<&f64> = dequantized_array.iter().collect();
let results_str: Vec<String> =
Expand Down

0 comments on commit b20e5ae

Please sign in to comment.