diff --git a/pdequinox/_hierarchical.py b/pdequinox/_hierarchical.py index 8a85fee..a1175d3 100644 --- a/pdequinox/_hierarchical.py +++ b/pdequinox/_hierarchical.py @@ -18,10 +18,6 @@ class Hierarchical(eqx.Module): - """ - Uses convolution for downsampling instead of max pooling - """ - lifting: Block down_sampling_blocks: List[Block] left_arch_blocks: List[Block] # Includes the bottleneck @@ -54,15 +50,68 @@ def __init__( **boundary_kwargs, ): """ - You are probably looking for pdequinox.arch.ClassicUNet. + Generic constructor for hierarchical block-based architectures like + UNets. (For the classic UNet, use `pdequinox.arch.ClassicUNet` instead. + + Hierarchical architectures us a number of different spatial resolutions. + The lower the resolution, the wider the receptive of convolutions. - num_levels define how deep the UNet goes. If set to 0, this will just be - a classical conv net. If set to 1, this will be a single down and up - sampling block etc. + **Arguments:** - Use the channel multipliers to adjust the channel growth over depth. If - set to None, the channels will grow by a factor of reduction_factor at - each level. + - `num_spatial_dims`: The number of spatial dimensions. For example + traditional convolutions for image processing have this set to `2`. + - `in_channels`: The number of input channels. + - `out_channels`: The number of output channels. + - `hidden_channels`: The number of channels in the hidden layers. This + refers to the highest resolution. Right after the input, the input + channels will be lifted to this feature dimension without changing + the spatial resolution. + - `num_levels`: The number of levels in the hierarchy. This is the + number of down and up sampling blocks. If set to 0, this will just + be a classical conv net. If set to 1, this will be a single down and + up sampling block etc. The total number of resolutions are + `num_levels + 1`. + - `activation`: The activation function to use in the blocks. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + - `reduction_factor`: The factor by which the spatial resolution is + reduced at each level. This has to be an integer. In order to avoid + ambiguities in shapes, it is best if the input spatial resolution is + a multiple of `reduction_factor ** num_levels`. Default is `2`. + - `boundary_mode`: The boundary mode to use for the convolution. + (Keyword only argument) + - `channel_multipliers`: The factor by which the number of channels is + multiplied at each level. If set to `None`, the channels will grow + by a factor of `reduction_factor` at each level. This is similar to + the classical UNet which trades spatial resolution for feature + dimension. Note however, that the parameters of convolutions scale + with the mapped channels, hence the majority of numbers will then be + in the coarsest representation. Supply a tuple of integers that + represent the desired number of channels at each resolution + different than the original one. The length of the tuple must match + `num_levels`. For example, to not change the number of channels at + any level, set this to `(1,) * num_levels`. Default is `None`. + - `lifting_factory`: The factory to use for the lifting block. + Default is `ClassicDoubleConvBlockFactory` which is a classic double + convolution block. + - `down_sampling_factory`: The factory to use for the down sampling + blocks. This must be a block that is able to change the spatial + resolution. Default is `LinearConvDownBlockFactory` which is a + simple linear strided convolution block. + - `left_arch_factory`: The factory to use for the left architecture + blocks. Default is `ClassicDoubleConvBlockFactory` which is a + classic double convolution block. + - `up_sampling_factory`: The factory to use for the up sampling blocks. + This must be a block that is able to change the spatial resolution. + It should work in conjunction with the `down_sampling_factory`. + Default is `LinearConvUpBlockFactory` which is a simple linear + strided transposed convolution block. + - `right_arch_factory`: The factory to use for the right architecture + blocks. Default is `ClassicDoubleConvBlockFactory` which is a + classic double convolution block. + - `projection_factory`: The factory to use for the projection block. + Default is `LinearChannelAdjustBlockFactory` which is simply a + linear 1x1 convolution for channel adjustment. """ self.down_sampling_blocks = [] self.left_arch_blocks = []