본문 바로가기

공부기록/ML&DL

[tool] TorchIO를 활용한 medical dataset processing

Medical Image는 아무래도 전문 지식이 없으면 다루기가 쉽지 않은것같다.

구글링하면 관련 개념과 구현코드는 나와있는데, 뭔가 중구난방되어있고 쓰기 복잡하다고 해야하나

커스텀 네트워크 짜보면서 소 뒷발질하듯이 데이터 전처리 도구를 찾다가 발견한 TorchIO.. 간단하게 쓰기 괜찮아보임

 

* 비슷하게 Nvidia에서 만든 플랫폼 MONAI, 사견이지만 딥하게 들어가려면 TorchIO보다는 이게 더 나은거같다. 다만 쓰기 귀찮을뿐 ㅎ.. 


TorchIO Official DOC

 

TorchIO

PyPI downloads PyPI version Conda version Google Colab notebooks Documentation status Tests status Code style: black Coverage status Code quality Maintainability pre-commit Slack Twitter Twitter co...

torchio.readthedocs.io

 


Summary

TorchIO is an open-source Python library for efficient loading, preprocessing, 
augmentation and patch-based sampling of 3D medical images in deep learning, following the design of PyTorch.
It includes multiple intensity and spatial transforms for data augmentation and preprocessing. 
These transforms include typical computer vision operations such as random affine transformations 
and also domain-specific ones such as simulation of intensity artifacts 
due to MRI magnetic field inhomogeneity (bias) or k-space motion artifacts.

Official docs에 있는 머릿말 그대로, 대충 간추려보면 아래와 같다.

 

- Pytorch 프레임워크를 타겟으로 하여 개발, 3D Shape 메디컬 이미지 데이터셋 로딩 / 전처리 / 증강과 Patch-Based Sampling을 지원

- 데이터 증강 및 전처리를 위해 Multiple Intensity(T1, T2, FLAIR 등 모달리티를 말하는듯?) 및 Spatial Transform 기법을 포함

  (Random Affine Transformation부터 K-space motion Artifacts / Bias Field(자기장) inhomogeneity와 같은 기법까지 다양하게..)

- Pytorch Ecosystem & Developer Day 2021에서 공개된 Official Pytorch Ecosystem, 다양한 그룹이 연구목적으로 많이 사용 중 어쩌구저쩌구..

 


Installation & Examples

$ pip install torchio 		# --upgrade parameter(라이브러리 버전 업데이트/재설치)
$ pip install torchio[plot]	# TorchIO의 plotting 관련 메소드 사용(Matplotlib)

간단한 튜토리얼은 아래 링크 참고

 

GitHub - fepegar/torchio: Medical imaging toolkit for deep learning

Medical imaging toolkit for deep learning. Contribute to fepegar/torchio development by creating an account on GitHub.

github.com

 


Data Structures

크게 Image, Subject, Dataset으로 구성

1. Image(Label): 이미지 및 레이블(객체내용 출력해보면 Path로 지정되어 있는 것 같은데 다시 확인해봐야함)

2. Subject: Image 및 Label을 쌍으로 묶은 형태

3. Dataset: 여러개의 Subjects들이 모인 형태

 

자세한 설명이 있긴 한데 대충 이런 구조라는거만 이해해도 다루는데 크게 문제는 없었음.

 


Patch-Based Pipelines

3D shape 의료영상 이미지는 그 특성상 제약조건이 매우 많다.

일례로 Dimension이 하나 더 추가되면서 요구되는 리소스가 2D shape에 비해 훨씬 커짐

 

* RGB 2D 이미지 1장(가로세로 256 Pixel): 3x256x256

* Z축이 추가된 3D 이미지 1케이스(가로세로+z축 256 Pixel): 1x256x256x256

 

Distributed Learning(with Multiple GPUs)같은 방법 없이 그대로 네트워크에 들여보내면 당연히 토해낸다.

또한 경험상 각 Dimension마다 Pixel 수치는 256으로 항상 고정되어 있지 않고 그 이하 혹은 이상의 값을 가지는 경우가 있으며,

1 케이스에서는 축마다 Pixel 갯수가 일정하지 않는 경우도 많다 - 176x224x256 이런 식으로

Open Source Dataset의 경우 각 데이터마다 크기가 통일되지 않아 하나하나 보긴 힘들기도 하고 (이건 Transformation Part에서 더 자세히..)

 

축마다 잘라서 2D shape로 학습하는 방법을 쓰지 않는 이유는 공간정보가 중요하기 때문? 이라고 생각함

따라서 이미지 원래 크기 말고 일정크기만큼 자른 3D shape, 즉 Patch 단위로 분할된 데이터를 학습 후 결과값을 병합하는 방법이 리소스 / 추론 시간 측면에서 이점을 가질 수 있다

다만 경험해 본 결과 당연하지만 공간정보(Class 형태든 뭐든..)가 필수적으로 들어가야 하는것으로 보임

(특정 부분을 크롭하는 것이므로 각 Patch의 경계부분 - 엣지면의 경우 병합시 말끔하게 이어지지 않는 등 부정확하게 결과물이 생성될 수 있음)

 

아래는 공식 문서에 수록된 간단 모식도

Patch-Based 데이터의 학습/추론 파이프라인

지원되는 Sampling Method(Sampler)는 여러 종류가 있으며 크게 다음과 같음

 

1. UniformSampler

2. WeightedSampler

3. LabelSampler

4. PatchSampler

5. GridSampler(+ GridAggregator): 추론 과정에서만 사용하는듯하다

 

* Sampler 사용시 설정한 크기만큼 잘라서 학습에 사용하지만 엣지부위 연결이 자연스럽게 안되는거 발견하고 뺐음..

각각의 패치를 합칠때 엣지 부분을 자연스럽게 합칠 수 있도록 공간정보도 같이 넣어줘야 할 것 같은데, 아마 지원을 안하거나 내가 못 찾았을 수도? Brain Tumor같은 소규모 구획을 찾는 것에는 효과적일 수 있으나, 전체적으로 부피가 큰 경우 불리할 수 있다고 생각함


Transformation (작성중)

TorchIO를 사용하면서 가장 잘 써먹은 부분. 꽤 유용한 처리 메소드를 제공해준다.

각 메소드들은 Subject 혹은 Image(또는 Subclasses) 타입 - Pytorch Tensor or Numpy array, SimpleITK, Nibabel, Typical Python Dictionary..

모든 메소드들은 torchio.transforms/Transform에서 상속되며, 크게 전처리, 증강 및 Etc..

 

1. Data Preprocessing

2. Data Augmentation

3. Others

 

더보기
    transform = tio.Compose([tio.CropOrPad((256), mask_name='LABEL'), # (N, N, N) >> resize
                             tio.Resize((HWD), label_interpolation='label_gaussian'), 
                             tio.ToCanonical(),
                             tio.RescaleIntensity(out_min_max=(0, 1)),
                            #  tio.EnsureShapeMultiple(16),
                            ])

Executable Code: 아래 클릭하면 오픈 (소스코드 너무 길어서 줄임글 처리)

 

더보기

1. 패키지 임포트 및 커스텀 데이터로더 정의

class torchCustomDataset: #()
    def __init__(self, base_path, num_of_cases, transform=None, isTest=False, shuffle=True):
    	### 데이터셋의 크기 따위가 일정하지 않을 경우 서로 맞지 않는다는 오류가 출력됨: space reference
        ### 땜빵식으로 가장 처음에 불러온 NIfTI 파일로 통일하는 것으로 해결함 (비추)
        ### __getitem__()의 set space reference 부분 참고
        self.FLAG = 0
        self.labels_path = os.path.join(path_address) 
        self.images_path = os.path.join(path_address)
        
        self.transform = transform
    
        self.imagesList = []
        self.labelsList = glob('{}/*.nii.gz'.format(self.labels_path))[:num_of_cases]

        if isTest == True: self.labelsList = self.labelsList[:10] # 10개 모드만 테스트

        if shuffle == True: random.shuffle(self.labelsList)
        else: self.labelsList = sorted(self.labelsList)        

        for labels in self.labelsList: self.imagesList.append(labels.replace('labels', 'images')) # imagesList

    def __len__(self):
        return len(self.imagesList)

    def __getitem__(self, idx) :
        SUBJECTS = []
        for (img_path, lab_path) in tqdm(zip(self.imagesList[idx], self.labelsList[idx])):
            if self.FLAG == 0:
                self.SPACE_REF = tio.LabelMap(lab_path)
                self.FLAG = 1

            subject = tio.Subject(
                IMG = tio.ScalarImage(img_path),
                LABEL = tio.LabelMap(lab_path),
            )
            
            self.transform_init = tio.Compose([tio.ToCanonical(), tio.Resample(self.SPACE_REF)])

            subject_init = self.transform_init(subject)
            subject_transformed = self.transform(subject_init)
            label_target = subject_transformed.LABEL.data
            label_shape = label_target.squeeze()

            ### 하나의 array에 여러개의 class label: ex) (0, 1, 2, 3)
            ### 하나씩 분리해서 쌓아줌 > class별 loss 비교
            bg = torch.zeros((label_target.shape)).squeeze()
            cls1 = torch.zeros((label_target.shape)).squeeze()
            cls2 = torch.zeros((label_target.shape)).squeeze()
            cls3 = torch.zeros((label_target.shape)).squeeze()

            y1, x1, z1 = torch.where((label_shape == 1)) # cls1
            y2, x2, z2 = torch.where((label_shape == 2)) # cls2
            y3, x3, z3 = torch.where((label_shape == 3)) # cls3

            cls1[y1, x1, z1] = 1
            cls2[y2, x2, z2] = 2
            cls3[y3, x3, z3] = 3

            subject_transformed.LABEL.data = torch.stack((cls1, cls2, cls3))            
            SUBJECTS.append(subject_transformed)

        return SUBJECTS

2. 모델 정의: 개인적으로 만들어본거 연결해서 써봄

https://github.com/doodleima/healthcare/blob/main/Net/ViT_fsCNN/network.py


3. 정의한 데이터로더 사용: 데이터 로드 및 분할(Train & Valid dataset)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

subject = torchCustomDataset(BASE, NUM_OF_CASES, transform, IS_TEST, IS_SHUFFLE)
train_subjects = subject[:90]
valid_subjects = subject[90:]

print(len(train_subjects), len(valid_subjects)) # 90, 30

train_dataset = tio.SubjectsDataset(train_subjects)
valid_dataset = tio.SubjectsDataset(valid_subjects)

TRAIN_LOADER = DataLoader(
        train_dataset,
        batch_size = 1,
        shuffle = True,
        num_workers=0   # multiprocessing.cpu_count()
    	)

VALID_LOADER = DataLoader(
        valid_dataset,
        batch_size = 1,
        shuffle = True,
        num_workers=0   # multiprocessing.cpu_count()
    	)

4. 학습 및 평가 

### Train
def train(dataloader, model, optimizer, loss_function, epoch):
    model.train()
    train_loss = 0.0

    for batch_idx, batch in enumerate(tqdm(dataloader, desc=f'[TorchIO] TRAIN [{epoch}]')):
        train_data = batch['IMG'][DATA].to(device, dtype=torch.float)    # RAW NIfTI
        train_label = batch['LABEL'][DATA].to(device, dtype=torch.int8)  # Label(0, 1, 2, 3 ...)

        output_train = model(train_data)        
        t_loss = loss_function(output_train, train_label)
        train_loss += t_loss.item()

        # backprop
        optimizer.zero_grad()
        t_loss.backward()
        optimizer.step()

    return train_loss
    
### Validation
def eval(dataloader, model, loss_function, epoch):
    model.eval()
    val_loss = 0.0

    ### valid model
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f'[TorchIO] VALID [{epoch}]')):
            valid_data = batch['IMG'][DATA].to(device, dtype=torch.float)    # RAW NIfTI
            valid_label = batch['LABEL'][DATA].to(device, dtype=torch.int8)  # Label(0, 1, 2, 3 ...)
            
            output_val = model(valid_data)
            v_loss = loss_function(output_val, valid_label)
            val_loss += v_loss.item()

    return val_loss
    
    
### Monai에서 제공하는 DiceCE(Cross Entropy) 사용
model = customModel(params) 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-04, weight_decay=1e-09)
criterion = DiceCELoss(to_onehot_y=False, sigmoid=True).to(device)

for epochs in range(EPOCH_SIZE):
	train_loss = train(TRAIN_LOADER, model, optimizer, criterion, epochs)
	valid_loss = eval(VALID_LOADER, model, criterion, epochs)

 

더보기

1. 패키지 임포트 및 커스텀 데이터로더 정의

### import packages ###
import os
import random
import torch

import numpy as np
import nibabel as nib
import torchio as tio
import torch.nn as nn

from tqdm import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchsummary import summary


### torchio based dataloader(custom) ###
class customDataset:
    def __init__(self, imglist, lablist):
    	### 데이터셋의 크기 따위가 일정하지 않을 경우 서로 맞지 않는다는 오류가 출력됨: space reference
        ### 땜빵식으로 가장 처음에 불러온 NIfTI 파일로 통일하는 것으로 해결함 (비추)
        ### __getitem__()의 set space reference 부분 참고
        self.FLAG = 0
        self.imagesList = imglist
        self.labelsList = lablist
    
    
    def __len__(self):
        return len(self.imagesList)
    
    
    def __getitem__(self, idx):
        SUBJECTS = []
        for (img, label) in zip(self.imagesList[idx], self.labelsList[idx]):
            ### set space reference
            if self.FLAG ==0:
                self.SPACE_REF = tio.LabelMap(img)
                self.FLAG = 1
               
            if label=='CN': LABEL = [1., 0., 0.]
            elif label=='MCI': LABEL = [0., 1., 0.]
            else: LABEL = [0., 0., 1.]

            subject = tio.Subject(IMG = tio.ScalarImage(img), LABEL = torch.tensor(LABEL))
            
            self.transform = self.transform_tio()
            subject_transformed = self.transform(subject)
        
            SUBJECTS.append(subject_transformed)
        
        return SUBJECTS

    
    def transform_tio(self):
        transform = tio.Compose([tio.Resample(self.SPACE_REF),
                                      tio.ToCanonical(),
                                      tio.CropOrPad((256)),
                                      tio.Resize((128), image_interpolation='label_gaussian')])
        
        return transform

2. 모델 정의:  CNN, 레이어 임의로 쌓아올려서 정의하였음

### Traditional CNN model: pytorch based
class CNN3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=(3, 3, 3), padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
        self.conv2 = nn.Conv3d(32, 32, kernel_size=(3, 3, 3), padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1)
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
        self.conv4 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=1)
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)

        self.fc1 = nn.Linear(128 * 8 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 3)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool3(x)

        x = self.conv4(x)
        x = F.relu(x)
        x = self.pool4(x)
        
        x = x.view(-1, 128 * 8 * 8 * 8)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        
        return x

3. 정의한 데이터로더 사용: 데이터 로드 및 분할(Train & Valid dataset)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Datapath ###
imgDir=[]
clsDir=[]

BASE = '/target/data/path/absolute'

for cls in ['class1', 'class2', 'class3']:
    cls_base = os.path.join(BASE, cls)
    FILES = os.listdir(cls_base)
    random.shuffle(FILES)
    
    for file in FILES:
    	imgDir.append(os.path.join(BASE, cls, file))
        clsDir.append(cls)
    
# print(len(imgDir), len(clsDir)) # 120, 120

### dataload > split from 120 cases ###
subject = customDataset(imgDir, clsDir)
train_dataset = tio.SubjectsDataset(subject[:90])
valid_dataset = tio.SubjectsDataset(subject[90:])

print(len(train_dataset), len(valid_dataset)) # 90, 30

TRAIN_LOADER = DataLoader(
    train_dataset,
    batch_size = 1,
    shuffle = False,
    num_workers = 0 # or multiprocessing.cpu_count()
)

VALID_LOADER = DataLoader(
    valid_dataset,
    batch_size = 1,
    shuffle = False,
    num_workers = 0 # or multiprocessing.cpu_count()
)

4. 학습 및 평가

### Train with CNN3D
model = CNN3D().to(device) # cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-04)
criterion = nn.CrossEntropyLoss()

summary(model, torch.randn([1, 1, 128, 128, 128]))

### Train the model ###
model.train()
train_loss = 0.0

for epochs in tqdm(range(200)):
    for batch_idx, batch in enumerate(TRAIN_LOADER):
        t_data = batch['IMG']['data'].to(device, dtype=torch.float)
        t_label = batch['LABEL'].to(device, dtype=torch.float)
        t_output = model(t_data)        
        t_loss = criterion(t_output, t_label)
        t_loss.backward()
        
        optimizer.step()
        
### Evaluate the model ###
### Loss 계산 대신 예측 클래스 바로 프린트 ###
model.eval()
with torch.no_grad():
    for v_batch_idx, v_batch in enumerate(VALID_LOADER):
#     for v_batch_idx, v_batch in enumerate(tqdm(VALID_LOADER, desc=f'Valid')):
        v_data = v_batch['IMG']['data'].to(device, dtype=torch.float)
        v_label = v_batch['LABEL'].to(device, dtype=torch.float)

        v_output = model(v_data)
        v_output = F.softmax(v_output, dim=1)
        _, predicted = torch.max(v_output.data, 1)
        _, actual = torch.max(v_label.data, 1)
        
        print(f'Predicted class: {predicted.item()}\nActual class: {actual.item()}' )