Flow Matching
Key Points
- •Flow matching learns a time-dependent vector field (x, c) whose ODE transports simple noise to complex data, enabling fast, deterministic sampling.
- •Conditional flow matching (CFM) conditions the vector field on side information c (e.g., class labels), so each condition learns its own flow within one shared model.
- •Training uses synthetic trajectories between base noise and data ; for a straight-line path = (1 - t) + t , the target velocity is .
- •Minimizing the mean squared error between v_ t, c) and yields the correct marginal vector field at each (x, t, c) without computing scores or divergences.
- •At test time, you integrate dx/, t, c) from to 1 (e.g., with RK4) starting from ~ N(0, I) to generate samples.
- •CFM avoids stochastic reverse-time SDEs and often needs far fewer steps than diffusion sampling while remaining stable.
- •If desired, the same vector field can be used for continuous normalizing flow likelihoods via the divergence term, though CFM training itself is likelihood-free.
- •Good conditioning, time input, and stable ODE integration (step size control) are critical to sample quality.
Prerequisites
- →Ordinary Differential Equations (ODEs) — Flows evolve points via dx/dt = v_t(x); understanding integration and stability is fundamental.
- →Multivariable Calculus (Gradient and Divergence) — The continuity equation and CNF likelihoods use ∇·v and Jacobians.
- →Probability and Random Variables — CFM samples from base and target distributions and reasons about pushforwards.
- →Conditional Probability — CFM conditions the flow on side information c and couples (x_0, x_1).
- →Neural Networks and Backpropagation — The vector field v_θ is parameterized by a neural network trained with MSE.
- →Numerical Methods (Euler, RK4) — Sampling requires stable and accurate ODE solvers.
- →Data Normalization/Standardization — Proper scaling reduces stiffness and improves training stability.
- →Linear Algebra — Network operations and potential CNF divergence estimation rely on vector-Jacobian products.
- →PyTorch/LibTorch Basics — The provided C++ implementations use the LibTorch API for training and inference.
- →Optimization (SGD/Adam) — Training minimizes MSE between predicted and target velocities.
Detailed Explanation
Tap terms for definitions01Overview
Flow matching is a generative modeling framework where we learn a time-dependent vector field that continuously transports samples from a simple base distribution (like a standard Gaussian) to the target data distribution. Instead of directly learning a mapping f: noise → data in one jump, we define an ordinary differential equation (ODE) dx/dt = v_t(x) and integrate it from time t = 0 to t = 1. The solution curve carries the initial sample x_0 from the base distribution to x_1 that follows the data distribution. Conditional flow matching (CFM) extends this by allowing side information c (e.g., class labels, text embeddings) to modulate the vector field so the same model generates different kinds of outputs depending on c. Crucially, flow matching trains the vector field using synthetic, easy-to-compute “target velocities” along simple interpolating paths between pairs (x_0, x_1). For the commonly used straight-line interpolation x_t = (1 - t) x_0 + t x_1, the path’s time derivative is u_t = x_1 - x_0. We train a neural network v_θ(x, t, c) to match u_t at randomly sampled points along these paths. The remarkable result is that the best L2 predictor under this training recovers the correct marginal vector field that transports the entire distribution. Sampling is deterministic and efficient: starting from noise, numerically integrate the learned ODE using a few solver steps (e.g., Runge–Kutta) to obtain data-like samples.
02Intuition & Analogies
Imagine moving a crowd from an open field (noise) into well-organized seats in a stadium (data). If you try to teleport everyone at once, it’s chaos. Instead, you assign ushers (the vector field) who, at every moment t, point each person in the direction they should walk. If the ushers give consistent directions over time, everyone ends up in the right seats. Flow matching trains these ushers. How? We show them pairs of “starting spot” and “target seat” for many people. Then we pick a time t between 0 and 1, place each person at the spot they’d occupy if they walked in a straight line from start to seat for t fraction of the time, and tell the usher the correct direction right there: “point from the start toward the seat.” Over enough examples, ushers learn to point correctly for any person they encounter along the way. Conditional flow matching is like having different seating charts (conditions c). The same usher team can read the chart label and adapt directions accordingly—families to family sections, VIPs to VIP rows, students to student sections—without changing the general strategy. Crucially, ushers don’t need to know the crowd density or compute complex global plans. They only need to imitate local motion along simple straight paths, and—thanks to the mathematics of conditional expectations—this local imitation produces the globally consistent crowd flow that rearranges the entire distribution correctly. Finally, when it’s game time (sampling), you just follow the ushers’ directions forward in time starting from random entrances, and you’ll end up with the stadium filled like your training data.
03Formal Definition
04When to Use
Use conditional flow matching when you need a fast, stable, and deterministic generative model that can be conditioned on side information. Typical use cases include class-conditional image or audio generation, conditional molecular conformations, and structured data synthesis where labels or attributes guide outputs. It shines when you want diffusion-like quality but with fewer sampling steps since you integrate an ODE rather than simulating a stochastic SDE. CFM is also attractive when you prefer likelihood-free training (pure regression) without computing score functions or reverse-time drifts. In low- to mid-dimensional settings (e.g., 2D–3D point clouds, simulation states) CFM can be very efficient and easy to visualize. In high dimensions (images, audio), pair CFM with expressive neural architectures (UNets, transformers) and modern ODE solvers. If you also care about exact or approximate likelihoods, the learned field can be used within a continuous normalizing flow (by estimating divergence), giving both sampling and density evaluation in one model. Finally, CFM is ideal for settings where drawing pairs (x_0, x_1) and simple interpolations are easy, such as base Gaussian to empirical dataset examples.
⚠️Common Mistakes
- Ignoring time input: The vector field must depend on t; using a time-independent v(x, c) collapses flexibility and often fails to match endpoints.
- Mismatched base distribution: Training assumes x_0 \sim π_0 (often \mathcal{N}(0, I)); sampling from a different base leads to poor results. Standardize data and match scales.
- Not sampling t uniformly: Using only endpoints or a fixed t biases training and harms transport quality. Sample t \sim Uniform(0, 1).
- Forgetting conditioning: If you condition during training but omit c at sampling (or vice versa), generations will be wrong. Ensure c is consistently encoded in both phases.
- Too-large ODE steps: Euler steps with large step size cause drift or instability. Prefer RK4 or smaller steps; monitor trajectory norms over t.
- Poor coupling/interpolation: For some tasks, straight lines may cross low-density regions, making learning harder. Consider alternative bridges (e.g., Gaussian bridges) if convergence is slow.
- Data not normalized: Unscaled features lead to stiff dynamics and training instability. Normalize inputs per feature and encode t at a similar scale.
- Capacity/regularization mismatch: Underpowered networks underfit vector fields; overpowered ones may overfit noisy targets. Use residual connections, weight decay, and validation flows to tune.
Key Formulas
Probability Flow ODE
Explanation: The learned vector field moves a sample over time. Integrating from to 1 transports noise to data conditioned on c.
Displacement Interpolation
Explanation: A simple straight-line path between a noise sample and a data sample . It is widely used in flow matching for constructing targets.
Target Velocity (Straight Path)
Explanation: The instantaneous velocity along the straight-line path is constant and equals the displacement. This provides the supervised signal during training.
Conditional Flow Matching Loss
Explanation: We regress the neural vector field onto the target velocity at random points along the path between and . Minimizing this mean squared error recovers the correct marginal vector field.
Optimality as Conditional Expectation
Explanation: The MSE minimizer equals the conditional expectation of the target velocity given the current location x and time t. This is the field whose flow matches the data.
Continuity Equation
Explanation: Probability density moves under the vector field like an incompressible fluid with sources/sinks given by divergence. It ensures mass conservation along the flow.
CNF Log-Density Evolution
Explanation: When tracking likelihoods, the log-density along a trajectory changes by minus the divergence of the vector field. Integrating this gives exact CNF likelihoods.
Integral Form of Sampling
Explanation: To generate a sample, integrate the learned velocity along time. Numerical solvers approximate this integral with finite steps.
Mean Squared Error
Explanation: The standard regression loss used to fit v_ to over a batch of samples. It penalizes squared deviations.
Pushforward Equality
Explanation: The flow at t = 1 pushes the base distribution forward to the target distribution. This expresses the goal of generative modeling via ODE flows.
Complexity Analysis
Code Examples
1 // Compile with: g++ -O2 -std=c++17 cfm_train.cpp -I${LIBTORCH}/include -I${LIBTORCH}/include/torch/csrc/api/include -L${LIBTORCH}/lib -ltorch -lc10 -o cfm_train -Wl,-rpath,${LIBTORCH}/lib 2 // This example trains v_θ(x, t, c) on a simple 2D Gaussian mixture with class labels. 3 #include <torch/torch.h> 4 #include <iostream> 5 #include <vector> 6 #include <random> 7 8 // Simple MLP for v_θ(x, t, c) 9 struct VelocityNetImpl : torch::nn::Module { 10 torch::nn::Sequential net; 11 VelocityNetImpl(int in_dim, int hidden, int out_dim) { 12 net = torch::nn::Sequential( 13 torch::nn::Linear(in_dim, hidden), 14 torch::nn::SiLU(), 15 torch::nn::Linear(hidden, hidden), 16 torch::nn::SiLU(), 17 torch::nn::Linear(hidden, out_dim) 18 ); 19 register_module("net", net); 20 } 21 torch::Tensor forward(torch::Tensor x) { 22 return net->forward(x); 23 } 24 }; 25 TORCH_MODULE(VelocityNet); 26 27 // One-hot encode class labels 0..K-1 28 torch::Tensor one_hot(torch::Tensor labels, int K) { 29 auto B = labels.size(0); 30 auto oh = torch::zeros({B, K}, labels.options().dtype(torch::kFloat32)); 31 oh.scatter_(1, labels.view({B, 1}), 1.0f); 32 return oh; 33 } 34 35 // Sample a batch from a 2D class-conditional Gaussian mixture 36 // K classes with means on a circle; small isotropic noise 37 std::pair<torch::Tensor, torch::Tensor> sample_data(int B, int K, float radius, float stddev, torch::Device device) { 38 auto labels = torch::randint(0, K, {B}, torch::TensorOptions().dtype(torch::kLong).device(device)); 39 auto theta = (labels.to(torch::kFloat32) * (2.0f * M_PI / K)); 40 auto means = torch::stack({radius * torch::cos(theta), radius * torch::sin(theta)}, 1); // Bx2 41 auto eps = stddev * torch::randn({B, 2}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); 42 auto x = means + eps; 43 return {x, labels}; 44 } 45 46 int main() { 47 torch::manual_seed(42); 48 torch::Device device(torch::kCPU); 49 50 const int d = 2; // data dimension 51 const int K = 4; // number of classes (conditions) 52 const int in_dim = d + 1 + K; // x (2) + time t (1) + one-hot c (K) 53 const int hidden = 128; 54 const int out_dim = d; // velocity vector in R^2 55 56 VelocityNet model(in_dim, hidden, out_dim); 57 model->to(device); 58 59 torch::optim::Adam opt(model->parameters(), torch::optim::AdamOptions(1e-3)); 60 61 const int B = 512; // batch size 62 const int iters = 2000; // training iterations 63 const float radius = 2.5f; 64 const float data_std = 0.15f; 65 66 for (int it = 1; it <= iters; ++it) { 67 model->train(); 68 // 1) Sample conditional data x1 ~ π1(·|c) 69 auto [x1, y] = sample_data(B, K, radius, data_std, device); // x1: Bx2, y: B 70 // 2) Sample base noise x0 ~ N(0, I) 71 auto x0 = torch::randn_like(x1); 72 // 3) Sample time t ~ Uniform(0, 1) 73 auto t = torch::rand({B, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); 74 // 4) Build straight path x_t and target velocity u_t 75 auto x_t = (1.0f - t) * x0 + t * x1; // Bx2 76 auto u_t = (x1 - x0); // Bx2 (constant over t for straight paths) 77 // 5) Encode inputs: [x_t, t, one_hot(c)] 78 auto c_oh = one_hot(y, K); 79 auto input = torch::cat({x_t, t, c_oh}, 1); // B x (2+1+K) 80 // 6) Predict velocity and compute MSE loss 81 auto v_pred = model->forward(input); 82 auto loss = torch::mse_loss(v_pred, u_t); 83 // 7) Optimize 84 opt.zero_grad(); 85 loss.backward(); 86 opt.step(); 87 88 if (it % 200 == 0) { 89 std::cout << "Iter " << it << ", loss = " << loss.item<float>() << std::endl; 90 } 91 } 92 93 // Save the model 94 torch::save(model, "cfm_velocity.pt"); 95 std::cout << "Training complete. Model saved to cfm_velocity.pt\n"; 96 return 0; 97 } 98
This program trains a neural vector field v_θ(x, t, c) on a simple 2D Gaussian mixture with K class labels (conditions). Each batch draws data samples x1 by label, noise x0 from a standard Gaussian, and times t from Uniform(0,1). We construct the straight-line interpolation x_t and the target velocity u_t = x1 - x0, feed [x_t, t, one_hot(c)] to an MLP, and minimize MSE to match u_t. The population minimizer recovers the correct marginal field whose ODE transports N(0, I) to the class-conditional data distribution. The example logs training loss and saves the model.
1 // Compile with: g++ -O2 -std=c++17 cfm_sample.cpp -I${LIBTORCH}/include -I${LIBTORCH}/include/torch/csrc/api/include -L${LIBTORCH}/lib -ltorch -lc10 -o cfm_sample -Wl,-rpath,${LIBTORCH}/lib 2 #include <torch/torch.h> 3 #include <iostream> 4 #include <vector> 5 6 struct VelocityNetImpl : torch::nn::Module { 7 torch::nn::Sequential net; 8 VelocityNetImpl(int in_dim, int hidden, int out_dim) { 9 net = torch::nn::Sequential( 10 torch::nn::Linear(in_dim, hidden), 11 torch::nn::SiLU(), 12 torch::nn::Linear(hidden, hidden), 13 torch::nn::SiLU(), 14 torch::nn::Linear(hidden, out_dim) 15 ); 16 register_module("net", net); 17 } 18 torch::Tensor forward(torch::Tensor x) { return net->forward(x); } 19 }; 20 TORCH_MODULE(VelocityNet); 21 22 // One-hot encode labels 0..K-1 23 torch::Tensor one_hot(torch::Tensor labels, int K) { 24 auto B = labels.size(0); 25 auto oh = torch::zeros({B, K}, labels.options().dtype(torch::kFloat32)); 26 oh.scatter_(1, labels.view({B, 1}), 1.0f); 27 return oh; 28 } 29 30 // Evaluate v_θ(x, t, c) given x (Bxd), scalar t, and labels y 31 torch::Tensor eval_velocity(VelocityNet &model, torch::Tensor x, float t_scalar, torch::Tensor y, int K) { 32 auto B = x.size(0); 33 auto t = torch::full({B, 1}, t_scalar, x.options()); 34 auto c_oh = one_hot(y, K).to(x.device()); 35 auto inp = torch::cat({x, t, c_oh}, 1); 36 return model->forward(inp); 37 } 38 39 // One RK4 step: x_{n+1} = x_n + (h/6)(k1 + 2k2 + 2k3 + k4) 40 // where k1 = v(t_n, x_n), k2 = v(t_n + h/2, x_n + h k1/2), etc. 41 torch::Tensor rk4_step(VelocityNet &model, torch::Tensor x, float t, float h, torch::Tensor y, int K) { 42 auto k1 = eval_velocity(model, x, t, y, K); 43 auto k2 = eval_velocity(model, x + 0.5f * h * k1, t + 0.5f * h, y, K); 44 auto k3 = eval_velocity(model, x + 0.5f * h * k2, t + 0.5f * h, y, K); 45 auto k4 = eval_velocity(model, x + h * k3, t + h, y, K); 46 return x + (h / 6.0f) * (k1 + 2.0f * k2 + 2.0f * k3 + k4); 47 } 48 49 int main() { 50 torch::manual_seed(0); 51 torch::Device device(torch::kCPU); 52 53 const int d = 2; 54 const int K = 4; 55 const int in_dim = d + 1 + K; 56 const int hidden = 128; 57 const int out_dim = d; 58 59 // Load trained model (must match architecture used during training) 60 VelocityNet model(in_dim, hidden, out_dim); 61 try { 62 torch::load(model, "cfm_velocity.pt"); 63 } catch (const c10::Error& e) { 64 std::cerr << "Failed to load model cfm_velocity.pt. Did you run training first?\n"; 65 return 1; 66 } 67 model->to(device); 68 model->eval(); 69 70 const int B = 16; // number of samples to generate 71 auto y = torch::randint(0, K, {B}, torch::TensorOptions().dtype(torch::kLong).device(device)); 72 73 // Start from base noise x0 ~ N(0, I) 74 auto x = torch::randn({B, d}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); 75 76 // Integrate from t=0 to t=1 with RK4 77 const int steps = 64; // increase for higher fidelity 78 const float h = 1.0f / steps; 79 float t = 0.0f; 80 81 torch::NoGradGuard no_grad; 82 for (int s = 0; s < steps; ++s) { 83 x = rk4_step(model, x, t, h, y, K); 84 t += h; 85 } 86 87 // Print generated samples and labels 88 auto x_cpu = x.cpu(); 89 auto y_cpu = y.cpu(); 90 auto x_a = x_cpu.accessor<float, 2>(); 91 auto y_a = y_cpu.accessor<long, 1>(); 92 for (int i = 0; i < B; ++i) { 93 std::cout << "sample " << i << ": label=" << y_a[i] 94 << ", x=[" << x_a[i][0] << ", " << x_a[i][1] << "]\n"; 95 } 96 97 return 0; 98 } 99
This program loads the trained velocity network and generates samples by integrating the ODE dx/dt = v_θ(x, t, c) from t = 0 to 1 using a 4th-order Runge–Kutta solver. It starts from base noise x0 ~ N(0, I) and conditions on random class labels. Increasing the number of steps improves accuracy at the cost of runtime. The output is a list of 2D samples paired with their labels.