Skip to content

Commit

Permalink
WIP: save work.
Browse files Browse the repository at this point in the history
  • Loading branch information
oddkiva committed Dec 20, 2023
1 parent 2e048b2 commit 86e980b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 7 additions & 1 deletion python/oddkiva/shakti/inference/yolo/darknet_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@ def __init__(self, in_channels, darknet_params: dict[str, Any], id: int):
raise ValueError(f'No convolutional activation named {activation}')
self.block.add_module(f'{activation}{id}', activation_fn);


def forward(self, x):
return self.block.forward(x)

def load_weights(self, weights_file: Path):
with open(weights_file, 'rb') as fp:
w_data = fp.read(conv.weight.shape.numel() * 4)
conv.weight.data.copy_(torch.from_numpy
conv.bias.data = fp.read(conv.bias.shape.numel() * 4)



class MaxPool(nn.Module):

Expand Down
11 changes: 9 additions & 2 deletions python/oddkiva/shakti/inference/yolo/darknet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def __init__(self, darknet_config: DarknetConfig):

self.model = self.create_network(darknet_config)

self.major = None
self.minor = None
self.revision = None
self.seen = None
self.transpose = None

def create_network(self, darknet_config: DarknetConfig):
model = nn.ModuleList()

Expand All @@ -26,8 +32,9 @@ def create_network(self, darknet_config: DarknetConfig):

return model

def load_weights(self, weights_file: Path):
pass
def load_convolutional_weights(self, conv, weights_file: Path):
with open(weights_file, 'rb') as fp:
fp.read(

def save_weights(self, weights_file: Path)
pass

0 comments on commit 86e980b

Please sign in to comment.