본문 바로가기

Python/PyTorch

Custom Dataset 생성

Dataset의 기본 구성 요소

from torch.utils.data import Dataset # torch.utils.data의 Dataset 라이브러리 상속

class CustomDataset(Dataset):
    def __init__(self):
        pass
    def __getitem__(self, index):
        pass
  
    def __len__(self):
        pass

 

__init__ 메서드

데이터의 위치나 파일명과 같은 초기화 작업을 위해 동작한다. 일반적으로 CSV파일이나 XML파일과 같은 데이터를 이때 불러온다. 이렇게 함으로써 모든 데이터를 메모리에 로드하지 않고 효율적으로 사용할 수 있다. 여기에 이미지를 처리할 transforms들을 Compose 해서 정의해 둔다.

 

__len__ 메서드

Dataset의 최대 요소 수를 반환하는 데 사용된다. __len__ 매서드를 통해서 현재 불러오는 데이터의 인덱스가 적절한 범위 안에 있는지 확인할 수 있다.

 

__getitem__ 메서드

데이터셋의 idx번째 데이터를 반환하는데 사용된다. 일반적으로 원본 데이터를 가져와서 전처리하고 데이터 증강하는 부분이 모두 여기에서 진행된다.