본문 바로가기

Tech/Medical AI

[Torchio]-3D Segmentation

Torchio?

Torchio는 Pytorch를 기반으로 구현되어 있으며, 3D Segmentation 특히 의료 분야의 데이터를 분석하기 용이한 오픈소스 라이브러리입니다. 이번 포스팅에서는 Torchio 내에서 3D 데이터를 로드하고 모델에 입력해 주는 과정을 담당하는 DataStructures, 다양한 기법으로 데이터를 전처리 및 증강 해주는 Transform, 큰 용량의 3D 데이터를 효과적으로 학습하게 해주는 Patch-Based Pipeline 부분을 소개해 드리고자 합니다. torchio를 활용해 3D segmentation을 직접 진행할 필요가 있으시다면, 아래 글을 읽고 torchio에서 제공해 주는 공식 튜토리얼을 학습해 보는 걸 추천드립니다.🤠
(Torchio 공식 tutorial : Getting started)

DataStructures

Image

torchio에서 Image는 데이터를 취급하는 가장 작은 단위입니다. Image에는 다시 ScalarImage와 LabelMap으로 나눠볼 수 있습니다.

  • ScalarImage
    ScalarImage는 보통 모델 학습 시 원본에 해당하는 이미지를 넣어줍니다. MRI 혹은 CT와 같이 분석하고자 하는 입력 데이터를 말하죠.
  • LabelMapLabelMap에는 분석하고자 하는 입력 데이터에 대한 라벨링 데이터를 넣어줍니다. 3D segmentation에선 흔히 mask라고 표현하기도 합니다.

이렇게 이미지를 ScalarImage와 LabelMap으로 나누는 이유는, 후에 augmentation을 위한 transform을 적용하는 데에 있습니다. 만약 Blur나 Noise와 같은 transform을 적용한다고 할 때(이를 보통 Intensity transforms이라 부릅니다), ScalarImage에만 적용되어야 하며 LabelMap에는 적용되어선 안됩니다. 반면 Flip이나 Rotation과 같은 transform이 적용될 때는(이를 보통 Spatial Transforms라 부릅니다), ScalarImage와 LabelMap 둘 다 적용이 되어야 합니다. torchio에선 이를 위해 ScalarImage와 LabelMap을 따로 구분하게 됩니다.

import torch
import torchio as tio

image = tio.ScalarImage('image.nii.gz')
mask = tio.LabelMap('mask.nii.gz')

Subject

Subject는 위에서 생성한 ScalarImage, LabelMap과 같은 이미지들을 포함한 모든 메타데이터를 저장해놓은 객체입니다. 의료 관점으로 봤을 땐, 한 환자를 분석하는데 필요한 모든 정보를 포함해 놓은 장부라 볼 수 있습니다.

코드로는 python의 dict를 활용하여 subject를 정의할 수 있습니다.

import torchio as tio

subject_dict = {
    'image': tio.ScalarImage('path_to_image.nii.gz'),
    'mask': tio.LabelMap('path_to_seg.nii.gz'),
    'name': "rimo",
    'age': 25,
    'hospital': 'Hospital Juan Negrín',
}
subject = tio.Subject(subject_dict)

SubjectDataset

Subject가 한 환자에 대한 모든 데이터 정보를 포함한 객체라면, SubjectDataset은 학습에 필요한 모든 환자에 대한 Subject들을 한데 모은 데이터 셋이라 볼 수 있겠습니다. pytorch의 Dataset 클래스를 상속하고 있으며, pythorch의 dataloader와 함께 쓰일 수 있습니다.

import torchio as tio

# 환자 a에 대한 subject
subject_a = tio.Subject(
    t1=tio.ScalarImage('t1.nrrd',),
    t2=tio.ScalarImage('t2.mha',),
    label=tio.LabelMap('t1_seg.nii.gz'),
    age=31,
    name='Fernando Perez',
)
# 환자 b에 대한 subject
subject_b = tio.Subject(
    t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
    t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
    label=tio.LabelMap('colin27_seg1.nii.gz'),
    age=56,
    name='Colin Holmes',
)
# subject들을 하나의 리스트로 꾸려줍니다.
subjects_list = [subject_a, subject_b]

# 적용할 transform을 정의해 줍니다.
transforms = [
    tio.RescaleIntensity(out_min_max=(0, 1)),
    tio.RandomAffine(),
]
transform = tio.Compose(transforms)

# subject들을 담은 리스트를 SubjectDatset으로 묶어 줍니다.
subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)

Transform

torchio를 도입하기로 한 결정적인 요인입니다. 2D 이미지를 다룰 때는 Torchvision, Albumentation 등 다양한 라이브러리를 활용해 transform을 적용할 수 있었습니다. 하지만 3D 데이터는 2D에 비해 transform을 적용하기 더 복잡해졌습니다. 고려해야 할 차원이 하나 더 늘어났기 때문이죠. 분명 누군가는 3D 데이터를 위한 transform 라이브러리를 제작해 놨을 거란 생각에 서칭하다 찾은 게 바로 torchio입니다. 특히 의료 데이터에 많이 적용되는 augmentation 기법들이 많았고, 코드로 적용하기 쉬워 torchio로 pipeline을 개발하기로 했습니다. torchio에 있는 여러 augmentation들은 아래 공식 문서에 친절히 시각화로 설명되어 있으니 참고하시면 큰 도움이 될 수 있습니다. (Torchio Transforms : Transforms)

training_transform = tio.Compose([
    tio.ToCanonical(),
    tio.Resample(2),
    tio.RandomNoise(p=0.2),
    tio.RandomFlip(p=0.4),
    tio.RandomBlur(p=0.2),
    tio.RandomBiasField(p=0.1),
    tio.RandomMotion(p=0.2, degrees=10, translation=10, num_transforms=2),
    tio.RandomAffine(p=0.2, degrees=(2,2)),
    tio.ZNormalization(p=1),
    tio.Lambda(to_float),
    tio.OneHot(num_classes=4),
])

validation_transform = tio.Compose([
    tio.ToCanonical(),
    tio.Resample(2),
    tio.ZNormalization(p=1),
    tio.Lambda(to_float),
    tio.OneHot(num_classes=4),
])

 

코드로 봤을 때는 기존의 다른 transform 라이브러리와 큰 차이가 없습니다. 혹시나 torchio에 존재하지 않는 다른 transform을 직접 정의하여 사용하고 싶다면, tio.Lambda(to_float)과 같은 lambda 함수로 묶어 자신만의 transform을 적용할 수도 있습니다.

Patch-Based Pipeline

 

3D 이미지는 2D 이미지에 비해서 GPU 메모리를 훨씬 많이 소모하게 됩니다. (512,512) 사이즈인 2D 이미지 한 장을 512개 쌓아서 만든 게 (512,512,512) 사이즈인 하나의 3D 데이터가 되기 때문이죠. 보통 딥러닝 모델을 학습시킬 때는 하나의 배치에 여러 개의 데이터를 학습시키게 됩니다. 특히 BatchNormalization의 효과를 보기 위해선 하나의 배치에 여러 장의 데이터를 태워야 효과를 보게 되죠. 하지만, 3D 데이터는 용량이 크기 때문에 배치 사이즈를 늘리는데 제한이 있습니다. (모델이 클 경우, 배치 사이즈를 1로만 두어도 cuda out of memory가 뜰 경우가 종종 있었습니다.) 이러한 문제점을 해소하기 위해 나온 기법이 Patch 기법입니다.

위의 그림에서 보듯이, 하나의 3D 데이터를 여러 patch로 나누어 학습에 이용한다는 개념입니다. 큰 용량의 3D 데이터를 작은 용량의 3D 데이터로 쪼개서 학습을 시킨다는 것이죠. 이렇게 하면 GPU의 부족한 메모리를 어느 정도 해소할 수 있게 됩니다.

Patch Sampler

위의 그림에서는 단순히 하나의 데이터를 grid 하게 나누어 모든 patch를 활용하는 샘플링 기법입니다.(아래의 GridSampler에 해당합니다) 하지만, 데이터의 유형에 따라서 patch를 나누는 방법을 달리하여 좀 더 효율적인 patch를 사용할 수도 있습니다. 아래는 torchio에서 제공해 주는 여러 샘플링 기법입니다.

  • GridSampler
    voxel을 grid하게 구간을 나누어, 모든 영역을 patch로 샘플링 해냅니다.
  • UniformSampler
    voxel내에서 균일한 확률로 patch들을 샘플링 합니다.
  • WeightedSampler
    voxel내에서 특정 영역의 patch를 더 많이 추출해 내고 싶을 때, 영역마다 확률값을 지정해주어 patch를 추출해 냅니다.
  • LabelSampler
    특정 라벨을 가진 영역의 patch가 추출되도록 컨트롤 하며 샘플링할 수 있습니다.

학습 시 patch 기법을 사용할 수 있도록 코드를 구현해 보면 아래와 같습니다.

# 학습 시
patch_size = (32,128,128)
samples_per_volume = CFG["samples_per_volume"]
max_queue_length = CFG["max_queue_length"]
sampler = tio.data.LabelSampler(patch_size, label_name="label")

patches_training_set = tio.Queue(
    subjects_dataset=train_dataset,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=sampler,
    num_workers=0,
    shuffle_subjects=True,
    shuffle_patches=True,
)

patches_validation_set = tio.Queue(
    subjects_dataset=valid_dataset,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=sampler,
    num_workers=0,
    shuffle_subjects=False,
    shuffle_patches=False,
)

이렇게 Sampler를 활용해 하나의 데이터를 여러 조각으로 쪼개어 학습을 했다면, 추론 시에는 다시 쪼개진 큐브들을 하나의 voxel로 합쳐줄 필요가 있습니다. 이때는 torchio의 aggregator를 활용해 다시 원본으로 합쳐줄 수 있습니다.

# 추론 시
patch_size = (32,128,128)
patch_overlap = (5, 5, 5)
grid_sampler = tio.inference.GridSampler(
    subject,
    patch_size,
    patch_overlap,
)

patch_loader = torch.utils.data.DataLoader(
				grid_sampler, 
				batch_size=validation_batch_size)
                
aggregator = tio.inference.GridAggregator(grid_sampler)

이러한 patch 기반의 파이프라인은 제한된 GPU 메모리 환경에서도 학습 및 추론을 가능하게 해주지만, 하나의 3D 데이터를 쪼개는 과정에서 일부 공간 정보가 흩어지게 될 수 있습니다. 분석하고자 하는 데이터의 특성에 맞게 적용해야 할 필요가 있는 부분입니다.

마무리

3D segmentation은 처음 해보는 task였습니다. 항상 처음은 낯설고 어디서부터 시작해야 할지 감이 잘 오지 않습니다. 하지만 대게 처음 해보는 것들은, 이미 누군가가 열심히 연구해놓고, 만들어놓고, 공유해놓기 마련입니다. torchio를 제작해 주신 분들께 감사드리며, 이번 포스팅 또한 누군가가 처음 시작하는 길목에서 도움이 되는 글이 되었으면 좋겠습니다. 마지막으로 torchio를 활용해 학습한 모델로, CT 내의 신장 영역을 추출해 낸 결과를 시각화 하며 포스팅을 마무리하겠습니다. 읽어 주셔서 감사합니다.🤠