Physics-Aware Representation Learning for Physical Systems

New York University
NYU CSCI-GA 2572 Deep Learning — Spring 2026

*All authors contributed equally; ordered alphabetically by last name.
H-JEPA computational graph
H-JEPA. Online encoder (orange) processes the corrupted context x̃=x+σε bottom-up into shallow / deep latents 1, z̃2; EMA target encoders (dashed) encode the clean context x and the paired next-window y into targets z1tgt, z2tgt. Predictors gφ1, gφ2 produce 1, ẑ2, paired against the targets through losses 1 (latent denoising) and 2 (next-window prediction).
FA-JEPA physics-prior attention schematic
FA-JEPA physics-prior attention. Orientation-field patch features q feed two branches: the standard learned attention Adata=softmax(QKT/√d) and a physics prior Aphys=softmax(τS) built from pairwise dot products of unit-normalized orientation vectors. They are mixed by a learnable convex combination A=(1−λ)Adata+λAphys with λ=σ(η) and τ=τmaxσ(ρ) — so the model adopts the prior only where it helps.

We extend Joint-Embedding Predictive Architectures (JEPA) with two complementary physics-aware designs. H-JEPA (left) decomposes spatial and temporal modeling into a hierarchical inner/outer pathway; FA-JEPA (right) introduces field-aware attention together with a physics-prior attention bias drawn from the nematic order tensor. Both are trained without labels on the active matter dataset from the Polymathic Well benchmark and evaluated by frozen linear and kNN probes regressing the parameters (α,ζ).

Abstract

Learning representations aligned with physical dynamics remains a central challenge in scientific machine learning. We study this problem within the JEPA framework on the active matter dataset from the Well benchmark, and identify preservation of temporal structure during encoding as the key factor driving downstream physics-prediction quality.

Building on this insight, we propose two complementary approaches. H-JEPA decomposes spatial and temporal modeling through its hierarchical structure, reducing total test MSE on the regressed parameters (α,ζ) by up to 70% relative to the baseline. FA-JEPA introduces field-aware attention together with a physics-prior attention mechanism that injects domain knowledge directly into the attention process, lowering total test MSE by roughly 20% over the prior-free variant on both linear and kNN probes.

Why physics-aware JEPA?

Self-supervised pixel-space surrogate models trained to predict the next frame are expensive and often misaligned with what scientific applications actually need: estimates of the hidden physical parameters that govern dynamics. JEPAs avoid pixel reconstruction by predicting latent futures, but vanilla JEPA encoders trained on physical fields often lose the temporal structure that downstream parameter inference depends on.

We work on Polymathic Well's active matter benchmark — rod-like self-propelled particles in a Stokes fluid, governed by two scalar parameters: the active dipole strength α and the alignment interaction strength ζ. Each trajectory is observed through 11 physical channels (concentration, 2-D velocity, rank-2 orientation tensor D, rank-2 strain-rate tensor E). The task is to learn an encoder, then freeze it and recover (α,ζ) via a single linear layer or kNN — no fine-tuning, no MLP heads, no attention pooling.

H-JEPA: Hierarchical JEPA

H-JEPA structurally decomposes the encoder into two pathways supervised by complementary self-supervised pretexts:

  • An inner pathway — shallow encoder fθ1 plus inner predictor gφ1 — performs latent denoising of a Gaussian-corrupted context, retaining instantaneous local field structure.
  • An outer pathway — deep encoder fθ2 plus outer predictor gφ2 — performs standard next-window latent prediction, capturing trajectory-level dynamics.

Both pathways share parameters in the shallow encoder, so the two pretexts shape a common trunk along complementary axes. EMA target encoders provide stop-gradient targets and the loss is a VICReg pair loss with coefficients (2,40,2). The full model has ~36.5 M parameters; only ~4.0 M are used during frozen probing.

What helps. Ablations show that (i) keeping the inner denoising pretext is necessary (removing it inflates kNN total MSE by 26%); (ii) paired next-window supervision is the dominant signal — redirecting the outer target to the corrupted current window almost doubles linear MSE; (iii) heavier shallow supervision (w1=1.0) consistently helps rather than competing with the outer loss.

FA-JEPA: Field-Aware JEPA with a physics prior

FA-JEPA uses a ViViT-style encoder/predictor with three attention modes — field (across the four physical fields at a fixed space-time location), spatial (across patches at a fixed time within a field), and temporal (across time at a fixed location within a field) — so that the 11 channels are grouped by their physical type (scalar / vector / rank-2 tensor) and patched with field-specific 3-D tubelets of size 2×16×16.

On top of field attention, we add a physics-prior attention bias. The bias is built from the nematic order tensor Q: orientation components are pooled to the patch grid, unit-normalized, and pairwise dot products Sij = q̂i · q̂j give a similarity that rewards attending to patches with similar nematic alignment. The final attention is a convex combination A = (1−λ) Adata + λ Aphys with Aphys=softargmax(τS) and λ = σ(η), τ = τmaxσ(ρ), both learned. Initialized small so training begins close to standard attention and adopts the prior only where useful.

Field-attention enrichment matrix
Cross-field coupling. Field-attention enrichment matrix (mean attention minus uniform 1/4). Velocity, orientation, and strain-rate predominantly attend to other fields rather than themselves — the model captures cross-field coupling, with the strongest directed enrichment from orientation to velocity, consistent with the kinetic theory of active suspensions.
Effect of physics-prior attention on orientation field
Where the prior shifts attention. All-block average of orientation-field spatial attention: learned data attention (left), physics-prior attention (middle), and the difference between the effective blended attention and the learned data attention (right). The prior introduces a localized, orientation-aware structure that the effective attention partially adopts.

Results

We report z-score-normalized test MSE for (α,ζ) per-parameter and averaged ("Total"). The encoder is frozen and used as a feature extractor; probes are a closed-form least-squares linear regressor (torch.linalg.lstsq) and inverse-distance-weighted kNN. The supervised reference predicts (α,ζ) directly without any probing.

Method Linear MSE ↓ kNN MSE ↓ k
α ζ Total α ζ Total
Baseline (Qu et al., 2026) 0.0590.4600.260 0.0810.5700.3255
Temporal — ViT3D-d6 + FFT (ours) 0.0160.1970.107 0.0090.2310.12020
H-JEPA (ours) 0.0140.1730.093 0.0090.1890.09920
FA-JEPA (ours) 0.0330.2910.162 0.0330.5500.2924
Supervised reference 0.0270.0770.052

z-score MSE on the test set per configuration and probe. Highlighted rows are our contributions (Temporal, H-JEPA, FA-JEPA); the top row is the JEPA baseline of Qu et al. 2026 that we build on. Bold marks the best self-supervised result per column. H-JEPA reduces total test MSE by ~64% (linear) and ~70% (kNN) versus the baseline while keeping the probed feature dimension small (~4 M trainable parameters in the encoder used for probing).

Linear probe metrics: physics-prior vs data-only attention
FA-JEPA linear probe. Adding the physics-prior attention (orange) lowers the combined (α,ζ) MSE versus the data-attention-only variant (blue) across training; the gain on ζ dominates the small regression on α.
kNN probe metrics: physics-prior vs data-only attention
FA-JEPA kNN probe. Under kNN probing, the physics-prior variant improves on both parameters: the prior preferentially exchanges information between patches with similar nematic alignment, sharpening the learned feature geometry.

Takeaways

  • Tokenization, not depth, drives the JEPA baseline. A six-block ViT3D over 4×16×16 spatiotemporal patches outperforms every CNN/Conv+Attn variant we tried at matched depth — the win is the tokenizer encoding temporal structure before pooling collapses it.
  • Hierarchical pretexts compose. H-JEPA's inner denoising and outer next-window prediction shape a shared trunk along complementary axes; removing either degrades the probe.
  • Physics priors belong inside attention. Adding a soft, learnable nematic-alignment bias to FA-JEPA's attention improves probe MSE without overriding the learned data attention — the model adopts the prior where it helps and ignores it elsewhere.
  • Aligning representation learning with the structure of the underlying physics is the lever — bigger generic models are a weaker substitute.

BibTeX

@misc{ji2026physicsaware,
  title  = {Physics-Aware Representation Learning for Physical Systems},
  author = {Ji, Charles Cheng and Shi, Zhanhe and Wang, Richard and Yalovetzky, Romina},
  year   = {2026},
  note   = {NYU CSCI-GA 2572 Deep Learning, Spring 2026},
  url    = {https://saltfish-len.github.io/Physics-Aware-Project-Page/}
}