questionet
LLM Trend Note2 (2) Supervised Fine-Tuning 본문
SFT
이번 노트에서는 kogpt-2를 instruction dataset으로 SFT를 진행해 보겠습니다.
먼저 필요한 라이브러리들을 불러오겠습니다.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import Adam
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import Trainer, TrainingArguments
from copy import deepcopy
import copy
import logging
import json
from dataclasses import dataclass
다음으로 모델과 토크나이저를 불러오겠습니다.
model = AutoModelForCausalLM.from_pretrained('skt/kogpt2-base-v2')
tokenizer = AutoTokenizer.from_pretrained(
'skt/kogpt2-base-v2', bos_token='</s>', eos_token='</s>', unk_token='</s>', pad_token='</s>',
padding_side="right",
model_max_length=512,
)
print(tokenizer)
모델 인퍼런스 단계에서 사용할 prompt 딕셔너리 템플릿과 SFT 데이터셋 클래스를 정의하겠습니다.
from typing import Optional, Dict, Sequence
class SFT_dataset(Dataset):
def __init__(self, data_path_1_SFT: str, tokenizer: transformers.PreTrainedTokenizer, verbose=False):
super(SFT_dataset, self).__init__()
logging.warning("Loading data...")
pattern_instruction = 'prompt' # instruction
pattern_output = 'completion' # response
data_path_1_SFT = './data_kochatgpt/kochatgpt_1_SFT.jsonl'
with open(data_path_1_SFT, "r", encoding='utf-8-sig') as json_file:
list_data_dict = json.load(json_file)
PROMPT_DICT = {
"prompt_input": (
"### Instruction(명령어):\n{prompt}\n\n### Response(응답):"
)
}
prompt_input = PROMPT_DICT["prompt_input"]
sources = []
for example in list_data_dict:
tmp = prompt_input.format_map(example)
sources.append(tmp)
targets = []
for example in list_data_dict:
targets.append(f"{example[pattern_output]}{tokenizer.eos_token}")
examples = [s + t for s, t in zip(sources, targets)]
sources_tokenized = self._tokenize_fn(sources, tokenizer) # source
examples_tokenized = self._tokenize_fn(examples, tokenizer) # source + target
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = -100
data_dict = dict(input_ids=input_ids, labels=labels)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
logging.warning("Loading data done!!: %d"%(len(self.labels)))
def _tokenize_fn(self, strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
SFT_dataset 클래스의 이니셜라이저에서
label[:source_len] = -100 코드의 -100이 의미하는 게 무엇인가요?
해당 코드가 필요한 이유와 그 기능은 무엇인가요?
다음 링크를 참고해보세요.
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
@dataclass
class DataCollatorForSupervisedDataset(object):
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value= -100)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
DataCollatorForSupervisedDataset 클래스에서
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value= -100) 코드의
padding_value= -100 인자의 -100은 어떤 기능을 하나요?
이제 SFT_dataset 클래스를 사용해 훈련셋을 만들고 data collator 인스턴스를 만들겠습니다.
train_dataset = SFT_dataset(data_path_1_SFT='./data_kochatgpt/kochatgpt_1_SFT.jsonl', tokenizer=tokenizer)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
print('input : %s'%train_dataset.input_ids[0])
print('output: %s'%train_dataset.labels[0])
훈련을 위한 마지막 단계로 Training arguments를 사용해 trainer 클래스를 정의하겠습니다.
training_args = TrainingArguments(
output_dir="./test",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=5,
prediction_loss_only=True,
fp16 = True
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset
)
SFT 훈련을 진행해볼까요?
(빠르게 학습해보기 위해 1epoch만 돌려보도록 하겠습니다.)
trainer.train()
model.save_pretrained('./output_1_SFT')
loss가 잘 줄어들었나요?
이제 문장 생성 능력을 확인하기 위해
빠르게 허깅페이스의 pipleline 클래스를 사용하여 generator를 만들어보겠습니다.
generator = pipeline('text-generation', model='./output_1_SFT', tokenizer=tokenizer)
generation_args = dict(
num_beams=4,
repetition_penalty=2.0,
no_repeat_ngram_size=4,
eos_token_id=375, # \n
max_new_tokens=64,
do_sample=True,
top_k=50,
early_stopping=True
)
PROMPT_DICT = {
"prompt_input": (
"### Instruction(명령어):\n{prompt}\n\n### Response(응답):"
)
}
list_prompt = ['불고기용 고기 한우에요?',
'리처드 닉슨이 43대 부통령직을 수행한 년도는?',
'시카고 오헤어 국제공항은 어디에 있어?',
'오늘 미세먼지 어때?']
list_prompt = [PROMPT_DICT['prompt_input'].format_map({'prompt' : tmp}) for tmp in list_prompt]
list_result = generator(list_prompt, **generation_args)
for prompt, result in zip(list_prompt, list_result):
print()
print((result[0]['generated_text']))
SFT 모델의 성능은 어떤가요?
SFT 단계를 최적화하기 위해선 무엇보다도 instruction dataset의 품질과 initial모델의 언어모델링 성능이 중요합니다.
GPT를 새로 pretrain 하여 언어모델 성능을 도약시키는 일은 우리의 학습목표를 넘어서는 일이니
우선은 데이터셋 전처리를 더 수행하고 최상의 디코딩 전략이 적용된 generator를 설계한다면
더 나은 성능을 기대해 볼 수 있을 것입니다.
하지만 지금은 baseline을 빠르게 돌려보는 게 목적입니다.
이제 다음 단계인 reward modeling으로 넘어가 보도록 하겠습니다.
메모리 관리를 위해 캐시를 비우고 넘어가겠습니다.
torch.cuda.empty_cache()
'Deep learning' 카테고리의 다른 글
LLM Trend Note2 (4) PPO(Proximal Policy Optimization) (0) | 2023.06.27 |
---|---|
LLM Trend Note2 (3) Reward Model (0) | 2023.06.27 |
LLM Trend Note2 (1) Base model and Dataset for RLHF (0) | 2023.06.27 |
LLM Trend Note (6) LLM.int8(), LoRA, Prefix LM, Sparse transfomer, Sparse attention, Model parallelism, Data parallelism (2) | 2023.06.27 |
LLM Trend Note (5) GPT-4 vs LLaMA (0) | 2023.06.27 |