body_3d = create_body(r3d_18, pretrained = False)
Fastai's DynamicUnet
allows construction of a UNet using any pretrained CNN as backbone/encoder. A key module is nn.PixelShuffle
which allows subpixel convolutions for upscaling in the UNet Blocks. However, nn.PixelShuffle
is only for 2D images, so in faimed3d nn.ConvTranspose3d
is used instead.
Fastai's PixelShuffle_ICNR
first performes a convolution to increase the layer size, then applies PixelShuffle
to resize the image. A special initialization technique is applied to PixelShuffle
, which can reduce checkerboard artifacts (see https://arxiv.org/pdf/1707.02937.pdf). It is probably not needed for nn.ConvTranspose3d
ConvTranspose3D(256, 128)(torch.randn((1, 256, 3, 13, 13))).size()
ConvTranspose3D(256, 128, blur = True)(torch.randn((1, 256, 3, 13, 13))).size()
To work with 3D data, the UnetBlock
from fastai is adapted, replacing PixelShuffle_ICNR
with the above created ConvTranspose3D
and also adapting all conv-layers and norm-layers to the 3rd dimension. As small differences in size may appear, forward
-func contains a interpolation step, which is also adapted to work with 5D input instead of 4D. UnetBlock3D
receives the lower level features as hooks.
The output size of the last Unet-Block can be slightly different than the original input size, so one of the last steps in DynamicUnet
is ResizeToOrig
which is also adapted to work with 5D instead of 4D input images.
DynamicUnet3D
is the main UNet class for faimed3d
and is very similar to the fastai
DynamicUnet
.
unet3d = DynamicUnet3D(body_3d, 2, (10,50,50), 1)
unet3d(torch.rand(2, 3, 10, 50, 50)).size()