Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Sep 24, 2024
1 parent 965eaec commit 5bd596c
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions tests/test_onnx_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import pytest
import timm
import torch
from torch.utils.data import RandomSampler
import torchvision.models as models

FUSE = True
PRETRAINED = False
MEMORY_LIMIT_GB = 0.75 # User's memory limit
MEMORY_PER_PARAM = 4e-9 # Approximate memory required per parameter in GB

os.makedirs("tmp", exist_ok=True)

Expand Down Expand Up @@ -47,14 +50,32 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)):


class TestTimmClass:
@pytest.fixture(params=[model for model in timm.list_models() if 'huge' not in model and 'giant' not in model])
@pytest.fixture(params=timm.list_models())
def model_name(self, request):
"""Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization."""
yield request.param

skip_keywords = ["enormous", "giant", "huge", "xlarge"]

def test_timm(self, request, model_name):
"""Tests a TIMM model's forward pass with a random input tensor of the appropriate size."""
model = timm.create_model(model_name, pretrained=PRETRAINED)
if any(keyword in model_name.lower() for keyword in self.skip_keywords):
pytest.skip(f"Skipping model due to size keyword in name: {model_name}")

try:
model = timm.create_model(model_name, pretrained=PRETRAINED)
except RuntimeError as e:
if "out of memory" in str(e):
pytest.skip(f"Skipping model {model_name} due to memory error.")

num_params = sum(p.numel() for p in model.parameters())

# Calculate estimated memory requirement
estimated_memory = num_params * MEMORY_PER_PARAM

if estimated_memory > MEMORY_LIMIT_GB:
pytest.skip(f"Skipping model {model_name}: estimated memory {estimated_memory:.2f} GB exceeds limit.")

input_size = model.default_cfg.get("input_size")
x = torch.randn((1,) + input_size)
directory = f"tmp/{request.node.name}"
Expand Down

0 comments on commit 5bd596c

Please sign in to comment.