Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation

Yuhui Zhang, Shih-Cheng Huang, Zhengping Zhou, Matthew P. Lungren, Serena Yeung
Proceedings of the 2nd Machine Learning for Health symposium, PMLR 193:391-404, 2022.

Abstract

Given the prevalence of 3D medical imaging technologies such as MRI and CT that are widely used in diagnosing and treating diverse diseases, 3D segmentation is one of the fundamental tasks of medical image analysis. Recently, Transformer-based models have started to achieve state-of-the-art performances across many vision tasks, through pre-training on large-scale natural image benchmark datasets. While works on medical image analysis have also begun to explore Transformer-based models, there is currently no optimal strategy to effectively leverage pre-trained Transformers, primarily due to the difference in dimensionality between 2D natural images and 3D medical images. Existing solutions either split 3D images into 2D slices and predict each slice independently, thereby losing crucial depth-wise information, or modify the Transformer architecture to support 3D inputs without leveraging pre-trained weights. In this work, we use a simple yet effective weight inflation strategy to adapt pre-trained Transformers from 2D to 3D, retaining the benefit of both transfer learning and depth information. We further investigate the effectiveness of transfer from different pre-training sources and objectives. Our approach achieves state-of-the-art performances across a broad range of 3D medical image datasets, and can become a standard strategy easily utilized by all work on Transformer-based models for 3D medical images, to maximize performance.

Cite this Paper


BibTeX
@InProceedings{pmlr-v193-zhang22a, title = {Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation}, author = {Zhang, Yuhui and Huang, Shih-Cheng and Zhou, Zhengping and Lungren, Matthew P. and Yeung, Serena}, booktitle = {Proceedings of the 2nd Machine Learning for Health symposium}, pages = {391--404}, year = {2022}, editor = {Parziale, Antonio and Agrawal, Monica and Joshi, Shalmali and Chen, Irene Y. and Tang, Shengpu and Oala, Luis and Subbaswamy, Adarsh}, volume = {193}, series = {Proceedings of Machine Learning Research}, month = {28 Nov}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v193/zhang22a/zhang22a.pdf}, url = {https://proceedings.mlr.press/v193/zhang22a.html}, abstract = {Given the prevalence of 3D medical imaging technologies such as MRI and CT that are widely used in diagnosing and treating diverse diseases, 3D segmentation is one of the fundamental tasks of medical image analysis. Recently, Transformer-based models have started to achieve state-of-the-art performances across many vision tasks, through pre-training on large-scale natural image benchmark datasets. While works on medical image analysis have also begun to explore Transformer-based models, there is currently no optimal strategy to effectively leverage pre-trained Transformers, primarily due to the difference in dimensionality between 2D natural images and 3D medical images. Existing solutions either split 3D images into 2D slices and predict each slice independently, thereby losing crucial depth-wise information, or modify the Transformer architecture to support 3D inputs without leveraging pre-trained weights. In this work, we use a simple yet effective weight inflation strategy to adapt pre-trained Transformers from 2D to 3D, retaining the benefit of both transfer learning and depth information. We further investigate the effectiveness of transfer from different pre-training sources and objectives. Our approach achieves state-of-the-art performances across a broad range of 3D medical image datasets, and can become a standard strategy easily utilized by all work on Transformer-based models for 3D medical images, to maximize performance.} }
Endnote
%0 Conference Paper %T Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation %A Yuhui Zhang %A Shih-Cheng Huang %A Zhengping Zhou %A Matthew P. Lungren %A Serena Yeung %B Proceedings of the 2nd Machine Learning for Health symposium %C Proceedings of Machine Learning Research %D 2022 %E Antonio Parziale %E Monica Agrawal %E Shalmali Joshi %E Irene Y. Chen %E Shengpu Tang %E Luis Oala %E Adarsh Subbaswamy %F pmlr-v193-zhang22a %I PMLR %P 391--404 %U https://proceedings.mlr.press/v193/zhang22a.html %V 193 %X Given the prevalence of 3D medical imaging technologies such as MRI and CT that are widely used in diagnosing and treating diverse diseases, 3D segmentation is one of the fundamental tasks of medical image analysis. Recently, Transformer-based models have started to achieve state-of-the-art performances across many vision tasks, through pre-training on large-scale natural image benchmark datasets. While works on medical image analysis have also begun to explore Transformer-based models, there is currently no optimal strategy to effectively leverage pre-trained Transformers, primarily due to the difference in dimensionality between 2D natural images and 3D medical images. Existing solutions either split 3D images into 2D slices and predict each slice independently, thereby losing crucial depth-wise information, or modify the Transformer architecture to support 3D inputs without leveraging pre-trained weights. In this work, we use a simple yet effective weight inflation strategy to adapt pre-trained Transformers from 2D to 3D, retaining the benefit of both transfer learning and depth information. We further investigate the effectiveness of transfer from different pre-training sources and objectives. Our approach achieves state-of-the-art performances across a broad range of 3D medical image datasets, and can become a standard strategy easily utilized by all work on Transformer-based models for 3D medical images, to maximize performance.
APA
Zhang, Y., Huang, S., Zhou, Z., Lungren, M.P. & Yeung, S.. (2022). Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation. Proceedings of the 2nd Machine Learning for Health symposium, in Proceedings of Machine Learning Research 193:391-404 Available from https://proceedings.mlr.press/v193/zhang22a.html.

Related Material