Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZSNet #277

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions LIST_OF_PAPERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ The following is a short list of fastMRI publications. Clicking on the title wil
10. Defazio, A., Murrell, T., & Recht, M. P. (2020). [MRI Banding Removal via Adversarial Training](#mri-banding-removal-via-adversarial-training). In *Advances in Neural Information Processing Systems*, 33, pages 7660-7670.
11. Muckley, M. J.\*, Riemenschneider, B.\*, Radmanesh, A., Kim, S., Jeong, G., Ko, J., ... & Knoll, F. (2021). [Results of the 2020 fastMRI Challenge for Machine Learning MR Image Reconstruction](#results-of-the-2020-fastmri-challenge-for-machine-learning-mr-image-reconstruction). *IEEE Transactions on Medical Imaging*, 40(9), pages 2306-2317.
12. Johnson, P. M., Jeong, G., Hammernik, K., Schlemper, J., Qin, C., Duan, J., ..., & Knoll, F. [Evaluation of the Robustness of Learned MR Image Reconstruction to Systematic Deviations Between Training and Test Data for the Models from the fastMRI Challenge](#evaluation-of-the-robustness-of-learned-mr-image-reconstruction-to-systematic-deviations-between-training-and-test-data-for-the-models-from-the-fastmri-challenge). In *MICCAI Machine Learning for Medical Image Reconstruction Workshop*, pages 25–34, 2021.
13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). *Accepted at MIDL, 2022*, to appear.
13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). In *Medical Imaging with Deep Learning*.
14. 14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](https://doi.org/10.1148/ryai.210313). *Radiology: Artificial Intelligence*, e210313.


## fastMRI: An open dataset and benchmarks for accelerated MRI

Expand Down Expand Up @@ -282,4 +284,35 @@ Most current approaches to undersampled multi-coil MRI reconstruction focus on l
pages={to appear},
year={2022},
}
```
```

## Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI

[publication](https://doi.org/10.1148/ryai.210313)

**Purpose**

To explore the limits of deep learning-based brain MRI reconstruction and identify useful acceleration ranges for general-purpose imaging and potential screening.

**Materials and Methods**

In this retrospective study conducted from 2019 through 2021, a model was trained for reconstruction on 5,847 brain MRIs. Performance was evaluated across a wide range of accelerations (up to 100-fold along a single phase-encoded direction for two-dimensional [2D] slices) on the fastMRI test set collected by New York University, consisting of 558 image volumes. In a sample of 69 volumes, reconstructions were classified by radiologists for identifying two clinical thresholds: 1) general-purpose diagnostic imaging and 2) potential use in a screening protocol. A Monte Carlo procedure was developed for estimating reconstruction error with only undersampled data. The model was evaluated on both in-domain and out-of-domain data. Confidence intervals were calculated using the percentile bootstrap method.

**Results**

Radiologists rated 100% of 69 volumes as having sufficient image quality for general-purpose imaging at up to 4× acceleration and 65 of 69 (94%) of volumes as having sufficient image quality for screening at up to 14× acceleration. The Monte Carlo procedure estimated ground truth peak signal-to-noise ratio and mean squared error with coefficients of determination greater than 0.5 at all accelerations. Out-of-distribution experiments demonstrated the model’s ability to produce images substantially distinct from the training set, even at 100× acceleration.

**Conclusion**

For 2D brain images using deep learning-based reconstruction, maximum acceleration for potential screening was 3–4 times higher than that for diagnostic general-purpose imaging.

```BibTeX
@article{radmanesh2022exploring,
title={Exploring the Acceleration Limits of Deep Learning {VarNet}-based Two-dimensional Brain {MRI}},
author={Radmanesh, Alireza and Muckley, Matthew J and Murrell, Tullie and Lindsey, Emma and Sriram, Anuroop and Knoll, Florian and Sodickson, Daniel K and Lui, Yvonne W},
journal={Radiology: Artificial Intelligence},
pages={e210313},
year={2022},
publisher={Radiological Society of North America}
}
```
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,5 @@ corresponding abstracts, as well as links to preprints and code can be found
10. Defazio, A., Murrell, T., & Recht, M. P. (2020). [MRI Banding Removal via Adversarial Training](https://papers.nips.cc/paper/2020/hash/567b8f5f423af15818a068235807edc0-Abstract.html). In *Advances in Neural Information Processing Systems*, 33, pages 7660-7670.
11. Muckley, M. J.\*, Riemenschneider, B.\*, Radmanesh, A., Kim, S., Jeong, G., Ko, J., ... & Knoll, F. (2021). [Results of the 2020 fastMRI Challenge for Machine Learning MR Image Reconstruction](https://doi.org/10.1109/TMI.2021.3075856). *IEEE Transactions on Medical Imaging*, 40(9), pages 2306-2317.
12. Johnson, P. M., Jeong, G., Hammernik, K., Schlemper, J., Qin, C., Duan, J., ..., & Knoll, F. (2021). [Evaluation of the Robustness of Learned MR Image Reconstruction to Systematic Deviations Between Training and Test Data for the Models from the fastMRI Challenge](https://doi.org/10.1007/978-3-030-88552-6_3). In *MICCAI MLMIR Workshop*, pages 25–34,
13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). *In MIDL*. Accepted.
13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). In *MIDL*.
14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](https://doi.org/10.1148/ryai.210313). *Radiology: Artificial Intelligence*, e210313.
1 change: 1 addition & 0 deletions fastmri/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LICENSE file in the root directory of this source tree.
"""

from ._zsnet import ZSNet
from .adaptive_varnet import AdaptiveVarNet
from .policy import StraightThroughPolicy
from .unet import Unet
Expand Down
319 changes: 319 additions & 0 deletions fastmri/models/_zsnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import fastmri
from fastmri.data import transforms
from fastmri.models.varnet import SensitivityModel

from ._zsnet_unet import ZSNetUnet


def _create_zero_tensor(tensor_type: torch.Tensor) -> torch.Tensor:
return torch.zeros(
(1, 1, 1, 1, 1), dtype=tensor_type.dtype, device=tensor_type.device
)


class NormUnet(nn.Module):
"""
Normalized U-Net model.
This is the same as a regular U-Net, but with normalization applied to the
input before the U-Net. This keeps the values more numerically stable
during training.
"""

def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
):
"""
Args:
chans: Number of output channels of the first convolution layer.
num_pools: Number of down-sampling and up-sampling layers.
in_chans: Number of channels in the input to the U-Net model.
out_chans: Number of channels in the output to the U-Net model.
drop_prob: Dropout probability.
"""
super().__init__()

self.unet = ZSNetUnet(
in_chans=in_chans,
out_chans=out_chans,
chans=chans,
num_pool_layers=num_pools,
drop_prob=drop_prob,
)

def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, two = x.shape
return x.permute(0, 4, 1, 2, 3).reshape(b, two * c, h, w)

def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c2, h, w = x.shape
assert c2 % 2 == 0
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()

def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape

x = x.contiguous().view(b, c, c // c * h * w)
mean = (
x.mean(dim=2)
.view(b, c, 1, 1, 1)
.expand(b, c, c // c, 1, 1)
.contiguous()
.view(b, c, 1, 1)
)
std = (
x.std(dim=2)
.view(b, c, 1, 1, 1)
.expand(b, c, c // c, 1, 1)
.contiguous()
.view(b, c, 1, 1)
)

x = x.view(b, c, h, w)

return (x - mean) / std, mean, std

def unnorm(
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
if not mean.shape[1] == x.shape[1]:
mean = mean[:, :2]
if not std.shape[1] == x.shape[1]:
std = std[:, :2]
return x * std + mean

def pad(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
_, _, h, w = x.shape
w_mult = ((w - 1) | 15) + 1
h_mult = ((h - 1) | 15) + 1
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
# TODO: fix this type when PyTorch fixes theirs
# the documentation lies - this actually takes a list
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
# https://github.com/pytorch/pytorch/pull/16949

x = F.pad(x, w_pad + h_pad)

return x, (h_pad, w_pad, h_mult, w_mult)

def unpad(
self,
x: torch.Tensor,
h_pad: List[int],
w_pad: List[int],
h_mult: int,
w_mult: int,
) -> torch.Tensor:
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]

def forward(self, x: torch.Tensor) -> torch.Tensor:
# get shapes for unet and normalize
x = self.complex_to_chan_dim(x)
x, pad_sizes = self.pad(x)
x, mean, std = self.norm(x)

x = self.unet(x.contiguous())

# get shapes back and unnormalize
x = self.unnorm(x, mean, std)
x = self.unpad(x, *pad_sizes)
x = self.chan_complex_to_last_dim(x)

return x


class ZSNetSensitivityModel(SensitivityModel):
def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
mask_center: bool = True,
):
super().__init__(chans, num_pools, in_chans, out_chans, drop_prob, mask_center)
# overwrite unet
self.norm_unet = NormUnet(
chans,
num_pools,
in_chans=in_chans,
out_chans=out_chans,
drop_prob=drop_prob,
)


class ZSNet(nn.Module):
def __init__(
self,
image_crop_size: int,
num_concat_cascades: int = 18,
sens_chans: int = 10,
sens_pools: int = 5,
chans: int = 20,
pools: int = 5,
mask_center: bool = False,
):
"""
Args:
num_cascades: Number of cascades (i.e., layers) for variational
network.
sens_chans: Number of channels for sensitivity map U-Net.
sens_pools Number of downsampling and upsampling layers for
sensitivity map U-Net.
chans: Number of channels for cascade U-Net.
pools: Number of downsampling and upsampling layers for cascade
U-Net.
"""
super().__init__()

self.sens_net = ZSNetSensitivityModel(
sens_chans, sens_pools, mask_center=mask_center
)
self.init_layer = ZSNetBlock(NormUnet(chans, pools), crop_size=image_crop_size)
self.cascades = nn.ModuleList()
for i in range(1, num_concat_cascades):
chan_mult = 2
self.cascades.append(
ZSNetConcatBlock(
NormUnet(chans=chans, num_pools=pools, in_chans=2 * chan_mult),
crop_size=image_crop_size,
)
)

def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
num_low_frequencies: Optional[int] = None,
only_center: bool = False,
) -> torch.Tensor:
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
kspace_pred, image = self.init_layer(
masked_kspace.clone(), masked_kspace, mask, sens_maps
)
previous_images = [image]
for cascade in self.cascades:
kspace_pred, image = cascade(
kspace_pred,
masked_kspace,
mask,
sens_maps,
previous_images,
)

return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1)


class ZSNetBaseBlock(nn.Module):
def __init__(self, model: nn.Module, crop_size: int):
"""
Args:
model: Module for "regularization" component of variational
network.
"""
super().__init__()

self.model = model
self.crop_size = crop_size
self.dc_weight = nn.Parameter(torch.ones(1))

def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))

def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
return fastmri.complex_mul(
fastmri.ifft2c(x), fastmri.complex_conj(sens_maps)
).sum(dim=1, keepdim=True)

def image_crop(self, image: torch.Tensor) -> torch.Tensor:
input_shape = image.shape
crop_size = (min(self.crop_size, input_shape[-3]), input_shape[-2])

return transforms.complex_center_crop(image, crop_size)

def image_uncrop(
self, image: torch.Tensor, original_image: torch.Tensor
) -> torch.Tensor:
"""Insert values back into original image."""
in_shape = original_image.shape
crop_height = image.shape[-3]
in_height = in_shape[-3]
pad_height = (in_height - crop_height) // 2
if (in_height - crop_height) % 2 != 0:
pad_height_top = pad_height + 1
else:
pad_height_top = pad_height

original_image[..., pad_height_top:-pad_height, :, :] = image[...] # type: ignore

return original_image

def apply_model_with_crop(self, image: torch.Tensor) -> torch.Tensor:
if self.crop_size is not None:
image = self.image_uncrop(
self.model(self.image_crop(image)), image[..., :2].clone()
)
else:
image = self.model(image)

return image


class ZSNetBlock(ZSNetBaseBlock):
def forward(
self,
current_kspace: torch.Tensor,
ref_kspace: torch.Tensor,
mask: torch.Tensor,
sens_maps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
zero = _create_zero_tensor(current_kspace)
soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
image = self.sens_reduce(current_kspace, sens_maps)
model_term = self.sens_expand(self.apply_model_with_crop(image), sens_maps)

return current_kspace - soft_dc - model_term, image


class ZSNetConcatBlock(ZSNetBaseBlock):
def forward(
self,
current_kspace: torch.Tensor,
ref_kspace: torch.Tensor,
mask: torch.Tensor,
sens_maps: torch.Tensor,
previous_images: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
zero = _create_zero_tensor(current_kspace)
soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight
image = self.sens_reduce(current_kspace, sens_maps)

model_term = self.sens_expand(
self.apply_model_with_crop(torch.cat([image] + previous_images, dim=-1)),
sens_maps,
)
return current_kspace - soft_dc - model_term, image
Loading