Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 5, 2024
1 parent 4b7a5a7 commit 326e76d
Showing 1 changed file with 60 additions and 11 deletions.
71 changes: 60 additions & 11 deletions pdequinox/_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@


class Hierarchical(eqx.Module):
"""
Uses convolution for downsampling instead of max pooling
"""

lifting: Block
down_sampling_blocks: List[Block]
left_arch_blocks: List[Block] # Includes the bottleneck
Expand Down Expand Up @@ -54,15 +50,68 @@ def __init__(
**boundary_kwargs,
):
"""
You are probably looking for pdequinox.arch.ClassicUNet.
Generic constructor for hierarchical block-based architectures like
UNets. (For the classic UNet, use `pdequinox.arch.ClassicUNet` instead.
Hierarchical architectures us a number of different spatial resolutions.
The lower the resolution, the wider the receptive of convolutions.
num_levels define how deep the UNet goes. If set to 0, this will just be
a classical conv net. If set to 1, this will be a single down and up
sampling block etc.
**Arguments:**
Use the channel multipliers to adjust the channel growth over depth. If
set to None, the channels will grow by a factor of reduction_factor at
each level.
- `num_spatial_dims`: The number of spatial dimensions. For example
traditional convolutions for image processing have this set to `2`.
- `in_channels`: The number of input channels.
- `out_channels`: The number of output channels.
- `hidden_channels`: The number of channels in the hidden layers. This
refers to the highest resolution. Right after the input, the input
channels will be lifted to this feature dimension without changing
the spatial resolution.
- `num_levels`: The number of levels in the hierarchy. This is the
number of down and up sampling blocks. If set to 0, this will just
be a classical conv net. If set to 1, this will be a single down and
up sampling block etc. The total number of resolutions are
`num_levels + 1`.
- `activation`: The activation function to use in the blocks.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
- `reduction_factor`: The factor by which the spatial resolution is
reduced at each level. This has to be an integer. In order to avoid
ambiguities in shapes, it is best if the input spatial resolution is
a multiple of `reduction_factor ** num_levels`. Default is `2`.
- `boundary_mode`: The boundary mode to use for the convolution.
(Keyword only argument)
- `channel_multipliers`: The factor by which the number of channels is
multiplied at each level. If set to `None`, the channels will grow
by a factor of `reduction_factor` at each level. This is similar to
the classical UNet which trades spatial resolution for feature
dimension. Note however, that the parameters of convolutions scale
with the mapped channels, hence the majority of numbers will then be
in the coarsest representation. Supply a tuple of integers that
represent the desired number of channels at each resolution
different than the original one. The length of the tuple must match
`num_levels`. For example, to not change the number of channels at
any level, set this to `(1,) * num_levels`. Default is `None`.
- `lifting_factory`: The factory to use for the lifting block.
Default is `ClassicDoubleConvBlockFactory` which is a classic double
convolution block.
- `down_sampling_factory`: The factory to use for the down sampling
blocks. This must be a block that is able to change the spatial
resolution. Default is `LinearConvDownBlockFactory` which is a
simple linear strided convolution block.
- `left_arch_factory`: The factory to use for the left architecture
blocks. Default is `ClassicDoubleConvBlockFactory` which is a
classic double convolution block.
- `up_sampling_factory`: The factory to use for the up sampling blocks.
This must be a block that is able to change the spatial resolution.
It should work in conjunction with the `down_sampling_factory`.
Default is `LinearConvUpBlockFactory` which is a simple linear
strided transposed convolution block.
- `right_arch_factory`: The factory to use for the right architecture
blocks. Default is `ClassicDoubleConvBlockFactory` which is a
classic double convolution block.
- `projection_factory`: The factory to use for the projection block.
Default is `LinearChannelAdjustBlockFactory` which is simply a
linear 1x1 convolution for channel adjustment.
"""
self.down_sampling_blocks = []
self.left_arch_blocks = []
Expand Down

0 comments on commit 326e76d

Please sign in to comment.