Contrast augmentation¶
A large number of MRI contrasts exist. However, most label data comes from a small number of sequences. It is therefore complicated to train networks that generalize to any contrast. Here, we showcase a contrast augmentation transform that operates by fitting a Gaussian mixture model to the input image and shifting their means and variances.
First, let's download an example image.
In [1]:
Copied!
!pushd $TMPDIR \
&& wget https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/tutorial_data.tar.gz -O data.tar.gz \
&& tar -xzvf data.tar.gz \
&& popd
!pushd $TMPDIR \
&& wget https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/tutorial_data.tar.gz -O data.tar.gz \
&& tar -xzvf data.tar.gz \
&& popd
/home/scratch /autofs/space/pade_001/users/yb947/code/yb/cornucopia/docs/examples --2023-08-15 14:31:35-- https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/tutorial_data.tar.gz Resolving surfer.nmr.mgh.harvard.edu (surfer.nmr.mgh.harvard.edu)... 132.183.1.43 Connecting to surfer.nmr.mgh.harvard.edu (surfer.nmr.mgh.harvard.edu)|132.183.1.43|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 16644702 (16M) [application/x-gzip] Saving to: ‘data.tar.gz’ data.tar.gz 100%[===================>] 15.87M 63.8MB/s in 0.2s 2023-08-15 14:31:35 (63.8 MB/s) - ‘data.tar.gz’ saved [16644702/16644702] brain_2d_no_smooth.h5 brain_2d_smooth.h5 brain_3d.h5 fs_rgb.npy subj1.npz subj2.npz tutorial_data.npz /autofs/space/pade_001/users/yb947/code/yb/cornucopia/docs/examples
In [2]:
Copied!
import torch
import numpy as np
import matplotlib.pyplot as plt
from cornucopia import ContrastMixtureTransform
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from cornucopia import ContrastMixtureTransform
import os
In [3]:
Copied!
fname = os.path.join(os.environ['TMPDIR'], 'tutorial_data.npz')
dat = np.load(fname)['train'][0]
dat = torch.as_tensor(dat)
plt.figure(figsize=(10, 10))
plt.imshow(dat, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()
fname = os.path.join(os.environ['TMPDIR'], 'tutorial_data.npz')
dat = np.load(fname)['train'][0]
dat = torch.as_tensor(dat)
plt.figure(figsize=(10, 10))
plt.imshow(dat, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()
Let us now instantiate a contrast augmentation layer and apply it to the MRI. We use fewer classes (6) than the default (12), because we're dealing with skull-stripped 2D images, that have much fewer intensity modes than an intact 3D volume.
In [4]:
Copied!
trf = ContrastMixtureTransform(nk=6)
aug = trf(dat[None])[0]
plt.figure(figsize=(10, 10))
plt.imshow(aug, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()
trf = ContrastMixtureTransform(nk=6)
aug = trf(dat[None])[0]
plt.figure(figsize=(10, 10))
plt.imshow(aug, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()
Now, let's synthesize a bunch of them
In [5]:
Copied!
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(trf(dat[None])[0], cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(trf(dat[None])[0], cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()