LLM의 Knowledge Attribution 탐지: 모델이 '어디서' 답했는지 알 수 있을까
Probing for Knowledge Attribution in Large Language Models
TL;DR Highlight
LLM 내부 hidden state에서 간단한 선형 분류기만으로 모델이 컨텍스트를 썼는지 파라메트릭 메모리를 썼는지 0.96 F1로 구분할 수 있다.
Who Should Read
RAG 파이프라인에서 모델이 retrieved context를 무시하고 hallucination을 내뱉는 문제를 디버깅하려는 백엔드/ML 엔지니어. LLM 응답의 신뢰도를 높이거나 hallucination 감지 시스템을 구축 중인 개발자.
Core Mechanics
- LLM이 답변을 생성할 때 '프롬프트/컨텍스트에서 읽었는지' vs '내부 파라미터 기억에서 꺼냈는지' 정보가 중간~상위 레이어 hidden state에 선형으로 인코딩되어 있음
- Llama-3.1-8B, Mistral-7B, Qwen2.5-7B에서 가중치 학습 가능한 레이어별 집계 + 로지스틱 회귀만으로 최대 Macro-F1 0.96 달성
- SQuAD, WebQuestions 같은 완전히 다른 도메인 데이터셋에도 재학습 없이 0.94~0.99 정확도로 전이됨
- Attribution mismatch(엉뚱한 소스 사용)가 있을 때 오답률이 최대 70% 증가 — 특히 misleading context가 parametric 기억을 override할 때 심각
- 복잡한 MLP 분류기는 entity 반복 같은 lexical shortcut을 학습해 오히려 robustness가 떨어짐. 단순 선형 probe가 더 안정적
- AttriWiki라는 자동 데이터 파이프라인 공개: Wikipedia + GPT-4o-mini로 '컨텍스트 전용' vs '파라미터 전용' 예제를 자동 생성
Evidence
- Layer-weighted 로지스틱 회귀가 Mistral-7B에서 LTE 토큰 기준 Macro-F1 0.961 달성 (Final-Layer 단독 0.904 대비 +5.7pp)
- SQuAD(contextual), WebQuestions(parametric) 외부 벤치마크에서 Qwen2.5-7B 기준 각각 0.997, 0.999 정확도 — 재학습 없이
- Attribution mismatch 시 misleading context 조건에서 오답률 최대 70% 증가, parametric 기억 우선 사용 조건에서는 30% 증가
- 텍스트 기반 BoW/임베딩 분류기는 F1 0.65~0.68에 그쳐 hidden state 없이는 attribution 판단이 어려움을 입증
How to Apply
- RAG 시스템에서 모델 응답 생성 시 첫 번째 생성 토큰(FTG)의 hidden state를 추출해 attribution probe를 통과시키면, 모델이 retrieved context를 실제로 사용했는지 실시간으로 감지할 수 있다. context 무시로 판단되면 re-retrieval 또는 경고 트리거 가능.
- Llama/Mistral/Qwen 계열 모델에서 AttriWiki 파이프라인을 자체 도메인 데이터(법률, 의료 등)로 재현해 도메인 특화 attribution classifier를 만들면, 특정 도메인에서 hallucination 발생 패턴을 사전에 포착할 수 있다.
- 챗봇 UI에서 attribution 결과를 노출할 때 '이 답변은 제공된 문서 기반' vs '모델 내부 지식 기반'으로 구분 표시하면, 사용자가 답변 신뢰도를 스스로 판단하도록 유도할 수 있다.
Code Example
# Hidden state 기반 attribution probe 추론 예시 (Llama-3.1-8B 기준)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
import numpy as np
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
def get_first_token_hidden_states(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# 모든 레이어 hidden states: (num_layers+1, seq_len, hidden_size)
hidden_states = outputs.hidden_states
# 마지막 입력 토큰(첫 번째 생성 직전)의 각 레이어 representation
layer_reps = torch.stack([h[0, -1, :] for h in hidden_states[1:]]) # (L, H)
return layer_reps.cpu().numpy()
# probe 학습 (AttriWiki 데이터로 학습된 가중치 로드 가정)
# alpha: 각 레이어 softmax 가중치, probe: 학습된 LogisticRegression
def predict_attribution(prompt: str, alpha: np.ndarray, probe) -> str:
layer_reps = get_first_token_hidden_states(prompt) # (L, H)
# 레이어 가중 평균
blended = (alpha[:, None] * layer_reps).sum(axis=0) # (H,)
pred = probe.predict([blended])[0]
return "contextual" if pred == 1 else "parametric"
# 사용 예시
context_prompt = "Based on this document: [DOC]. Q: What is X? A:"
result = predict_attribution(context_prompt, alpha, probe)
if result == "parametric":
print("⚠️ 모델이 제공된 컨텍스트를 무시하고 내부 지식에서 답변 중")Terminology
Related Resources
Original Abstract (Expand)
Large language models (LLMs) often generate fluent but unfounded claims, or hallucinations, which fall into two types: (i) faithfulness violations - misusing user context - and (ii) factuality violations - errors from internal knowledge. Proper mitigation depends on knowing whether a model's answer is based on the prompt or its internal weights. This work focuses on the problem of contributive attribution: identifying the dominant knowledge source behind each output. We show that a probe, a simple linear classifier trained on model hidden representations, can reliably predict contributive attribution. For its training, we introduce AttriWiki, a self-supervised data pipeline that prompts models to recall withheld entities from memory or read them from context, generating labelled examples automatically. Probes trained on AttriWiki data reveal a strong attribution signal, achieving up to 0.96 Macro-F1 on Llama-3.1-8B, Mistral-7B, and Qwen-7B, transferring to out-of-domain benchmarks (SQuAD, WebQuestions) with 0.94-0.99 Macro-F1 without retraining. Attribution mismatches raise error rates by up to 70%, demonstrating a direct link between knowledge source confusion and unfaithful answers. Yet, models may still respond incorrectly even when attribution is correct, highlighting the need for broader detection frameworks.