Skip to main content
Featured image for post: Dissecting JEPA

Dissecting JEPA

4 min
This was an adventure. Reminded me of setting up growth at Fourwalls. The scientific method always prevails. That’s how I’ll format it.

Test 1 - Latent Memory / Information Rentention

Objective

Prove whether the encoder stored enough information about the features XX in it’s latent space ZZ

Hypothesis

If the encoder learned XX then there would be correleation between XX and ZZ and a linear probe should be able to recover XX to some degree

Method

Started with a linear regression with L2 regularization (Ridge) but had to enolve to a polynomial probe

Results

correlation

Running Linear Probes...
Energy: Ridge R2 = 0.5660
Energy: Poly R2 = 0.7335
X_Pos: Ridge R2 = 0.5949
X_Pos: Poly R2 = 0.8209
Can_Eat: Accuracy = 0.9893
Near_Food: Accuracy = 0.7910

Interpretation

  • Features show movement across the latent dims but energy and x_pos required expanding into polynomial space.
  • Binary features are spread across more dims but are linearly recovered which suggets there’s a clear decision boundary in the internal geometry

QED: Encoder contains meaningful representations in Z that is recoverable (via simple decoder)

Test 2 - Dynamics & Prediction Accuracy

Objective

Prove the encoder + transition head have learnt the correct environment dynamics

Hypothesis

  • A good encoder + transition head will be able to predict the next state in both feature and latent space

Method

  • Train a supervised Oracle MLP to predict the XnextX_{next}
  • Train a supervised Frozen MLP to test if Zt+AtZ_{t} + A_{t} predict XnextX_{next}
  • Compare Frozen against Oracle and Identity (aka Do Nothing)
  • Compare full JEPA model against Oracle

Results

FeatureIdentityOracleFrozenJEPA
Energy0.00000.00010.01490.5704
X_Pos0.00080.000097.68931140.8027
Can_Eat0.00540.00510.00990.0758
Near_Food0.00440.00420.11260.3627

Note: these are all MSE so low is good

Interpretation

  • Oracle performs well which confirms the task is learnable
  • Frozen struggles to predict x_pos and near_food but does OK on the others
  • Full JEPA is a disaster, not usable but is this because of a bad transition head or decoder?

QED: Encoder has learnt some dynamics to predict next state but may need to be improved

Sidequest showed encoder + transition head work well together to predict next latent

Test 3 - Action Table

Objective

Prove the model uses the actions (A) when predicting XnextX_{next}

Hypothesis

The transition head should predict the correct increase / decrease across x_pos and energy for the available actions

Method

Comparing average feature deltas between the ground truth and what the model predicts will show if actions are being used or ignored

Results

ActionX_Pos_Delta_GTX_Pos_Delta_ModelEnergy_Delta_GTEnergy_Delta_Model
NOOP0.00-29.008549-0.000333-0.713546
FORWARD0.05-29.322405-0.000833-0.708938
EAT0.00-29.5985970.000431-0.709099

Interpretation

  • Not good. x_pos is consistently wrong and energy is going down for EAT

QED: Looks like actions are ignored, BUT this doesn’t help identify if it’s the transition head or decoder that’s responsible.

Test 4 - Reconstruction Test

Objective

Prove if the decoder is bad

Hypothesis

The decoder is bad!

Method

Simple test to encode to Z and then back to X and look at MSE + variance

Results

Energy: MSE = 0.5614
X_Pos: MSE = 1156.3208
Can_Eat: Accuracy = 0.9875
Near_Food: Accuracy = 0.7540
Energy Variance: 0.05883194
Energy Std Dev: 0.24255297
X_pos Variance: 297.66675
X_pos Std Dev: 17.253021

Interpretation

  • Earlier probe showed there is a simple decoder that can recover x_pos and energy

QED: This decoder can’t be trusted

Test 5 - Latent Action Sensitivity (Decoder Free)

Objective

Prove the transition head hasn’t collapsed

Hypothesis

A good transition head will change ZnextZ_{next} based on the action

Method

Repeat Z for each action and measure the distance between the predicted latents

Results

Action / State ratio: [0.39865744 0.2414919  0.2973536 ]
90% of states ratio: [0.1856695  0.11809186 0.1157009 ]

Interpretation

  • Actions change the predicted next latent by 20-40% of a typical step
  • In most cases tehre’s always some difference

QED: Transition head is not ignoring the actions but they might still be moving in wrong direction

”Aha” Moment

This is like having to test a big mechanical system. Each part must be tested on its own first to ensure it works, then you test the parts that interact, and only at the end do you check the whole machine. You can’t skip steps because you can’t interpret the higher level failures.

You could say it’s similar to software testing except that behaviour is hidden and you have to infer correctness from indirect signals.

What still feels messy

Understanding what range for MSE, L2, accuracy, etc. felt a lot like guess work.

Next step

In theory I have a decent encoder + transition head but I’m still missing a good decoder so I think that’s the next step after the holidays🎄