On-Policy Distillation: LLM 증류 성공 조건과 실패 복구 전략 분석

서론

3주간의 fine-tuning을 마치고 드디어 7B 모델을 1.5B로 증류하려던 순간, 학생 모델의 성능이 교사 모델은커녕 증류 전보다도 하락하는 것을 발견했다면? 이는 최근 LLM 연구 커뮤니티가 직면한 공공연한 비밀이다. On-Policy Distillation(OPD)은 GPT-4나 Claude급 교사 모델의 능력을 경량 모델에 전이하는 핵심 기술로 주목받고 있지만, 실무에서는 실패 사례가 성공 못지않게 흔하다.

2025년 arXiv에 발표된 “Rethinking On-Policy Distillation of Large Language Models” 연구는 바로 이 지점을 정면으로 공략한다. OPD가 왜 어떤 경우에는 획기적인 성능 향상을 가져오고, 어떤 경우에는 오히려 모델을 망가뜨리는지에 대한 체계적인 분석을 제공한다. 단순히 “더 많은 데이터, 더 좋은 하이퍼파라미터” 같은 처방이 아니다. 이 연구는 토큰 수준에서의 메커니즘을 파헤치고, 실패한 증류를 복구하는 구체적인 전략까지 제시한다.

대규모 언어 모델을 실제 서비스에 배포하려는 엔지니어에게, 혹은 제한된 컴퓨팅 자원으로 최고의 성능을 짜내야 하는 연구자에게, 이 연구의 발견은 증류 파이프라인 전체를 재설계하는 계기가 될 수 있다.

본론

On-Policy Distillation의 본질: 행동 공간에서의 만남

전통적인 knowledge distillation이 정적 데이터셋에서 교사의 출력 분포를 모방하는 것이라면, OPD는 동적이다. 학생 모델이 스스로 생성한 시퀀스 위에서 교사 모델의 피드백을 받는 방식이다. 이 차이가 왜 중요할까?

1
2
3
4
5
6
graph TD
    A[학생 모델 샘플링] --> B[시퀀스 생성]
    B --> C[교사 모델 평가]
    C --> D[토큰별 보상 계산]
    D --> E[정책 그래디언트 업데이트]
    E --> A

OPD의 핵심은 학생이 방문한 상태(state)에서 교사가 어떤 토큰을 선호하는지 학습하는 것이다. 하지만 여기에 근본적인 딜레마가 존재한다. 학생이 결코 방문하지 않는 상태 공간에 있는 교사의 지식은 전이될 수 없다. 반대로, 학생과 교사의 사고 패턴이 전혀 다르다면, 교사의 피드백은 일관성 없는 노이즈가 된다.

OPD 성공의 두 가지 조건

연구가 식별한 두 조건은 직관적이면서도 깊은 통찰을 담고 있다.

조건 1: 호환되는 사고 패턴 공유

학생과 교사가 유사한 추론 경로를 따라야 한다. 예를 들어, 수학 문제를 풀 때 교사가 “근의 공식을 사용하자"고 접근하는데 학생이 “대입법을 시도하자"는 방향으로 간다면, 토큰 수준에서의 정렬은 불가능해진다.

조건 2: 진정한 새로운 능력 제공

더 흥미로운 발견은 두 번째 조건이다. 교사와 학생의 사고 패턴이 완벽히 일치하더라도, 교사가 학생의 훈련 분포 내에 있는 능력만 보여준다면 증류는 실패한다. 연구진은 “weak-to-strong reverse distillation” 실험으로 이를 증명했다. 1.5B 모델이 7B 교사를 모방하려 할 때, 같은 모델 패밀리라면 분포적으로 구분이 불가능할 정도로 유사해진다. 즉, 교사가 진정으로 “더 나은” 능력을 보여주지 못하는 것이다.

토큰 수준 메커니즘: 97-99%의 확률 질량 집중

성공적인 OPD의 가장 인상적인 특징은 토큰 분포의 집중 현상이다.

| 지표 | 성공한 OPD | 실패한 OPD | | :— | :— | :— | | 공유 토큰 비율 | 전체 어휘의 극소수 | 상대적으로 넓은 분포 | | 확률 질량 집중도 | 97%-99% | 70%-85% | | 정렬 패턴 | 점진적 수렴 | 불안정한 발산 | | 교사-학생 KL 발산 | 0.1-0.3 (낮음) | 1.5-3.0 (높음) |

이는 무엇을 의미할까? 성공적인 증류에서 교사와 학생은 어휘 집합의 극히 일부분에 대해서만 의미 있는 합의를 이룬다. 나머지 토큰들은 무시해도 될 정도로 낮은 확률을 가진다. 이 “작은 공유 토큰 집합"이 바로 증류가 실제로 일어나는 공간이다.

실전 구현: PyTorch 기반 OPD 파이프라인

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

class OnPolicyDistillator:
    def __init__(
        self, 
        teacher_name: str, 
        student_name: str,
        temperature: float = 2.0,
        alpha: float = 0.5
    ):
        self.teacher = AutoModelForCausalLM.from_pretrained(
            teacher_name, 
            torch_dtype=torch.bfloat16
        )
        self.student = AutoModelForCausalLM.from_pretrained(
            student_name,
            torch_dtype=torch.bfloat16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(teacher_name)
        self.temperature = temperature
        self.alpha = alpha  # KD loss weight
        
        self.teacher.eval()
    
    def compute_shared_token_mass(
        self, 
        teacher_logits: torch.Tensor,
        student_logits: torch.Tensor,
        top_k: int = 10
    ) -> dict:
        """공유 토큰 집합의 확률 질량을 계산"""
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        student_probs = F.softmax(student_logits / self.temperature, dim=-1)
        
        # 각각의 top-k 토큰 추출
        _, teacher_top_indices = torch.topk(teacher_probs, top_k, dim=-1)
        _, student_top_indices = torch.topk(student_probs, top_k, dim=-1)
        
        # 공유 토큰 마스크 생성
        shared_mask = torch.zeros_like(teacher_probs, dtype=torch.bool)
        for i in range(top_k):
            shared_mask.scatter_(
                -1, 
                teacher_top_indices[:,:,i:i+1], 
                True
            )
            shared_mask.scatter_(
                -1, 
                student_top_indices[:,:,i:i+1], 
                True
            )
        
        shared_mass_teacher = (
            teacher_probs * shared_mask
        ).sum(dim=-1).mean().item()
        shared_mass_student = (
            student_probs * shared_mask
        ).sum(dim=-1).mean().
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
item()
        
        return {
            "shared_mass_teacher": shared_mass_teacher,
            "shared_mass_student": shared_mass_student,
            "shared_token_ratio": shared_mask.float().mean().item()
        }
    
    def distillation_step(self, input_ids: torch.Tensor) -> dict:
        """단일 증류 스텝"""
        # 학생으로 시퀀스 생성 (on-policy)
        with torch.no_grad():
            student_outputs = self.student.generate(
                input_ids,
                max_new_tokens=128,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
        
        # 생성된 시퀀스에 대해 교사/학생 로짓 계산
        with torch.no_grad():
            teacher_logits = self.teacher(student_outputs).logits
        student_logits = self.student(student_outputs).logits
        
        # 공유 토큰 질량 모니터링
        metrics = self.compute_shared_token_mass(
            teacher_logits[:, -128:, :], 
            student_logits[:, -128:, :]
        )
        
        # KL divergence loss
        teacher_probs = F.softmax(
            teacher_logits / self.temperature, dim=-1
        )
        student_log_probs = F.log_softmax(
            student_logits / self.temperature, dim=-1
        )
        
        kd_loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        return {"loss": kd_loss, "metrics": metrics}

이 코드에서 주목할 점은 compute_shared_token_mass 함수다. 실시간으로 교사-학생 간 공유 토큰 집합의 확률 질량을 모니터링하여, 증류가 올바른 방향으로 진행되는지 진단할 수 있다.

실패 복구 전략 1: Off-Policy Cold Start

1
2
3
4
5
graph LR
    A[Cold Start 실패 상태] --> B[Off-Policy 데이터로 초기화]
    B --> C[교사 시퀀스로 예열]
    C --> D[On-Policy로 전환]
    D --> E[정상 증류 진행]

가장 실용적인 복구 전략은 off-policy cold start다. 학생 모델이 처음부터 on-policy로 시작하면, 초기의 낮은 품질 시퀀스 때문에 교사의 피드백이 무의미해진다. 대신, 교사 모델이 생성한 고품질 시퀀스로 먼저 학생을 예열(warm-up)하는 것이다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def off_policy_cold_start(
    distillator: OnPolicyDistillator,
    prompts: list[str],
    warmup_steps: int = 500
):
    """Off-policy cold start 구현"""
    optimizer = torch.optim.AdamW(
        distillator.student.parameters(),
        lr=1e-5
    )
    
    for step in range(warmup_steps):
        prompt = prompts[step % len(prompts)]
        input_ids = distillator.tokenizer(
            prompt, return_tensors="pt"
        ).input_ids.cuda()
        
        # 교사가 시퀀스 생성 (off-policy)
        with torch.no_grad():
            teacher_seq = distillator.teacher.generate(
                input_ids.cuda(),
                max_new_tokens=256,
                do_sample=True,
                temperature=0.8
            )
        
        # 교사 시퀀스에 대해 학생 학습
        with torch.no_grad():
            teacher_logits = distillator.teacher(
                teacher_seq
            ).logits
        
        student_logits = distillator.student(teacher_seq).logits
        
        # 표준 KD loss 계산
        loss = F.kl_div(
            F.log_softmax(student_logits / 2.0, dim=-1),
            F.softmax(teacher_logits / 2.0, dim=-1),
            reduction='batchmean'
        ) * 4.0
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if step % 100 == 0:
            print(f"Warmup step {step}, Loss: {loss.item():.4f}")
    
    return distillator

실패 복구 전략 2: Teacher-Aligned Prompt Selection

두 번째 전략은 프롬프트 선택의 문제로 접근한다. 모든 프롬프트가 증류에 동등하게 유용한 것이 아니다. 교사 모델의 사고 패턴과 정렬된 프롬프트만 선별하여 학습하는 것이다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def select_teacher_aligned_prompts(
    teacher_model,
    student_model,
    tokenizer,
    candidate_prompts: list[str],
    alignment_threshold: float = 0.7
) -> list[str]:
    """교사와 정렬된 프롬프트 선별"""
    aligned_prompts = []
    
    for prompt in candidate_prompts:
        input_ids = tokenizer(
            prompt, return_tensors="pt"
        ).input_ids
        
        # 각 모델의 응답 생성
        with torch.no_grad():
            teacher_response = teacher_model.generate(
                input_ids.cuda(),
                max_new_tokens=64,
                do_sample=False  # 결정론적 비교
            )
            student_response = student_model.generate(
                input_ids.cuda(),
                max_new_tokens=64,
                do_sample=False
            )
        
        # 응답 유사도 계산 (간소화된 버전)
        teacher_text = tokenizer.decode(
            teacher_response[0], skip_special_tokens=True
        )
        student_text = tokenizer.decode(
            student_response[0], skip_special_tokens=True
        )
        
        # 토큰 오버랩 기반 유사도
        teacher_tokens = set(teacher_text.split())
        student_tokens = set(student_text.split())
        
        if len(teacher_tokens) == 0:
            continue
            
        overlap = len(teacher_tokens & student_tokens)
        similarity = overlap / len(teacher_tokens)
        
        if similarity >= alignment_threshold:
            aligned_prompts.append(prompt)
    
    return aligned_prompts

OPD의 숨겨진 비용: Long-Horizon의 함정

연구의 마지막 부분은 다소 우울한, 하지만 반드시 직면해야 할 통찰을 제공한다. OPD의 “공짜 점심” 같은 dense token-level reward에는 숨겨진 비용이 있다.

1
2
3
4
5
6
7
8
graph TD
    A[OPD 시작] --> B{시퀀스 길이}
    B -->|짧은 시퀀스| C[안정적 학습]
    B -->| 시퀀스| D[에러 누적]
    D --> E[분포 편향 증가]
    E --> F[교사-학생 간극 확대]
    C --> G[성공적 증류]
    F --> H[증류 실패]

짧은 시퀀스에서는 각 토�


출처: http://arxiv.org/abs/2604.13016v1

Hugo로 만듦
JimmyStack 테마 사용 중