Theory 3 papers

Theory Digest — Apr 20, 2026

Today’s Digest at a Glance

Preliminary

Today’s papers tackle discrete diffusion modeling through continuous-time Markov chains, theoretical foundations of adversarial training for large language models, and scaling laws for mixture-of-experts architectures.

Continuous-Time Markov Chains for Discrete Diffusion

Discrete diffusion models face a fundamental challenge: unlike continuous diffusion where Gaussian noise enables tractable reverse processes, discrete state spaces lack natural diffusion kernels. The naive approach of directly parameterizing discrete transition matrices leads to complex coupling between when transitions occur and where they go, making optimization difficult.

Continuous-Time Markov Chains (CTMCs) solve this by separating the timing and direction of state transitions. A CTMC is characterized by a rate matrix $Q(t)$ where $Q_{ij}(t)$ gives the instantaneous rate of transitioning from state $i$ to state $j$. The key insight is decomposing this into exit rates $\lambda_i(t) = -Q_{ii}(t)$ (how fast we leave state $i$) and jump distributions $r_{ij}(t) = Q_{ij}(t)/\lambda_i(t)$ (where we go when leaving). This decoupling allows independent parameterization of “when” and “where” components.

Instead of modeling complex discrete transition matrices directly, CTMCs let us model exit rates as scalar functions and jump distributions as categorical distributions, making the optimization landscape much more tractable.

Linear Self-Attention with Embedding (LSA-E)

Understanding why embedding-space adversarial training helps against token-space attacks requires bridging the gap between continuous perturbations during training and discrete token substitutions at test time. Standard transformer analysis tools are too complex for theoretical tractability.

LSA-E simplifies transformers to enable theoretical analysis while preserving key properties. It uses linear attention (removing softmax) and explicit embedding matrices, giving predictions of the form $\hat{y} = W^V \frac{EE^T}{N} W^{KQ} x$ where $E$ contains token embeddings and $W^{KQ}, W^V$ are parameter matrices. This linearity allows closed-form analysis of how embedding perturbations propagate through the model.

The model captures the essential mechanism: adversarial training in embedding space creates robustness that transfers to discrete token substitutions because both operate through the same embedding lookup process.

Reading Guide

The Neural CTMC paper introduces a novel decomposition approach for discrete diffusion that could inform the architectural choices in the MoE scaling analysis. The adversarial training analysis provides theoretical grounding for robustness techniques that complement the generalization bounds derived for MoE models. All three papers demonstrate how careful mathematical modeling can yield both theoretical insights and practical improvements.


Neural Continuous-Time Markov Chain: Discrete Diffusion via Decoupled Jump Timing and Direction

Authors: Jingyuan Li, Xiaoyi Jiang, Fukang Wen, Wei Liu et al. (8 authors) · Institution: Tsinghua University, Beijing Institute of Mathematical Sciences and Applications · Category: cs.LG

Decomposes discrete diffusion reverse processes into separate exit rate and jump distribution components, achieving strong empirical results with uniform forward processes.

Tags: discrete diffusion continuous-time Markov chains CTMC path measures variational bounds language modeling score-based generative models jump processes

arXiv · PDF

Problem Formulation

Motivation: Discrete diffusion models based on continuous-time Markov chains (CTMCs) have shown strong performance on language and discrete data generation. However, existing approaches parameterize the reverse rate matrix as a single object rather than aligning with the intrinsic CTMC decomposition into jump timing and direction. This misalignment with the fundamental structure of CTMCs may limit model performance.

Mathematical setup: Let $S = {1, \ldots, S}$ be a finite state space. A CTMC over $[0,T]$ is characterized by a rate matrix $R_t$ with $R_t(i,j) \geq 0$ for $i \neq j$ and $\sum_j R_t(i,j) = 0$.

The forward process corrupts data $X_0 \sim p_{data}$ via a CTMC with rate matrix $R_t$, yielding conditionals $q_{t 0}(x x_0)$ and marginals $q_t(x)$.

Any rate matrix admits the unique decomposition:

\[R_t(i,j) = \lambda_t(i) \cdot r_t(j|i)\]

where:

\[\lambda_t(i) := \sum_{j \neq i} R_t(i,j) \quad \text{(exit rate)}\] \[r_t(j|i) := \frac{R_t(i,j)}{\lambda_t(i)} \quad \text{(jump distribution)}\]

The true reverse CTMC has rate matrix:

\[\hat{R}_t(i,j) = R_t(j,i) \frac{q_t(j)}{q_t(i)}\]

Assumptions:

  1. Forward quantities $(q_t, q_{t 0}, p(x_0 x_t))$ are independent of $\theta$
  2. $R_t^\theta(i,j) > 0$ and differentiable in $\theta$ for all $i \neq j$
  3. Differentiation can be exchanged with expectation/integration

    Toy example: When $S = 2$ with uniform forward process $R_t(1,2) = R_t(2,1) = \beta_t$, the exit rates are $\lambda_t(i) = \beta_t$ and jump distributions are $r_t(j i) = 1$ (deterministic jump to the other state). The reverse process becomes $\hat{R}_t(i,j) = \beta_t \frac{q_t(j)}{q_t(i)}$.

    Formal objective: Minimize the negative log-likelihood:

    \[\min_\theta \mathbb{E}_{x_0 \sim p_{data}}[-\log p_\theta(x_0)]\]
Method

Method: Neural CTMC separately parameterizes the reverse process through exit rate and jump distribution using two dedicated network heads:

\[\Phi_\theta(x_t, t) = (\lambda_t^\theta(x_t), r_t^\theta(\cdot | x_t))\] \[R_t^\theta(i,j) = \lambda_t^\theta(i) \cdot r_t^\theta(j|i) \quad \text{for } j \neq i\]

Training objective: The method shows that the $\theta$-dependent part of the ELBO equals a reverse-process KL that decomposes as:

\[\text{KL}(\hat{Q} \| P_\theta) = \int_0^T \sum_{i \in S} q_t(i) \left[ \text{KL}_{Poi}(\hat{\lambda}_t(i) \| \lambda_t^\theta(i)) + \hat{\lambda}_t(i) \cdot \text{KL}_{Cat}(\hat{r}_t(\cdot|i) \| r_t^\theta(\cdot|i)) \right] dt\]

where:

\[\text{KL}_{Poi}(\lambda \| \lambda^\theta) := \lambda \log \frac{\lambda}{\lambda^\theta} - \lambda + \lambda^\theta\]

The practical training loss becomes:

\[L(\theta) = \mathbb{E}_{t,x_0,x_t} \left[ \sum_{j \neq x_t} \lambda_t^\theta(x_t) r_t^\theta(j|x_t) - \sum_{j \neq x_t} R_t(j,x_t) \frac{q_{t|0}(j|x_0)}{q_{t|0}(x_t|x_0)} \log(\lambda_t^\theta(x_t) r_t^\theta(j|x_t)) \right] + C\]
Application to toy example: For the $S=2$ case, the method would learn $\lambda_t^\theta(1)$, $\lambda_t^\theta(2)$ (scalar exit rates) and trivial jump distributions $r_t^\theta(2 1) = r_t^\theta(1 2) = 1$, with the Poisson KL terms driving the learning of when to jump between the two states.
Novelty & Lineage

Prior work:

  1. D3PM/Multinomial Diffusion (Austin et al. 2021, Hoogeboom et al. 2021): Introduced discrete diffusion using transition matrices in discrete time.
  2. SEDD (Lou et al. 2024): Parameterized reverse process via concrete score ratios $q_t(j)/q_t(x_t)$, achieving strong language modeling results.
  3. MDLM (Sahoo et al. 2024): Simplified training via clean-data prediction under masked forward process.

    Delta: This paper introduces the first reverse-process parameterization that explicitly aligns with the intrinsic CTMC decomposition into exit rate and jump distribution. Prior methods implicitly determine the full rate matrix $R_t^\theta(i,j)$ through proxy quantities, while this work directly parameterizes the timing ($\lambda_t^\theta$) and direction ($r_t^\theta$) components separately.

    Theory-specific assessment:

    • Main theorem: The decomposition of the reverse-process KL into Poisson KL (timing) + categorical KL (direction) is mathematically natural but not particularly surprising given the CTMC structure.
    • Proof technique: Routine application of path measure theory and Campbell-Mecke formula. The key insight is recognizing that the path-space objective inherits the timing/direction factorization, but this follows directly from CTMC theory.
    • Bounds: No new concentration or approximation bounds are provided. The work shows equivalence between conditional and marginal objectives under standard regularity assumptions.

    The theoretical contribution is solid but incremental - it makes explicit a decomposition that was implicit in the CTMC structure. The main value lies in showing this decomposition leads to improved empirical performance with uniform forward processes.

    Verdict: INCREMENTAL — The theoretical decomposition follows naturally from CTMC theory, though the empirical demonstration that this structure improves uniform-noise discrete diffusion is valuable.

Proof Techniques

Main proof strategy:

  1. Path measure decomposition: Uses classical CTMC path measure theory to express the log Radon-Nikodym derivative:

    \[\log \frac{dP_\theta}{dQ_{x_0}}(\omega) = \log p_{prior}(x_n) + \sum_{k=1}^n \log \frac{R_t^{\theta}(x_k, x_{k-1})}{R_t(x_{k-1}, x_k)} + \int_0^T [\lambda_t^{fwd}(X_t) - \lambda_t^\theta(X_t)] dt\]
  2. Importance sampling + Jensen’s inequality: Applies Jensen to get ELBO:

    \[-\log p_\theta(x_0) \leq -\mathbb{E}_{q_{T|0}}[\log p_{prior}(X_T)] + \int_0^T \sum_{i \in S} q_{t|0}(i|x_0) \ell_t(i) dt\]
  3. Campbell-Mecke formula: Key technical tool for handling jump processes. For measurable function $f$:

    \[\mathbb{E}_{Q_{x_0}}\left[\sum_{k=1}^N f(t_k, X_{t_k^-}, X_{t_k})\right] = \int_0^T \sum_{i \in S} q_{t|0}(i|x_0) \sum_{j \neq i} R_t(i,j) f(t,i,j) dt\]
  4. Rate matrix substitution: Exploits the reverse rate relationship $\hat{R}_t(j,i) = R_t(i,j) \frac{q_t(j)}{q_t(i)}$ and the factorization $R_t(i,j) = \lambda_t(i) r_t(j i)$ to decompose:
    \[\sum_{j \neq i} \hat{R}_t(i,j) \log \frac{\hat{R}_t(i,j)}{R_t^\theta(i,j)} = \hat{\lambda}_t(i) \log \frac{\hat{\lambda}_t(i)}{\lambda_t^\theta(i)} + \hat{\lambda}_t(i) \sum_{j \neq i} \hat{r}_t(j|i) \log \frac{\hat{r}_t(j|i)}{r_t^\theta(j|i)}\]
  5. Conditional-marginal equivalence: Shows that replacing intractable marginal quantities $q_t(x)$ with tractable conditional counterparts $q_{t 0}(x x_0)$ preserves gradients and minimizers under regularity assumptions.
Experiments & Validation

Datasets:

  • Binarized MNIST: Flattened to sequences of length 784 over $S = {0,1,\ldots,255}$
  • TinyStories: Small-scale language dataset
  • OpenWebText: Large-scale language corpus (262B tokens for fair comparison)

Baselines: SEDD, MDLM, GIDD with various interpolation parameters $p_{unif} \in {0.0, 0.1, 0.2}$

Key numbers:

  • TinyStories: Neural CTMC achieves ≤16.36 generative perplexity vs ≤37.60 for GIDD and ≤42.66 for MDLM
  • OpenWebText: At 32 sampling steps, τ-leaping achieves ≤258.8 vs ≤553.7 (MDLM) and ≤398.9 (GIDD)
  • First uniform-noise method to outperform mask-based approaches on OpenWebText
  • Remains competitive with SEDD despite using 2.6× fewer training tokens (262B vs 682B)

Architecture: Uses DiT (Diffusion Transformer) backbone with separate heads for exit rate and jump distribution

Samplers: Two sampling schemes tested - τ-leaping (allows multiple jumps per step) and Euler (single jump per step)

Limitations & Open Problems

Limitations:

  1. Regularity assumptions - TECHNICAL: The equivalence between conditional and marginal objectives requires standard regularity conditions (differentiability, exchange of differentiation/integration) that are typically satisfied but not verified.

  2. Uniform forward process restriction - NATURAL: Empirical evaluation focuses primarily on uniform forward processes, though the framework theoretically handles general forward processes including masked/GIDD-style schedules.

  3. Finite vocabulary assumption - NATURAL: Framework assumes discrete finite state space $S = {1,\ldots,S}$, standard for discrete diffusion but excludes continuous or countably infinite spaces.

  4. Single-step approximation in Euler sampler - TECHNICAL: Euler sampling requires $\lambda_t^\theta(x_t) \tau \leq 1$ constraint for valid probability distributions, limiting step sizes.

  5. Training budget comparison - TECHNICAL: OpenWebText comparison with SEDD uses different training token counts (262B vs 682B), making direct comparison imperfect.

    Open problems:

  6. Theoretical characterization of uniform vs masked forward processes: Why does the timing/direction decomposition particularly benefit uniform-noise processes? Can this be formalized theoretically?

  7. Optimal sampling schemes: Can the natural τ-leaping sampler be improved further, or are there other sampling schemes that better exploit the exit rate/jump distribution structure?


Understanding and Improving Continuous Adversarial Training for LLMs via In-context Learning Theory

Authors: Shaopeng Fu, Di Wang · Institution: King Abdullah University of Science and Technology · Category: cs.LG

First theoretical analysis explaining why embedding-space adversarial training improves LLM robustness against token-space jailbreak attacks, leading to improved embedding regularization method.

Tags: adversarial training large language models jailbreak attacks in-context learning transformer theory embedding regularization continuous optimization robust generalization

arXiv · PDF

Problem Formulation
  1. Motivation: Large language models (LLMs) are vulnerable to jailbreak attacks that induce harmful responses. While adversarial training (AT) defends against these attacks, standard AT on LLMs is computationally expensive due to discrete optimization in token space. Continuous AT (CAT) addresses this by performing adversarial perturbations in the embedding space, but the underlying mechanism—why embedding-space perturbations help defend against token-space attacks—remains theoretically unexplained.

  2. Mathematical setup: Consider an LLM with embedding function $E(\cdot)$ parameterized by embedding matrix $W^E \in \mathbb{R}^{d \times V }$, where $d$ is embedding dimension and $ V $ is vocabulary size. For any token sequence $x \in V^{ x }$, define the embedding as:
    \[E(x) := (W^E_{:,x_1} \cdots W^E_{:,x_{|x|}}) \in \mathbb{R}^{d \times |x|}\]

    Standard CAT solves:

    \[\min_\theta \left\{ -\alpha \cdot \mathbb{E}_{(x,y,\tilde{y}) \in D^{(h)}} \left[ \log p_\theta(\tilde{y}|E(x) + \delta^*) - \log p_\theta(y|E(x) + \delta^*) \right] - \mathbb{E}_{(x,y) \in D^{(u)}} \log p_\theta(y|x) \right\}\]
    where $\delta^* = \arg\max_{|\delta_1|_2, \ldots, |\delta_{ x }|_2 \leq \epsilon} \log p_\theta(y E(x) + \delta)$ is the embedding-space adversarial perturbation.

    Key assumptions:

    1. Linear transformer approximation (LSA-E model)
    2. In-context learning setup with Gaussian task weights $w_\tau \sim \mathcal{N}(0, I_{d_0})$
    3. Input covariance structure: $x_{\tau,i}, x_{\tau,q} \sim \mathcal{N}(0, \Lambda)$ where $\Lambda \in \mathbb{R}^{d_0 \times d_0}$
  3. Toy example: When $d_0 = 2$, $\Lambda = I_2$, and context length $N = 2$, the ICL input becomes:

    \[Z_\tau = \begin{pmatrix} x_{\tau,1} & x_{\tau,2} & x_{\tau,q} \\ y_{\tau,1} & y_{\tau,2} & 0 \end{pmatrix} \in \mathbb{R}^{3 \times 3}\]

    The core difficulty is bridging why adversarial perturbations $\delta^E_{\tau,i} \in \mathbb{R}^d$ added to embeddings $W^E x_{\tau,i}$ improve robustness against perturbations $\delta^O_{\tau,i} \in \mathbb{R}^{d_0}$ added directly to inputs $x_{\tau,i}$.

  4. Formal objective: Prove a robust generalization bound for the ICL embedding AT risk:

    \[\mathcal{R}^{adv}_{\rho,M}(\theta^*) = \mathbb{E}_\tau \left[ \max_{\|\Delta^O_\tau\|_{2,\infty} \leq \rho} \frac{1}{2}|\hat{y}_{q,\theta}(Z^{adv}_{\tau,M}) - y_{\tau,q}|^2 \right]\]
Method

The paper develops a Linear Self-Attention with Embedding (LSA-E) model to theoretically analyze CAT. The LSA-E model prediction is:

\[\hat{y}_{q,\theta}(Z_\tau) := \begin{pmatrix} (w^V_{21})^\top \\ w^V_{22} \end{pmatrix}^\top \frac{E(Z_\tau)E(Z_\tau)^\top}{N} \begin{pmatrix} W^{KQ}_{11} \\ (w^{KQ}_{21})^\top \end{pmatrix} W^E x_{\tau,q}\]

The ICL embedding AT formulation is:

\[\min_\theta \mathbb{E}_\tau \left[ \max_{\|\Delta^E_\tau\|_{2,\infty} \leq \epsilon} \frac{1}{2}|\hat{y}^{adv}_{q,\theta}(Z_\tau, \Delta^E_\tau) - y_{\tau,q}|^2 \right]\]

The method involves three key steps:

  1. Derive surrogate upper bound: Replace the original minimax objective with a tractable upper bound consisting of four terms $\ell_1(\theta) + \ell_2(\theta) + \ell_3(\theta) + \ell_4(\theta)$.

  2. Solve closed-form optimization: Under initialization assumptions, the optimal embedding matrix satisfies:

    \[w^{V*}_{22}(W^{E*})^\top W^{KQ*}_{11} W^{E*} = (W^{E*})^\top \left[ W^{E*} \Gamma_N \Lambda (W^{E*})^\top + \text{Tr}(\Lambda)\epsilon^2 I_d \right]^{-1} W^{E*} \Lambda\]

    where $\Gamma_N := \frac{N+1}{N}\Lambda + \frac{1}{N}\text{Tr}(\Lambda)I_{d_0}$.

  3. Prove robust bound: Establish the main generalization bound relating embedding-space AT to input-space robustness.

    Applied to toy example: When $d_0 = d = 2$, $N = 2$, $\Lambda = I_2$, the method computes the optimal $2 \times 2$ embedding matrix and shows how embedding-space perturbations with radius $\epsilon$ improve robustness against input-space perturbations with radius $\rho$.

Novelty & Lineage

Step 1 — Prior work:

  • Zhang et al. (2024) “Trained transformers learn linear models in-context” analyzed ICL for linear transformers but without embedding modules or adversarial training
  • Fu et al. (2025) “Short-length adversarial training helps LLMs defend long-length jailbreak attacks” studied adversarial training for ICL models but only with input-space perturbations
  • Xhonneux et al. (2024) “Efficient adversarial training in LLMs with continuous attacks” introduced continuous AT empirically but without theoretical foundation

Step 2 — Delta: This paper provides the first theoretical explanation for why embedding-space adversarial training improves input-space robustness. Key contributions:

  1. Novel LSA-E model bridging embedding and attention mechanisms
  2. Closed-form solution for ICL embedding AT
  3. Robust generalization bound showing negative correlation between embedding perturbation radius $\epsilon$ and input robustness.

    Step 3 — Theory-specific assessment:

    • Main theorem is moderately surprising: the connection between embedding-space and input-space robustness was empirically known but theoretically unexplained
    • Proof technique combines standard ICL analysis with novel embedding regularization, requiring new technical machinery but building on established ICL theory
    • Bound tightness unclear: no matching lower bounds presented; the $O(\frac{\sum_{i=1}^d \sigma_i(W^E)^4}{\sigma_{\min}(W^E)^4 + \epsilon^4})$ dependence seems reasonable but optimality unknown

    The theoretical framework successfully explains empirical CAT success, but the specific bound may not be tight. The analysis requires strong assumptions (linear transformers, Gaussian distributions) that limit practical applicability.

    Verdict: SIGNIFICANT — First rigorous theoretical explanation of continuous adversarial training mechanism, with actionable insights for improving LLM robustness through embedding regularization.

Proof Techniques

The proof strategy involves three main stages:

  1. Upper bound derivation: Apply Cauchy-Schwarz and triangle inequalities to bound the original minimax objective. The key inequality is:

    \[|\hat{y}^{adv}_{q,\theta}(Z_\tau, \Delta^E_\tau) - y_{\tau,q}|^2 \leq 2|\hat{y}_{q,\theta}(Z_\tau) - y_{\tau,q}|^2 + \text{perturbation terms}\]

    This leads to the surrogate upper bound with explicit dependence on $\epsilon^2$ and $\epsilon^4$ terms.

  2. Gradient flow analysis: Under initialization Assumption 1, prove that certain parameters ($w^{KQ}_{21}$, $w^V_{21}$) remain zero during continuous gradient flow. This uses the key observation:

    \[\frac{\partial \tilde{L}^{adv}_{LSAE}(\theta)}{\partial w^V_{21}} = 0 \text{ and } \frac{\partial \tilde{L}^{adv}_{LSAE}(\theta)}{\partial w^{KQ}_{21}} = 0\]

    at initialization and throughout training.

  3. Closed-form solution via matrix calculus: The optimal embedding matrix satisfies the matrix equation derived using matrix derivatives. The crucial step uses the identity:

    \[\frac{\partial}{\partial W^E} \text{Tr}(W^E A (W^E)^\top B) = AB^\top (W^E)^\top + BA^\top (W^E)^\top\]
  4. Robust bound proof: The main technical insight uses expectation manipulation for Gaussian random vectors. Key inequalities include:

    \[\mathbb{E}[\|W^E X_\tau\|^2_2] = \text{Tr}(W^E \Lambda (W^E)^\top)\] \[\max_{\|\Delta^O_\tau\|_{2,\infty} \leq \rho} \|(W^E)^\top \Delta^O_\tau\|^2_2 \leq M\rho^2 \sigma_{\max}(W^E)^2\]

    The final bound combines matrix perturbation theory with concentration inequalities, leveraging the structure:

    \[\mathcal{R}^{adv}_{\rho,M}(\theta^*) \leq O\left(\frac{(1 + M\rho^2/N^2) \sum_{i=1}^d \sigma_i(W^{E*})^4}{\sigma_{\min}(W^{E*})^4 + \epsilon^4}\right) + O(1)\]

    The proof critically uses the embedding matrix singular value structure to connect embedding-space perturbations to input-space robustness.

Experiments & Validation

Six real-world LLMs tested: Vicuna-7B-v1.5, Mistral-7B-Instruct-v0.3, Llama-2-7B-Chat, Llama-3.1-8B-Instruct, Qwen2.5-7B-Instruct, Gemma-2B-it.

Datasets: HarmBench (safety), UltraChat 200K (utility), AdvBench (evaluation), AlpacaEval (utility metric).

Six jailbreak attacks: GCG, BEAST, GCQ, Zhu’s AutoDAN (token-level); DeepInception, PAIR (prompt-level).

Key experimental results:

  • ER-CAT achieves better robustness-utility tradeoff than standard CAT
  • Vicuna/Mistral: ER-CAT maintains similar robustness (ASR within 4%) while doubling utility (LC-WinRate)
  • Llama-2/Qwen2.5: ER-CAT reduces ASR by 7-11% while maintaining utility within 3%
  • Computational overhead: only 100-200 seconds additional training time
  • Ablation studies show robustness to hyperparameter β selection

Results validate theoretical predictions that embedding matrix singular value regularization improves robustness-utility balance.

Limitations & Open Problems

Limitations:

  1. Linear transformer assumption - RESTRICTIVE: Real LLMs use non-linear attention and feed-forward layers, limiting direct applicability of theoretical results.

  2. Gaussian distribution assumption - TECHNICAL: ICL inputs assumed Gaussian with specific covariance structure, needed for tractable analysis but unrealistic for language data.

  3. Embedding dimension constraint $d \leq d_0$ - TECHNICAL: Required for proof technique but counterintuitive since LLM embeddings typically have higher dimension than input features.

  4. In-context learning setup - RESTRICTIVE: Analysis limited to regression tasks within ICL framework, whereas jailbreak attacks target general language generation.

  5. Surrogate objective bound - TECHNICAL: Upper bound may not be tight, potentially leading to loose generalization guarantees.

  6. Single-layer analysis - RESTRICTIVE: Multi-layer transformer dynamics not captured, limiting insights for deep LLMs.

    Open problems:

  7. Extend theoretical analysis to non-linear transformers with ReLU activations and multi-layer architectures that better approximate real LLMs.

  8. Develop embedding regularization techniques that provably improve robustness for more general adversarial attacks beyond suffix perturbations, including semantic-preserving jailbreaks.


Generalization and Scaling Laws for Mixture-of-Experts Transformers

Authors: Mansour Zoubeirou a Mayaki · Institution: Université Lyon 2, LIRIS UMR 5205 · Category: cs.LG

Develops generalization theory for MoE Transformers showing they scale like dense networks when measured against active parameters, with routing contributing logarithmic overhead in worst case.

Tags: mixture-of-experts transformer-theory scaling-laws generalization-bounds intrinsic-dimension sparse-neural-networks approximation-theory covering-numbers

arXiv · PDF

Problem Formulation
  1. Motivation: Mixture-of-Experts (MoE) Transformers enable conditional computation where only a subset of experts is activated per token, decoupling active parameters from total parameters. Despite empirical gains, existing scaling theory tracks error with total parameters, while MoE requires refined accounting for active capacity and routing overhead.

  2. Mathematical setup: Consider i.i.d. samples $(x_i, y_i)_{i=1}^n$ from unknown distribution $Q$ over $\mathcal{X} \times \mathcal{Y}$. MoE Transformer with $L_T$ blocks, each containing attention layer and MoE layer with $M$ experts, activating at most $k$ per token.

    For token embedding $\tilde{h}_t$, the MoE layer computes:

    \[h_t^{(j)} = \tilde{h}_t + \sum_{m=1}^M g_{j,m}(\tilde{h}_t) E_{j,m}(\tilde{h}_t)\]

    Routing constraints:

    \[g_{j,m}(h) \geq 0, \quad \sum_{m=1}^M g_{j,m}(h) = 1, \quad |\{m : g_{j,m}(h) \neq 0\}| \leq k\]

    Assumptions:

    1. Data lies on $d$-dimensional manifold $\mathcal{M} \subset \mathbb{R}^D$
    2. Target function $f \in C^{\beta}(\mathcal{M})$ with $|f|_{C^{\beta}} \leq B$
    3. Hard top-$k$ routing with bounded parameters $|\theta|_{\infty} \leq \kappa$
    4. Router can implement $k$-sparse partition of unity

    Active parameter budget:

    \[N_{\text{act}} := L_T \Pi_{\text{attn}} + L_T k \Pi_{\text{exp}}\]
  3. Toy example: Consider $d=2$ manifold, $L_T=2$ layers, $M=4$ experts with $k=2$ active per token. If each expert has $\Pi_{\text{exp}}=100$ parameters and attention has $\Pi_{\text{attn}}=50$, then $N_{\text{act}} = 2(50 + 2 \cdot 100) = 500$ active parameters.

  4. Formal objective: Minimize the population risk:

    \[\mathbb{E}\|\hat{T}_n - f\|_{L^2(Q)}^2\]
Method

The method decomposes MoE analysis into approximation and generalization components by conditioning on fixed routing patterns.

Key steps:

  1. Construct MoE approximator for $C^{\beta}$ function using $k$-sparse routing
  2. Derive covering number by union-bounding over routing patterns
  3. Apply standard ERM analysis with routing overhead

    For approximation (Theorem 3.2), construct MoE with experts as ReLU MLPs of depth $L_{\text{FFN}}$ and width $w_{\text{FFN}}$. The router implements $k$-sparse partition of unity on the manifold.

    For generalization (Theorem 3.4), condition on fixed routing pattern $\pi$, reducing MoE to deterministic active subnetwork. The covering number satisfies:

    \[\log \mathcal{N}(\delta, \mathcal{T}_{\text{MoE}}, \|\cdot\|_{\infty}) \leq C_1(N_{\text{act}}) \log\left(\frac{C_2\kappa R M_0}{\delta}\right) + C_3 L_T \ell k \log\left(\frac{eM}{k}\right)\]

    Applied to toy example: With $N_{\text{act}} = 500$, $L_T = 2$, $\ell = 128$, $k = 2$, $M = 4$, the routing term contributes $2 \cdot 128 \cdot 2 \cdot \log(e \cdot 4/2) \approx 1000$ to the covering number bound.

Novelty & Lineage

Step 1 — Prior work:

  • “Understanding scaling laws with statistical and approximation theory for transformer neural networks on intrinsically low-dimensional data” (Havrilla & Liao, 2024): developed dense Transformer approximation/generalization theory with intrinsic dimension $d$
  • “Scaling laws for neural language models” (Kaplan et al., 2020): established empirical power-law scaling for dense models with total parameters
  • “Switch transformers: scaling to trillion parameter models” (Fedus et al., 2022): introduced practical MoE architectures but without theoretical analysis

Step 2 — Delta: This paper extends intrinsic-dimension theory to MoE by:

  • Separating active capacity $N_{\text{act}}$ from total parameters in approximation bounds
  • Deriving routing-specific covering number with $L_T \ell k \log(eM/k)$ overhead term
  • Providing constructive approximation theorem for MoE architectures

Step 3 — Theory-specific assessment:

  • Main theorem is predictable: applies standard manifold approximation + union bound over routing patterns
  • Proof routine: assembles known covering number techniques with MoE-specific routing decomposition
  • Bounds appear loose: routing term from worst-case union bound over all $\binom{M}{k}$ patterns, while practice likely uses much fewer effective patterns
  • No known lower bounds for MoE-specific complexity

The approximation rate $N_{\text{act}}^{-2\beta/d}$ matches dense networks when measured against active parameters, with routing adding logarithmic overhead. The key insight is that MoE behaves like dense networks in active parameter budget, not total parameters.

Verdict: INCREMENTAL — solid extension of existing intrinsic-dimension theory to MoE setting with predictable techniques and likely loose bounds.

Proof Techniques

Main proof strategy involves conditioning on routing patterns to reduce MoE to deterministic subnetworks:

  1. Approximation construction (Theorem 3.2): Construct MoE approximator by: - Partitioning manifold $\mathcal{M}$ with $k$-sparse partition of unity - Each expert approximates local smooth function via standard ReLU approximation - Union of expert approximations yields global approximator

    Key approximation inequality:

    \[\inf_{T \in \mathcal{T}_{\text{MoE}}} \|T - f\|_{\infty}^2 \leq C \cdot \min\{N_{\text{act}}^{-2\beta/d}, M^{-2\beta/d}\}\]
  2. Covering number bound (Lemma 3.7): For fixed routing pattern $\pi$: - Parameter-to-function Lipschitz bound for attention and FFN layers - $\eta$-grid on parameter space $[-\kappa, \kappa]^{N_{\text{act}}}$ induces $\delta$-cover - Union bound over all routing patterns adds factor $\binom{M}{k}^{L_T \ell}$

    Key covering inequality:

    \[\log \mathcal{N}(\delta, \mathcal{T}_{\text{MoE}}, \|\cdot\|_{\infty}) \leq C_1 N_{\text{act}} \log\left(\frac{C_2\kappa R M_0}{\delta}\right) + C_3 L_T \ell k \log\left(\frac{eM}{k}\right)\]
  3. Generalization bound (Theorem 3.4): Apply standard ERM analysis: - Condition on routing pattern to get smooth active subnetwork - Apply Rademacher complexity bound with covering numbers - Union bound over routing patterns

    Key generalization bound:

    \[\mathbb{E}\|\hat{T}_n - f\|_{L^2(Q)}^2 \leq \varepsilon^2 + \tilde{O}\left(\frac{N_{\text{act}}}{n} + \frac{L_T \ell k \log(eM/k)}{n}\right)\]

    The technical insight is that conditioning on routing eliminates the combinatorial complexity during approximation/estimation analysis, relegating it to a multiplicative factor in covering numbers.

Experiments & Validation

Empirical evaluation on three text corpora: TinyStories (synthetic children’s stories), WikiText-103 (Wikipedia articles), OpenWebText (Reddit-linked web pages).

Key experimental components:

  1. Intrinsic dimension estimation using Levina-Bickel MLE on GPT-2 representations: $d \approx 23$ (TinyStories), $d \approx 32$ (WikiText-103), $d \approx 45$ (OpenWebText)

  2. Model scaling: sweep $(L_T, d_{ff}) \in {2,3,4,5,6,8} \times {256,…,1536}$ with MoE configs $M \in {4,8,16}$, $k \in {1,2}$. Empirical exponent $\hat{\alpha}_N \approx 0.06-0.09$, close to theoretical $\alpha_N = 2\beta/d$.

  3. Data scaling: vary token budget $D \in {5 \times 10^4, …, 8 \times 10^5}$ with fixed architecture. Empirical $\hat{\alpha}_D \approx 0.04-0.08$, consistent with theoretical $\alpha_D = 2\beta/(2\beta + d)$.

  4. Routing ablations: for moderate regime $M/k \leq 8$, validation loss increases with routing term $k\log(eM/k)$ as predicted. For larger $M/k$, performance improves due to specialization effects beyond worst-case analysis.

    Main findings: scaling exponents match intrinsic-dimension theory when measured against active parameters $N_{\text{act}}$, with routing contributing overhead in moderate regime but specialization gains in large expert pool regime.

Limitations & Open Problems

Limitations:

  1. RESTRICTIVE: Hard top-$k$ routing assumption - practical systems use soft routing with load balancing
  2. RESTRICTIVE: Bounded parameter assumption $|\theta|_{\infty} \leq \kappa$ - modern large models have unbounded parameters
  3. TECHNICAL: Union bound over all routing patterns is conservative - practical routing uses much fewer effective patterns
  4. RESTRICTIVE: Squared loss analysis - language modeling uses cross-entropy loss
  5. TECHNICAL: Router expressivity assumption (can implement $k$-sparse partition of unity) is strong
  6. RESTRICTIVE: Analysis limited to statistical complexity, ignores optimization dynamics and expert specialization

    Open problems:

  7. Develop data-dependent routing analysis that captures expert specialization effects observed empirically for large $M/k$
  8. Establish lower bounds for MoE-specific complexity to determine tightness of current upper bounds