Skip to content

Commit

Permalink
I would like to slowly move towards downloading weights automatically…
Browse files Browse the repository at this point in the history
… in python.
  • Loading branch information
pfeatherstone committed Aug 30, 2024
1 parent c8c9b92 commit 183b751
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
5 changes: 0 additions & 5 deletions download_weights.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
mkdir weights
wget https://pjreddie.com/media/files/yolov3-tiny.weights -P weights
wget https://pjreddie.com/media/files/yolov3.weights -P weights
wget https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4.weights -P weights
wget https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights -P weights
wget https://github.com/ultralytics/yolov3/releases/download/v8/yolov3-spp.weights -P weights
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5nu.pt -P weights
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5su.pt -P weights
wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5mu.pt -P weights
Expand Down
21 changes: 16 additions & 5 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import time
import os
import torch
import torchvision
import torchvision.transforms.functional as vF
import matplotlib.pyplot as plt
from models import *

weight_paths = {
'yolov3-tiny' : 'https://pjreddie.com/media/files/yolov3-tiny.weights',
'yolov3' : 'https://pjreddie.com/media/files/yolov3.weights',
'yolov3-spp' : 'https://github.com/ultralytics/yolov3/releases/download/v8/yolov3-spp.weights',
'yolov4' : 'https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4.weights',
'yolov4-tiny' : 'https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights',
}

torch.set_printoptions(5)
def download_if_not_exist(model_type: str, filepath: str):
if not os.path.exists(filepath):
torch.hub.download_url_to_file(weight_paths[model_type], filepath)


def load_from_darknet(net: Union[Yolov3, Yolov3Tiny, Yolov4, Yolov4Tiny], weights_path: str):
Expand Down Expand Up @@ -126,6 +133,8 @@ def params2():


def test(type: str, size: str = ''):
os.makedirs('../weights', exist_ok=True)

match type:
case 'yolov3' : net = Yolov3(80, False).eval()
case 'yolov3-spp': net = Yolov3(80, True).eval()
Expand All @@ -147,7 +156,9 @@ def test(type: str, size: str = ''):
has_obj = False

elif 'yolov3' in type or 'yolov4' in type :
load_from_darknet(net, '../weights/{}.weights'.format(type))
filepath = '../weights/{}.weights'.format(type)
download_if_not_exist(type, filepath)
load_from_darknet(net, filepath)

elif type == 'yolov6':
load_from_yolov6_official(net, "../weights/yolov6{}.pt".format(size))
Expand Down

0 comments on commit 183b751

Please sign in to comment.