From 6de2e6adff240eff6945fa3efc9061d58647836b Mon Sep 17 00:00:00 2001 From: valhassan Date: Thu, 21 Sep 2023 16:14:34 -0400 Subject: [PATCH] added documentation changes --- docs/source/model.rst | 17 +++++++++++++++++ models/hrnet/hrnet_ocr.py | 9 ++++++--- models/segformer.py | 6 ++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/source/model.rst b/docs/source/model.rst index 8b9cce6f..c6dea199 100755 --- a/docs/source/model.rst +++ b/docs/source/model.rst @@ -52,3 +52,20 @@ folder to the complete list on different combinaisons. Also from the same library, another version of *DeepLabV3*, named *DeepLabV3+* of the *Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation* paper. + +Segformer +================================================ + +*Segformer* model implementation is based on the `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers `_ paper. +The encoder is called from `SMP `_. For more code implementation details check this `repo `_. + +.. autoclass:: models.segformer.SegFormer + + +HRNet + OCR +================================================ + +*HRNet + OCR* model implementation is based on the `HRNet paper `_ and `OCR paper `_. +For more code implementation details check this `repo `_. + +.. autoclass:: models.hrnet.hrnet_ocr.HRNet \ No newline at end of file diff --git a/models/hrnet/hrnet_ocr.py b/models/hrnet/hrnet_ocr.py index bcb807d3..070d6d7f 100644 --- a/models/hrnet/hrnet_ocr.py +++ b/models/hrnet/hrnet_ocr.py @@ -8,9 +8,12 @@ class HRNet(nn.Module): - """ - High Resolution Network (hrnet_w48_v2) with Object Contextual Representation module - + """High Resolution Network (hrnet_w48_v2) with Object Contextual Representation module + + Args: + pretrained (bool): use pretrained weights + in_channels (int): number of bands/channels + classes (int): number of classes """ def __init__(self, pretrained, in_channels, classes) -> None: super(HRNet, self).__init__() diff --git a/models/segformer.py b/models/segformer.py index a61868da..b23a680b 100644 --- a/models/segformer.py +++ b/models/segformer.py @@ -71,6 +71,12 @@ def forward(self, x): class SegFormer(nn.Module): + """Segformer Model + Args: + encoder (str): encoder name + in_channels (int): number of bands/channels + classes (int): number of classes + """ def __init__(self, encoder, in_channels, classes) -> None: super().__init__() self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, depth=5, drop_path_rate=0.1)