questionet

LLM Trend Note (5) GPT-4 vs LLaMA 본문

Deep learning

LLM Trend Note (5) GPT-4 vs LLaMA

orthanc 2023. 6. 27. 14:03

지난 노트에서 우리는
OpenAI에서 발표한 InstructGPT 논문을 중심으로
LLM에 RLHF가 도입되기까지의 흐름과 RLHF 학습 메커니즘에 관한 내용을 간단히 살펴보았습니다.

이 글을 쓰고 있는 2023년 4월을 기준으로
가장 강력한 Emergence를 보여주는 모델은 GPT-4입니다.
현재 OpenAI의 유료 chatGPT 서비스는 GPT-4를 사용하고 있기 때문입니다.
문제는 GPT-4의 소스코드가 공개되어 있지 않다는 사실입니다.
우리는 GPT-4가 구체적으로 어떤 모델학습 인프라 위에서
어떤 아키텍쳐를 기반으로 얼마나 많은 파라미터를 가지고,
얼마나 많은 수의 토큰을 봤는지,
어떻게 데이터셋을 구성하여 어떤 RLHF 방식으로 학습시켰는지 알 수 없습니다.
(GPT-4 Technical Report 참고)

오픈소스 정신을 표방하는 OpenAI가 주춤하는 사이
Meta는 자사의 Foundation Model인 LLaMA 를 (자의반 타의반)오픈소스로 제공하고 있습니다.
(LLaMA: Open and Efficient Foundation Language Models논문 및 깃헙 참고 링크)

LLaMA는 GPT-4와 비슷한 시기에 발표되어 GPT-3보다 더 작으면서 성능은 더 좋은 모델로 평가받고 있습니다.

LLaMA는 6.7B, 13B, 32.5B, 65.2B 사이즈의 네가지 버전으로 출시되었습니다.
이 중 LLaMA(13B)모델은 GPT3(175B)보다 1/10 이상 모델 사이즈가 작지만
대부분의 zero-shot 벤치마크에서 GPT3(175B)를 능가하는 성능을 갖는 것으로 평가받았습니다.

심지어 가장 작은 6.7B 모델은 V100 single machine에서도 실행 가능하다고 합니다!

LLaMA는 GPT-3와 똑같은 디코더 기반 트랜스포머 아키텍쳐 모델입니다.
오픈소스로 공개되어 있는 LLaMA에 RLHF를 구현한 레퍼런스 코드가 있다면
이전 노트에서 살펴본 내용을 좀 더 잘 이해해볼 수 있을 것 같습니다.

LLaMA가 공개된 이후로 오픈소스 LLM인 Vicuna가 발표되어 곧바로 LLaMA를 뛰어넘었고
LLaMA를 instruction tuning한 Alpaca가 발표되는 등
RLHF를 사용한 LLM에 대한 접근성이 대중에게 열리기 시작했습니다.

이번 노트에서는 허깅페이스에서 공개한
StackLLaMA: A hands-on guide to train LLaMA with RLHF 를 살펴보고자 합니다.

StackLLaMA 는 StackExchange의 질문 답변 데이터셋
Meta에서 공개한 LLaMA, 그리고 RLHF를 pytorch로 구현한 TRL 라이브러리를 사용해 만든 모델입니다.

구체적으로는 가장 작은 6.7B LLaMA를 initial LM으로 사용했고,
특히 RM의 경우 고품질의 Human label을 필요로 하는 OpenAI방식 대신
Anthropic이 PMP를 위해 사용한 방식을 따랐습니다.

그러면 이제부터 StackLLaMA 의 각 훈련 단계를 살펴보도록 하겠습니다.

1단계 (SFT) :

initial model로 사용할 디코더 기반 트랜스포머를 관심있는 도메인에 fine-tuning 하기 위해
LLaMA를 StackExchange 질문답변 dataset으로 causal language modeling 합니다.
dataset 링크
supervised_finetuning 참고

위의 dataset 링크에서 StackExchange 데이터셋을 어떻게 구축했는지 살펴보세요.
response_j, response_k를 각각 어떻게 만들었고 어떻게 사용될까요?

위의 supervised_finetuning.py에서 ConstantLengthDataset 클래스를 찾아보세요.
해당 클래스의 기능은 무엇일까요?

위의 supervised_finetuning.py 스크립트에서 chars_token_ratio 함수와 create_datasets 함수를 찾아보세요.
각 함수의 기능은 무엇일까요?

2단계 (RM) :

StackExchange 데이터셋을 활용해 upvotes를 기반으로 질문에 대한 답변들의 랭킹을 매기고,
더 높은 랭크의 답변을 인간의 피드백이 반영된(인간이 더 선호하는) 답변으로 간주하는 RM을 학습합니다.
reward_modeling 참고

위의 reward_modeling.py스크립트에서 RewardTrainer 클래스를 찾아보세요.
loss 계산식이 그와 같이 설계된 이유는 무엇일까요?

위의 reward_modeling.py스크립트에서 RewardDataCollatorWithPadding 클래스를 찾아보세요.
해당 클래스의 기능은 무엇일까요?

3단계 (RLHF) :

1단계에서 준비한 LM(RL policy)으로 prompt와 response를 수집하고,
이를 2단계에서 준비한 RM에 넣어 reward를 계산합니다.
LM의 카피본에도 같은 prompt와 response를 넣어 KL-divergence 값을 구해 최종 reward를 계산합니다.
StackLLaMA에서는 peft를 사용하므로 RL policy의 LoRA가중치를 PPO로 최적화합니다.
rl_training 참고
ppo_trainer 참고

위의 rl_training.py 스크립트에서 question_tensors와 response_tensors, rewards 변수를 찾아보세요.
각 변수는 어떻게 생성될까요?

어떠셨나요?

StackExchange라는 특정 도메인에 국한된 데이터셋을 사용해서,
ChatGPT와 같은 Human labeling 방식을 간접적으로 모방한 RM을 학습시켜 모델 전체를 업데이트하지 않고
head만 PPO로 업데이트하는 방식을 취한 것이 StackLLaMA와 ChatGPT의 큰 차이점이라고 볼 수 있겠습니다.

자 그러면 다음 노트에서는 우리가 자세히 살펴보지 않은 나머지 기술적인 부분들을 훑어보고,
LLM과 Emergent abilities에서 시작된 여러가지 질문들에 대해 각자의 답을 내려보는 시간을 가져보겠습니다.

Comments