ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [26' ICLR] Dataless Weight Disentanglement in Task Arithmetic via Kronecker-Factored Approximate Curvature
    논문 리뷰 2026. 3. 20. 10:28

    한 줄 요약

    Task vector 합산 시 발생하는 cross-task interference를 외부 데이터 없이 최소화하기 위해, KFAC(Kronecker-Factored Approximate Curvature) 기반 regularization을 fine-tuning 과정에 적용하여, 단순한 task vector 덧셈만으로도 SOTA급 multi-task 성능을 달성하는 TAK 프레임워크를 제안한다.

     

    문제점

    Task Arithmetic에서의 Representation Drift

    여러 task vector를 단순 합산하면, task t'의 벡터가 추가될 때 기존 task t의 last-layer activation이 변화하는 representation drift가 발생한다.
    이 drift는 다음과 같이 정량화된다:

    Δ_t→t,t'(x) = ||z_t,t'(x) - z_t(x)||²₂

    이로 인해 개별 task에서는 우수한 task vector가 합산 시 심각한 성능 저하를 유발한다.

    기존 해결책의 한계

    Linearized fine-tuning 계열:

    • Weight disentanglement를 보장하지만, 추가적인 forward-backward pass로 인한 학습 비용 2배 증가가 단점이다.
    • Jacobian projection (τ_Jp)은 다른 task의 데이터에 접근해야 하므로 dataless 조건을 위반한다.

    Post-hoc merging 계열:

    • TIES-Merging, DARE 등은 학습 이후 task vector를 조작하므로, 학습 과정에서의 drift를 근본적으로 방지하지 못한다.
    • 최적 scaling coefficient α에 대한 민감한 grid search가 필요하다.

    TaLoS:

    • Sparse fine-tuning으로 disentanglement를 개선하지만, 명시적 curvature 정보를 활용하지 않는다.

     

    제안 방법

    Representation Drift의 수학적 분석

    Linearized model 가정 하에서, representation drift는 Jacobian Gram matrix를 통한 이차 형식으로 표현된다:

    L^drift_t→t,t'(τ_t') = α²_t' · τ_t'^T · G_t(θ_0) · τ_t'

    여기서 G_t(θ_0) = (1/|D_t|) Σ_x J_θf(x,θ_0)^T J_θf(x,θ_0)는 task t에 대한 Jacobian Gram matrix이다.

    핵심 연결: GGN과의 관계

    Jacobian Gram matrix가 Generalized Gauss-Newton (GGN) matrix의 한 instance임을 밝힌다.
    이 연결을 통해 기존 optimization 문헌의 효율적인 curvature 근사 기법을 활용할 수 있는 이론적 기반을 마련한다.

    KFAC 근사

    P×P 크기의 intractable한 GGN matrix를 layer별 block-diagonal + Kronecker product로 근사한다:

    G(vec W^l) ≈ B^l ⊗ A^l

    • A^l = (1/|D|) Σ_n a^l_n (a^l_n)^T: Input activation의 covariance (pre-activation)
    • B^l = E[g^l_{n,m} (g^l_{n,m})^T]: Output gradient의 covariance (backpropagation gradient)

    이를 통해 저장 공간이 O(D₁D₂ × D₁D₂)에서 O(D₁² + D₂²)로 대폭 감소한다.

    Multi-Task Regularization 목적 함수

    Task t'를 학습할 때, 다른 모든 task에 대한 drift를 방지하는 regularization을 추가한다:

    L_total = L_{D_t'}(τt') + β Σ{t≠t'} λ_t Σ_l τ^l_t'^T (B^l_t ⊗ A^l_t) τ^l_t'

    여기서 β는 전체 regularization 강도, λ_t = |D_t|/Σ|D_t|는 데이터셋 크기 기반 가중치이다.

    Regularizer Merging을 통한 O(1) 복잡도

    Task 수 T에 비례하는 regularization 비용을 줄이기 위해, per-task KFAC factor를 단일 surrogate로 병합한다:

    G_{-t'}(θ^l_0) ≈ (Σ{t≠t'} B^l_t) ⊗ (Σ{t≠t'} λ_t A^l_t)

    이를 통해 task 수에 무관한 상수 복잡도 O(1)를 달성하며, 경험적으로 naive multi-task regularization과 유사한 성능을 보인다.

    Dataless 속성

    핵심적으로, KFAC factor (A^l, B^l)는 각 task의 자체 학습 데이터만으로 사전 계산 가능하다.
    Task t'를 학습할 때 다른 task들의 원본 데이터에 접근할 필요 없이, 미리 계산된 Kronecker factor만 있으면 된다.
    이는 data privacydistributed learning 시나리오에서 중요한 실용적 이점이다.

     

    학습 상세

    Vision 도메인 (8 Vision Benchmark)

    • 아키텍처: CLIP ViT-B/32, ViT-B/16, ViT-L/14
    • 데이터셋: Caltech-101, DTD, EuroSAT, MNIST, RESISC45, SVHN, Sun397, Traffic (8개 task)
    • Linear fine-tuning: 1000 steps/task, LR 0.01, SGD (momentum 0.9), batch size 128
    • Non-linear fine-tuning: 5000 epochs/task, LR 0.001
    • KFAC 계산: Monte Carlo 1 sample, 128 examples/task, 약 4분 소요 (8 task 기준)
    • Regularization β: Validation set에서 tuning

    Language 도메인

    • 아키텍처: T5-base
    • 데이터셋: SNLI, MultiNLI, SICK, SciTail, RTE, QNLI (6개 task)
    • Linear fine-tuning: 1000 steps/task, LR 0.001, AdamW, batch size 32
    • Non-linear (attention-only): 5 epochs/task, LR 0.0001

    KFAC 계산 효율성

    • Exact B matrix 계산: 91.5s → Monte Carlo 근사: 0.2s (450배 가속)
    • 128-256 examples로 성능이 포화되어 소량 데이터로 충분하다.

     

    실험 결과

    Task Addition: Linearized Fine-Tuning Regime

    ViT-B/32 (Abs. / Norm.):

    • Linear FT (α=best): 78.8% / 89.9%
    • TaLoS: 79.7% / 90.8%
    • τ_Jp (non-dataless): 85.6% / 98.2%
    • TAK (Ours, α=1): 85.8% / 97.6%
    • TAK (Ours, α=best): 86.0% / 97.8%

    ViT-B/16 (Abs. / Norm.):

    • Linear FT (α=best): 82.0% / 90.9%
    • τ_Jp (non-dataless): 88.6% / 98.7%
    • TAK (Ours, α=1): 88.3% / 97.9%
    • TAK (Ours, α=best): 88.3% / 98.1%

    ViT-L/14 (Abs. / Norm.):

    • Linear FT (α=best): 88.0% / 94.8%
    • τ_Jp (non-dataless): 91.1% / 98.5%
    • TAK (Ours, α=1): 91.6% / 99.3%

    TAK는 dataless 조건임에도 불구하고, 다른 task 데이터에 접근하는 τ_Jp와 동등하거나 우수한 성능을 보였다.
    특히 ViT-L/14에서 TAK (99.3%)가 τ_Jp (98.5%)를 0.8%p 상회하였다.

    Task Addition: Non-Linear Fine-Tuning Regime

    ViT-B/32 (Abs. / Norm.):

    • Non-linear FT (α=best): 73.5% / 80.4%
    • TaLoS: 79.7% / 90.8%
    • Attn. Only FT + TAK: 83.1% / 91.3%

    Non-linear regime에서도 TAK를 결합하면 TaLoS를 능가하는 성능을 보였다.

    Task Negation

    모델 방법 Target ↓ Control ↑
    ViT-B/32 Linear FT 9.3% 60.5%
    ViT-B/32 TaLoS 11.0% 60.7%
    ViT-B/32 τ_Jp 6.7% 60.8%
    ViT-B/32 TAK 3.4% 62.4%
    ViT-L/14 τ_Jp 3.7% 73.0%
    ViT-L/14 TAK 3.5% 72.6%

    TAK는 task negation에서도 target accuracy를 3.4%까지 낮추면서 control accuracy를 유지하거나 개선하였다.
    이는 KFAC regularization이 task 간 분리를 매우 효과적으로 달성함을 보여준다.

    Language Task (T5-base)

    방법 Abs. Norm.
    Non-linear FT 75.7% 87.7%
    Linear FT 76.9% 92.8%
    TaLoS 76.3% 93.4%
    τ_Jp (non-dataless) 81.3% 100%
    TAK (Ours) 78.7% 98.9%

    Dataless 방법 중에서는 최고 성능이지만, τ_Jp (100%)와는 약간의 gap이 존재한다.

    KFAC Merging Heuristic 검증

    O(T) 복잡도의 naive multi-task regularization과 O(1) 복잡도의 TAK merged regularization을 비교하였다.
    ViT-B/32에서 naive(98.4%) vs TAK(97.6%)로 0.8%p 차이에 불과하며, 실용적으로 무시 가능한 수준이다.

    α 튜닝 불필요성

    TAK는 α=1에서 이미 near-optimal 성능을 달성하여, 별도의 scaling coefficient 탐색이 불필요하다.
    이는 다른 merging 방법(TIES, DARE 등)이 α에 민감한 것과 대조적이다.

    메모리 오버헤드

    • Linearized regime: 11.5GB → 12.9GB (+12%)
    • Attention-only regime: 6.8GB → 8.3GB (+22%)
    • KFAC factor 압축 시: 550MB → 70MB (87% 감소, ~1%p 성능 저하)

     

    한계점 및 개인 의견

    한계점

    1. Layer width에 대한 이차 복잡도: Kronecker factor 저장이 layer width의 제곱에 비례하여, 매우 큰 모델 (LLaMA-70B 등)에서는 메모리 부담이 커질 수 있다.
    2. Language 도메인에서의 gap: Vision에서는 τ_Jp를 능가하지만, T5-base에서는 98.9% vs 100%로 여전히 gap이 존재하며, text domain에서의 curvature 추정 개선이 필요하다.
    3. Merging heuristic의 근사 오차: Factor 병합 시 수학적으로 exact하지 않은 근사를 사용하며, 대규모 task 수에서의 오차 누적 가능성이 미검증이다.
    4. Linearization 가정 의존: 이론적 분석이 linearized model 가정에 기반하므로, non-linear regime에서의 동작에 대한 이론적 보장이 부족하다.
    5. β 하이퍼파라미터 튜닝: Regularization 강도 β는 여전히 validation set에서의 tuning이 필요하다.

    개인 의견

    본 논문은 task arithmetic의 핵심 문제인 cross-task interference를 curvature 관점에서 체계적으로 분석하고, KFAC라는 잘 정립된 도구를 활용하여 해결한다는 점에서 이론적 완성도가 높다.
    특히 representation drift를 GGN matrix로 공식화하고, 이를 KFAC로 tractable하게 근사하는 과정이 자연스럽고 설득력 있다.

    Multi-Delta 간 interference 최소화 관점에서, 이 방법론은 각 task vector를 학습할 때 다른 task들의 curvature landscape를 존중하도록 강제한다는 점에서 근본적이다.
    Post-hoc pruning (TIES, DARE)이나 학습 시 sparsity (TaLoS)와는 달리, "어떤 방향으로 움직이면 안 되는지"를 명시적으로 지정하는 접근이다.

    실무적으로 가장 매력적인 점은 α=1에서 near-optimal이라는 것이다.
    기존 merging 방법들이 α에 매우 민감하여 별도 validation이 필요했던 것과 달리, TAK는 단순히 task vector를 더하기만 하면 된다.
    이는 validation data가 없는 시나리오에서 매우 유용하다.

    향후 연구에서는 KFAC의 메모리 효율적 버전이나, LLM 스케일에서의 적용 가능성 검증이 중요할 것이다.
    또한 TaLoS와 TAK를 결합하는 실험 (Attn. Only FT + TAK)이 이미 유망한 결과를 보여주고 있어, structured sparsity + curvature regularization의 시너지를 더 탐구할 가치가 있다.

     

    논문 정보 및 리소스

    • 논문 제목: Dataless Weight Disentanglement in Task Arithmetic via Kronecker-Factored Approximate Curvature
    • 저자: Angelo Porrello, Pietro Buzzega, Felix Dangel, Thomas Sommariva, Riccardo Salami, Lorenzo Bonicelli, Simone Calderara
    • 학회: ICLR 2026
    • arXiv: https://arxiv.org/abs/2602.17385
    반응형
Designed by Tistory.