-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
32 lines (25 loc) · 1.04 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from PIL import Image
from torchvision import transforms
def predict(device: torch.device, model, image_path: str) -> tuple[float, float, float, float]:
"""
Predicts the coordinates of the bounding box for an object in the given image using the trained model.
"""
# Define the transformation
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load and transform the image
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0) # type: ignore # Add batch dimension
# Move the tensor to the appropriate device
image_tensor = image_tensor.to(device)
# Disable gradient calculation for inference
with torch.no_grad():
# Get the model predictions
outputs = model(image_tensor)
# Move the outputs to the CPU and convert to numpy
outputs = outputs.cpu().squeeze().numpy()
# Return the predicted bounding box coordinates
x_min, y_min, x_max, y_max = outputs
return x_min, y_min, x_max, y_max