original = TensorDicom3D.create('../data/series/radiopaedia_10_85902_1.nii.gz')
mask = TensorMask3D.create('../data/masks/radiopaedia_10_85902_1.nii.gz')
original.show()
mask.show(add_to_existing = True, alpha = 0.25, cmap = 'jet')
Resize3D((10,50,50))(original, split_idx = 0).show()
Resize3D((10,50,50))(mask, split_idx = 0).show(add_to_existing = True, alpha = 0.25, cmap = 'jet')
Unlike normal images, DICOM images occupy a certain volume of space. The dimension of the voxels therefore says nothing about the real size of the image but only the dimension of the pixels multiplied by the spacing. When processing multiple DICOM sequences at the same time, it is possible that these sequences have a different orientation around space and a different size, although the voxel dimensions are identical. Resample3D allows to adjust the 3D volumes to a uniform size and orientation. This way it can be achieved that the same structure is always present in the same place in different images.
Resample3D((20, 224, 224), (20, 5, 7))(original).show()
Pad3D((10, 800, 800))(original).show()
Flipping
In medical images, the left and right side often cannot be differentiated from each other (e.g. scans of the head, hand, knee, ...). Therefore, the image orientation is stored in the image header, enabling the viewer system to accurately display the images. For deep learning, only the pixel array is extracted, so the header information is lost. When displaying only the pixel array, the images might already be displayed flipped or in inverted slice order. So, implementing vertical/horizontal flipping as well as flipping alongside the z-axis can be used for data augmentation.
torch.stack((original, RandomFlip3D()(original, split_idx = 0),
RandomFlip3D()(original, split_idx = 0),
RandomFlip3D()(original, split_idx = 0))).show(nrow = 15)
Rotating
Medical images should show no rotation, however with removal of the image file header, the pixel array might appear rotated when displayed and therefore be introduced to the model rotated. Fruthermore, in some images the patients might be rotated to some degree. Because of this, rotation of 90° and 180° as well as substeps should be implemented.
torch.stack((original, RandomRotate3D()(original, split_idx = 0),
RandomRotate3D()(original, split_idx = 0), RandomRotate3D()(original, split_idx = 0))).show(nrow = 15)
Pytorch does not support rotation of 3D images, so some transformations need to be applied slicewise.
tmp1 = RandomRotate3DBy()(original, split_idx = 0)
tmp2 = RandomRotate3DBy(p=1., degrees=(10, 10, 45), axis=[-1, -2, -3])(original, split_idx = 0)
original.show(nrow = 15)
tmp1.show(nrow = 15)
tmp2.show(nrow =15)
Rotating by 90 (or 180 and 270) degrees should not be done via RandomRotate3DBy
but by rotate_90_3d
, as this is approximatly 28 times faster.
"Dihedral" transformation
As the 3D array can be flipped by three sides, but should only be rotated along the z axis, this is not a complete dihedral group. Still multiple combinations of flipping and rotating should be implemented:
- original (= flipp ll, roate 180 = same as original image)
- rotate 90
- rotate 180
- rotate 270
- flip ll (=flip ap, rotate 180)
- flip ap
- flip cc
- flip cc, rotate 90
- flip cc, rotate 180
- flip cc, rotate 270
- flip ll, rotate 90
- flipp ll, rotate 270
- flip ap, rotate 90
- flip ap rotate 270
- flip cc, flip ll, rotate 90
- flip cc, flipp ll, rotate 270
- flip cc, flip ap, rotate 90
- flip cc, flip ap rotate 270
I am not sure if this is complete...
dihedral = RandomDihedral3D()
torch.stack((original, dihedral(original, split_idx = 0), dihedral(original, split_idx = 0),
dihedral(original, split_idx = 0),dihedral(original, split_idx = 0),
dihedral(original, split_idx = 0))).show(nrow=15)
Random crop
A reasonable approach for 3D medical images would be a presizing to a uniform but large volume and subsequent random cropping to the target dimension. As most areas of interest are located centrally in the image/volume, some cropping can always be applied.
Also random cropping should be applied after any rotation, that is not in 90/180/270 degrees, so that empty margins are cropped.
Crop = RandomCrop3D((10,50,50), (10,20,20), False)
torch.stack((Crop(original, split_idx = 0), Crop(original, split_idx = 0),
Crop(original, split_idx = 0), Crop(original, split_idx = 0))).show(nrow = 10)
im = Crop(original).resize_3d((10, 100, 100))
crop_mask = TensorMask3D(torch.ones(4, 100, 20)).pad_to((10, 100, 100))
crop_mask = crop_mask + crop_mask.rotate_90_3d()
crop_mask = torch.where(crop_mask == 0, 0, 1)
crop_mask2 = TensorMask3D(torch.ones(10, 100, 20)).pad_to((10, 100, 100))
crop_mask2 = crop_mask2 + crop_mask2.rotate_90_3d()
crop_mask2 = torch.where(crop_mask2 == 0, 1, 0)
crop_mask.show()
crop_mask2.show()
MaskErease(mask = crop_mask)(im).show()
MaskErease(mask = crop_mask2)(im).show()
im2 = TensorDicom3D.create('../data/example_grid.nii.gz')
im2 = im2.unsqueeze(0)
im2.show()
RandomPerspective3D(im.size(-1), p = 1.)(im2, split_idx=0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
Noise= RandomNoise3D(p=1)
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomBlur3D(p=1., sigma = 10)(im, split_idx=0).show()
torch.stack((im.mean_scale(),
RandomBrightness3D(p=1., beta_range=[0.9, 1])(im.mean_scale(), split_idx = 0),
RandomBrightness3D(p=1., beta_range=[-0.9, -1])(im.mean_scale(), split_idx = 0))).show()
im.mean_scale().show()
RandomContrast3D(p=1.)(im.mean_scale(), split_idx = 0).show()
RandomContrast3D(p=1.)(im.mean_scale(), split_idx = 0).show()
def elastic_transform_3d(image, labels=None, alpha=4, sigma=35, bg_val=0.1):
"""
Elastic deformation of images as described in
Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual
Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
Modified from:
https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a
https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62
Modified to take 3D inputs
Deforms both the image and corresponding label file
image linear/trilinear interpolated
Label volumes nearest neighbour interpolated
"""
assert image.ndim == 3
shape = image.shape
dtype = image.dtype
# Define coordinate system
coords = np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])
# Initialize interpolators
im_intrps = RegularGridInterpolator(coords, image,
method="linear",
bounds_error=False,
fill_value=bg_val)
# Get random elastic deformations
dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
# Define sample points
x, y, z = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]]
indices = np.reshape(x + dx, (-1, 1)), \
np.reshape(y + dy, (-1, 1)), \
np.reshape(z + dz, (-1, 1))
# Interpolate 3D image image
image = np.empty(shape=image.shape, dtype=dtype)
image = im_intrps(indices).reshape(shape)
# Interpolate labels
if labels is not None:
lab_intrp = RegularGridInterpolator(coords, labels,
method="nearest",
bounds_error=False,
fill_value=0)
labels = lab_intrp(indices).reshape(shape).astype(labels.dtype)
return image, labels
return image
Putting it all together
A good workflow would be to apply random crop to all images after one transformation. For this, the images should be presized to a size, just some pixels larger than desired, then transformed and cropped to the final size. Using this approach empty space, which e.g. appears after RandomRotate3DBy
will be cropped and will not affect the accuracy of the model. One only has to be careful, that the region of interest, e.g. the prostate, will be in every cropped image.
Crop = RandomCrop3D((2,10,10), (1,2,2))
tfms = [RandomBrightness3D(), RandomContrast3D(), RandomWarp3D(), RandomDihedral3D(), RandomNoise3D(), RandomRotate3DBy()]
tfms = [Pipeline([RandomBrightness3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomContrast3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomWarp3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomDihedral3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomNoise3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomRotate3DBy(p=1.), Crop], split_idx = 0)]
comp = setup_aug_tfms(tfms)
comp
ims = [t(im).squeeze() for t in tfms]
torch.stack(ims).show(nrow = 6)
@patch
def make_pseudo_color(t: (TensorDicom3D, TensorMask3D)):
'''
The 3D CNN still expects color images, so a pseudo color image needs to be created as long as I don't adapt the 3D CNN
'''
if t.ndim == 3:
return t.unsqueeze(0).float()
elif t.ndim == 4:
return t.unsqueeze(1).float()
else:
return t
class PseudoColor(RandTransform):
split_idx, p = None, 1
def __init__(self, p=1):
super().__init__(p=p)
def __call__(self, b, split_idx=None, **kwargs):
"change in __call__ to enforce, that the Transform is always applied on every dataset. "
return super().__call__(b, split_idx=split_idx, **kwargs)
def encodes(self, x:(TensorDicom3D, TensorMask3D)):
return x.make_pseudo_color()
MakeColor = PseudoColor()
im.shape, MakeColor(im, split_idx = 0).shape
tmp = Pipeline(aug_transforms_3d(p_all = 1.), split_idx=0)(im)
print(tmp.size())
tmp.show()
mask.reduce_classes([1]).unique()