From e73a13678bff35a99e9c88d4415e0a117179354d Mon Sep 17 00:00:00 2001 From: Yvonne Chen Date: Mon, 10 Jan 2022 06:25:36 +0100 Subject: [PATCH] add round_mode into pooling param origin is https://github.com/BVLC/caffe/commit/24b09053b44f47b1aa5ad57008809cb2c0e1f592 and we have some fix --- include/caffe/layers/pooling_layer.hpp | 1 + src/caffe/layers/pooling_layer.cpp | 20 +++++++++++++++++++- src/caffe/proto/caffe.proto | 6 ++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/include/caffe/layers/pooling_layer.hpp b/include/caffe/layers/pooling_layer.hpp index c124a2be..75251aa0 100644 --- a/include/caffe/layers/pooling_layer.hpp +++ b/include/caffe/layers/pooling_layer.hpp @@ -55,6 +55,7 @@ class PoolingLayer : public Layer { Blob max_idx_; int pad_type_; //CUSTOMIZATION bool ceil_mode_; + PoolingParameter_RoundMode round_mode_; int pad_l_; //CUSTOMIZATION int pad_r_; //CUSTOMIZATION int pad_t_; //CUSTOMIZATION diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index e1677007..9a1af25a 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -35,7 +35,25 @@ void PoolingLayer::LayerSetUp(const vector*>& bottom, || (!pool_param.has_stride_h() && !pool_param.has_stride_w())) << "Stride is stride OR stride_h and stride_w are required."; global_pooling_ = pool_param.global_pooling(); - ceil_mode_ = pool_param.ceil_mode(); + + if(pool_param.has_round_mode()) + { + round_mode_ = pool_param.round_mode(); + switch (round_mode_) { + case PoolingParameter_RoundMode_CEIL: + ceil_mode_ = true; + break; + case PoolingParameter_RoundMode_FLOOR: + ceil_mode_ = false; + break; + default: + LOG(FATAL) << "Unknown rounding mode."; + } + } + else{ + ceil_mode_ = pool_param.ceil_mode(); + } + if (global_pooling_) { kernel_h_ = bottom[0]->height(); kernel_w_ = bottom[0]->width(); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index d5cca188..b98a25d0 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -2212,6 +2212,12 @@ message PoolingParameter { optional bool global_pooling = 12 [default = false]; optional bool ceil_mode = 14 [default = true]; // Specify floor/ceil mode rounding + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 15 [default = CEIL]; // functionality similar to ceil_mode; higher priority than ceil_mode } message Pooling3DParameter {