βš½πŸ’·πŸ“Š
Research Deep DivearXiv December 2025
CausalTraj: Coherent Multi-Agent Trajectory Forecasting
A model that predicts where ALL players will move together as a team β€” not just where each player might go individually.
Step-by-Step PredictionTeam-Level AccuracyMultiple Possible Futures45 min read
Author: Wei Zhen Teoh
What's This Paper Actually About?

The simple version: When we try to predict where basketball or football players will move in the next few seconds, most models predict each player separately. Player A might go here, Player B might go there. But here's the problem β€” those individual predictions might not make sense together as a team.

Imagine a model predicts that Player A will run to the left corner, and separately predicts that the ball will fly to the right corner. Those predictions might each be reasonable on their own, but they don't work together β€” the ball isn't going to magically fly away from the ball carrier!

CausalTraj fixes this by predicting all players together as one unit. Instead of asking "where might each player go?" it asks "what might this whole team do next?" The result: predictions where players actually move in coordination, passes look realistic, and the team behaves like an actual team.

The Problem: Measuring Players Separately Misses the Point

How Most Models Are Measured (And Why It's Flawed)

When researchers build trajectory prediction models, they need a way to measure how good the predictions are. The standard approach uses metrics called minADE and minFDE. Let's break down what these actually mean:

Quick Explanation: What Are These Metrics?
ADE (Average Displacement Error): How far off was the prediction on average across all time steps? If you predicted a player would be at position X but they were actually at position Y, how big was that gap on average?
FDE (Final Displacement Error): How far off was the prediction at the very end? Where did you predict the player would end up vs where they actually ended up?
The "min" part: Models usually generate multiple possible futures (say, 20 different predictions). The "min" means we only count the best one β€” the prediction that was closest to what actually happened.

The Hidden Problem With Per-Player Metrics

Here's the catch that most people miss: when calculating minADE for Player A, we pick Player A's best prediction out of 20 samples. When calculating minADE for Player B, we pick Player B's best prediction β€” but it might come from a completely different sample!

❌ Per-Player Metrics: The Cherry-Picking Problem

Each player gets their "best" prediction, but from different samples:

Player A: best prediction from Sample #3
Player B: best prediction from Sample #7
Player C: best prediction from Sample #1
Ball: best prediction from Sample #12
β†’ These 4 "best" predictions came from 4 different samples!
β†’ They might never actually occur together!
βœ“ Joint Metrics: Evaluate the Whole Team

Pick the best complete scenario, then measure everyone in that scenario:

Sample #3 (all players together):
- Player A: 0.8m error
- Player B: 1.2m error
- Player C: 0.9m error
- Ball: 0.5m error
β†’ All from the same sample!
β†’ This actually represents a coherent team movement!
Why Does This Matter?

A model could score really well on per-player metrics by generating diverse predictions for each player β€” even if those predictions never make sense together. But when you actually want to USE the model (to simulate a game, analyze tactics, etc.), you need predictions that work as a coherent team. You can't mix-and-match player predictions at runtime β€” you need to pick one complete scenario.

What Goes Wrong When Models Optimize for Per-Player Metrics?

When models are trained and tuned to maximize per-player accuracy, they often produce predictions that look fine individually but are obviously wrong when you watch them together:

Common Problems:
  • β€’ Curved ball trajectories: The ball arcs through the air unrealistically instead of traveling in a straight line for a pass
  • β€’ Everyone moves the same direction: All players drift toward the same area because each player independently predicts that area is likely
  • β€’ Ball-player disconnect: The ball carrier runs one way while the ball goes another
  • β€’ No spacing: Teammates don't create space for each other like real players would
What Good Predictions Look Like:
  • β€’ Straight passes: Ball travels quickly in a direct line between players
  • β€’ Coordinated movement: When one player moves, teammates adjust their positions in response
  • β€’ Ball follows carrier: The ball stays with whoever is controlling it
  • β€’ Strategic spacing: Players spread out or bunch up in ways that make tactical sense
CausalTraj's Key Insight

Instead of trying to optimize per-player accuracy and hoping coherence emerges, CausalTraj directly models the joint distribution β€” the probability of all players moving together in a particular way. When you model the whole team as a unit, coherent team behavior and reasonable individual predictions both emerge naturally.

The Key Idea: Predict One Moment at a Time
Why step-by-step prediction works better than predicting everything at once

Two Ways to Predict the Future

There are two fundamentally different approaches to predicting where players will be in the next 4 seconds:

❌ Approach 1: Predict Everything At Once

Look at the current situation, compress it into a summary, then predict all future positions simultaneously.

Step 1: Look at current player positions
Step 2: Create a "summary" of the situation
Step 3: From that summary, output positions at t+1, t+2, t+3... t+20 all at once
Problem: The prediction at t+10 doesn't know what happened at t+5. Each future timestep is predicted independently from the same summary.
βœ“ Approach 2: Predict Step-by-Step (Causal)

Predict the next moment, then use that prediction to predict the moment after, and so on.

Step 1: Look at current positions β†’ predict t+1
Step 2: Look at current + t+1 β†’ predict t+2
Step 3: Look at current + t+1 + t+2 β†’ predict t+3
...and so on
Benefit: Each prediction knows what happened in the previous moments. Players can "react" to what other players did in the previous step.
Why "Causal"?

The term "causal" here means respecting the order of time β€” you can only use information from the past and present to predict the future, not the other way around.

This is the same idea behind large language models (LLMs) like ChatGPT: they predict one word at a time, where each new word depends on all the words that came before. CausalTraj applies this idea to player positions: predict one timestep at a time, where each timestep depends on all previous timesteps.

A Simple Analogy: Writing a Story

Parallel Approach (Bad for Stories)

"Given the opening sentence, write sentences 2, 3, 4, 5, 6, 7, 8, 9, and 10 all at once."

Each sentence is written independently β€” sentence 5 doesn't know what sentence 3 says. The result is often incoherent.

Causal Approach (How We Actually Write)

"Write sentence 2 based on sentence 1. Then write sentence 3 based on sentences 1-2. Then write sentence 4 based on sentences 1-3..."

Each sentence builds on everything before it. The result flows naturally.

The Mathematical Idea (Simplified)

CausalTraj factors the prediction problem as a chain of conditional probabilities:

In plain English:
Probability(all future positions | observed history) =
Probability(position at t+1 | history) Γ—
Probability(position at t+2 | history + t+1) Γ—
Probability(position at t+3 | history + t+1 + t+2) Γ—
... and so on
Each step multiplies the probability of the next position, given everything that happened before. This is how we build up a probability for the entire future trajectory.
Why This Helps Team Coordination

Because each timestep prediction sees what all players did in the previous timestep, players can "react" to each other. If Player A moved left at t+3, Player B's prediction at t+4 can take that into account. This is how coordinated team behavior emerges β€” players responding to each other step by step.

Handling Uncertainty: Multiple Possible Futures
How CausalTraj represents the fact that many different things could happen next

Sports are inherently unpredictable. From the same starting position, a player might go left OR right OR stay still. A good prediction model needs to represent this uncertainty β€” there isn't just one correct answer, there are multiple plausible futures.

The Mixture of Gaussians Approach

CausalTraj represents uncertainty using something called a "Mixture of Gaussians."Let's break this down:

What's a Gaussian?

A Gaussian (also called a "normal distribution" or "bell curve") is a way to say "the player will probably be around here, but might be a bit to the left or right." It gives you a center point (the most likely position) and a spread (how uncertain we are).

What's a Mixture?

Sometimes one bell curve isn't enough. A player might go left OR right β€” those are two completely different directions, not just "somewhere in between." A mixture combines multiple bell curves, each representing a different possible outcome.

In CausalTraj:

The model outputs 8 different "scenarios" (called components). Each scenario says "here's where all the players might go" with a different tactical possibility. One might represent "team pushes forward," another "team defends deep," etc.

The Clever Part: Shared Mixture Weights

Here's what makes CausalTraj special: when the model chooses which scenario is most likely, it makes that choice for THE WHOLE TEAM at once. Not separately for each player.

How it works:
The model outputs 8 possible scenarios, each with a probability weight:
Scenario 1: 25% likely (team pushes left)
Scenario 2: 35% likely (team pushes right) ← most likely
Scenario 3: 10% likely (team holds position)
...
When we sample a prediction, we pick ONE scenario (maybe Scenario 2), and ALL players' positions come from that same scenario.
Result: Players are automatically coordinated because they're all drawn from the same "version of the future."
Technical Detail: Why Not Full Covariance?

In theory, you could model the EXACT correlations between all pairs of players (if Player A goes left, Player B is more likely to go right, etc.). But with 22 players in football, that's 22Γ—22 = 484 correlation values to learn at every timestep β€” computationally expensive.

CausalTraj uses a simpler approach: within each scenario, players are independent, but the shared scenario selection couples them together. This captures most of the coordination without the computational cost.

How the Model Works: A Four-Step Pipeline
Breaking down the architecture into understandable pieces

CausalTraj processes player tracking data through four main stages. Each stage builds on the previous one:

1
Stage 1: Summarize Each Player's History

Goal: For each player individually, create a summary of where they've been up to the current moment. No interaction between players yet.

Option A: Causal PointNet

A neural network that looks at a player's positions over time and creates a compressed representation. The "causal" part means at each moment, it only looks at past positions β€” it can't peek ahead.

Option B: Mamba2

A newer type of sequence model (similar in spirit to transformers but more efficient) that processes each player's trajectory. Also respects the time ordering β€” only uses past information.

Input: Each player's (x, y) positions over time
Output: A "history embedding" for each player β€” a learned summary of their movement pattern
2
Stage 2: Let Players "See" Each Other

Goal: Now we let each player's representation "attend to" other players. This is where the model learns about player interactions.

How Attention Works Here

Each player looks at all other players at the same moment in time and decides which ones are most relevant. A player near the ball might pay more attention to nearby defenders. A goalkeeper might focus on the ball carrier. The model learns these attention patterns.

Special Feature: Spatial Relation Transformer (SRTE)

Most attention mechanisms only look at "what" each player is (their learned representation). SRTE also explicitly includes "where" each player is relative to others. It computes the exact distance and direction between every pair of players and feeds that geometric information into the attention calculation.

Key insight: After this stage, each player's representation contains information about what other players are doing β€” not just their own history.
3
Stage 3: Combine Into a Scene Representation

Goal: Take all the individual player representations and combine them into one representation of the entire scene.

Process:
  1. Add current position and velocity information back to each player's representation
  2. Compress each player's representation through a small neural network
  3. Stack all players together into one big scene vector
  4. Process through another neural network to get the final scene representation
Result: A single vector that captures "what's happening on the pitch right now" β€” including all player positions, movements, and their relationships to each other.
4
Stage 4: Output the Prediction

Goal: From the scene representation, predict where all players will move in the next timestep.

What Gets Predicted

The model doesn't predict absolute positions β€” it predicts displacements(how far each player moves from their current position). This makes learning easier because displacements are typically small and centered around zero.

The Mixture Output

For each of the 8 scenarios, the model outputs:

  • Probability weight: How likely is this scenario? (8 numbers that sum to 1)
  • Mean displacement: For each player in this scenario, where do we expect them to move? (8 Γ— N_players Γ— 2 numbers)
  • Uncertainty: How confident are we about each player's movement in this scenario? (8 Γ— N_players Γ— 3 numbers for covariance)

The coordination mechanism: All players share the same 8 scenario weights. When sampling, we first pick a scenario (maybe #3 with 35% probability), then sample ALL player positions from that scenario. This is how players end up coordinated β€” they're all drawn from the same "version of the future."

Training: How the Model Learns

The Training Objective

CausalTraj is trained using maximum likelihood β€” a fancy way of saying "make the model assign high probability to what actually happened."

The Training Process in Plain English
Step 1: Show the model a sequence of real player positions from an actual game.
Step 2: At each timestep, ask the model "what's the probability of the players moving to where they actually moved?"
Step 3: Adjust the model's parameters to make that probability higher.
Step 4: Repeat millions of times with different game sequences.

Preventing Collapse: The Entropy Regularizer

There's a common problem with mixture models: the model might learn to use only ONE of the 8 scenarios and ignore the other 7. This is called "mode collapse" β€” the model collapses to a single mode.

The Fix: Encourage Diversity

CausalTraj adds an "entropy regularizer" to the training loss. In simple terms, this penalizes the model if it puts all the probability weight on one scenario. It encourages the model to use all 8 scenarios, at least during early training.

Think of it like a teacher telling students "you can't just write the same answer for every question β€” I want to see you use different approaches."

Training vs Inference: A Clever Trick

The "Teacher Forcing" Trick

Even though CausalTraj predicts step-by-step during inference, training can be done efficiently in parallel using "teacher forcing."

During Training:

We know the real positions at every timestep (it's historical data). So we can train all timesteps in parallel β€” at each step, we use the REAL previous positions as input, not the model's predictions.

During Inference:

We don't know the future, so we use the model's own predictions. Predict t+1, then use that prediction to predict t+2, and so on. This is slower but necessary.

Understanding the Metrics

The paper evaluates models on both per-agent metrics (the traditional approach) and joint metrics (the new approach). Here's what each measures:

Per-Agent Metrics (Traditional)
minADEβ‚‚β‚€ (minimum Average Displacement Error over 20 samples)

For each agent, look at the 20 predicted trajectories. Pick the one closest to reality. Measure the average error across all timesteps. Then average across all players.

minFDEβ‚‚β‚€ (minimum Final Displacement Error over 20 samples)

Same as above, but only measure the error at the final timestep (where did the player end up vs where we predicted they'd end up).

⚠️ Remember: Each player's "best" prediction can come from a different sample!
Joint Metrics (CausalTraj's Focus)
minJADEβ‚‚β‚€ (minimum Joint Average Displacement Error over 20 samples)

Look at the 20 predicted scenarios (each containing all players). Pick the scenario whose average error across all players and timesteps is lowest. Report that error.

minJFDEβ‚‚β‚€ (minimum Joint Final Displacement Error over 20 samples)

Same as above, but only measure error at the final timestep. All players must still come from the same scenario.

βœ“ All players come from the same sample β€” measures real coherent predictions!
Why Both Metrics Matter

Per-agent metrics tell you "can this model produce good predictions for individual players?" Joint metrics tell you "can this model produce good predictions where the team moves coherently?" A model good at the first might be bad at the second (and vice versa). For applications like game simulation, you need good joint metrics.

Results: How Well Does It Work?
Comparing CausalTraj to other methods

Datasets Used

NBA SportVU
  • β€’ Basketball tracking data
  • β€’ 10 players + ball per frame
  • β€’ 5 frames per second
  • β€’ Task: See 2 seconds (10 frames), predict 4 seconds (20 frames)
Basketball-U
  • β€’ Another basketball dataset
  • β€’ Derived from NBA data
  • β€’ 50-frame sequences
  • β€’ Task: See 30 frames, predict 20 frames
Football-U
  • β€’ NFL tracking data (American football)
  • β€’ 22 players + ball per frame
  • β€’ 10 frames per second
  • β€’ Task: See 30 frames, predict 20 frames

NBA SportVU Results (4-second prediction horizon)

ModelminADEβ‚‚β‚€ (m)minFDEβ‚‚β‚€ (m)minJADEβ‚‚β‚€ (m)minJFDEβ‚‚β‚€ (m)
GroupNet0.951.222.123.72
LED0.811.101.632.99
MoFlow (default)0.710.871.693.31
CausalTraj (Mamba2)0.771.021.382.57
How to Read This Table

Lower numbers are better (less prediction error). The table shows that MoFlow is best on per-agent metrics (minADE, minFDE) but CausalTraj is best on joint metrics (minJADE, minJFDE). This exactly matches what we'd expect: models optimized for per-agent metrics win on per-agent metrics, but CausalTraj's focus on joint prediction pays off when we measure the whole team together.

Key Takeaway

CausalTraj reduces joint error (minJADEβ‚‚β‚€) by about 18% compared to MoFlow β€” a significant improvement. It's slightly worse on per-agent metrics (0.77 vs 0.71 meters average error), but that's because it's optimizing for a different goal: coherent team predictions rather than maximizing individual prediction diversity.

What Do the Predictions Actually Look Like?

Numbers only tell part of the story. The paper includes visualizations that show WHY joint metrics matter β€” CausalTraj's predictions look qualitatively different from other models.

❌ What Other Models Often Produce
  • Homogeneous trajectories: In each predicted scenario, all players tend to move in similar directions. It's like the model picks a "theme" (everyone goes left) rather than diverse individual actions.
  • Curved ball paths: The ball arcs through the air in unrealistic ways, as if it has a mind of its own rather than following physics.
  • Marginal modes: Players independently gravitate toward "likely positions" without coordinating with teammates.
βœ“ What CausalTraj Produces
  • Coordinated movement: When one player moves, others respond. If a forward cuts right, a midfielder might shift to cover.
  • Realistic passes: Ball trajectories are straight and fast, like real passes between players.
  • Reactive behavior: Players sometimes change direction abruptly β€” the model captures moments where a player adjusts to track someone else.
  • Strategic spacing: Teammates position themselves in formations that make tactical sense.
Honest Limitations

CausalTraj isn't perfect. The paper notes that occasionally:

  • β€’ Ball carriers can end up with unrealistic gaps between them and the ball
  • β€’ The ball sometimes "collides" with court boundaries (the model doesn't enforce physical constraints)
  • β€’ Very fast movements or abrupt stops can still be unrealistic

The author attributes these issues to limited covariance modeling between the ball and the ball carrier β€” an area for future improvement.

Which Components Actually Help?
Ablation studies: removing parts to see what matters

To understand which design choices actually matter, the authors ran "ablation studies" β€” tests where they remove or simplify one component at a time and measure the impact.

What Was ChangedminJADEβ‚‚β‚€minJFDEβ‚‚β‚€Effect
Full CausalTraj (Mamba2)0.971.77Best performance
No Spatial Relation Transformer0.991.81Slightly worse
Single Gaussian (no mixture)1.031.86Noticeably worse
Sample from means only (not full distribution)1.052.13Significantly worse
Spatial Relations Matter

Removing the explicit spatial encoding (SRTE) makes performance worse. Knowing exactly where players are relative to each other helps the model make better predictions.

Multiple Scenarios Matter

Using only one Gaussian (one possible future) instead of a mixture of 8 hurts performance. The future is uncertain β€” we need to represent multiple possibilities.

Full Sampling Matters

Sampling only from the center of each Gaussian (not the full distribution) is worst of all. The spread/uncertainty in predictions captures real variation in player movement.

How Does CausalTraj Compare to Other Approaches?
AspectTranSPORTmerDiffootCausalTraj
How it predicts over timePredicts all timesteps at oncePredicts all timesteps at oncePredicts one step at a time (causal)
Output formatSingle trajectoryMany samples via diffusionMixture of Gaussians
Handles multiple futures?❌ Noβœ“ Yesβœ“ Yes
Optimizes for team coherence?❌ Per-player loss❌ Per-player lossβœ“ Joint likelihood
Can do multiple tasks?βœ“ Yes (4 tasks)❌ Prediction only❌ Prediction only
Inference speedβœ“ Fast❌ Slow (many steps)⚠️ Medium
Best use caseReal-time, multi-task apps"What if" analysisCoherent game simulation
Limitations and Future Directions

Current Limitations

1. Ball-Player Synchronization Gaps

The model treats players as independent within each scenario (only the scenario selection couples them). This means the tight synchronization between a ball carrier and the ball isn't perfectly captured. Sometimes the ball "floats" away from the carrier.

2. No Physical Constraints

The model can predict that the ball goes outside the court, or that a player moves impossibly fast. It learned from data patterns but doesn't know the rules of physics or the boundaries of the playing area.

3. Slower Than Parallel Models at Long Horizons

Because CausalTraj predicts step-by-step, inference takes longer than models that predict everything at once. For real-time applications, this might be a drawback.

Future Directions (from the paper)

Better Physics Modeling

Learn to respect physical constraints β€” ball can't go through walls, players can't teleport, passes should obey gravity. Could add constraint layers or physics-based losses.

Controllable Generation

"Show me what happens if this team plays a high press" or "Generate a counter-attack." Condition the model on tactical concepts, not just observed history.

Better Coherence Metrics

Joint error (minJADE) is better than per-agent error, but still doesn't directly measure "does this look realistic?" Could develop perceptual metrics or user studies.

Full Covariance Modeling

Model exact correlations between all player pairs, not just shared scenario weights. Could use low-rank approximations to make this computationally feasible.

Summary: What CausalTraj Does and Doesn't Do
βœ“ What It DOES Well
  • β€’ Predicts teams, not individuals: All players' futures come from the same scenario
  • β€’ Step-by-step prediction: Each moment reacts to the previous one
  • β€’ Multiple futures: 8 possible scenarios capture uncertainty
  • β€’ Explicit spatial reasoning: Knows exactly where players are relative to each other
  • β€’ Best joint metrics: State-of-the-art on minJADE/minJFDE
  • β€’ Realistic coordination: Passes are straight, players react to each other
  • β€’ Works across sports: Tested on both basketball and American football
❌ What It DOESN'T Do
  • β€’ No physics engine: Can predict impossible movements
  • β€’ No boundary awareness: Ball can leave the court
  • β€’ Imperfect ball-player sync: Sometimes gaps appear
  • β€’ Single task only: Just prediction, not imputation or classification
  • β€’ Not the fastest: Step-by-step is slower than parallel
  • β€’ Trades per-player accuracy: MoFlow beats it on minADE at long horizons
Key Takeaways

Measure What Matters

Per-agent metrics let you cherry-pick. Joint metrics evaluate what you'd actually use β€” complete team predictions.

Step-by-Step Beats All-at-Once

Causal prediction lets each moment react to the previous one, naturally creating coordination.

Shared Weights = Shared Fate

All players share the same scenario selection. Pick scenario #3, and everyone's prediction comes from scenario #3.

Geometry Helps

Explicitly feeding in 'Player A is 5 meters left of Player B' helps the model reason about interactions.

Likelihood Training

Train by asking 'what's the probability of what actually happened?' β€” simple, principled, and effective.

Coherence vs Individual Accuracy

You can optimize for one or the other. CausalTraj chooses coherence, accepting slightly worse per-player metrics.