728x90

안녕하세요. 오늘은 인공지능과 딥러닝에서 매우 많이 사용하는 PyTorch의 'Dataset'에 대해 자세하게 설명해보려고 합니다. 인공지능과 딥러닝에서 모델을 학습시키기 위해 Data를 처리하는 코드는 복잡하고 유지보수가 어려울 수 있습니다. 이를 위해 PyTorch에서는 더 나은 가독성과 모듈성을 위해 데이터셋 코드를 분리합니다. 

이 글에서 PyTorch의 'Dataset'은 클래스를 직접 구현해 보며 이해해보려고 합니다.

 

 

Dataset이란?


PyTorch에서 Dataset은 데이터와 레이블을 저장하고 있으며, 데이터에 쉽게 접근할 수 있도록 도와주는 추상 클래스입니다. 이 클래스를 사용하여 다양한 데이터 소스(예: 파일, 데이터베이스, 메모리 등)에서 데이터를 불러오고, 필요한 전처리 작업을 수행할 수 있습니다. 

요약하자면, Dataset은 샘플과 정답들을 저장하는 클래스라고 생각하시면 됩니다.

 

기본적으로 Dataset 클래스는 다음 세 가지 메서드를 구현해야 합니다.

  • __init__: 필요한 변수들을 선언합니다.  init 함수는 Dataset 객체가 생성될 때 한 번만 실행됩니다.
  • __len__: 데이터셋의 총 데이터 개수를 반환합니다
  • __getitem__: 주어진 인덱스에 해당하는 샘플을 데이터셋에서 불러와 반환합니다.

 

 

구현 예제


import torch
from torch.utils.data import Dataset

# 사용자 정의 데이터셋 클래스
class CustomDataset(Dataset):

	# 생성자, 데이터를 전처리하는 부분
    def __init__(self, length=100):
        self.data = torch.randn(length, 10)  # 임의의 데이터 생성 (예: 100x10 텐서)
        self.labels = torch.randint(0, 2, (length,))  # 임의의 레이블 생성 (0 또는 1)

	# Dataset 클래스의 길이를 반환
    def __len__(self):
        return len(self.data)

	# 데이터셋의 데이터중 idx위치의 Data를 반환하는 코드
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        return sample, label
        
# 데이터셋 생성
dataset = CustomDataset()

 

실행결과

>>>len(dataset) # 데이터셋의 길이는 100
100

>>>dataset[0]	# 데이터셋의 첫번째 데이터의 값(샘플, label)
(tensor([ 0.5283, -0.5272, -1.0905,  0.4210,  0.2976, -0.2760, -0.8738,  1.0800,
         -1.9537, -0.2197]), 
tensor(0))

 

 

마치며

PyTorch의 Dataset 클래스를 사용하면 다양한 데이터 소스에서 데이터를 효율적으로 불러오고, 전처리하는 과정을 간소화할 수 있습니다. 사용자 정의 Dataset을 만들 때는 __len__과 __getitem__ 메서드를 구현해야 한다는 점을 기억해주세요.

728x90