Skip to content

Commit

Permalink
Load sam inputs from S3
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 25, 2024
1 parent 83fb47e commit 615ac43
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchbenchmark/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, test, device, batch_size=1, extra_args=[]):
self.model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.model.to(device=device)
data_folder = os.path.dirname(DATA_PATH)
image_path = os.path.join(data_folder, "data", "sam", "truck.jpg")
image_path = os.path.join(data_folder, "data", ".data", "sam_inputs", "truck.jpg")
if not os.path.exists(image_path):
from torchbenchmark.util.framework.fb.installer import install_data

Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/sam/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def download_checkpoint():

def download_data():
s3_utils.checkout_s3_data(
"INPUT_TARBALLS", "sam.tar.gz", decompress=True
"INPUT_TARBALLS", "sam_inputs.tar.gz", decompress=True
)


Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/sam_fast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, test, device, batch_size=1, extra_args=[]):
self.model = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
self.model.to(device=device)
data_folder = os.path.dirname(DATA_PATH)
image_path = os.path.join(data_folder, "data", "sam", "truck.jpg")
image_path = os.path.join(data_folder, "data", ".data", "sam_inputs", "truck.jpg")
assert os.path.exists(image_path), f"Expected image file exists at {image_path} but not found."
self.image = cv2.imread(image_path)
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/sam_fast/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def download_checkpoint():

def download_data():
s3_utils.checkout_s3_data(
"INPUT_TARBALLS", "sam.tar.gz", decompress=True
"INPUT_TARBALLS", "sam_inputs.tar.gz", decompress=True
)


Expand Down

0 comments on commit 615ac43

Please sign in to comment.