Skip to content

Commit

Permalink
[Benchmarking-Py] Adding TF Model - SpineNet49 Mobile
Browse files Browse the repository at this point in the history
  • Loading branch information
DEKHTIARJonathan committed Sep 22, 2022
1 parent a337598 commit 6cf16c4
Show file tree
Hide file tree
Showing 7 changed files with 439 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/#!/usr/bin/env bash
#!/usr/bin/env bash

nvidia-smi

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# --experiment_type=retinanet_mobile_coco
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
losses:
l2_weight_decay: 3.0e-05
model:
anchor:
anchor_size: 3
aspect_ratios: [0.5, 1.0, 2.0]
num_scales: 3
backbone:
spinenet_mobile:
stochastic_depth_drop_rate: 0.2
model_id: '49'
se_ratio: 0.2
type: 'spinenet_mobile'
decoder:
type: 'identity'
head:
num_convs: 4
num_filters: 48
use_separable_conv: true
input_size: [384, 384, 3]
max_level: 7
min_level: 3
norm_activation:
activation: 'swish'
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
train_data:
dtype: 'bfloat16'
global_batch_size: 256
is_training: true
parser:
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
dtype: 'bfloat16'
global_batch_size: 8
is_training: false
trainer:
checkpoint_interval: 462
optimizer_config:
learning_rate:
stepwise:
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
linear:
warmup_learning_rate: 0.0067
warmup_steps: 2000
steps_per_loop: 462
train_steps: 277200
validation_interval: 462
validation_steps: 625
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# --experiment_type=retinanet_mobile_coco
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
losses:
l2_weight_decay: 3.0e-05
model:
anchor:
anchor_size: 3
aspect_ratios: [0.5, 1.0, 2.0]
num_scales: 3
backbone:
spinenet_mobile:
stochastic_depth_drop_rate: 0.2
model_id: '49'
se_ratio: 0.2
type: 'spinenet_mobile'
decoder:
type: 'identity'
head:
num_convs: 4
num_filters: 48
use_separable_conv: true
input_size: [384, 384, 3]
max_level: 7
min_level: 3
norm_activation:
activation: 'swish'
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
train_data:
dtype: 'float32'
global_batch_size: 256
is_training: true
parser:
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
dtype: 'float32'
global_batch_size: 8
is_training: false
trainer:
checkpoint_interval: 462
optimizer_config:
learning_rate:
stepwise:
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
linear:
warmup_learning_rate: 0.0067
warmup_steps: 2000
steps_per_loop: 462
train_steps: 277200
validation_interval: 462
validation_steps: 625
145 changes: 145 additions & 0 deletions tftrt/benchmarking-python/tf_models/spinetnet49_mobile/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import math
import os
import sys

import numpy as np
import tensorflow as tf

# Allow import of top level python files
import inspect

currentdir = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
parentdir = os.path.dirname(currentdir)
parentdir = os.path.dirname(parentdir)

sys.path.insert(0, parentdir)

from benchmark_args import BaseCommandLineAPI
from benchmark_runner import BaseBenchmarkRunner


class CommandLineAPI(BaseCommandLineAPI):

def __init__(self):
super(CommandLineAPI, self).__init__()

self._parser.add_argument(
'--input_size',
type=int,
default=384,
help='Size of input images expected by the model'
)

def _validate_args(self, args):
super(CommandLineAPI, self)._validate_args(args)

# TODO: Remove when proper dataloading is implemented
if not args.use_synthetic_data:
raise ValueError(
"This benchmark does not currently support non-synthetic data "
"--use_synthetic_data"
)
# This model requires that the batch size is 1
if args.batch_size != 1:
raise ValueError(
"This benchmark does not currently support "
"--batch_size != 1"
)


class BenchmarkRunner(BaseBenchmarkRunner):

def get_dataset_batches(self):
"""Returns a list of batches of input samples.
Each batch should be in the form [x, y], where
x is a numpy array of the input samples for the batch, and
y is a numpy array of the expected model outputs for the batch
Returns:
- dataset: a TF Dataset object
- bypass_data_to_eval: any object type that will be passed unmodified to
`evaluate_result()`. If not necessary: `None`
Note: script arguments can be accessed using `self._args.attr`
"""

tf.random.set_seed(10)

inputs = tf.random.uniform(
shape=(1, self._args.input_size, self._args.input_size, 3),
maxval=255,
dtype=tf.int32
)

dataset = tf.data.Dataset.from_tensor_slices(inputs)

dataset = dataset.map(
lambda x: {"inputs": tf.cast(x, tf.uint8)}, num_parallel_calls=tf.data.AUTOTUNE
)

dataset = dataset.repeat()
dataset = dataset.batch(self._args.batch_size)

dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset, None

def preprocess_model_inputs(self, data_batch):
"""This function prepare the `data_batch` generated from the dataset.
Returns:
x: input of the model
y: data to be used for model evaluation
Note: script arguments can be accessed using `self._args.attr` """

return data_batch, None

def postprocess_model_outputs(self, predictions, expected):
"""Post process if needed the predictions and expected tensors. At the
minimum, this function transforms all TF Tensors into a numpy arrays.
Most models will not need to modify this function.
Note: script arguments can be accessed using `self._args.attr`
"""

# NOTE : DO NOT MODIFY FOR NOW => We do not measure accuracy right now

return predictions.numpy(), expected.numpy()

def evaluate_model(self, predictions, expected, bypass_data_to_eval):
"""Evaluate result predictions for entire dataset.
This computes overall accuracy, mAP, etc. Returns the
metric value and a metric_units string naming the metric.
Note: script arguments can be accessed using `self._args.attr`
"""
return None, "Raw Pitch Accuracy"


if __name__ == '__main__':

cmdline_api = CommandLineAPI()
args = cmdline_api.parse_args()

runner = BenchmarkRunner(args)
runner.execute_benchmark()
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env bash

pip install tf-models-official==2.9.2

wget https://raw.githubusercontent.com/tensorflow/models/v2.9.2/official/vision/configs/experiments/retinanet/coco_spinenet49_mobile_tpu.yaml \
-O coco_spinenet49_mobile_tpu_fp16.yaml

sed 's/bfloat16/float32/g' coco_spinenet49_mobile_tpu_fp16.yaml > coco_spinenet49_mobile_tpu_fp32.yaml

BATCH_SIZES=(
"1"
"8"
"16"
"32"
"64"
"128"
)

MODEL_DIR="/models/tf_models/spinetnet49_mobile"

for batch_size in "${BATCH_SIZES[@]}"; do

python -m official.vision.serving.export_saved_model \
--experiment="retinanet_mobile_coco" \
--checkpoint_path="${MODEL_DIR}/checkpoint/" \
--config_file="coco_spinenet49_mobile_tpu_fp32.yaml" \
--export_dir="${MODEL_DIR}/" \
--export_saved_model_subdir="saved_model_bs${batch_size}" \
--input_image_size=384,384 \
--batch_size="${batch_size}"

saved_model_cli show --dir "${MODEL_DIR}/saved_model_bs${batch_size}/" --all 2>&1 \
| tee "${MODEL_DIR}/saved_model_bs${batch_size}/analysis.txt"

done
48 changes: 48 additions & 0 deletions tftrt/benchmarking-python/tf_models/spinetnet49_mobile/run_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

SCRIPT_DIR=""

EXPERIMENT_NAME="spinetnet49_mobile"

BASE_BENCHMARK_DATA_EXPORT_DIR="/workspace/benchmark_data/${EXPERIMENT_NAME}"
rm -rf ${BASE_BENCHMARK_DATA_EXPORT_DIR}
mkdir -p ${BASE_BENCHMARK_DATA_EXPORT_DIR}

# EXPERIMENT_FLAG="--experiment_name=${EXPERIMENT_NAME} --upload_metrics_endpoint=http://10.31.241.12:5000/record_metrics/"
EXPERIMENT_FLAG=""

#########################

BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"

BENCHMARK_DATA_EXPORT_DIR="${BASE_BENCHMARK_DATA_EXPORT_DIR}/tf_models/"
mkdir -p ${BENCHMARK_DATA_EXPORT_DIR}

model_name="spinetnet49_mobile"

RUN_ARGS="${EXPERIMENT_FLAG} --data_dir=/tmp --input_saved_model_dir=/models/tf_models/${model_name}/saved_model_bs1/ "
RUN_ARGS="${RUN_ARGS} --debug --batch_size=1 --display_every=5 --use_synthetic_data --num_warmup_iterations=200 --num_iterations=500"
TF_TRT_ARGS="--use_tftrt --use_dynamic_shape --num_calib_batches=10"
TF_XLA_ARGS="--use_xla_auto_jit"

export TF_TRT_SHOW_DETAILED_REPORT=1
# export TF_TRT_BENCHMARK_EARLY_QUIT=1

MODEL_DATA_EXPORT_DIR="${BENCHMARK_DATA_EXPORT_DIR}/${model_name}"
mkdir -p ${MODEL_DATA_EXPORT_DIR}

SCRIPT_PATH="${BASE_DIR}/run_inference.sh"
METRICS_JSON_FLAG="--export_metrics_json_path=${MODEL_DATA_EXPORT_DIR}"

# TF Native
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tf_fp32.log
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tf_fp16.log

# TF-XLA manual
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} ${TF_XLA_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tfxla_fp32.log
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} ${TF_XLA_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tfxla_fp16.log

# TF-TRT
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_fp32.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_fp32.log
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_fp16.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_fp16.log
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_int8.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=INT8" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_int8.log
Loading

0 comments on commit 6cf16c4

Please sign in to comment.