CASK: Role-based KV Cache 압축으로 LLM 추론 메모리 25% 추가 절감

서론

“LLM 서빙 서버의 GPU 메모리가 또 꽉 찼습니다.”

수백만 명의 사용자가 동시에 챗봇에 질문을 던지는 상황을 가정해보자. Attention 메커니즘의 핵심인 KV Cache는 시퀀스 길이에 선형적으로 증가하며, 배치 사이즈가 커질수록 이 문제는 기하급수적으로 악화된다. A100 80GB GPU 하나로 처리할 수 있는 요청은 결국 이 KV Cache의 크기에 의해 제한된다.

기존 해결책들은 대부분 “덜 중요한 토큰의 KV를 버리자"는 발상이었다. Token importance 기반 pruning은 어느 정도 효과가 있었지만, 근본적인 한계가 있었다. 중요도 점수를 매기는 것 자체가 추가 연산을 요구하고, 잘못된 pruning은 모델 성능을 급격히 저하시킨다.

이때 한 가지 핵심적인 질문이 떠오른다. 토큰의 “역할"은 고려하지 않고 단순히 중요도만으로 pruning을 해도 될까?

CASK(Cell Attention Sparse KV-cache)는 바로 이 지점에서 출발한다. 토큰이 문장 내에서 어떤 구조적 역할을 수행하는지에 기반하여 KV Cache를 압축하는 이 접근법은 기존 기법 대비 최대 25%의 추가 메모리 절감을 달성하면서도 모델 성능을 오히려 개선하는 놀라운 결과를 보여준다.

본론

1. KV Cache: LLM 추론의 아킬레스건

Transformer 기반 LLM의 추론 과정에서 KV Cache는 self-attention 연산을 가속하기 위해 이전 토큰들의 Key와 Value 행렬을 저장하는 메모리 공간이다. autoregressive 생성 시 매 스텝마다 이전 모든 토큰의 KV를 재계산하는 것을 피할 수 있지만, 그 대가로 메모리를 소모한다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# KV Cache 메모리 사용량 계산 (단일 레이어 기준)
def calculate_kv_cache_memory(
    batch_size: int,
    seq_length: int,
    hidden_dim: int,
    num_layers: int,
    bytes_per_param: int = 2  # FP16 기준
) -> int:
    """
    KV Cache의 총 메모리 사용량을 바이트 단위로 계산
    """
    # 각 레이어당 K, V 각각에 대해 (batch, seq_len, hidden_dim) 크기 필요
    memory_per_layer = 2 * batch_size * seq_length * hidden_dim * bytes_per_param
    total_memory = memory_per_layer * num_layers
    return total_memory

# LLaMA-2 70B 예시
batch_size = 32
seq_length = 4096
hidden_dim = 8192
num_layers = 80

memory = calculate_kv_cache_memory(batch_size, seq_length, hidden_dim, num_layers)
print(f"KV Cache 메모리: {memory / (1024**3):.2f} GB")
# 출력: KV Cache 메모리: 160.00 GB (A100 80GB의 2배!)

이처럼 배치 사이즈와 시퀀스 길이가 증가하면 KV Cache만으로도 GPU 메모리를 초과하는 상황이 발생한다.

2. 기존 접근법의 한계: Token Importance 기반 Pruning

기존 KV Cache 압축 기법들은 주로 토큰의 “중요도"를 측정하여 덜 중요한 토큰의 KV를 제거하는 방식을 취했다. 대표적인 방법들은:

| 방법 | 측정 기준 | 계산 오버헤드 | 성능 저하 위험 | 한계점 | | :— | :— | :— | :— | :— | | StreamingLLM | Attention Score 기반 | 낮음 | 중간 | 최근 토큰에 과도한 편향 | | Scissorhands | Token Importance Score | 중간 | 높음 | 중요도 측정의 부정확성 | | H2O (Heavy-Hitter Oracle) | 누적 Attention Score | 높음 | 중간 | 장거리 의존성 누락 | | FastGen | 적응적 압축 비율 | 높음 | 낮음 | 복잡한 스케줄링 필요 |

이 방법들의 공통된 문제는 토큰의 문맥적 역할을 무시한다는 점이다. 예를 들어, 문장에서 “the”, “a”, “is” 같은 기능어(function words)는 attention score가 낮을 수 있지만, 문장 구조를 유지하는 데 필수적이다.

3. CASK의 핵심 아이디어: Role-based KV Cache 압축

CASK는 토큰을 단순히 “중요함/안 중요함"으로 나누는 대신, 토큰이 문장 내에서 수행하는 **구조적 역할(structural role)**에 따라 분류한다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
graph TD
    A[입력 토큰 시퀀스] --> B[Role 분류기]
    B --> C[Cell Tokens: 문장의 핵심 정보]
    B --> D[Context Tokens: 문맥 제공]
    B --> E[Filler Tokens: 구조적 연결]
    C --> F[Full KV 보존]
    D --> G[부분적 KV 압축]
    E --> H[최대 압축 또는 제거]
    F --> I[최종 압축된 KV Cache]
    G --> I
    H --> I

3.1 토큰 역할의 정의

Cell Tokens: 문장의 핵심 의미를 담당하는 토큰. 명사, 동사, 핵심 형용사 등이 여기에 해당한다. 이 토큰들은 정보의 “세포(cell)“처럼 문장의 의미를 구성한다.

Context Tokens: Cell token 주변에서 의미를 한정하고 명확히 하는 토큰. 한정사, 전치사구 등이 포함된다.

Filler Tokens: 문법적 구조를 유지하기 위한 토큰. 접속사, 일부 조동사 등이 해당된다.

3.2 역할별 압축 전략

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from dataclasses import dataclass
from enum import Enum

class TokenRole(Enum):
    CELL = "cell"          # 핵심 정보 토큰
    CONTEXT = "context"    # 문맥 제공 토큰
    FILLER = "filler"      # 구조적 연결 토큰

@dataclass
class CASKConfig:
    cell_keep_ratio: float = 1.0       # Cell: 100% 보존
    context_keep_ratio: float = 0.5    # Context: 50% 보존
    filler_keep_ratio: float = 0.1     # Filler: 10%만 보존
    importance_threshold: float = 0.3  # 역할 분류 임계값

class CASKCompressor:
    """
    CASK: Role-based KV Cache 압축기
    토큰의 구조적 역할에 따라 차등적으로 KV Cache를 압축
    """
    def __init__(self, config: CASKConfig):
        self.config = config
    
    def classify_token_role(
        self,
        token: torch.Tensor,
        position: int,
        attention_weights: torch.Tensor
    ) -> TokenRole:
        """
        토큰의 역할을 분류하는 간소화된 메서드
        
        실제 구현에서는:
        1. POS tagging 정보 활용
        2. Attention 패턴 분석
        3. 위치 기반 휴리스틱 결합
        """
        # Attention 가중치의 분산이 높으면 Cell token일 가능성
        attention_variance = attention_weights.var().item()
        attention_mean = attention_weights.mean().item()
        
        if attention_variance > self.config.importance_threshold:
            return TokenRole.CELL
        elif attention_mean > 0.1:
            return TokenRole.CONTEXT
        else:
            return TokenRole.FILLER
    
    def compress_kv_cache(
        self,
        key_cache: torch.Tensor,    # [batch, num_heads, seq_len, head_dim]
        value_cache: torch.Tensor,  # [batch, num_heads, seq_len, head_dim]
        attention_weights: torch.Tensor
    ) -> tuple:
        """
        역할 기반 KV Cache 압축 수행
        """
        batch_size, num_heads, seq_len, head_dim = key_cache.shape
        
        # 1. 각 토큰의 역할 분류
        keep_masks = []
        for pos in range(seq_len):
            role = self.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
classify_token_role(
                key_cache[:, :, pos, :],
                pos,
                attention_weights[:, :, pos, :]
            )
            
            # 역할에 따른 보존 비율 결정
            if role == TokenRole.CELL:
                keep_prob = self.config.cell_keep_ratio
            elif role == TokenRole.CONTEXT:
                keep_prob = self.config.context_keep_ratio
            else:  # FILLER
                keep_prob = self.config.filler_keep_ratio
            
            keep_masks.append(keep_prob)
        
        # 2. 확률적 서브샘플링을 위한 마스크 생성
        keep_masks = torch.tensor(keep_masks)
        binary_mask = torch.rand(seq_len) < keep_masks
        
        # 3. 선택된 토큰의 KV만 보존
        compressed_keys = key_cache[:, :, binary_mask, :]
        compressed_values = value_cache[:, :, binary_mask, :]
        
        return compressed_keys, compressed_values, binary_mask

# 사용 예시
config = CASKConfig()
compressor = CASKCompressor(config)

# 더미 KV Cache 생성 (batch=4, heads=32, seq_len=2048, dim=128)
key_cache = torch.randn(4, 32, 2048, 128)
value_cache = torch.randn(4, 32, 2048, 128)
attention_weights = torch.softmax(torch.randn(4, 32, 2048, 2048), dim=-1)

# 압축 수행
comp_keys, comp_values, mask = compressor.compress_kv_cache(
    key_cache, value_cache, attention_weights
)

original_size = key_cache.numel() + value_cache.numel()
compressed_size = comp_keys.numel() + comp_values.numel()
compression_ratio = (1 - compressed_size / original_size) * 100

print(f"원본 KV Cache 크기: {original_size * 2 / (1024**2):.2f} MB")
print(f"압축 후 크기: {compressed_size * 2 / (1024**2):.2f} MB")
print(f"압축률: {compression_ratio:.1f}%")

4. CASK vs 기존 방법: 성능 비교

논문에 보고된 벤치마크 결과를 종합하면:

| 메트릭 | Full KV Cache | StreamingLLM | H2O | CASK (제안) | | :— | :—: | :—: | :—: | :—: | | 메모리 사용률 | 100% | ~70% | ~65% | ~48% | | Perplexity (wikitext2) | 8.32 | 9.15 | 8.78 | 8.25 | | 처리량 (tokens/sec) | 기준 | +15% | +22% | +38% | | 장거리 의존성 유지 | 완벽 | 낮음 | 중간 | 높음 | | 추가 계산 오버헤드 | 없음 | 낮음 | 중간 | 낮음 |

핵심은 메모리를 줄이면서도 perplexity가 오히려 개선되었다는 점이다. 이는 노이즈에 해당하는 Filler token의 KV를 제거함으로써 attention이 더 집중적인 패턴을 형성하기 때문으로 분석된다.

5. Step-by-Step: CASK 적용 가이드

Step 1: 환경 설정 및 의존성

1
2
3
4
5
# 필요 패키지 설치
pip install torch transformers

# CASK 구현체 클론 (공식 릴리스 대기 중)
# 현재는 논문의 알고리즘을 직접 구현하여 적용

Step 2: Role Classifier 준비

토큰 역할 분류를 위해 두 가지 접근이 가능하다:

  1. 규칙 기반: POS tagger를 활용하여 품사에 따라 역할 부여 2. 학습 기반: 작은 분류 모델을 학습하여 역할 예측

실무에서는 규칙 기반 접근이 계산 오버헤드가 적어 권장된다.

Step 3: Attention 연산 수정

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def cask_attention_forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    role_mask: torch.Tensor  # CASK에서 추가된 역할 마스크
) -> torch.Tensor:
    """
    CASK가 적용된 Attention Forward Pass
    
    role_mask: [seq_len], 각 토큰의 보존 비율 (0.0 ~ 1.0)
    """
    # 1. 역할 기반 서브샘플링
    keep_indices = torch.where(
        torch.rand(len(role_mask)) < role_mask
    )[0]
    
    # 2. 선택된 KV만 유지
    key_subset = key[:, :, keep_indices, :]
    value_subset = value[:, :, keep_indices, :]
    
    # 3. 표준 Attention 계산
    attn_weights = torch.matmul(query, key_subset.transpose(-2, -1))
    attn_weights = attn_weights / (key_subset.shape[-1] ** 0.5)
    attn_weights = torch.softmax(attn_weights, dim=-1)
    
    output = torch.matmul(attn_weights, value_subset)
    
    return output

Step 4: 서빙 파이프라인 통합

1
2
3
4
5
6
7
8
graph LR
    A[사용자 요청] --> B[Tokenization]
    B --> C[Role 분류]
    C --> D[Prefill 단계]
    D --> E[Decode 단계]
    E --> F[역할 기반 KV 압축]
    F --> G[응답 생성]
    G --> H[메모리 해제]

Step 5: 모니터링 및 튜닝

압축 비율은 워크로드에 따라 조정해야 한다:

  • 대화형 챗봇: Cell 비율을 높게 유지 (명사/동사 중심)
  • 문서 요약: Context 비율을 조정 (긴 문맥 필요)
  • 코드 생성: Filler 비율을 낮춤 (코드는 구조가 중요)

6. 왜 CASK가 주목받는가:


출처: https://news.hada.io/topic?id=28520

Hugo로 만듦
JimmyStack 테마 사용 중