AI & DL/Pytorch Lightning

[Pytorch Lightning] BART를 훈련해 Text를 요약해보자 - 1

Giliit 2024. 7. 11. 21:57
728x90
  • 1편 : 프로젝트 요약 및 LightningModule 설계
  • 2편 : Dataset 구축 및 DataLoader 구축
  • 3편 : logger 작성, wandb 연동 및 확인
  • 4편 : 모델 로딩 및 실행결과(wandb) 확인

 

오늘은 이전에 캡스톤디자인에서 BART를 이용해서 Text를 요약하는 Task를 진행했었는데, 이것에 대해 설명하면서 코드를 어떻게 작성했는지 설명하려고 한다. 

간단하게 요약하면 다음과 같다. 카페 메뉴에 대한 Text가 입력이 되면, BART를 이용해서 영어로 번역한 뒤 필요한 내용만 출력하는 것이다. 아이스 아메리카노 하나, 카페라떼 하나를 영어로 요약하는 것이다. 다음과 같은 Task를 위해 BART를 훈련시키는 코드에 대해 설명하겠다.

 

필요한 라이브러리 설치

pip install datasets==2.17.0
pip install pytorch_lightning==2.2.4
pip install torkenizers==0.15.2
pip install transformers==4.38.1
pip install pandas==2.0.3

위의 라이브러리는 제가 사용한 라이브러리들의 버전이다. 혹시 버전오류가 발생한다면 다음과 같이 라이브러리를 재설치하면 된다.

 

훈련 모듈 설계

요약

training_step 훈련 데이터의 각 배치에 대해 실행  →  Train_dataloader에 대해 실행 

e.g., trainer.fit(module, train_dataloader, valid_dataloader)
validation_step 검증 데이터의 각 배치에 대해 실행 → Valid_dataloader에 대해 실행

e.g., trainer.fit(module, train_dataloader, valid_dataloader)
test_step 시험 데이터의 각 배치에 대해 실행 → Test_dataloader에 대해 실행

e.g., trainer.test(module, test_dataloader)
configure_optimizers 모델의 optimizer를 구성
on_validation_epoch_end 각 validation_epoch가 종료될때마다 실행

 

모델 초기화

def __init__(
        self,
        model,
        model_save_dir,
        total_steps,
        max_learning_rate: float = 2e-4,
        min_learning_rate: float = 2e-5,
        warmup_rate: float = 0.1,
    ):
    super().__init__()

    self.model = model
    self.total_steps = total_steps
    self.max_learning_rate = max_learning_rate
    self.min_learning_rate = min_learning_rate
    self.warmup_rate = warmup_rate
    self.model_save_dir = model_save_dir
    self.validation_step_loss = []

모델 설계에 필요한 인수들을 입력받습니다.

  • model : BART 모델 인스턴스
  • model_save_dir : 모델을 저장할 경로
  • total_steps, max_learning_step, min_learning_rate, warmup_rate : 학습에 사용될 hyper parameter
  • validation_step_loss : validation_epoch가 끝난 후 손실을 저장하기 위한 리스트

 

훈련 단계( training_step)

def training_step(self, batch, batch_idx):
    output = self.model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        decoder_input_ids=batch["decoder_input_ids"],
        decoder_attention_mask=batch["decoder_attention_mask"],
        return_dict=True,
    )

    labels = batch["decoder_input_ids"][:, 1:].reshape(-1)
    logits = output["logits"][:, :-1].reshape([labels.shape[0], -1])

    loss = F.cross_entropy(logits, labels, ignore_index=self.model.config.pad_token_id)
    metrics = {"loss": loss}
    self.log_dict(metrics, prog_bar=True, logger=True, on_step=True)

    return metrics

모델에 특정 배치가 들어갔을 때 훈련을 하게 된다. 

  • input_ids: 입력 시퀀스의 토큰 ID 배열.
  • attention_mask: 입력 시퀀스의 어텐션 마스크 배열, 패딩 된 부분을 0으로, 유효한 토큰을 1로 표시.
  • decoder_input_ids: 디코더에 입력되는 시퀀스의 토큰 ID 배열.
  • decoder_attention_mask: 디코더 시퀀스의 어텐션 마스크 배열.
  • labels : 맨 앞 토큰([BOS])을 제외하고 나머지 토큰을 사용
  • logits : 모델 출력에서 마지막 토큰을 제외한 나머지
  • cross_entropy : label과 logits의 손실을 계산하며 ignore_index를 통해 [PAD] 토큰 무시
  • log_dict : 여러 메트릭을 로깅하는데 사용
    • metrics: 로깅하고자 하는 메트릭을 담고 있는 딕셔너리. 키는 메트릭의 이름, 값은 로깅하려는 수치.
    • prog_bar: 메트릭이 훈련 프로그레스 바에 표시. 
    • logger: 로깅하는 메트릭이 백엔드 로거(예: TensorBoard, CSVLogger 등)에 전송.
    • on_epoch: 로깅되는 메트릭이 에폭 단위로 기록. 각 에폭에서 메트릭의 평균값을 계산하여 로깅

cross_entropy함수를 이용해서 정답과 예측값의 손실을 계산한다. 또한 ignore_index인수를 통해 [PAD] 토큰들을 무시하도록 합니다. 마지막

 

검증 단계, 테스트 단계( validation_step, test_step)

def validation_step(self, batch, batch_idx):
    output = self.model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        decoder_input_ids=batch["decoder_input_ids"],
        decoder_attention_mask=batch["decoder_attention_mask"],
        return_dict=True,
    )

    labels = batch["decoder_input_ids"][:, 1:].reshape(-1)
    logits = output["logits"][:, :-1].reshape([labels.shape[0], -1])

    loss = F.cross_entropy(logits, labels, ignore_index=self.model.config.pad_token_id)
    metrics = {"loss(v)": loss}
    self.validation_step_loss.append(loss)

    self.log_dict(metrics, prog_bar=True, logger=True, on_epoch=True)

    return metrics
    

def test_step(self, *args, **kwargs):
    return self.validation_step(*args, **kwargs)

train_step과 거의 동일하지만, 다른점은 validation_step_loss를 저장하며 test_step은 validation_step과 동일하다.

 

옵티마이저 설정( configure_optimizers)

def configure_optimizers(self):
    optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.max_learning_rate)

    return {
        "optimizer": optimizer
    }

optimizer를 adamW로 설정하고 학습률에 따라서 model의 파라미터를 업데이트한 것을 optimizer를 반환하여 가중치를 업데이트한다.

 

검증 에폭 종료 시 ( on_validation_epoch_end)

def on_validation_epoch_end(self):
    if self.trainer.is_global_zero:
        losses = [output.mean() for output in self.validation_step_loss]
        loss_mean = sum(losses) / len(losses)

        self.model.save_pretrained(
            os.path.join(
                self.model_save_dir,
                f"model-{self.current_epoch:02d}epoch-{self.global_step}steps-{loss_mean:.4f}loss",
            ),
        )

    self.validation_step_loss.clear()  # free memory
  • self.trainer.is_global_zero : 멀티 GPU 또는 분산 설정에서의 주 노드에서 실행되고 있는지를 검사
  • self.model.save_pretrained : 모델의 파라미터를 저장하며, 제목은 다음과 같이 한다.
  • self.validation_step_loss.clear : 리스트를 비우고 다음 에폭에서 새로운 손실 데이터를 깨끗하게 수집

 

전체코드

import os

import torch
import pytorch_lightning as pl
import torch.nn.functional as F


class StoryModule(pl.LightningModule):
    """
    Attributes:
        model: BART model
        total_steps: total training steps for lr scheduling
        max_learning_rate: Max LR
        min_learning_rate: Min LR
        warmup_rate: warmup step rate
        model_save_dir: path to save model
    """

    def __init__(
        self,
        model,
        model_save_dir,
        total_steps,
        max_learning_rate: float = 2e-4,
        min_learning_rate: float = 2e-5,
        warmup_rate: float = 0.1,
    ):
        super().__init__()

        self.model = model
        self.total_steps = total_steps
        self.max_learning_rate = max_learning_rate
        self.min_learning_rate = min_learning_rate
        self.warmup_rate = warmup_rate
        self.model_save_dir = model_save_dir
        self.validation_step_loss = []

        self.save_hyperparameters(
            {
                **model.config.to_dict(),
                "total_steps": total_steps,
                "max_learning_rate": self.max_learning_rate,
                "min_learning_rate": self.min_learning_rate,
                "warmup_rate": self.warmup_rate,
            }
        )

    def training_step(self, batch, batch_idx):
        output = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            decoder_input_ids=batch["decoder_input_ids"],
            decoder_attention_mask=batch["decoder_attention_mask"],
            return_dict=True,
        )

        labels = batch["decoder_input_ids"][:, 1:].reshape(-1)
        logits = output["logits"][:, :-1].reshape([labels.shape[0], -1])

        loss = F.cross_entropy(logits, labels, ignore_index=self.model.config.pad_token_id)
        metrics = {"loss": loss}
        self.log_dict(metrics, prog_bar=True, logger=True, on_step=True)

        return metrics
    
    def validation_step(self, batch, batch_idx):
        output = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            decoder_input_ids=batch["decoder_input_ids"],
            decoder_attention_mask=batch["decoder_attention_mask"],
            return_dict=True,
        )

        labels = batch["decoder_input_ids"][:, 1:].reshape(-1)
        logits = output["logits"][:, :-1].reshape([labels.shape[0], -1])

        loss = F.cross_entropy(logits, labels, ignore_index=self.model.config.pad_token_id)
        metrics = {"loss(v)": loss}
        self.validation_step_loss.append(loss)
        
        self.log_dict(metrics, prog_bar=True, logger=True, on_epoch=True)

        return metrics

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.max_learning_rate)

        return {
            "optimizer": optimizer
        }

    def on_validation_epoch_end(self):
        if self.trainer.is_global_zero:
            losses = [output.mean() for output in self.validation_step_loss]
            loss_mean = sum(losses) / len(losses)

            self.model.save_pretrained(
                os.path.join(
                    self.model_save_dir,
                    f"model-{self.current_epoch:02d}epoch-{self.global_step}steps-{loss_mean:.4f}loss",
                ),
            )

        self.validation_step_loss.clear()  # free memory

 

다음은 dataset 및 DataLoader 구축에 대해 설명하겠습니다!