Skip to content

Commit

Permalink
Fixes (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Dec 9, 2024
1 parent 4a05fc1 commit 9c7f088
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 85 deletions.
40 changes: 32 additions & 8 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __getitem__(self, index) -> float:
def __sub__(self, latency: "Latency") -> "Latency":
latencies = [lat - latency.mean for lat in self.values]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(values=latencies, unit=self.unit)

Expand Down Expand Up @@ -275,7 +278,10 @@ def get_latency(self) -> Latency:
(end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand Down Expand Up @@ -386,7 +392,10 @@ def get_prefill_latency(self) -> Latency:
for start_event, end_event in zip(self.prefill_start_events, self.prefill_end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand All @@ -406,7 +415,10 @@ def get_decode_latency(self) -> Latency:
for start_event, end_event in zip(self.decode_start_events, self.decode_end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand All @@ -426,7 +438,10 @@ def get_per_token_latency(self) -> Latency:
for start_event, end_event in zip(self.per_token_start_events, self.per_token_end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand Down Expand Up @@ -525,7 +540,10 @@ def get_step_latency(self) -> Latency:
for start_event, end_event in zip(self.per_step_start_events, self.per_step_end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand All @@ -545,7 +563,10 @@ def get_call_latency(self) -> Latency:
for start_event, end_event in zip(self.call_start_events, self.call_end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)

Expand Down Expand Up @@ -597,6 +618,9 @@ def get_latency(self) -> Latency:
(end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events)
]

assert all(latency >= 0 for latency in latencies)
assert all(latency >= 0 for latency in latencies), (
"Found some negative latencies while performing substraction. "
"Please increase the dimensions of your benchmark or the number of warmup runs."
)

return Latency.from_values(latencies, unit=LATENCY_UNIT)
4 changes: 2 additions & 2 deletions tests/configs/_inference_.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ scenario:
sequence_length: 16

generate_kwargs:
max_new_tokens: 4
min_new_tokens: 4
max_new_tokens: 16
min_new_tokens: 16

call_kwargs:
num_inference_steps: 4
76 changes: 4 additions & 72 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,17 @@
import torch

from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig, TrainingConfig
from optimum_benchmark.backends.diffusers_utils import (
extract_diffusers_shapes_from_model,
get_diffusers_pretrained_config,
)
from optimum_benchmark.backends.timm_utils import extract_timm_shapes_from_config, get_timm_pretrained_config
from optimum_benchmark.backends.transformers_utils import (
extract_transformers_shapes_from_artifacts,
get_transformers_pretrained_config,
get_transformers_pretrained_processor,
)
from optimum_benchmark.generators.dataset_generator import DatasetGenerator
from optimum_benchmark.generators.input_generator import InputGenerator
from optimum_benchmark.import_utils import get_git_revision_hash
from optimum_benchmark.system_utils import is_nvidia_system, is_rocm_system
from optimum_benchmark.trackers import LatencySessionTracker, MemoryTracker

PUSH_REPO_ID = os.environ.get("PUSH_REPO_ID", "optimum-benchmark/local")

LIBRARIES_TASKS_MODELS = [
("timm", "image-classification", "timm/resnet50.a1_in1k"),
("transformers", "text-generation", "openai-community/gpt2"),
("transformers", "fill-mask", "google-bert/bert-base-uncased"),
("transformers", "multiple-choice", "FacebookAI/roberta-base"),
("transformers", "text-classification", "FacebookAI/roberta-base"),
("transformers", "token-classification", "microsoft/deberta-v3-base"),
("transformers", "image-classification", "google/vit-base-patch16-224"),
("diffusers", "text-to-image", "CompVis/stable-diffusion-v1-4"),
("timm", "image-classification", "timm/tiny_vit_21m_224.in1k"),
("transformers", "fill-mask", "hf-internal-testing/tiny-random-BertModel"),
("transformers", "text-generation", "hf-internal-testing/tiny-random-LlamaForCausalLM"),
("diffusers", "text-to-image", "hf-internal-testing/tiny-stable-diffusion-torch"),
]

INPUT_SHAPES = {
Expand Down Expand Up @@ -177,58 +161,6 @@ def test_api_push_to_hub_mixin():
assert from_hub_artifact.to_dict() == artifact.to_dict()


@pytest.mark.parametrize("library,task,model", LIBRARIES_TASKS_MODELS)
def test_api_input_generator(library, task, model):
if library == "transformers":
model_config = get_transformers_pretrained_config(model)
model_processor = get_transformers_pretrained_processor(model)
model_shapes = extract_transformers_shapes_from_artifacts(model_config, model_processor)
elif library == "timm":
model_config = get_timm_pretrained_config(model)
model_shapes = extract_timm_shapes_from_config(model_config)
elif library == "diffusers":
model_config = get_diffusers_pretrained_config(model)
model_shapes = extract_diffusers_shapes_from_model(model)
else:
raise ValueError(f"Unknown library {library}")

input_generator = InputGenerator(
task=task,
input_shapes=INPUT_SHAPES,
model_shapes=model_shapes,
)
generated_inputs = input_generator()

assert len(generated_inputs) > 0, "No inputs were generated"

for key in generated_inputs:
assert len(generated_inputs[key]) == INPUT_SHAPES["batch_size"], "Incorrect batch size"


@pytest.mark.parametrize("library,task,model", LIBRARIES_TASKS_MODELS)
def test_api_dataset_generator(library, task, model):
if library == "transformers":
model_config = get_transformers_pretrained_config(model=model)
model_shapes = extract_transformers_shapes_from_artifacts(config=model_config)
elif library == "timm":
model_config = get_timm_pretrained_config(model)
model_shapes = extract_timm_shapes_from_config(config=model_config)
elif library == "diffusers":
model_config = get_diffusers_pretrained_config(model)
model_shapes = extract_diffusers_shapes_from_model(model)
else:
raise ValueError(f"Unknown library {library}")

if task == "multiple-choice":
DATASET_SHAPES["num_choices"] = 2

generator = DatasetGenerator(task=task, dataset_shapes=DATASET_SHAPES, model_shapes=model_shapes)
generated_dataset = generator()

assert len(generated_dataset) > 0, "No dataset was generated"
assert len(generated_dataset) == DATASET_SHAPES["dataset_size"]


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("backend", ["pytorch", "other"])
def test_api_latency_tracker(device, backend):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_cli_exit_code_0(launcher):
"name=test",
"launcher=" + launcher,
# compatible task and model
"backend.model=google-bert/bert-base-uncased",
"backend.task=text-classification",
"backend.model=bert-base-uncased",
"backend.device=cpu",
# input shapes
"+scenario.input_shapes.batch_size=1",
Expand All @@ -91,8 +91,8 @@ def test_cli_exit_code_1(launcher):
"name=test",
"launcher=" + launcher,
# incompatible task and model to trigger an error
"backend.model=google-bert/bert-base-uncased",
"backend.task=image-classification",
"backend.model=bert-base-uncased",
"backend.device=cpu",
# input shapes
"+scenario.input_shapes.batch_size=1",
Expand All @@ -117,8 +117,8 @@ def test_cli_numactl(launcher):
"name=test",
"launcher=" + launcher,
"launcher.numactl=True",
"backend.model=google-bert/bert-base-uncased",
"backend.task=text-classification",
"backend.model=bert-base-uncased",
"backend.device=cpu",
# input shapes
"+scenario.input_shapes.batch_size=1",
Expand Down

0 comments on commit 9c7f088

Please sign in to comment.