Skip to content

Commit

Permalink
MAINT: fix network forward function.
Browse files Browse the repository at this point in the history
  • Loading branch information
oddkiva committed Dec 22, 2023
1 parent a92bc31 commit 67c8d78
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 26 deletions.
8 changes: 4 additions & 4 deletions python/oddkiva/sara/pybind11/test/test_image_processing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from oddkiva.sara import resize
import oddkiva.sara as sara

import numpy as np


def test_resize():
src = np.arange(16).reshape((4, 4))
src = np.array([src, src, src]).astype(np.float32)
print(src)

dst = np.zeros((3, 4, 4)).astype(np.float32)

resize(src, dst)
print(dst)
sara.resize(src, dst)

assert np.linalg.norm(src - dst) < 1e-12
15 changes: 8 additions & 7 deletions python/oddkiva/shakti/inference/darknet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def forward(self, x):
ys = [x]
boxes = []
for block in self.model:
logging.debug(f'{block}')
if type(block) is darknet.ConvBNA:
conv = block
x = ys[-1]
Expand All @@ -248,11 +249,8 @@ def forward(self, x):
elif type(block) is darknet.RouteConcat:
concat = block
if len(concat.layers) == 2:
x1, x2 = [ys[l] for l in concat.layers]
y = concat(x1, x2)
elif len(concat.layers) == 4:
x1, x2, x3, x4 = [ys[l] for l in concat.layers]
y = concat(x1, x2, x3, x4)
xs = [ys[l] for l in concat.layers]
y = concat(*xs)
else:
raise NotImplementedError(
'Unsupported number of inputs for concat'
Expand All @@ -265,8 +263,11 @@ def forward(self, x):
yolo = block
x = ys[-1]
y = yolo(x)
boxes.append(y)
else:
raise NotImplementedError

return boxes
ys.append(y)
if type(block) is darknet.Yolo:
boxes.append(y)

return boxes
26 changes: 11 additions & 15 deletions python/oddkiva/shakti/inference/darknet/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(self, layer: int,
self.id = id

def forward(self, x):
if groups == 1:
if self.groups == 1:
return x
else:
# Get the number of channels.
Expand All @@ -126,7 +127,7 @@ def forward(self, x):
group_size = C // self.groups
# Get the slice that we want.
c1 = self.group_id * group_size
c2 = c1 + self.group_size
c2 = c1 + group_size
return x[:, c1:c2, :, :]


Expand All @@ -137,15 +138,10 @@ def __init__(self, layers: [int], id: Optional[int] = None):
self.layers = layers
self.id = id

def forward(self, x1, x2):
if len(self.layers) != 2:
raise RuntimeError("This route-concat layer requires 2 inputs")
return torch.cat((x1, x2), 1)

def forward(self, x1, x2, x3, x4):
if len(self.layers) != 4:
raise RuntimeError("This route-concat layer requires 4 inputs")
return torch.cat((x1, x2, x3, x4), 1)
def forward(self, *xs):
if len(self.layers) != len(xs):
raise RuntimeError(f"This route-concat layer requires {self.layers} inputs")
return torch.cat(xs, 1)


class Shortcut(nn.Module):
Expand Down Expand Up @@ -185,7 +181,7 @@ def __init__(self, darknet_params: dict[str, Any]):
self.masks = darknet_params['mask']
self.anchors = darknet_params['anchors']
self.scale_x_y = darknet_params['scale_x_y']
self.classes = darknet_params['classes']
self.num_classes = darknet_params['classes']

self.alpha = self.scale_x_y
self.beta = -0.5 * (self.scale_x_y - 1)
Expand All @@ -201,13 +197,13 @@ def forward(self, x):
ys = [box * num_box_features + 1 for box in range(3)]
# ws = [box * num_box_features + 2 for box in range(3)]
# hs = [box * num_box_features + 3 for box in range(3)]
y[:, xs, :, :] = self.alpha * nn.Sigmoid(x[:, xs, :, :]) + self.beta
y[:, ys, :, :] = self.alpha * nn.Sigmoid(x[:, ys, :, :]) + self.beta
y[:, xs, :, :] = self.alpha * torch.sigmoid(x[:, xs, :, :]) + self.beta
y[:, ys, :, :] = self.alpha * torch.sigmoid(x[:, ys, :, :]) + self.beta

# P[object] and P[class|object] probabilities.
for box in range(0, 3):
c_begin = box * num_box_features + 4
c_end = (box + 1) * num_box_features
y[:, c_begin:c_end, :, :] = nn.Sigmoid(x[:, c_begin:c_end, :, :])
y[:, c_begin:c_end, :, :] = torch.sigmoid(x[:, c_begin:c_end, :, :])

return y

0 comments on commit 67c8d78

Please sign in to comment.