본문 바로가기

공부기록/ML&DL

[Pytorch] Distributed Data Parallel

Pytorch를 사용한 분산 데이터 병렬처리

Data Parallel; DP도 있지만 DDP의 경우 여러 장점이 있다고 한다.
(파이토치 튜토리얼: https://tutorials.pytorch.kr/intermediate/ddp_tutorial.html)

 

실사용 입장에서 DP보다 고르게 데이터를 실을 수 있다는게 더 와닿았던 것 같다. ??


 

1. 기본 패키지 임포트

import os
import gc 			# regarding swap memory, no needed necessary
import torch

import numpy as np
import torch.distributed as dist

from torch.utils.data import DataLoader 
from torch.utils.data.distributed import DistributedSampler

from torch.nn.parallel import DistributedDataParallel as DDP

 

torch에서 distributed,

torch.utils.data에서 각각 데이터로더와 분산 데이터샘플러,

torch.nn.parallel에서 분산 데이터 병렬처리 관련 패키지를 임포트

 

2. 기본 셋업

def ddp_setup(rank, world_size, port_num: str='65535'):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = port_num

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

 

rank,  world_size (GPU 관련: 갯수 및 번호(몇번 GPU인지?)) 파라미터에 추가적으로 port_num를 따로 파라미터로 추가

한 서버에 GPU가 8개가 있다고 할 경우, 작업을 3개 혹은 4개 GPU에 나누어 돌리고 싶을 때, 다른 포트번호를 지정해서 돌아가게끔 만듦

 

3. 데이터 로드 및 학습

### torch.multiprocessing.spawn() 메소드로 감싸서 실행
def ddp_train(proc, device_count, ddp_port, train_dataset, valid_dataset, model, **kwargs):
    torch.manual_seed(777)  # 시드값 고정
    ddp_setup(proc, device_count, ddp_port) ### ddp setup
            
    ### DDP 데이터 샘플러 정의: 매 에포크마다 데이터 자동 샘플링
    train_sampler = DistributedSampler(train_dataset, num_replicas=device_count, rank=proc)
    valid_sampler = DistributedSampler(valid_dataset, num_replicas=device_count, rank=proc)
    
    ### 데이터로더: 각 데이터셋에 대한 샘플러가 에포크마다 샘플링 수행; shuffle=False
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, sampler=train_sampler,
                              num_workers=0, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, sampler=valid_sampler,
                              num_workers=0, pin_memory=True)
        
    ddp_model = DDP(model.to(proc)) ### DDP로 모델 랩핑, 이때 모델의 복사본이 각 프로세스에 전달됨
    if p_trained: ddp_model.module.load_state_dict(p_trained) ### pretrained 불러와서 덮어씌움
    optimizer, loss = optimizer, loss   ### 옵티마이저, 손실함수 정의
        
    
    for epoch in range(epochs):         ### 정해진 에포크 수 만큼 반복
        ### 매 에포크마다 데이터 샘플링
        train_sampler.set_epoch(epoch)
        valid_sampler.set_epoch(epoch)
        
        ### train, eval 수행 
        t_loss, t_acc = train_epoch(proc, ddp_model, train_loader, kwargs)
        v_loss, v_acc = valid_epoch(proc, ddp_model, valid_loader, kwargs)
        
        criteria = np.mean([t_acc, v_acc])
        save_chkpoint(ddp_model, criteria)
        gc.collect()    ### garbage collector, not necessary
        
    dist.destroy_process_group()

 

기본 셋업 후 데이터 샘플러를 선언하여 데이터로더에 파라미터로 전달 -> 다중작업일 경우 샘플러가 샘플링을 수행
(DDP를 사용하지 않을 경우 sampler에 condition을 추가하여 shuffle 파라미터에 if sampler is None과 같이 지정해 주면  Non DDP 환경에서 샘플러 대신 데이터 셔플작업을 수행하도록 할 수도 있음)

 

추가로 Pretrained weight가 있을 경우 model.module.load_state_dict()을 사용해서 가중치를 불러와서 모델에 얹어줄 수 있다. 이거 확인못해서 계속 학습 못하고 쩔쩔매고 있었음.. 

(model.state_dict(): DDP 관련 추가정보 포함되어있음, 저장 후 로드하려고 하면 key값 일치하지 않는다는 에러 발생)

 

이후에는 MLOps (WandB 같은거) 붙여서 모니터링 해 줘도 좋고 그냥 학습만 진행해도 무방!