Dissecting JEPA
Test 1 - Latent Memory / Information Rentention
Objective
Prove whether the encoder stored enough information about the features in it’s latent space
Hypothesis
If the encoder learned then there would be correleation between and and a linear probe should be able to recover to some degree
Method
Started with a linear regression with L2 regularization (Ridge) but had to enolve to a polynomial probe
Results

Running Linear Probes...
Energy: Ridge R2 = 0.5660
Energy: Poly R2 = 0.7335
X_Pos: Ridge R2 = 0.5949
X_Pos: Poly R2 = 0.8209
Can_Eat: Accuracy = 0.9893
Near_Food: Accuracy = 0.7910
Interpretation
- Features show movement across the latent dims but
energyandx_posrequired expanding into polynomial space. - Binary features are spread across more dims but are linearly recovered which suggets there’s a clear decision boundary in the internal geometry
QED: Encoder contains meaningful representations in Z that is recoverable (via simple decoder)
Test 2 - Dynamics & Prediction Accuracy
Objective
Prove the encoder + transition head have learnt the correct environment dynamics
Hypothesis
- A good encoder + transition head will be able to predict the next state in both feature and latent space
Method
- Train a supervised Oracle MLP to predict the
- Train a supervised Frozen MLP to test if predict
- Compare Frozen against Oracle and Identity (aka Do Nothing)
- Compare full JEPA model against Oracle
Results
| Feature | Identity | Oracle | Frozen | JEPA |
|---|---|---|---|---|
| Energy | 0.0000 | 0.0001 | 0.0149 | 0.5704 |
| X_Pos | 0.0008 | 0.0000 | 97.6893 | 1140.8027 |
| Can_Eat | 0.0054 | 0.0051 | 0.0099 | 0.0758 |
| Near_Food | 0.0044 | 0.0042 | 0.1126 | 0.3627 |
Note: these are all MSE so low is good
Interpretation
- Oracle performs well which confirms the task is learnable
- Frozen struggles to predict
x_posandnear_foodbut does OK on the others - Full JEPA is a disaster, not usable but is this because of a bad transition head or decoder?
QED: Encoder has learnt some dynamics to predict next state but may need to be improved
Test 3 - Action Table
Objective
Prove the model uses the actions (A) when predicting
Hypothesis
The transition head should predict the correct increase / decrease across x_pos and energy for the available actions
Method
Comparing average feature deltas between the ground truth and what the model predicts will show if actions are being used or ignored
Results
| Action | X_Pos_Delta_GT | X_Pos_Delta_Model | Energy_Delta_GT | Energy_Delta_Model |
|---|---|---|---|---|
| NOOP | 0.00 | -29.008549 | -0.000333 | -0.713546 |
| FORWARD | 0.05 | -29.322405 | -0.000833 | -0.708938 |
| EAT | 0.00 | -29.598597 | 0.000431 | -0.709099 |
Interpretation
- Not good.
x_posis consistently wrong andenergyis going down for EAT
QED: Looks like actions are ignored, BUT this doesn’t help identify if it’s the transition head or decoder that’s responsible.
Test 4 - Reconstruction Test
Objective
Prove if the decoder is bad
Hypothesis
The decoder is bad!
Method
Simple test to encode to Z and then back to X and look at MSE + variance
Results
Energy: MSE = 0.5614
X_Pos: MSE = 1156.3208
Can_Eat: Accuracy = 0.9875
Near_Food: Accuracy = 0.7540
Energy Variance: 0.05883194
Energy Std Dev: 0.24255297
X_pos Variance: 297.66675
X_pos Std Dev: 17.253021
Interpretation
- Earlier probe showed there is a simple decoder that can recover
x_posandenergy
QED: This decoder can’t be trusted
Test 5 - Latent Action Sensitivity (Decoder Free)
Objective
Prove the transition head hasn’t collapsed
Hypothesis
A good transition head will change based on the action
Method
Repeat Z for each action and measure the distance between the predicted latents
Results
Action / State ratio: [0.39865744 0.2414919 0.2973536 ]
90% of states ratio: [0.1856695 0.11809186 0.1157009 ]
Interpretation
- Actions change the predicted next latent by 20-40% of a typical step
- In most cases tehre’s always some difference
QED: Transition head is not ignoring the actions but they might still be moving in wrong direction
”Aha” Moment
This is like having to test a big mechanical system. Each part must be tested on its own first to ensure it works, then you test the parts that interact, and only at the end do you check the whole machine. You can’t skip steps because you can’t interpret the higher level failures.
You could say it’s similar to software testing except that behaviour is hidden and you have to infer correctness from indirect signals.
What still feels messy
Understanding what range for MSE, L2, accuracy, etc. felt a lot like guess work.
Next step
In theory I have a decent encoder + transition head but I’m still missing a good decoder so I think that’s the next step after the holidays🎄