언어 모델 Continual Learning의 Spurious Forgetting
Spurious Forgetting in Continual Learning of Language Models
TL;DR Highlight
LLM이 새 태스크 학습 후 성능이 떨어지는 건 지식을 잊어서가 아니라 태스크 정렬이 깨져서이며, 하위 레이어만 얼려도 대부분 막을 수 있다.
Who Should Read
LLM을 순차적으로 파인튜닝하거나 안전 정렬(safety alignment) 이후 추가 파인튜닝을 고려하는 ML 엔지니어. 특히 모델이 새 태스크를 학습한 뒤 이전 태스크 성능이 갑자기 추락하는 현상을 경험한 적 있는 사람.
Core Mechanics
- LLM이 새 태스크를 학습할 때 성능이 떨어지는 건 '지식 손실'이 아니라 '태스크 정렬(task alignment) 붕괴'다 — 겨우 10개 예제로 복구 가능하다는 실험으로 입증
- 새 태스크 학습 초반 150 스텝이 문제의 핵심 — 이 구간에서 이전 태스크 alignment를 빠르게 덮어쓴다
- 하위 레이어(임베딩 포함)가 task alignment를 담당하며, 여기서 가중치 업데이트가 직교(orthogonal) 방향으로 일어날 때 spurious forgetting이 발생한다
- 하위 레이어를 동결(Freeze)하기만 해도 SEQ 방식의 Task 0 정확도가 11% → 44%로 향상 — EWC, LAMOL, Gradient Projection 등 기존 방법 모두 22% 이하
- LLaMa-2-7B-Chat의 safety alignment에 적용 시 jailbreak rate 99.80% → 1.15% (6레이어 동결 기준)
- LLaMa-3-8B-Instruct, Qwen2.5-7B-Instruct, Mistral-8B-Instruct-2410 등 여러 모델의 수학·코드 SFT에서도 일반 능력 저하를 Freeze로 완화 확인
Evidence
- Biography 합성 데이터셋에서 Freeze(7레이어 + early stop)가 Task 0 정확도 44.22% 달성 vs SEQ 11.18%, 최고 경쟁 방법(Task Vector) 30.75%
- Safety Alignment: Freeze 6레이어 적용 시 jailbreak rate 99.80% → 1.15% (LLaMa-2-7B-Chat)
- 복구 실험: Task 1 학습 150스텝 후에도 recovered Task 0 정확도 96% 유지 → 지식 자체는 살아있음을 확인
- LLaMa-3-8B-Instruct 수학 SFT(lr=5e-6): 일반 능력 평균 64.15 → Freeze 적용 시 66.11, 수학 능력 유지(80.29→80.17)
How to Apply
- 순차 파인튜닝 시 첫 번째 태스크 학습 직후부터 하위 1~3 레이어(+임베딩)를 동결하고 이후 태스크를 학습하면 된다. 태스크 포맷이 비슷할수록(QA→QA 등) 더 많은 레이어를 얼리는 게 유리하다
- 안전 정렬 이후 추가 파인튜닝이 필요한 경우(예: 도메인 특화 지식 주입), 하위 6레이어를 동결한 채 파인튜닝하면 safety alignment 붕괴를 대폭 억제할 수 있다
- 코드/수학 SFT처럼 단일 태스크 파인튜닝에서도 하위 1레이어만 동결해도 일반 능력 저하를 줄이면서 목표 성능을 유지할 수 있다 — 파라미터 절반 이하만 업데이트하므로 학습 비용도 절감
Code Example
# HuggingFace Transformers로 하위 N 레이어 동결하는 예시
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
N_FREEZE_LAYERS = 2 # 동결할 하위 레이어 수
# 임베딩 + 하위 N 레이어 동결
for param in model.model.embed_tokens.parameters():
param.requires_grad = False
for i in range(N_FREEZE_LAYERS):
for param in model.model.layers[i].parameters():
param.requires_grad = False
# 학습 가능한 파라미터 수 확인
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"학습 파라미터: {trainable/total:.1%}")
# 이후 일반 파인튜닝 진행Terminology
Related Resources
Original Abstract (Expand)
Recent advancements in large language models (LLMs) reveal a perplexing phenomenon in continual learning: despite extensive training, models experience significant performance declines, raising questions about task alignment and underlying knowledge retention. This study first explores the concept of"spurious forgetting", proposing that such performance drops often reflect a decline in task alignment rather than true knowledge loss. Through controlled experiments with a synthesized dataset, we investigate the dynamics of model performance during the initial training phases of new tasks, discovering that early optimization steps can disrupt previously established task alignments. Our theoretical analysis connects these shifts to orthogonal updates in model weights, providing a robust framework for understanding this behavior. Ultimately, we introduce a Freezing strategy that fix the bottom layers of the model, leading to substantial improvements in four continual learning scenarios. Our findings underscore the critical distinction between task alignment and knowledge retention, paving the way for more effective strategies in continual learning.