Skip to content

Commit

Permalink
MAINT: clean code.
Browse files Browse the repository at this point in the history
  • Loading branch information
oddkiva committed Dec 21, 2023
1 parent 8036076 commit a94e1ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ BOOST_AUTO_TEST_SUITE(TestTensorRT)
BOOST_AUTO_TEST_CASE(test_inference_engine)
{
// Load the network on the host device (CPU).
const auto data_dir_path = fs::canonical(fs::path{src_path("data")});
const auto yolov4_tiny_dirpath =
data_dir_path / "trained_models" / "yolov4-tiny";
const auto model_dir_path =
fs::canonical(fs::path{src_path("trained_models")});
const auto yolov4_tiny_dirpath = model_dir_path / "yolov4-tiny";

// Convert it into a TensorRT network object.
auto serialized_net = trt::convert_yolo_v4_network_from_darknet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ namespace sara = DO::Sara;
namespace shakti = DO::Shakti;
namespace trt = shakti::TensorRT;

namespace nvonnx = nvonnxparser;


template <typename T, int N>
using PinnedTensor = sara::Tensor_<T, N, shakti::PinnedMemoryAllocator>;
Expand Down Expand Up @@ -101,9 +99,9 @@ BOOST_AUTO_TEST_CASE(test_yolox_tiny_onnx_conversion_to_trt_serialized_engine)
//
// It is available here:
// https://yolox.readthedocs.io/en/latest/demo/onnx_readme.html#download-onnx-models
const auto data_dir_path = fs::canonical(fs::path{src_path("data")});
const auto yolox_tiny_onnx_filepath =
data_dir_path / "trained_models" / "yolox_tiny.onnx";
const auto model_dir_path =
fs::canonical(fs::path{src_path("trained_models")});
const auto yolox_tiny_onnx_filepath = model_dir_path / "yolox_tiny.onnx";
BOOST_CHECK(fs::exists(yolox_tiny_onnx_filepath));

// Instantiate an ONNX parser and read the ONNX model file.
Expand Down
10 changes: 5 additions & 5 deletions python/oddkiva/shakti/inference/darknet/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def forward(self, x):

def load_weights(self, weights_file: Path):
pass
# with open(weights_file, 'rb') as fp:
# w_data = fp.read(conv.weight.shape.numel() * 4)
# conv.weight.data.copy_(torch.from_numpy
# conv.bias.data = fp.read(conv.bias.shape.numel() * 4)
# with open(weights_file, 'rb') as fp:
# w_data = fp.read(conv.weight.shape.numel() * 4)
# conv.weight.data.copy_(torch.from_numpy
# conv.bias.data = fp.read(conv.bias.shape.numel() * 4)


class MaxPool(nn.Module):
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(self, x):
# P[object] and P[class|object] probabilities.
for box in range(0, 3):
c_begin = box * num_box_features + 4
c_end = (box + 1) * num_box_features
c_end = (box + 1) * num_box_features
y[:, c_begin:c_end, :, :] = nn.Sigmoid(x[:, c_begin:c_end, :, :])

return y

0 comments on commit a94e1ab

Please sign in to comment.