Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
Signed-off-by: manickavela29 <[email protected]>
  • Loading branch information
manickavela29 committed Jun 27, 2024
1 parent f657c36 commit fa235ad
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def export_encoder_model_onnx(
encoder_filename: str,
opset_version: int = 11,
feature_dim: int = 80,
fp16: bool = False,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
Expand Down Expand Up @@ -489,12 +488,6 @@ def build_inputs_outputs(tensors, i):

add_meta_data(filename=encoder_filename, meta_data=meta_data)

if(fp16) :
logging.info("Exporting Encoder model in fp16")
encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
onnx.save(encoder_fp16,encoder_filename)

def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
Expand Down Expand Up @@ -741,7 +734,6 @@ def main():
encoder_filename,
opset_version=opset_version,
feature_dim=params.feature_dim,
fp16=params.fp16,
)
logging.info(f"Exported encoder to {encoder_filename}")

Expand All @@ -766,8 +758,27 @@ def main():
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

if(params.fp16) :
logging.info("Exporting models in fp16")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
Expand Down

0 comments on commit fa235ad

Please sign in to comment.