Gradient-based Planning for World Models at Longer Horizons
.grasp-results-table table { font-size: 0.875rem; line-height: 1.35; width: 100%; }
.grasp-results-table th,
.grasp-results-table td { padding: 0.35rem 0.5rem; }
/* Consistent whitespace between major sections (this post is long and hr-heavy) */
article.post-content h2 {
margin-top: 2.75rem;
margin-bottom: 0.75rem;
}
article.post-content h2:first-of-type {
margin-top: 2.25rem;
}
article.post-content h3 {
margin-top: 1.65rem;
margin-bottom: 0.5rem;
}
article.post-content hr {
margin-top: 2.5rem;
margin-bottom: 2.5rem;
}
GRASP is a new gradient-based planner for learned dynamics (a “world model”) that makes long-horizon planning practical by (1) lifting the trajectory into virtual states so optimization is parallel across time, (2) adding stochasticity directly to the state iterates for exploration, and (3) reshaping gradients so actions get clean signals while we avoid brittle “state-input” gradients through high-dimensional vision models.
Large, learned world models are becoming increasingly capable. They can predict long sequences of future observations in high-dimensional visual spaces and generalize across tasks in ways that were difficult to imagine a few years ago. As these models scale, they start to look less like task-specific predictors and more like general-purpose simulators.
But having a powerful predictive model is not the same as being able to use it effectively for control/learning/planning. In practice, long-horizon planning with modern world models remains fragile: optimization becomes ill-conditioned, non-greedy structure creates bad local minima, and high-dimensional latent spaces introduce subtle failure modes.
In this blog post, I describe the problems that motivated this project and our approach to address them: why planning with modern world models can be surprisingly fragile, why long horizons are the real stress test, and what we changed to make gradient-based planning much more robust.
This blog post discusses work done with Mike Rabbat, Aditi Krishnapriyan, Yann LeCun, and Amir Bar (* denotes equal advisorship), where we propose GRASP.
What is a world model?
These days, the term “world model” is quite overloaded, and depending on the context can either mean an explicit dynamics model or some implicit, reliable internal state that a generative model relies on (e.g. when an LLM generates chess moves, whether there is some internal representation of the board). We give our loose working definition below.
Suppose you take actions $a_t \in \mathcal{A}$ and observe states $s_t \in \mathcal{S}$ (images, latent vectors, proprioception). A world model is a learned model that, given the current state and a sequence of future actions, predicts what will happen next. Formally, it defines a predictive distribution on a sequence of observed states $s_{t-h:t}$ and current action $a_t$:
\[P_\theta(s_{t+1} \mid s_{t-h:t},\; a_t)\]
that approximates the environment’s true conditional $P(s_{t+1} \mid s_{t-h:t},\; a_t)$. For this blog post, we’ll assume a Markovian model $P(s_{t+1} \mid s_{t-h:t},\; a_t)$ for simplicity (all results here can be extended to the more general case), and when the model is deterministic it reduces to a map over states:
\[s_{t+1} = F_\theta(s_t, a_t).\]
In practice the state $s_t$ is often a learned latent representation (e.g., encoded from pixels), so the model operates in a (theoretically) compact, differentiable space. The key point is that a world model gives you a differentiable simulator; you can roll it forward under hypothetical action sequences and backpropagate through the predictions.
Planning: choosing actions by optimizing through the model
Given a start $s_0$ and a goal $g$, the simplest planner chooses an action sequence $\mathbf{a}=(a_0,\dots,a_{T-1})$ by rolling out the model and minimizing terminal error:
\[\min_{\mathbf{a}} \; \| s_T(\mathbf{a}) - g \|_2^2, \quad \text{where } s_T(\mathbf{a}) = \mathcal{F}_{\theta}^{T}(s_0,\mathbf{a}).\]
Here we use $\mathcal{F}^T$ as shorthand for the full rollout through the world model (dependence on model parameters $\theta$ is implicit):
\[\mathcal{F}_{\theta}^{T}(s_0, \mathbf{a}) = F_\theta(F_\theta(\cdots F_\theta(s_0, a_0), \cdots, a_{T-2}), a_{T-1}).\]
In short horizons and low-dimensional systems, this can work reasonably well. But as horizons grow and models become larger and more expressive, its weaknesses become amplified.
So why doesn’t this just work at scale?
Why long-horizon planning is hard (even when everything is differentiable)
There are two separate pain points for the more general world model, plus a third that is specific to learned, deep learning-based models.
1) Long-horizon rollouts create deep, ill-conditioned computation graphs
Those familiar with backprop through time (BPTT) may notice that we’re differentiating through a model applied to itself repeatedly, which will lead to the exploding/vanishing gradients problem. Namely, if we take derivatives (note we’re differentiating vector-valued functions, resulting in Jacobians that we denote with $D_x (\cdots)$) with respect to earlier actions (e.g. $a_0$):
\[D_{a_0} \mathcal{F}_{\theta}^{T}(s_0, \mathbf{a}) = \Bigl(\prod_{t=1}^T D_s F_\theta(s_t, a_t)\Bigr) D_{a_0}F_\theta(s_0, a_0).\]
We see that the Jacobian’s conditioning scales exponentially with time $T$:
\[\sigma_{\text{max/min}}(D_{a_0}\mathcal{F}_{\theta}^{T}) \sim \sigma_{\text{max/min}}(D_s F_\theta)^{T-1},\]
leading to exploding or vanishing gradients.
2) The landscape is non-greedy and full of traps
At short horizons, the greedy solution, where we move straight toward the goal at every step, is often good enough. If you only need to plan a few steps ahead, the optimal trajectory usually doesn’t deviate much from “head toward $g$” at each step.
As horizons grow, two things happen. First, longer tasks are more likely to require non-greedy behavior: going around a wall, repositioning before pushing, backing up to take a better path. And as horizons grow, more of these non-greedy steps are typically needed. Second, the optimization space itself scales with horizon: $\mathrm{dim}(\mathcal{A} \times \cdots \times \mathcal{A}) = T\mathrm{dim}(\mathcal{A})$, further expanding the space of local minima for the optimization problem.
Distance to goal along the optimal path is non-monotonic, and the resulting loss landscape can be rough.
A long-horizon fix: lifting the dynamics constraint
Suppose we treat the dynamics constraint $s_{t+1} = F_{\theta}(s_t, a_t)$ as a soft constraint, and we instead optimize the following penalty function over both actions $(a_0,\ldots,a_{T-1})$ and states $(s_0,\ldots,s_T)$:
\[\min_{\mathbf{s},\mathbf{a}} \mathcal{L}(\mathbf{s}, \mathbf{a}) = \sum_{t=0}^{T-1} \big\|F_\theta(s_t,a_t) - s_{t+1}\big\|_2^2,
\quad \text{with } s_0 \text{ fixed and } s_T=g.\]
This is also sometimes called collocation in planning/robotics literature. Note the lifted formulation shares the same global minimizers as the original rollout objective (both are zero exactly when the trajectory is dynamically feasible). But the optimization landscapes are very different, and we get two immediate benefits:
Each world model evaluation $F_{\theta}(s_t,a_t)$ depends only on local variables, so all $T$ terms can be computed in parallel across time, resulting in a huge speed-up for longer horizons, and
You no longer backpropagate through a single deep $T$-step composition to get a learning signal, since the previous product of Jacobians now splits into a sum, e.g.:
\[D_{a_0} \mathcal{L} = 2(F_\theta(s_0, a_0) - s_1).\]
Being able to optimize states directly also helps with exploration, as we can temporarily navigate through unphysical domains to find the optimal plan:
Collocation-based planning allows us to directly perturb states and explore midpoints more effectively.
However, lunch is never free. And indeed, especially for deep learning-based world models, there is a critical issue that makes the above optimization quite difficult in practice.
An issue for deep learning-based world models: sensitivity of state-input gradients
The tl;dr of this section is: directly optimizing states through a deep learning-based $F_{\theta}$ is incredibly brittle, à la adversarial robustness. Even if you train your world model in a lower-dimensional state space, the training process for the world model makes unseen state landscapes very sharp, whether it be an unseen state itself or simply a normal/orthogonal direction to the data manifold.
Adversarial robustness and the “dimpled manifold” model
Adversarial robustness originally looked at classification models $f_\theta : \mathbb{R}^{w\times h \times c} \to \mathbb{R}^K$, and showed that by following the gradient of a particular logit $\nabla f_\theta^k$ from a base image $x$ (not of class $k$), you did not have to move far along $x’ = x + \epsilon\nabla f_\theta^k$ to make $f_\theta$ classify $x’$ as $k$ (Szegedy et al., 2014; Goodfellow et al., 2015):
Depiction of the classic example from (Goodfellow et al., 2015).
Later work has painted a geometric picture for what’s going on: for data near a low-dimensional manifold $\mathcal{M}$, the training process controls behavior in tangential directions, but does not regularize behavior in orthogonal directions, thus leading to sensitive behavior (Stutz et al., 2019). Another way stated: $f_\theta$ has a reasonable Lipschitz constant when considering only tangential directions to the data manifold $\mathcal{M}$, but can have very high Lipschitz constants in normal directions. In fact, it often benefits the model to be sharper in these normal directions, so it can fit more complicated functions more precisely.
As a result, such adversarial examples are incredibly common even for a single given model. Further, this is not just a computer vision phenomenon; adversarial examples also appear in LLMs (Wallace et al., 2019) and in RL (Gleave et al., 2019).
While there are methods to train for more adversarially robust models, there is a known trade-off between model performance and adversarial robustness (Tsipras et al., 2019): especially in the presence of many weakly-correlated variables, the model must be sharper to achieve higher performance. Indeed, most modern training algorithms, whether in computer vision or LLMs, do not train adversarial robustness out. Thus, at least until deep learning sees a major regime change, this is a problem we’re stuck with.
Why is adversarial robustness an issue for world model planning?
Consider a single component of the dynamics loss we’re optimizing in the lifted state approach:
\[\min_{s_t, a_t, s_{t+1}} \|F_\theta(s_t, a_t) - s_{t+1}\|_2^2\]
Let’s further focus on just the base state:
\[\min_{s_t} \|F_\theta(s_t, a_t) - s_{t+1}\|_2^2.\]
Since world models are typically trained on state/action trajectories $(s_1, a_1, s_2, a_2, \ldots)$, the state-data manifold for $F_{\theta}$ has dimensionality bounded by the action space:
\[\mathrm{dim}(\mathcal{M}_s) \le \mathrm{dim}(\mathcal{A}) + 1 + \mathrm{dim}(\mathcal{R}),\]
where $\mathcal{R}$ is some optional space of augmentations (e.g. translations/rotations). Thus, we can typically expect $\mathrm{dim}(\mathcal{M}_s)$ to be much lower than $\mathrm{dim}(\mathcal{S})$, and thus: it is very easy to find adversarial examples that hack any state to any other desired state.
As a result, the dynamics optimization
\[\sum_{t=0}^{T-1} \big\|F_\theta(s_t,a_t) - s_{t+1}\big\|_2^2\]
feels incredibly “sticky,” as the base points $s_t$ can easily trick $F_{\theta}$ into thinking it’s already made its local goal.1
1. This adversarial robustness issue, while particularly bad for lifted-state approaches, is not unique to them. Even for serial optimization methods that optimize through the full rollout map $\mathcal{F}^T$, it is possible to get into unseen states, where it is very easy to have a normal component fed into the sensitive normal components of $D_s F_{\theta}$. The action Jacobian’s chain rule expansion is
\[\Bigl(\prod_{t=1}^T D_s F_\theta(s_t, a_t)\Bigr) D_{a_0}F_\theta(s_0, a_0).\]
See what happens if any stage of the product has any component normal to the data manifold. ↩
Our fix
This is where our new planner GRASP comes in. The main observation: while $D_s F_{\theta}$ is untrustworthy and adversarial, the action space is usually low-dimensional and exhaustively trained, so $D_a F_{\theta}$ is actually reasonable to optimize through and doesn’t suffer from the adversarial robustness issue!
The action input is usually lower-dimensional and densely trained (the model has seen every action direction), so action gradients are much better behaved.
At its core, GRASP builds a first-order lifted state / collocation-based planner that is only dependent on action Jacobians through the world model. We thus exploit the differentiability of learned world models $F_{\theta}$, while not falling victim to the inherent sensitivity of the state Jacobians $D_s F_{\theta}$.
GRASP: Gradient RelAxed Stochastic Planner
As noted before, we start with the collocation planning objective, where we lift the states and relax dynamics into a penalty:
\[\min_{\mathbf{s},\mathbf{a}} \mathcal{L}(\mathbf{s}, \mathbf{a}) = \sum_{t=0}^{T-1} \big\|F_\theta(s_t,a_t) - s_{t+1}\big\|_2^2,
\quad \text{with } s_0 \text{ fixed and } s_T=g.\]
We then make two key additions.
Ingredient 1: Exploration by noising the state iterates
Even with a smoother objective, planning is nonconvex. We introduce exploration by injecting Gaussian noise into the virtual state updates during optimization.
A simple version:
\[s_t \leftarrow s_t - \eta_s \nabla_{s_t}\mathcal{L} + \sigma_{\text{state}} \xi, \qquad \xi\sim\mathcal{N}(0,I).\]
Actions are still updated by non-stochastic descent:
\[a_t \leftarrow a_t - \eta_a \nabla_{a_t}\mathcal{L}.\]
The state noise helps you “hop” between basins in the lifted space, while the actions remain guided by gradients. We found that specifically noising states here (as opposed to actions) finds a good balance of exploration and the ability to find sharper minima.2
2. Because we only noise the states (and not the actions), the corresponding dynamics are not truly Langevin dynamics. ↩
Ingredient 2: Reshape gradients: stop brittle state-input gradients, keep action gradients
As discussed, the fragile pathway is the gradient that flows into the state input of the world model, \(D_s F_{\theta}\). The most straightforward way to do this initially is to just stop state gradients into \(F_{\theta}\) directly:
Let $\bar{s}_t$ be the same value as $s_t$, but with gradients stopped.
Define the stop-gradient dynamics loss:
\[\mathcal{L}_{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a})
= \sum_{t=0}^{T-1} \big\|F_\theta(\bar{s}_t, a_t) - s_{t+1}\big\|_2^2.\]
This alone does not work. Notice now states only follow the previous state’s step, without anything forcing the base states to chase the next ones. As a result, there are trivial minima for just stopping at the origin, then only for the final action trying to get to the goal in one step.
Dense goal shaping
We can view the above issue as the goal’s signal being cut off entirely from previous states. One way to fix this is to simply add a dense goal term throughout prediction:
\[\mathcal{L}_{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a})
= \sum_{t=0}^{T-1} \big\|F_\theta(\bar{s}_t, a_t) - g\big\|_2^2.\]
In normal settings this would over-bias towards the greedy solution of straight chasing the goal, but this is balanced in our setting by the stop-gradient dynamics loss’s bias towards feasible dynamics. The final objective is then as follows:
\[\mathcal{L}(\mathbf{s},\mathbf{a}) = \mathcal{L}_{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a}) + \gamma \, \mathcal{L}_{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a}).\]
The result is a planning optimization objective that does not have dependence on state gradients.
Periodic “sync”: briefly return to true rollout gradients
The lifted stop-gradient objective is great for fast, guided exploration, but it’s still an approximation of the original serial rollout objective.
So every $K_{\text{sync}}$ iterations, GRASP does a short refinement phase:
Roll out from $s_0$ using current actions $\mathbf{a}$, and take a few small gradient steps on the original serial loss:
\[\mathbf{a} \leftarrow \mathbf{a} - \eta_{\text{sync}}\,\nabla_{\mathbf{a}}\,\|s_T(\mathbf{a})-g\|_2^2.\]
The lifted-state optimization still provides the core of the optimization, while this refinement step adds some assistance to keep states and actions grounded towards real trajectories. This refinement step can of course be replaced with a serial planner of your choice (e.g. CEM); the core idea is to still get some of the benefit of the full-path synchronization of serial planners, while still mostly using the benefits of the lifted-state planning.
How GRASP addresses long-range planning
Collocation-based planners offer a natural fix for long-horizon planning, but this optimization is quite difficult through modern world models due to adversarial robustness issues. GRASP proposes a simple solution for a smoother collocation-based planner, alongside stable stochasticity for exploration. As a result, longer-horizon planning ends up not only succeeding more, but also finding such successes faster:
Push-T demo: longer-horizon planning with GRASP.
Horizon
CEM
GD
LatCo
GRASP
H=40
61.4% / 35.3s
51.0% / 18.0s
15.0% / 598.0s
59.0% / 8.5s
H=50
30.2% / 96.2s
37.6% / 76.3s
4.2% / 1114.7s
43.4% / 15.2s
H=60
7.2% / 83.1s
16.4% / 146.5s
2.0% / 231.5s
26.2% / 49.1s
H=70
7.8% / 156.1s
12.0% / 103.1s
0.0% / —
16.0% / 79.9s
H=80
2.8% / 132.2s
6.4% / 161.3s
0.0% / —
10.4% / 58.9s
Push-T results. Success rate (%) / median time to success. Bold = best in row. Note the median success time will bias higher with higher success rate; GRASP manages to be faster despite higher success rate.
What’s next?
There is still plenty of work to be done for modern world model planners. We want to exploit the gradient structure of learned world models, and collocation (lifted-state optimization) is a natural approach for long-horizon planning, but it’s crucial to understand typical gradient structure here: smooth and informative action gradients and brittle state gradients. We view GRASP as an initial iteration for such planners.
Extension to diffusion-based world models (deeper latent timesteps can be viewed as smoothed versions of the world model itself), more sophisticated optimizers and noising strategies, and integrating GRASP into either a closed-loop system or RL policy learning for adaptive long-horizon planning are all natural and interesting next steps.
I do genuinely think it’s an exciting time to be working on world model planners. It’s a funny sweet spot where the background literature (planning and control overall) is incredibly mature and well-developed, but the current setting (pure planning optimization over modern, large-scale world models) is still heavily underexplored. But, once we figure out all the right ideas, world model planners will likely become as commonplace as RL.
For more details, read the full paper or visit the project website.
Citation
@article{psenka2026grasp,
title={Parallel Stochastic Gradient-Based Planning for World Models},
author={Michael Psenka and Michael Rabbat and Aditi Krishnapriyan and Yann LeCun and Amir Bar},
year={2026},
eprint={2602.00475},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2602.00475}
}
[2026-04-20]
Identifying Interactions at Scale for LLMs
Understanding the behavior of complex machine learning systems, particularly Large Language Models (LLMs), is a critical challenge in modern artificial intelligence. Interpretability research aims to make the decision-making process more transparent to model builders and impacted humans, a step toward safer and more trustworthy AI. To gain a comprehensive understanding, we can analyze these systems through different lenses: feature attribution, which isolates the specific input features driving a prediction (Lundberg & Lee, 2017; Ribeiro et al., 2022); data attribution, which links model behaviors to influential training examples (Koh & Liang, 2017; Ilyas et al., 2022); and mechanistic interpretability, which dissects the functions of internal components (Conmy et al., 2023; Sharkey et al., 2025).
Across these perspectives, the same fundamental hurdle persists: complexity at scale. Model behavior is rarely the result of isolated components; rather, it emerges from complex dependencies and patterns. To achieve state-of-the-art performance, models synthesize complex feature relationships, find shared patterns from diverse training examples, and process information through highly interconnected internal components.
Therefore, grounded or reality-checked interpretability methods must also be able to capture these influential interactions. As the number of features, training data points, and model components grow, the number of potential interactions grows exponentially, making exhaustive analysis computationally infeasible. In this blog post, we describe the fundamental ideas behind SPEX and ProxySPEX, algorithms capable of identifying these critical interactions at scale.
Attribution through Ablation
Central to our approach is the concept of ablation, measuring influence by observing what changes when a component is removed.
Feature Attribution: We mask or remove specific segments of the input prompt and measure the resulting shift in the predictions.
Data Attribution: We train models on different subsets of the training set, assessing how the model’s output on a test point shifts in the absence of specific training data.
Model Component Attribution (Mechanistic Interpretability): We intervene on the model’s forward pass by removing the influence of specific internal components, determining which internal structures are responsible for the model’s prediction.
In each case, the goal is the same: to isolate the drivers of a decision by systematically perturbing the system, in hopes of discovering influential interactions. Since each ablation incurs a significant cost, whether through expensive inference calls or retrainings, we aim to compute attributions with the fewest possible ablations.
Masking different parts of the input, we measure the difference between the original and ablated outputs.
SPEX and ProxySPEX Framework
To discover influential interactions with a tractable number of ablations, we have developed SPEX (Spectral Explainer). This framework draws on signal processing and coding theory to advance interaction discovery to scales orders of magnitude greater than prior methods. SPEX circumvents this by exploiting a key structural observation: while the number of total interactions is prohibitively large, the number of influential interactions is actually quite small.
We formalize this through two observations: sparsity (relatively few interactions truly drive the output) and low-degreeness (influential interactions typically involve only a small subset of features). These properties allow us to reframe the difficult search problem into a solvable sparse recovery problem. Drawing on powerful tools from signal processing and coding theory, SPEX uses strategically selected ablations to combine many candidate interactions together. Then, using efficient decoding algorithms, we disentangle these combined signals to isolate the specific interactions responsible for the model’s behavior.
In a subsequent algorithm, ProxySPEX, we identified another structural property common in complex machine learning models: hierarchy. This means that where a higher-order interaction is important, its lower-order subsets are likely to be important as well. This additional structural observation yields a dramatic improvement in computational cost: it matches the performance of SPEX with around 10x fewer ablations. Collectively, these frameworks enable efficient interaction discovery, unlocking new applications in feature, data, and model component attribution.
Feature Attribution
Feature attribution techniques assign importance scores to input features based on their influence on the model’s output. For example, if an LLM were used to make a medical diagnosis, this approach could identify exactly which symptoms led the model to its conclusion. While attributing importance to individual features can be valuable, the true power of sophisticated models lies in their ability to capture complex relationships between features. The figure below illustrates examples of these influential interactions: from a double negative changing sentiment (left) to the necessary synthesis of multiple documents in a RAG task (right).
The figure below illustrates the feature attribution performance of SPEX on a sentiment analysis task. We evaluate performance using faithfulness: a measure of how accurately the recovered attributions can predict the model’s output on unseen test ablations. We find that SPEX matches the high faithfulness of existing interaction techniques (Faith-Shap, Faith-Banzhaf) on short inputs, but uniquely retains this performance as the context scales to thousands of features. In contrast, while marginal approaches (LIME, Banzhaf) can also operate at this scale, they exhibit significantly lower faithfulness because they fail to capture the complex interactions driving the model’s output.
SPEX was also applied to a modified version of the trolley problem, where the moral ambiguity of the problem is removed, making “True” the clear correct answer. Given the modification below, GPT-4o mini answered correctly only 8% of the time. When we applied standard feature attribution (SHAP), it identified individual instances of the word trolley as the primary factors driving the incorrect response. However, replacing trolley with synonyms such as tram or streetcar had little impact on the prediction of the model. SPEX revealed a much richer story, identifying a dominant high-order synergy between the two instances of trolley, as well as the words pulling and lever, a finding that aligns with human intuition about the core components of the dilemma. When these four words were replaced with synonyms, the model’s failure rate dropped to near zero.
Data Attribution
Data attribution identifies which training data points are most responsible for a model’s prediction on a new test point. Identifying influential interactions between these data points is key to explaining unexpected model behaviors. Redundant interactions, such as semantic duplicates, often reinforce specific (and possibly incorrect) concepts, while synergistic interactions are essential for defining decision boundaries that no single sample could form alone. To demonstrate this, we applied ProxySPEX to a ResNet model trained on CIFAR-10, identifying the most significant examples of both interaction types for a variety of difficult test points, as shown in the figure below.
As illustrated, synergistic interactions (left) often involve semantically distinct classes working together to define a decision boundary. For example, grounding the synergy in human perception, the automobile (bottom left) shares visual traits with the provided training images, including the low-profile chassis of the sports car, the boxy shape of the yellow truck, and the horizontal stripe of the red delivery vehicle. On the other hand, redundant interactions (right) tend to capture visual duplicates that reinforce a specific concept. For instance, the horse prediction (middle right) is heavily influenced by a cluster of dog images with similar silhouettes. This fine-grained analysis allows for the development of new data selection techniques that preserve necessary synergies while safely removing redundancies.
Attention Head Attribution (Mechanistic Interpretability)
The goal of model component attribution is to identify which internal parts of the model, such as specific layers or attention heads, are most responsible for a particular behavior. Here too, ProxySPEX uncovers the responsible interactions between different parts of the architecture. Understanding these structural dependencies is vital for architectural interventions, such as task-specific attention head pruning. On an MMLU dataset (highschool‐us‐history), we demonstrate that a ProxySPEX-informed pruning strategy not only outperforms competing methods, but can actually improve model performance on the target task.
On this task, we also analyzed the interaction structure across the model’s depth. We observe that early layers function in a predominantly linear regime, where heads contribute largely independently to the target task. In later layers, the role of interactions between attention heads becomes more pronounced, with most of the contribution coming from interactions among heads in the same layer.
What’s Next?
The SPEX framework represents a significant step forward for interpretability, extending interaction discovery from dozens to thousands of components. We have demonstrated the versatility of the framework across the entire model lifecycle: exploring feature attribution on long-context inputs, identifying synergies and redundancies among training data points, and discovering interactions between internal model components. Moving forwards, many interesting research questions remain around unifying these different perspectives, providing a more holistic understanding of a machine learning system. It is also of great interest to systematically evaluate interaction discovery methods against existing scientific knowledge in fields such as genomics and materials science, serving to both ground model findings and generate new, testable hypotheses.
We invite the research community to join us in this effort: the code for both SPEX and ProxySPEX is fully integrated and available within the popular SHAP-IQ repository (link).
https://github.com/mmschlk/shapiq (SHAP-IQ Github)
https://openreview.net/forum?id=KI8qan2EA7 (ProxySPEX NeurIPS 2025)
https://openreview.net/forum?id=pRlKbAwczl (SPEX ICML 2025)
https://openreview.net/forum?id=glGeXu1zG4 (Learning to Understand NeurIPS 2024)
[2026-03-13]