-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbinary_classifier_client.py
152 lines (124 loc) · 5.25 KB
/
binary_classifier_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import sys
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException
if sys.version_info >= (3, 0):
import queue
else:
import Queue as queue
def parse_model(model_metadata, model_config):
input_metadata = model_metadata.inputs[0]
input_config = model_config.input[0]
output_metadata = model_metadata.outputs[0]
if output_metadata.datatype != "FP32":
raise Exception("expecting output datatype to be FP32, model '" +
model_metadata.name + "' output type is " +
output_metadata.datatype)
output_batch_dim = (model_config.max_batch_size > 0)
non_one_cnt = 0
for dim in output_metadata.shape:
if output_batch_dim:
output_batch_dim = False
elif dim > 1:
non_one_cnt += 1
if non_one_cnt > 1:
raise Exception("expecting model output to be a vector")
# Model input must have 1 dims
input_batch_dim = (model_config.max_batch_size > 0)
n_features = input_metadata.shape[1 if input_batch_dim else 0]
return (model_config.max_batch_size, input_metadata.name,
output_metadata.name, n_features, input_config.format,
input_metadata.datatype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m',
'--model-name',
type=str,
required=True,
help='Name of model')
parser.add_argument(
'-x',
'--model-version',
type=str,
required=False,
default="",
help='Version of model. Default is to use latest version.')
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-b',
'--batch-size',
type=int,
required=False,
default=1,
help='Batch size. Default is 1.')
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='agc-triton-inference-server:8001',
help='Inference server URL. Default is agc-triton-inference-server:8001.')
parser.add_argument('-n',
'--num-batches',
type=int,
default=1,
help='Number of batches to send to inference server.')
parser.add_argument('test_events',
type=str,
default=None,
help='Input csv file path containing lines with events to infer.')
FLAGS = parser.parse_args()
try:
# Create gRPC client for communicating with the server
triton_client = grpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose)
except Exception as e:
print("client creation failed: " + str(e))
sys.exit(1)
# # Make sure the model matches our requirements, and get some
# # properties of the model that we need for preprocessing
try:
print("model_name = ", FLAGS.model_name)
print("model_version = ", FLAGS.model_version)
model_metadata = triton_client.get_model_metadata(
model_name=FLAGS.model_name,
model_version=FLAGS.model_version)
except InferenceServerException as e:
print("failed to retrieve the metadata: " + str(e))
sys.exit(1)
try:
model_config = triton_client.get_model_config(
model_name=FLAGS.model_name, model_version=FLAGS.model_version)
except InferenceServerException as e:
print("failed to retrieve the config: " + str(e))
sys.exit(1)
model_config = model_config.config
max_batch_size, input_name, output_name, n_features, format, dtype = parse_model(
model_metadata, model_config)
supports_batching = max_batch_size > 0
if not supports_batching and FLAGS.batch_size != 1:
print("ERROR: This model doesn't support batching.")
sys.exit(1)
# load csv into np array
test_events = FLAGS.test_events
data = np.loadtxt(test_events, dtype=np.float32, delimiter=',')
print(data.shape)
# batch data
data_length = data.shape[0]
num_batches = int(np.ceil(data_length/FLAGS.batch_size))
startind = 0
for i in range(FLAGS.num_batches):
data_current = data[startind:startind+FLAGS.batch_size,:]
startind+=FLAGS.batch_size
client = grpcclient
inpt = [client.InferInput(input_name, data_current.shape, dtype)]
inpt[0].set_data_from_numpy(data_current)
output = client.InferRequestedOutput(output_name)
results = triton_client.infer(model_name=FLAGS.model_name,
inputs=inpt,
outputs=[output])
inference_output = results.as_numpy(output_name)
print(np.round(inference_output))