MRI Synthesis¶
In this notebook, we will use transforms that generate synthetic MRIs with varying contrast and resolution from label maps.
This is a reimplementation of the domain randomization approach described in:
Billot, B., Greve, D.N., Puonti, O., Thielscher, A., Van Leemput, K., Fischl, B., Dalca, A.V. and Iglesias, J.E., 2023. SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining. Medical image analysis, 86, p.102789.
@article{billot2023synthseg,
title = {SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining},
author = {Billot, Benjamin and Greve, Douglas N and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V and Iglesias, Juan Eugenio and others},
journal = {Medical image analysis},
volume = {86},
pages = {102789},
year = {2023},
publisher = {Elsevier},
url = {https://www.sciencedirect.com/science/article/pii/S1361841523000506}
}
In [ ]:
Copied!
!pushd $TMPDIR \
&& curl \
https://github.com/BBillot/SynthSeg/raw/master/data/training_label_maps/training_seg_01.nii.gz \
&& popd
!pushd $TMPDIR \
&& curl \
https://github.com/BBillot/SynthSeg/raw/master/data/training_label_maps/training_seg_01.nii.gz \
&& popd
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
In [4]:
Copied!
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
LoadTransform, RelabelTransform,
IntensityTransform, SynthFromLabelTransform, RandomGaussianMixtureTransform)
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
LoadTransform, RelabelTransform,
IntensityTransform, SynthFromLabelTransform, RandomGaussianMixtureTransform)
In [5]:
Copied!
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
lab = LoadTransform(dtype=torch.int)(fname)
lab = lab[:, :, lab.shape[-2]//2, :]
lab = RelabelTransform()(lab)
plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
lab = LoadTransform(dtype=torch.int)(fname)
lab = lab[:, :, lab.shape[-2]//2, :]
lab = RelabelTransform()(lab)
plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[5], line 2 1 fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz') ----> 2 lab = LoadTransform(dtype=torch.int)(fname) 3 lab = lab[:, :, lab.shape[-2]//2, :] 4 lab = RelabelTransform()(lab) File ~/Dropbox/Workspace/code/balbasty/cornucopia/cornucopia/base.py:99, in Transform.__call__(self, *a, **k) 98 def __call__(self, *a, **k): ---> 99 out = super().__call__(*a, **k) 100 if isinstance(out, Returned): 101 out = out.obj File ~/miniforge3/envs/pytorch2/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File ~/miniforge3/envs/pytorch2/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None File ~/Dropbox/Workspace/code/balbasty/cornucopia/cornucopia/base.py:151, in Transform.forward(self, *a, **k) 148 return x 150 # now we're working with a single tensor (or str) --> 151 y = self.xform(x) 152 if not isinstance(y, Returned): 153 if not isinstance(y, type(self.returns)): File ~/Dropbox/Workspace/code/balbasty/cornucopia/cornucopia/io.py:109, in LoadTransform.xform(self, x) 107 message = [f'Could not load {x}:'] + exceptions 108 message = '\n'.join(message) --> 109 raise ValueError(message) ValueError: Could not load /var/folders/sy/4t85jky92nd3f_nbfnjljmr00000gn/T/demo.nii.gz: Empty file: '/var/folders/sy/4t85jky92nd3f_nbfnjljmr00000gn/T/demo.nii.gz' Empty file: '/var/folders/sy/4t85jky92nd3f_nbfnjljmr00000gn/T/demo.nii.gz' No data left in file
Then, instantiate a IntensityTransform and apply it to our labels.
Note that tensors fed to a Transform layer should have a channel dimension, and no batch dimension.
In [4]:
Copied!
trf = RandomGaussianMixtureTransform(background=0) + IntensityTransform()
img = trf(lab)
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
trf = RandomGaussianMixtureTransform(background=0) + IntensityTransform()
img = trf(lab)
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
Now, let's synthesize a bunch of them
In [5]:
Copied!
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(trf(lab).squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(trf(lab).squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
Finally, let's try the full pipeline with deformations
In [6]:
Copied!
trf = SynthFromLabelTransform()
img, newlab = trf(lab)
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.subplot(1, 2, 2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Label')
plt.show()
trf = SynthFromLabelTransform()
img, newlab = trf(lab)
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.subplot(1, 2, 2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Label')
plt.show()
In [7]:
Copied!
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]//2):
img, newlab = trf(lab)
plt.subplot(*shape, 2*i+1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(*shape, 2*i+2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.show()
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]//2):
img, newlab = trf(lab)
plt.subplot(*shape, 2*i+1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(*shape, 2*i+2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.show()