The simple version: When we try to predict where basketball or football players will move in the next few seconds, most models predict each player separately. Player A might go here, Player B might go there. But here's the problem β those individual predictions might not make sense together as a team.
Imagine a model predicts that Player A will run to the left corner, and separately predicts that the ball will fly to the right corner. Those predictions might each be reasonable on their own, but they don't work together β the ball isn't going to magically fly away from the ball carrier!
CausalTraj fixes this by predicting all players together as one unit. Instead of asking "where might each player go?" it asks "what might this whole team do next?" The result: predictions where players actually move in coordination, passes look realistic, and the team behaves like an actual team.
How Most Models Are Measured (And Why It's Flawed)
When researchers build trajectory prediction models, they need a way to measure how good the predictions are. The standard approach uses metrics called minADE and minFDE. Let's break down what these actually mean:
The Hidden Problem With Per-Player Metrics
Here's the catch that most people miss: when calculating minADE for Player A, we pick Player A's best prediction out of 20 samples. When calculating minADE for Player B, we pick Player B's best prediction β but it might come from a completely different sample!
Each player gets their "best" prediction, but from different samples:
Pick the best complete scenario, then measure everyone in that scenario:
A model could score really well on per-player metrics by generating diverse predictions for each player β even if those predictions never make sense together. But when you actually want to USE the model (to simulate a game, analyze tactics, etc.), you need predictions that work as a coherent team. You can't mix-and-match player predictions at runtime β you need to pick one complete scenario.
What Goes Wrong When Models Optimize for Per-Player Metrics?
When models are trained and tuned to maximize per-player accuracy, they often produce predictions that look fine individually but are obviously wrong when you watch them together:
- β’ Curved ball trajectories: The ball arcs through the air unrealistically instead of traveling in a straight line for a pass
- β’ Everyone moves the same direction: All players drift toward the same area because each player independently predicts that area is likely
- β’ Ball-player disconnect: The ball carrier runs one way while the ball goes another
- β’ No spacing: Teammates don't create space for each other like real players would
- β’ Straight passes: Ball travels quickly in a direct line between players
- β’ Coordinated movement: When one player moves, teammates adjust their positions in response
- β’ Ball follows carrier: The ball stays with whoever is controlling it
- β’ Strategic spacing: Players spread out or bunch up in ways that make tactical sense
Instead of trying to optimize per-player accuracy and hoping coherence emerges, CausalTraj directly models the joint distribution β the probability of all players moving together in a particular way. When you model the whole team as a unit, coherent team behavior and reasonable individual predictions both emerge naturally.
Two Ways to Predict the Future
There are two fundamentally different approaches to predicting where players will be in the next 4 seconds:
Look at the current situation, compress it into a summary, then predict all future positions simultaneously.
Step 2: Create a "summary" of the situation
Step 3: From that summary, output positions at t+1, t+2, t+3... t+20 all at once
Predict the next moment, then use that prediction to predict the moment after, and so on.
Step 2: Look at current + t+1 β predict t+2
Step 3: Look at current + t+1 + t+2 β predict t+3
...and so on
The term "causal" here means respecting the order of time β you can only use information from the past and present to predict the future, not the other way around.
This is the same idea behind large language models (LLMs) like ChatGPT: they predict one word at a time, where each new word depends on all the words that came before. CausalTraj applies this idea to player positions: predict one timestep at a time, where each timestep depends on all previous timesteps.
A Simple Analogy: Writing a Story
"Given the opening sentence, write sentences 2, 3, 4, 5, 6, 7, 8, 9, and 10 all at once."
Each sentence is written independently β sentence 5 doesn't know what sentence 3 says. The result is often incoherent.
"Write sentence 2 based on sentence 1. Then write sentence 3 based on sentences 1-2. Then write sentence 4 based on sentences 1-3..."
Each sentence builds on everything before it. The result flows naturally.
The Mathematical Idea (Simplified)
CausalTraj factors the prediction problem as a chain of conditional probabilities:
Because each timestep prediction sees what all players did in the previous timestep, players can "react" to each other. If Player A moved left at t+3, Player B's prediction at t+4 can take that into account. This is how coordinated team behavior emerges β players responding to each other step by step.
Sports are inherently unpredictable. From the same starting position, a player might go left OR right OR stay still. A good prediction model needs to represent this uncertainty β there isn't just one correct answer, there are multiple plausible futures.
The Mixture of Gaussians Approach
CausalTraj represents uncertainty using something called a "Mixture of Gaussians."Let's break this down:
A Gaussian (also called a "normal distribution" or "bell curve") is a way to say "the player will probably be around here, but might be a bit to the left or right." It gives you a center point (the most likely position) and a spread (how uncertain we are).
Sometimes one bell curve isn't enough. A player might go left OR right β those are two completely different directions, not just "somewhere in between." A mixture combines multiple bell curves, each representing a different possible outcome.
The model outputs 8 different "scenarios" (called components). Each scenario says "here's where all the players might go" with a different tactical possibility. One might represent "team pushes forward," another "team defends deep," etc.
The Clever Part: Shared Mixture Weights
Here's what makes CausalTraj special: when the model chooses which scenario is most likely, it makes that choice for THE WHOLE TEAM at once. Not separately for each player.
Scenario 2: 35% likely (team pushes right) β most likely
Scenario 3: 10% likely (team holds position)
...
In theory, you could model the EXACT correlations between all pairs of players (if Player A goes left, Player B is more likely to go right, etc.). But with 22 players in football, that's 22Γ22 = 484 correlation values to learn at every timestep β computationally expensive.
CausalTraj uses a simpler approach: within each scenario, players are independent, but the shared scenario selection couples them together. This captures most of the coordination without the computational cost.
CausalTraj processes player tracking data through four main stages. Each stage builds on the previous one:
Goal: For each player individually, create a summary of where they've been up to the current moment. No interaction between players yet.
A neural network that looks at a player's positions over time and creates a compressed representation. The "causal" part means at each moment, it only looks at past positions β it can't peek ahead.
A newer type of sequence model (similar in spirit to transformers but more efficient) that processes each player's trajectory. Also respects the time ordering β only uses past information.
Output: A "history embedding" for each player β a learned summary of their movement pattern
Goal: Now we let each player's representation "attend to" other players. This is where the model learns about player interactions.
Each player looks at all other players at the same moment in time and decides which ones are most relevant. A player near the ball might pay more attention to nearby defenders. A goalkeeper might focus on the ball carrier. The model learns these attention patterns.
Most attention mechanisms only look at "what" each player is (their learned representation). SRTE also explicitly includes "where" each player is relative to others. It computes the exact distance and direction between every pair of players and feeds that geometric information into the attention calculation.
Goal: Take all the individual player representations and combine them into one representation of the entire scene.
- Add current position and velocity information back to each player's representation
- Compress each player's representation through a small neural network
- Stack all players together into one big scene vector
- Process through another neural network to get the final scene representation
Goal: From the scene representation, predict where all players will move in the next timestep.
The model doesn't predict absolute positions β it predicts displacements(how far each player moves from their current position). This makes learning easier because displacements are typically small and centered around zero.
For each of the 8 scenarios, the model outputs:
- Probability weight: How likely is this scenario? (8 numbers that sum to 1)
- Mean displacement: For each player in this scenario, where do we expect them to move? (8 Γ N_players Γ 2 numbers)
- Uncertainty: How confident are we about each player's movement in this scenario? (8 Γ N_players Γ 3 numbers for covariance)
The coordination mechanism: All players share the same 8 scenario weights. When sampling, we first pick a scenario (maybe #3 with 35% probability), then sample ALL player positions from that scenario. This is how players end up coordinated β they're all drawn from the same "version of the future."
The Training Objective
CausalTraj is trained using maximum likelihood β a fancy way of saying "make the model assign high probability to what actually happened."
Preventing Collapse: The Entropy Regularizer
There's a common problem with mixture models: the model might learn to use only ONE of the 8 scenarios and ignore the other 7. This is called "mode collapse" β the model collapses to a single mode.
CausalTraj adds an "entropy regularizer" to the training loss. In simple terms, this penalizes the model if it puts all the probability weight on one scenario. It encourages the model to use all 8 scenarios, at least during early training.
Think of it like a teacher telling students "you can't just write the same answer for every question β I want to see you use different approaches."
Training vs Inference: A Clever Trick
Even though CausalTraj predicts step-by-step during inference, training can be done efficiently in parallel using "teacher forcing."
We know the real positions at every timestep (it's historical data). So we can train all timesteps in parallel β at each step, we use the REAL previous positions as input, not the model's predictions.
We don't know the future, so we use the model's own predictions. Predict t+1, then use that prediction to predict t+2, and so on. This is slower but necessary.
The paper evaluates models on both per-agent metrics (the traditional approach) and joint metrics (the new approach). Here's what each measures:
For each agent, look at the 20 predicted trajectories. Pick the one closest to reality. Measure the average error across all timesteps. Then average across all players.
Same as above, but only measure the error at the final timestep (where did the player end up vs where we predicted they'd end up).
Look at the 20 predicted scenarios (each containing all players). Pick the scenario whose average error across all players and timesteps is lowest. Report that error.
Same as above, but only measure error at the final timestep. All players must still come from the same scenario.
Per-agent metrics tell you "can this model produce good predictions for individual players?" Joint metrics tell you "can this model produce good predictions where the team moves coherently?" A model good at the first might be bad at the second (and vice versa). For applications like game simulation, you need good joint metrics.
Datasets Used
- β’ Basketball tracking data
- β’ 10 players + ball per frame
- β’ 5 frames per second
- β’ Task: See 2 seconds (10 frames), predict 4 seconds (20 frames)
- β’ Another basketball dataset
- β’ Derived from NBA data
- β’ 50-frame sequences
- β’ Task: See 30 frames, predict 20 frames
- β’ NFL tracking data (American football)
- β’ 22 players + ball per frame
- β’ 10 frames per second
- β’ Task: See 30 frames, predict 20 frames
NBA SportVU Results (4-second prediction horizon)
| Model | minADEββ (m) | minFDEββ (m) | minJADEββ (m) | minJFDEββ (m) |
|---|---|---|---|---|
| GroupNet | 0.95 | 1.22 | 2.12 | 3.72 |
| LED | 0.81 | 1.10 | 1.63 | 2.99 |
| MoFlow (default) | 0.71 | 0.87 | 1.69 | 3.31 |
| CausalTraj (Mamba2) | 0.77 | 1.02 | 1.38 | 2.57 |
Lower numbers are better (less prediction error). The table shows that MoFlow is best on per-agent metrics (minADE, minFDE) but CausalTraj is best on joint metrics (minJADE, minJFDE). This exactly matches what we'd expect: models optimized for per-agent metrics win on per-agent metrics, but CausalTraj's focus on joint prediction pays off when we measure the whole team together.
CausalTraj reduces joint error (minJADEββ) by about 18% compared to MoFlow β a significant improvement. It's slightly worse on per-agent metrics (0.77 vs 0.71 meters average error), but that's because it's optimizing for a different goal: coherent team predictions rather than maximizing individual prediction diversity.
Numbers only tell part of the story. The paper includes visualizations that show WHY joint metrics matter β CausalTraj's predictions look qualitatively different from other models.
- Homogeneous trajectories: In each predicted scenario, all players tend to move in similar directions. It's like the model picks a "theme" (everyone goes left) rather than diverse individual actions.
- Curved ball paths: The ball arcs through the air in unrealistic ways, as if it has a mind of its own rather than following physics.
- Marginal modes: Players independently gravitate toward "likely positions" without coordinating with teammates.
- Coordinated movement: When one player moves, others respond. If a forward cuts right, a midfielder might shift to cover.
- Realistic passes: Ball trajectories are straight and fast, like real passes between players.
- Reactive behavior: Players sometimes change direction abruptly β the model captures moments where a player adjusts to track someone else.
- Strategic spacing: Teammates position themselves in formations that make tactical sense.
CausalTraj isn't perfect. The paper notes that occasionally:
- β’ Ball carriers can end up with unrealistic gaps between them and the ball
- β’ The ball sometimes "collides" with court boundaries (the model doesn't enforce physical constraints)
- β’ Very fast movements or abrupt stops can still be unrealistic
The author attributes these issues to limited covariance modeling between the ball and the ball carrier β an area for future improvement.
To understand which design choices actually matter, the authors ran "ablation studies" β tests where they remove or simplify one component at a time and measure the impact.
| What Was Changed | minJADEββ | minJFDEββ | Effect |
|---|---|---|---|
| Full CausalTraj (Mamba2) | 0.97 | 1.77 | Best performance |
| No Spatial Relation Transformer | 0.99 | 1.81 | Slightly worse |
| Single Gaussian (no mixture) | 1.03 | 1.86 | Noticeably worse |
| Sample from means only (not full distribution) | 1.05 | 2.13 | Significantly worse |
Removing the explicit spatial encoding (SRTE) makes performance worse. Knowing exactly where players are relative to each other helps the model make better predictions.
Using only one Gaussian (one possible future) instead of a mixture of 8 hurts performance. The future is uncertain β we need to represent multiple possibilities.
Sampling only from the center of each Gaussian (not the full distribution) is worst of all. The spread/uncertainty in predictions captures real variation in player movement.
| Aspect | TranSPORTmer | Diffoot | CausalTraj |
|---|---|---|---|
| How it predicts over time | Predicts all timesteps at once | Predicts all timesteps at once | Predicts one step at a time (causal) |
| Output format | Single trajectory | Many samples via diffusion | Mixture of Gaussians |
| Handles multiple futures? | β No | β Yes | β Yes |
| Optimizes for team coherence? | β Per-player loss | β Per-player loss | β Joint likelihood |
| Can do multiple tasks? | β Yes (4 tasks) | β Prediction only | β Prediction only |
| Inference speed | β Fast | β Slow (many steps) | β οΈ Medium |
| Best use case | Real-time, multi-task apps | "What if" analysis | Coherent game simulation |
Current Limitations
The model treats players as independent within each scenario (only the scenario selection couples them). This means the tight synchronization between a ball carrier and the ball isn't perfectly captured. Sometimes the ball "floats" away from the carrier.
The model can predict that the ball goes outside the court, or that a player moves impossibly fast. It learned from data patterns but doesn't know the rules of physics or the boundaries of the playing area.
Because CausalTraj predicts step-by-step, inference takes longer than models that predict everything at once. For real-time applications, this might be a drawback.
Future Directions (from the paper)
Learn to respect physical constraints β ball can't go through walls, players can't teleport, passes should obey gravity. Could add constraint layers or physics-based losses.
"Show me what happens if this team plays a high press" or "Generate a counter-attack." Condition the model on tactical concepts, not just observed history.
Joint error (minJADE) is better than per-agent error, but still doesn't directly measure "does this look realistic?" Could develop perceptual metrics or user studies.
Model exact correlations between all player pairs, not just shared scenario weights. Could use low-rank approximations to make this computationally feasible.
- β’ Predicts teams, not individuals: All players' futures come from the same scenario
- β’ Step-by-step prediction: Each moment reacts to the previous one
- β’ Multiple futures: 8 possible scenarios capture uncertainty
- β’ Explicit spatial reasoning: Knows exactly where players are relative to each other
- β’ Best joint metrics: State-of-the-art on minJADE/minJFDE
- β’ Realistic coordination: Passes are straight, players react to each other
- β’ Works across sports: Tested on both basketball and American football
- β’ No physics engine: Can predict impossible movements
- β’ No boundary awareness: Ball can leave the court
- β’ Imperfect ball-player sync: Sometimes gaps appear
- β’ Single task only: Just prediction, not imputation or classification
- β’ Not the fastest: Step-by-step is slower than parallel
- β’ Trades per-player accuracy: MoFlow beats it on minADE at long horizons
Measure What Matters
Per-agent metrics let you cherry-pick. Joint metrics evaluate what you'd actually use β complete team predictions.
Step-by-Step Beats All-at-Once
Causal prediction lets each moment react to the previous one, naturally creating coordination.
Shared Weights = Shared Fate
All players share the same scenario selection. Pick scenario #3, and everyone's prediction comes from scenario #3.
Geometry Helps
Explicitly feeding in 'Player A is 5 meters left of Player B' helps the model reason about interactions.
Likelihood Training
Train by asking 'what's the probability of what actually happened?' β simple, principled, and effective.
Coherence vs Individual Accuracy
You can optimize for one or the other. CausalTraj chooses coherence, accepting slightly worse per-player metrics.