Skip to content

Commit

Permalink
refine the code structure to support all axis>=3 case
Browse files Browse the repository at this point in the history
  • Loading branch information
vera121 committed Dec 16, 2021
1 parent e434774 commit d2dfb1d
Showing 1 changed file with 5 additions and 22 deletions.
27 changes: 5 additions & 22 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,28 +236,11 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,

if(this->submanifold_sparse_)
{
if(bottom[0]->num_axes()==4)
{
CHECK_EQ(bottom[0]->height(), top[0]->height())<<
"Input and output blob height not equal! Submanifold sparse computation is invalid!";
CHECK_EQ(bottom[0]->width(), top[0]->width())<<
"Input and output blob width not equal! Submanifold sparse computation is invalid!";
}
else if(bottom[0]->num_axes()==5)
{
CHECK_EQ(bottom[0]->shape(2), top[0]->shape(2))<<
"Input and output blob depth not equal! Submanifold sparse computation is invalid!";
CHECK_EQ(bottom[0]->shape(3), top[0]->shape(3))<<
"Input and output blob height not equal! Submanifold sparse computation is invalid!";
CHECK_EQ(bottom[0]->shape(4), top[0]->shape(4))<<
"Input and output blob width not equal! Submanifold sparse computation is invalid!";
}
else
{
CHECK_EQ(bottom[0]->num_axes(), 3)<<"Not support Submanifold sparse computation for such blob dimension yet!";
CHECK_EQ(bottom[0]->shape(2), top[0]->shape(2))<<
"Input and output blob length not equal! Submanifold sparse computation is invalid!";
}
CHECK_GE(bottom[0]->num_axes(), 3)<<"Input blob dimension must >=3!";
for(int i=2; i<bottom[0]->num_axes();i++)
CHECK_EQ(bottom[0]->shape(i), top[0]->shape(i))<<
"Input and output blob shape does not match! Submanifold sparse computation is invalid!";

LOG(INFO)<<"Starts submanifold sparse computation.";

for(int index=0; index<bottom[0]->count(2); index++)
Expand Down

0 comments on commit d2dfb1d

Please sign in to comment.