본문 바로가기
파이썬 & 머신러닝

[Pytorch] 진짜 커스텀 데이터셋 만들기, 몇 가지 팁

by 두재 2021. 10. 4.

Pytorch 개발자들이 이미 데이터셋, 데이터로더 클래스를 여러 개 만들어 두었다.

데어터셋의 경우 ImageFolder, DatasetFolder 와 같이 내 폴더 안에 있는 데이터들을 돌게 해주는 애들과 CIFAR10, ImageNet 등 유명한 베이스라인 데이터셋을 다운로드부터 train/test 스플릿까지 손쉽게 해주는 클래스 들이 있다.

이번에는 이런 것보다 조금 더 low-level로 직접 데이터셋 클래스를 만들어서 이를 데이터로더에 집어 넣는 것까지 해보겠다.

핵심은 바로 위 사진에 있는 torch.utils.data.Dataset 이라는 class를 상속받는 자식 클래스를 만들 것이다.

이 자식 클래스가 필요로 하는 메소드는 3가지이며, 다음과 같다. (언더바가 두 개임) 설명을 잘 모르더라도 조금 있다가 코드를 함께 보면 이해가 조금 더 잘 될 것이다.

  1. __init__(self, 인수들) : 데이터셋을 처음 선언할 때, 즉 데이터셋 오브젝트가 생길 때 자동으로 불리는 함수이고, 여기에 우리가 몇 가지 인수들을 입력받도록 만들 수 있다 (path, transform 같은 것들).
  2. __len__(self) : 데이터셋의 길이다. 만약 dataset을 선언하고 나서 len(어떤 dataset)을 하면 내부적으로는 이 len 함수가 불리는 것이다. 이 len은 나중에 데이터셋을 선언하고 데이터로더를 사용할 때 또 내부적으로 사용된다. (데이터셋의 len을 알아야 데이터로더가 미니 배치 샘플링을 하면서 지금 다 돌았는지 아닌지를 알 수 있으니까)
  3. __getitem__(self, idx) : 이름에서 알 수 있듯이 데이터셋의 본분인 데이터 하나씩 뽑기이다. idx는 index를 말하는데, 몇 번째 데이터를 뽑을 건지에 대한 변수이다. 이는 데이터로더에서 또 사용될 것이다.

이러한 함수들을 만들고 나면 우리의 custom dataset을 하나 만든 것이고, 이 데이터셋을 하나 선언해서 사용하면 된다. 이 경우 말한대로 next와 iter 라는 함수를 사용하면 데이터셋 내부의 데이터를 하나하나씩 뽑을 수 있다. (next는 파이썬의 iterator object에서 다음 아이템을 뽑는 함수고 iter는 어떤 오브젝트를 iterator로 바꿔주는 함수이다.) 일단 실생활에서 잘 쓰이지는 않을 정말 단순한 Dataset을 만들어보았다. (이런 간단한 데이터셋도 연구에 사용되는 경우가 있다!...)

 

(난이도 쉬움) 복붙해서 돌려보기를 추천하는 코드 (매우 짧음!)

import torch
from torch.utils.data import Dataset, DataLoader


class SimpleDataset(Dataset):
    def __init__(self, t):
        self.t = t

    def __len__(self):
        return self.t

    def __getitem__(self, idx):
        return torch.LongTensor([idx])
        
if __name__ == "__main__":
    dataset = SimpleDataset(t=5)
    print(len(dataset))
    it = iter(dataset)

    for i in range(10):
        print(i, next(it))

우선 SimpleDataset을 선언했고, torch.utils.data.Dataset를 상속받았다. 하나씩 보자. 이 SimpleDataset은 t라는 숫자를 받고 self.t에 넣어준다.

여기서 중요한 점은, __init__ 함수는 데이터셋이 선언되는 그 때 (여기서는 dataset = SimpleDataset(t=5)) 딱 한 번 불려지고 그 이후에는 단 한 번도 안 사용된다. 그리고 return 도 하지 않기 때문에 __init__(self, 인수) 에서 인수들은 이 함수가 끝나고 나면 다 없어져버린다. 그래서 self.t = t 처럼 self에다가 값을 대입해주는 게 필요하다. 객체 지향 프로그래밍이나 파이썬의 클래스에 대해 알면 뭔 느낌인지 올 것이다.

그리고 len 함수는 입력받은 self.t를 return한다. 때문에 아래에 있는 print(len(dataset))에서는 5가 출력된다.

마지막으로 getitem에서는 idx라는 숫자를 그냥 return한다. 

위 코드를 돌려보면 len은 5가 나오고, i가 range(10)으로 10번 도는 동안 그냥 next(it)에서는 0부터 9가 나온다. 정말 별 거 없다. 그냥 데이터셋 데이터로더의 개념을 보기 위해 만들어봤다. 그러면 이제 데이터로더를 사용해보자. 위에서 설명한 함수들을 데이터로더 내부에서 어떤 식으로 작동하는지와 함께 보면 이해가 좀 된다. 위 코드블럭에서 메인 함수를 다음과 같이 바꿔보자.

 

(난이도 쉬움) 복붙해서 돌려보기를 추천하는 코드 (매우 짧음!)

if __name__ == "__main__":
    dataset = SimpleDataset(t=5)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=2,
                            shuffle=True,
                            drop_last=False)

    for epoch in range(2):
        print(f"epoch : {epoch} ")
        for batch in dataloader:
            print(batch)

이 경우 다음과 같이 출력된다.

설명을 해보자면, 데이터로더를 batch_size는 2로, shuffle은 True로, drop_last는 False로 해놨다. batch size를 바꿔보거나, shuffle, drop_last를 True/False로 바꿔보면서 체크해봐도 좋다. 

결과를 보면 for batch in dataloader에서 처음에는 [[3], [4]]이, 다음에는 [[0], [2]]이 나오고 마지막으로 [[1]]이 나온다. 일단 dataloader가 하는 일은, Dataset의 __len__함수를 통해 길이를 파악한다. 그리고 데이터로더의 shuffle이 False면 0, 1, 2, ... len(dataset) - 1 순서의 index array를 만들고 만약 shuffle이 True면 0부터 len(dataset) - 1까지의 index array를 만들고 순서를 랜덤하게 섞는다. torch.randperm처럼! drop_last라는 거는 batch_size 개수로 미니배치를 돌 것인데 만약에 맨 뒤에 길이가 조금 짧게 자투리가 남으면 이것도 뽑아줄지 아니면 그냥 버릴지에 대한 인수이다. 지금은 True라서 마지막에 길이가 1짜리인 애가 나온다 (len이 5인데 batch_size가 2니까)

https://pytorch.org/docs/stable/generated/torch.randperm.html

그리고 dataloader에서 데이터를 샘플링을 하면 데이터가 아닌 미니 배치가 나온다. 내부적으로는, 아까 만들어놓은 index array를 앞에서부터 차례대로 batch_size 개수만큼 뽑고, 그 뽑은 index를 하나씩 dataset의 __getitem__ 함수에 집어넣어 return된 데이터를 뽑아 놓는다. 그러면 이 데이터들이 batch_size개의 index만큼 따로 있을 것이고, 데이터로더가 이 데이터들을 미니배치라는 큰 텐서 하나로 concatenate을 해준다. 이 경우 concatenate은 0번째 차원으로 하는데, 그 전에 데이터들에 0번째 차원을 만든다. 코드로 보자면 다음일을 한다 (엄밀하게는 이렇지 않지만 이런 플로우로 이루어진다.) 여기서 torch.stack이 새로운 차원을 만들고 concatenate을 하는 코드이다. list안의 모든 tensor를 각각 unsqueeze하고 concatenate하는 것과 stack이 같은 역할을 한다.

index_array = torch.randperm(5)
# [1, 4, 0, 2, 3]
index = index_array[:batch_size]
# [1, 4]
data_list = []
for i in range(index.size(0)):
	data_list.append(dataset(index[i]))
# data_list : [tensor[1], tensor[4]]
batch = torch.stack(data_list, dim=0)

print(batch)
# [[1], [4]]

 

 

대충 __init__, __len__, __getitem__ 함수에 대해서 알 수 있었고, 커스텀 데이터셋을 선언하고 데이터로더를 사용하면 어떤 식으로 나오는지 알 수 있었을 것이다. 사실 여기까지는 약간 너무 디테일하다고 생각할 수도 있지만, 우리가 원하는대로 custom을 하려면 그래도 어떻게 작동하는지는 알아야 잘못된 움직임을 막을 수 있다. 그러면 이제 실생활로 가보자.

 

1. 직접 저장해놓은 cat과 dog 이미지들을 돌아주는 데이터셋

폴더 구조는 다음과 같다고 해보자

그리고 우선 다음과 같이 코드를 짜 보았다.

import glob
import torch
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader


class catdogDataset(Dataset):
    def __init__(self, path, train=True, transform=None):
        self.path = path
        if train:
            self.cat_path = path + '/cat/train'
            self.dog_path = path + '/dog/train'
        else:
            self.cat_path = path + '/cat/test'
            self.dog_path = path + '/dog/test'
        
        self.cat_img_list = glob.glob(self.cat_path + '/*.png')
        self.dog_img_list = glob.glob(self.dog_path + '/*.png')

        self.transform = transform

        self.img_list = self.cat_img_list + self.dog_img_list
        self.class_list = [0] * len(self.cat_img_list) + [1] * len(self.dog_img_list) 
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        label = self.class_list[idx]
        img = Image.open(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, label

if __name__ == "__main__":
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    dataset = catdogDataset(path='./cat_and_dog', train=True, transform=transform)
    dataloader = DataLoader(dataset=dataset,
                        batch_size=1,
                        shuffle=True,
                        drop_last=False)

    for epoch in range(2):
        print(f"epoch : {epoch} ")
        for batch in dataloader:
            img, label = batch
            print(img.size(), label)

사실 정말 정말 다양하게 코드를 짤 수 있는데 위의 방식은 제일 무난한 방식인 것 같다. (물론 나는 어쩌다 보니 사용하지 않는다)

우선 하던대로 __init__부터 보면, path 와 train, transform을 입력받을 수 있도록 만들어 놓았다 (당연히 추가하거나 바꿔도 된다). path는 데이터셋의 위치를 알기 위해서고, train이 True냐 False냐에 따라서 train 폴더를 볼지, test 폴더를 볼지 if 문을 통해서 만들어 놓았다. 그 이후에는 glob 이라는 것을 이용해 폴더 내에 있는 파일들을 찾는데, 마지막이 png로 끝나는 애들을 찾는다. 그 외에는 init에서는 별 거 없다. 한 가지는 label을 반환하기 위해서 cat image와 dog image에 대응될 수 있게 cat image 개수만큼 0을, dog image 개수만큼 1을 가지고 있는 class_list를 만들어 주었다. getitem에서는 파이토치 텐서를 반환해야만 나중에 데이터로더가 mini batch로 concatenate을 할 수 있기 때문에, 'cat', 'dog'와 같은 string을 쓰면 안 된다.

__len__ 에서는 cat과 dog의 이미지들의 개수를 return 한다.

__getitem__ 에서는 idx번째의 이미지를 PIL.Image를 통해 열고 적절한 transform을 해준다. 여기서는 torchvision이 제공하는 transforms을 사용하기로 했고 이들은 PIL Image를 인풋으로 받기 때문에 skimage.imread와 같이 이미지를 numpy array로 반환하는 애를 사용하면 안되고 Image.open(img_path)로 열었다. 또 방금 말했듯이, 파이토치 텐서를 반환해야 하기 때문에 ToTensor()를 통해 PIL Image를 텐서로 바꿔주었다. ToTensor 이전에 transforms.Resize라거나 CenterCrop이나 RandomHorizontalFlip 등 다양한 pre processing 이나 data augmentation을 사용할 수 있다. 마지막으로 이 이미지에 해당하는 label도 함께 return 한다.

위 코드를 돌리면 다음과 같이 나온다.

2048, 2304 짜리 아무 이미지 하나씩을 넣어놓았다. 지금은 batch_size가 1이라서 하나씩 도는 것을 볼 수 있고, label이 0과 1이 나오는 것도 볼 수 있다. 한 가지 상식으로는 Pytorch에서 2D Image를 가지는 mini-batch의 경우 [B, C, W, H] 의 순으로 텐서를 가진다. B는 배치 사이즈고, C는 Channel로 1이면 흑백, 3이면 RGB를 의미한다. W와 H는 이미지의 위아래 길이다. 여기서 알 수 있는 한 가지 딸려오는 상식은 getitem이 return하는 이미지의 위아래 길이는 매번 항상 같아야 한다. 이는 데이터로더가 concatenate을 해서 mini-batch를 만들어야하기 때문인데, 만약 위아래 차원이 다르면 컨캣이 안 되기 때문이다. 때문에 만약 이미지들이 다 크기가 다르다면 1) batch_size를 1로 해서 concat을 안해도 되게 하거나, 2) transform에 Resize나 CenterCrop을 사용해서 이미지들을 같은 크기로 조정해주어야 한다.

 

2. 이미지들을 __init__ 시점에서 메모리에 모두 올려버리기

이는 3D 데이터를 다루거나 파일의 입출력 시간이 좀 귀찮고 오래 걸리면 필요하다. 위 1번 데이터셋은 __init__ 에서는 이미지들의 경로를 self에 넣어주고, __getitem__에서 매번 path에 해당하는 이미지를 연다. 생각해보면, 우리는 epoch을 백번, 천번, 많게는 십만번도 도는데, 같은 이미지를 매번 그렇게 열어야 하나? 싶으면 이걸 한 번 보면 좋다. (사실 나는 3D data를 다루기 때문에 __init__ 단계에서 skimage.io.imread를 통해 tif를 열어 numpy array로 메모리에 올려놓는 식으로 사용하지만, 1번 데이터셋과 얼추 비슷하게 보여주려고 그냥 PIL Image를 사용했다.

class catdogDataset(Dataset):
    def __init__(self, path, train=True, transform=None):
        self.path = path
        if train:
            self.cat_path = path + '/cat/train'
            self.dog_path = path + '/dog/train'
        else:
            self.cat_path = path + '/cat/test'
            self.dog_path = path + '/dog/test'
        
        self.cat_img_list = glob.glob(self.cat_path + '/*.png')
        self.dog_img_list = glob.glob(self.dog_path + '/*.png')

        self.transform = transform

        self.img_list = self.cat_img_list + self.dog_img_list

        self.Image_list = []  # 바뀐부분!!!!!!!
        for img_path in self.img_list:  # 바뀐부분!!!!!!!
            self.Image_list.append(Image.open(img_path))  # 바뀐부분!!!!!!!

        self.class_list = [0] * len(self.cat_img_list) + [1] * len(self.dog_img_list) 
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img = self.Image_list[idx]  # 바뀐부분!!!!!!!
        label = self.class_list[idx]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

위 코드에서 "바뀐부분!!!!!!"이라는 주석만 바꿔놨다. 비교해보면 사실 변한 건 이거다. 원래는 __getitem__에서 Image.open을 하는데, 지금은 __init__에서 Image.open을 그냥 싹 다 해놓고 self에 넣어놓아 (메모리에 다 올려놓고) __getitem__ 에서는 그냥 인덱싱만 한다.

이런 식으로 하면 다음과 같은 일들이 이루어진다.

  • __init__ 때 조금 시간이 걸린다. 왜냐면 path에 있는 이미지들을 싹 다 읽으니까.
  • __init__ 이후에 컴퓨터 메모리를 보면 좀 많이 잡아 먹을 수 있다. 왜냐면 path에 있는 이미지들을 싹 다 메모리에 올려놓은 거니까.
  • 대신 __getitem__에서 파일 입출력을 하는 게 아니라 메모리에서 데이터를 가져오는 것이라 __getitem__이 빨라진다.

사실 이건 그냥 CIFAR나 ImageNet 같은 걸 할 때는 별로 추천하지 않는 것이, 이미지 파일 입출력이 그렇게 오래 안 걸린다. 내가 이걸 사용하는 이유는 내가 사용하는 데이터셋이 3-D 이미지 여러 개로 이루어져 있는데, 이 이미지 하나하나가 1기가 정도 되는 큰 이미지들이라 매번 새로 읽어들이면 파일을 읽는 것에 시간을 다 쓴다. 그래서 처음에 __init__ 때 한 번 좀 기다리면서 다 읽어놓고 그 이후에는 메모리 내에서만 데이터를 왔다갔다하려는 것이 목적이라 이렇게 한다.

 

다음은 그냥 몇 가지 팁들을 적어보았는데 아마 처음 보면 다 이해가 하나도 안 갈 수도 있다. 자세한 예시까지 첨부하면 참 좋겠지만, 너무 힘들다. 혹시 이해가 안 되어 질문을 하면 그 때는 다시 컨디션이 괜찮아서 자세히 설명할 수도 있을 것 같다.

 

 


몇 가지 신기한 혹은 꿀팁들

  1. 데이터셋 getitem 에서 numpy array를 return해도 Dataloader를 사용하면 배치를 만드는 과정에서 자동으로 torch tensor로 바꾼다. 0번째 차원을 추가하고 batch_size 개수의 item을 concatenate 하는 것도 당연히 하고!
  2. glob같은 경우 파일 순서가 이상할 수 잇다. 만약 img와 segmentation map img 와 같이 같이 pair 가 되어야하는데 각각이 다른 path에 있고 이름이 막 다른 경우에 glob의 output이 우리가 상식적으로 생각하는 0.png, 1.png, 2.png 이 순서가 아니라서 img 랑 segmentation map 이 다른 순서로 sorting이 되어 있을 수 있다. 이 때 단순히 getitem에서 index 기준으로 그냥 이미지랑 map을 뽑는 식으로 짜면 학습할 때 img와 segmentation map이 페어가 맞지 않을 수도 있다!
  3. getitem에서 그냥 여러 데이터(이미지, label, etc.)를 쉼표로 리턴하는 게 아니라 이 데이터들을 딕셔너리로 리턴할 수도 있을 텐데 (return {'img' : self.img_list[idx], 'label' : self.class_list[idx]}), 이렇게 만들면 데이터로더가 나중에 합칠 때도 이걸 고려해준다! 데이터로더에서 나오는 batch가 {'img' : 뭐시기, 'label' : 뭐시기} 이렇게 나온다!
  4. torchvision transforms 에서 Resize(128)로 해서 하면 이미지가 128, 128이 되는 것이 아니다. 만약 [128, 256] 짜리 이미지가 있고 Resize(128)로 transform을 만들어서 통과시켜도 [128, 256]이다. torchvision의 Resize가 비율을 유지해서 그렇다. 이는 여러 비율을 가지는 이미지들이 존재하면 위에서 잠깐 말한대로 이미지들의 크기가 달라서 concat이 안되어서 데이터로더에서 에러가 뜬다. 그러니까 Resize와 CenterCrop 같은 걸 함께 사용하자.

 


글이 아마 약간 이해하기 어렵게 쓴 느낌이 없잖아 있지만, Pytorch에서 커스텀 데이터셋을 사용하려고 한다면 굉장히 도움이 될만한 내용이 많다고 생각한다. (나름 오랜 기간 커스텀 데이터셋을 삽질하면서 내가 공부한 것이 많다.) 짧게 정리해보자면 1) Pytorch의 데이터로더는 별로 건들 게 없고, Dataset을 짜주면 된다. 2) __init__, __len__, __getitem__ 만 잘 짜주면 된다. 인 것 같다. 커스텀 데이터셋은 사실 내가 짠 방식 말고도 정말 수없이 다양하게 짤 수 있다. 내가 위에 예시 코드를 짜면서도 이거말고 다른방식으로 짤까 라는 고민을 정말 다양한 부분에서 했다. 단순히 CIFAR와 같은 널리 알려진 데이터셋만 사용하는 것이 아니라면 꼭 필수적인 것이 커스텀 데이터셋이다. (그리고 이런건 커스텀으로 짜야 멋이 나지) 아마 위 내용들을 알아두면 도움이 많이 될 것이다. 혹시 문장이 이해가 안가거나 궁금한 것이 있다면 댓글을 남겨주시면 최대한 답변을 할 것이다.