Skip to content

Commit

Permalink
Merge pull request #945 from AI-Hypercomputer:msingh-trillium
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682461471
  • Loading branch information
maxtext authors committed Oct 4, 2024
2 parents 30dce58 + 4a50ccf commit d847f6c
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 66 deletions.
6 changes: 6 additions & 0 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ then
BATCH_AND_PREFILL_LEN="256,216|512,108|1024,54"
fi

if [ -z "$TOK_OUTLEN_MULTIPLIER"];
then
TOK_OUTLEN_MULTIPLIER="3.0"
fi

if [ -z "$MAXENGINE_ARGS"];
then
CHECKPOINT="gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8_"
Expand Down Expand Up @@ -111,6 +116,7 @@ run_loadgen() {
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} \
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
${SKIP_WARMUP_OPTION} ${PROFILER_OPTION} 2>&1 | tee ${OUTPUT_LOG_DIR}/${LOADGEN_RUN_TYPE}_log.log

}
Expand Down
161 changes: 95 additions & 66 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@
required=False,
)

flags.DEFINE_float(
"tok_outlen_multiplier",
3.0,
"Multiplier for estimating max predicted output len",
required=False,
)

flags.DEFINE_bool(
"allow_skipping_queries",
False,
"Allow skipping queries which have target len greater than 2x configured max prefill len",
required=False,
)

scenario_map = {
"offline": lg.TestScenario.Offline,
"server": lg.TestScenario.Server,
Expand All @@ -167,7 +181,13 @@ def pad_tokens(tokens):


def _init_query_batches():
return [[], [], []]
query_batches = {}
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes.split("|")
len_batch = []
for lb in len_batch_str:
l, b = lb.split(",")
query_batches[(int(l), int(b))] = []
return query_batches


@contextlib.contextmanager
Expand All @@ -179,21 +199,25 @@ def timed(msg):
log.info(msg + " done: " + str(end - start))


def _classify_query(dataset_rows, index):
# return grouped indexes
def _classify_query(dataset_rows, index, query_batches):
sample = dataset_rows[index][1]
input_len = sample.tok_input_length
total_len = sample.tok_input_length + 3 * sample.tok_output_length
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes
target_inputs = [int(lb.split(",")[0]) for lb in len_batch_str.split("|")]
total_len = int(sample.tok_input_length + FLAGS.tok_outlen_multiplier * sample.tok_output_length)
query_batch_keys = list(query_batches.keys())
query_batch_keys.sort()
target_inputs = [lb[0] for lb in query_batch_keys]
target_totals = [2 * inp for inp in target_inputs]

if total_len <= target_totals[0] and input_len <= target_inputs[0]:
return 0
elif total_len <= target_totals[1] and input_len <= target_inputs[1]:
return 1
else:
return 2
for i in range(len(target_inputs)):
if total_len <= target_totals[i] and input_len <= target_inputs[i]:
log.debug(f"Added sample of input length {input_len} total_len {total_len} for {query_batch_keys[i]}")
return query_batch_keys[i]
if input_len <= target_inputs[i]:
log.debug(f"Added sample of input length {input_len} total_len {total_len} for {query_batch_keys[i]}")
return query_batch_keys[i]
if not FLAGS.allow_skipping_queries:
assert False, f"Invalid query input_len {input_len} > max prefill_len configured {query_batch_keys[-1]}."
return -1


def _pick_batch_size(num_samples, max_batch, dataset_size, sample_size):
Expand All @@ -217,10 +241,12 @@ def get_warmup_samples(dataset):
jax.block_until_ready(data.tokens)
sample_id_to_input = input_data
for sample_id in range(len(input_data)):
group = _classify_query(pandas_rows, sample_id)
group_idx = _classify_query(pandas_rows, sample_id, query_batches)
if group_idx == -1:
continue
input_ = copy.copy(sample_id_to_input[sample_id])
input_.id = sample_id
query_batches[group].append(input_)
query_batches[group_idx].append(input_)

interesting_buckets = [
0,
Expand All @@ -234,22 +260,25 @@ def get_warmup_samples(dataset):
]
warmup_samples = _init_query_batches()

for group_idx, group in enumerate(query_batches):
for start, end in zip(interesting_buckets[: group_idx - 3], interesting_buckets[1 : group_idx - 2]):
for sample in group:
for group_idx in query_batches:
prefill_len = group_idx[0]
idx = int(math.log2(prefill_len)) - 3
for start, end in zip(interesting_buckets[:idx], interesting_buckets[1 : (idx + 1)]):
log.debug(f"idx:{group_idx} start:{start} end:{end}")
for sample in query_batches[group_idx]:
if start < sample.true_length <= end:
warmup_samples[group_idx].append(sample)
log.info(f"Added sample of length {sample.true_length} for ({start}, {end}) bucket for group {group_idx}")
log.debug(f"Added warmup sample of length {sample.true_length} for ({start}, {end}) bucket for group {group_idx}")
break
warmup_samples[group_idx].extend(query_batches[group_idx][:50])
return warmup_samples


class SUT:

def __init__(self, data, offline_inf):
# dict of int (cache length) -> offline_inf
self.offline_inf = offline_inf
def __init__(self, data, offline_inf_instances):
# dict of int (cache length) -> offline_inf_instances
self.offline_inf_instances = offline_inf_instances

# pandas dataframe, it has tok
self._dataset = data
Expand All @@ -260,7 +289,6 @@ def __init__(self, data, offline_inf):
# index to loaded data
self._processed_data = None

# self.replicated = self.offline_inf.engine.env.sharding_by_axis(-1)
self._sample_id_to_input = None
self._query_batches = _init_query_batches()

Expand All @@ -271,18 +299,23 @@ def issue_queries(self, queries):
self._queries = queries

num_queries = len(self._queries)
num_grouped_queries = [len(q) for q in self._query_batches]
num_skipped_queries = 0
num_grouped_queries = [len(self._query_batches[b]) for b in self._query_batches]
log.info(f"Before Issue {num_queries} queries - classified queries {num_grouped_queries}")
self._query_batches = _init_query_batches()
for q in queries:
group = _classify_query(self.pandas_rows, q.index)
input_data = copy.copy(self._sample_id_to_input[q.index])
input_data.id = q.id
self._query_batches[group].append(input_data)

num_grouped_queries = [len(q) for q in self._query_batches]
log.info(f"Issue {num_queries} queries - classified queries {num_grouped_queries}")
assert len(self._queries) == sum(
group_idx = _classify_query(self.pandas_rows, q.index, self._query_batches)
if group_idx == -1:
num_skipped_queries += 1
log.debug("Filtering out query of input len larger than acceptable configuration")
else:
input_data = copy.copy(self._sample_id_to_input[q.index])
input_data.id = q.id
self._query_batches[group_idx].append(input_data)
num_grouped_queries = [len(self._query_batches[b]) for b in self._query_batches]
log.info(f"Issue {num_queries} queries - classified queries {num_grouped_queries} num_skipped {num_skipped_queries}")

assert len(self._queries) - num_skipped_queries == sum(
num_grouped_queries
), f"num_queries {num_queries} does not match num_grouped_queries {num_grouped_queries}"
# At this point _processed_data is ready
Expand All @@ -292,11 +325,12 @@ def issue_queries(self, queries):
def flush_queries(self):
log.info("Flush queries start")
start = time.perf_counter()
for group_idx, group in enumerate(self._query_batches):
for group_idx in self._query_batches:
group = self._query_batches[group_idx]
log.info(f"Flush queries processing {group_idx} with {len(group)} samples")
self.offline_inf[group_idx].init_decode_state()
result = self.offline_inf[group_idx].batch_inference(group, desc=f"batch-{group_idx}")
self.offline_inf[group_idx].decode_state = None
self.offline_inf_instances[group_idx].init_decode_state()
result = self.offline_inf_instances[group_idx].batch_inference(group, desc=f"batch-{group_idx}")
self.offline_inf_instances[group_idx].decode_state = None
gc.collect()
for key, val in result.items():
key = int(key)
Expand Down Expand Up @@ -329,7 +363,7 @@ def LoadSamplesToRam(self, sample_list):
log.info(f"LoadSamplesToRam finished: {end - start}s")

def UnloadSamplesFromRam(self, sample_list):
print("UnloadSamplesFromRam called")
log.info("UnloadSamplesFromRam called")
pass


Expand All @@ -344,22 +378,18 @@ def make_response(id_, response_token_ids):
return query_sample_response


def _count_by_bucket(dataset):

def _estimated_counts_by_bucket(dataset):
total_len = dataset.tok_input_length + dataset.tok_output_length

group1 = (total_len <= 512) & (dataset.tok_input_length <= 256)
group2 = (total_len <= 1024) & (dataset.tok_input_length <= 512)

# with 5 percent extra
mult = FLAGS.total_sample_count / len(dataset) * 1.05

counts = [
math.ceil(len(dataset[group1]) * mult),
math.ceil(len(dataset[~group1 & group2]) * mult),
math.ceil(len(dataset[~group1 & ~group2]) * mult),
]
return counts
estimates = {}
estimates["<256"] = math.ceil(len(dataset[group1]) * mult)
estimates["256-512"] = math.ceil(len(dataset[~group1 & group2]) * mult)
estimates[">512"] = math.ceil(len(dataset[~group1 & ~group2]) * mult)
return estimates


def main(argv):
Expand All @@ -383,48 +413,47 @@ def main(argv):
log.info("dataset path: %s", FLAGS.dataset_path)
dataset = pd.read_pickle(FLAGS.dataset_path)
rows = list(dataset.iterrows())
counts_by_bucket = _count_by_bucket(dataset)
log.info(f"Counts by bucket {counts_by_bucket}")
estimated_counts_by_bucket = _estimated_counts_by_bucket(dataset)
log.info(f"Estimated counts by bucket {estimated_counts_by_bucket}")
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes
log.info(f"Prefill lengths and Batch sizes: {len_batch_str}")
log.info(f"Maxengine args: {FLAGS.maxengine_args}")
length_and_batch = [tuple(map(int, lb.split(","))) for lb in len_batch_str.split("|")]
engines = []

log.info("Get warmup samples")
warmup_samples = get_warmup_samples(dataset)
offline_inf_instances = {}
query_batches = _init_query_batches()
params = None
base_engine = None
for i, (length, max_batch) in enumerate(length_and_batch):
batch = counts_by_bucket[i]
# Create an engine and corresponding offline_inf_instance per batch of queries
for group_idx in query_batches:
(length, batch) = group_idx
target_length = 2 * length
log.info(f"Using batch size: {max_batch} and length: {length}")
log.info(f"Using batch size: {batch} and length: {length}")
engine = create_engine_from_config_flags(
batch_size=max_batch,
batch_size=batch,
max_prefill_predict_length=length,
max_target_length=target_length,
args_str=FLAGS.maxengine_args,
)
offline_inf = offline_inference.OfflineInference(engine, params, base_engine)
if params is None and offline_inf.params is not None:
base_engine = engine
# offline_inf.dummy = True
params = offline_inf.params
engines.append(offline_inf)
offline_inf_instances[group_idx] = offline_inf

warmup_samples = None
if not FLAGS.skip_warmup:
print("Get warmup samples")
warmup_samples = get_warmup_samples(dataset)

with timed("warmup"):
warmup_grp = 0
for (length, _), engine in zip(length_and_batch, engines):
for group_idx in offline_inf_instances:
(length, batch) = group_idx
log.info(f"warm up for {length}")
engine.init_decode_state()
engine.warmup(length, warmup_samples[warmup_grp])
engine.decode_state = None # drop state
offline_inf_instances[group_idx].init_decode_state()
offline_inf_instances[group_idx].warmup(length, warmup_samples[group_idx])
offline_inf_instances[group_idx].decode_state = None # drop state
gc.collect()
warmup_grp += 1

sut = SUT(dataset, engines)
sut = SUT(dataset, offline_inf_instances)

if FLAGS.mlperf_test_mode == "accuracy":
settings.mode = lg.TestMode.AccuracyOnly
Expand Down
18 changes: 18 additions & 0 deletions MaxText/inference_mlperf/user.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# The format of this config file is 'key = value'.
# The key has the format 'model.scenario.key'. Value is mostly int64_t.
# Model maybe '*' as wildcard. In that case the value applies to all models.
# All times are in milli seconds

# Set performance_sample_count for each model.
llama2-70b.*.performance_sample_count_override = 24576
*.Offline.min_duration = 600000


# In Offline scenario, we always have one query. But LoadGen maps this to
# min_sample_count internally in Offline scenario. If the dataset size is larger
# than 24576 we limit the min_query_count to 24576 and otherwise we use
# the dataset size as the limit
llama2-70b.Offline.min_query_count = 24576

# These fields should be defined and overridden by user.conf.
*.Offline.target_qps = 5.0

0 comments on commit d847f6c

Please sign in to comment.