šŸŽ“How I Study AIHISA
šŸ“–Read
šŸ“„PapersšŸ“°BlogsšŸŽ¬Courses
šŸ’”Learn
šŸ›¤ļøPathsšŸ“šTopicsšŸ’”ConceptsšŸŽ“Shorts
šŸŽÆPractice
ā±ļøCoach🧩Problems🧠ThinkingšŸŽÆPrompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
šŸ“šTheoryAdvanced

GAN Theory & Training Dynamics

Key Points

  • •
    GANs frame learning as a two-player game where a generator tries to fool a discriminator, and the discriminator tries to detect fakes.
  • •
    The classic objective is a minimax problem; with the optimal discriminator, training implicitly minimizes a Jensen–Shannon divergence between real and generated distributions.
  • •
    In practice, the non-saturating generator loss, careful optimization (TTUR), and regularization (e.g., spectral norm or gradient penalties) greatly stabilize training.
  • •
    Training dynamics are delicate: if the discriminator is too strong or too weak, gradients vanish or explode, leading to mode collapse or failure to learn.
  • •
    Using logits plus numerically stable binary cross-entropy (BCEWithLogits) avoids instability from direct log(1 - D(.)) computations.
  • •
    Evaluation requires more than losses; track sample quality, diversity, and metrics like FID or coverage to diagnose collapse.
  • •
    Compute cost per step scales with batch size and network size; extra discriminator steps and penalties (e.g., WGAN-GP) increase cost.
  • •
    C++ implementations are practical with LibTorch; the examples show stable loss computation and a full toy 1D GAN training loop.

Prerequisites

  • →Probability distributions and expectations — GAN objectives are expectations over data and noise distributions; understanding p(x), p(z), and sampling is essential.
  • →Logistic regression and binary cross-entropy — The discriminator is a binary classifier; BCE on logits underpins stable loss computation.
  • →Neural networks and backpropagation — Both G and D are neural nets trained with gradient-based optimization.
  • →Stochastic gradient descent and Adam — GAN training alternates gradient steps; optimizer behavior affects dynamics and stability.
  • →Divergences (KL, JS) and distances — Theoretical grounding of GAN objectives relies on JS divergence and related measures.
  • →Autograd and computation graphs — Correctly detaching fake samples for D updates and propagating through D for G updates requires graph control.
  • →C++ with LibTorch (PyTorch C++ API) — Implementing GANs in C++ requires familiarity with tensor operations, modules, and optimizers in LibTorch.
  • →Numerical stability (log-sum-exp, logits) — Stable training relies on using logits and avoiding catastrophic under/overflow in losses.
  • →Game theory basics (Nash equilibrium) — GANs converge, in principle, to a Nash equilibrium where neither player can improve alone.

Detailed Explanation

Tap terms for definitions

01Overview

Think of Generative Adversarial Networks (GANs) as a cat-and-mouse game. One network (the generator) fabricates samples to look like real data, while another (the discriminator) judges whether a sample is real or fake. The generator improves by learning from the discriminator’s feedback, and the discriminator sharpens its judgment by seeing better and better forgeries. This interactive setup forms a minimax game: the generator minimizes a loss while the discriminator maximizes it. At equilibrium, the generator’s distribution should match the real data, making the discriminator indifferent. The original GAN objective connects beautifully to information theory: when the discriminator is optimal, the generator effectively minimizes the Jensen–Shannon divergence between the true data distribution and the model’s distribution. However, the path to that equilibrium is tricky in practice—training is unstable, gradients can vanish, and models can collapse to producing a few modes. To make GANs work reliably, practitioners use improved losses (like the non-saturating or hinge loss), better optimization (TTUR with Adam), and regularization (spectral normalization, gradient penalties). This resource builds intuition, formal footing, and concrete C++ implementations using LibTorch, so you can understand, diagnose, and implement GAN training dynamics from the ground up.

02Intuition & Analogies

Imagine a counterfeiter (generator) and a detective (discriminator). The counterfeiter prints fake currency trying to pass it as real. The detective studies both real bills and fakes, getting better at spotting subtle flaws. As the detective improves, the counterfeiter must craft more convincing bills. If the detective is too lenient, bad fakes pass and the counterfeiter learns little. If the detective is too strict too early, the counterfeiter only hears ā€œeverything is terribleā€ and learns nothing specific to improve. This push-and-pull is the heart of GAN training dynamics. Now swap money for images or sounds. The generator maps random noise (its ā€œcreative sparkā€) into a sample. The discriminator acts like a critic who returns a single score: how real does this look? Early on, the critic easily spots fakes; over time, the generator picks up patterns that fool the critic. When training goes well, they reach a stalemate: the critic can’t tell real from fake, and the generator’s outputs look real. But there are pitfalls. If feedback is phrased poorly (e.g., using log(1 āˆ’ D(.)) which saturates), the counterfeiter hears a muffled message and stops improving—this is vanishing gradients. If the detective overfits to superficial cues, the counterfeiter exploits loopholes and collapses to few outputs that trick the detective—this is mode collapse. To keep learning healthy, we tune the pace of both players (different learning rates, update counts), use robust grading rubrics (non-saturating or hinge losses), and add rules that prevent cheap tricks (spectral normalization, gradient penalties).

03Formal Definition

A GAN defines a two-player zero-sum game between a generator G and discriminator D. Let pdata​(x) be the real data distribution and pz​(z) a simple prior (e.g., Normal). The generator induces a model distribution pg​ by pushing z through G: x=G(z). The original minimax objective is: V(D, G) = Ex∼pdata​​[log D(x)] + Ez∼pz​​[log(1 - D(G(z)))]. Training seeks minG​ maxD​ V(D, G). For fixed G, the optimal discriminator is Dāˆ—(x) = pdata​(x)+pg​(x)pdata​(x)​. Substituting Dāˆ— into V yields V(Dāˆ—, G) = -log 4 + 2 JS(pdata​ ∄ pg​), so minimizing with respect to G is equivalent to minimizing the Jensen–Shannon divergence between pdata​ and pg​. In practice, one commonly uses the non-saturating generator loss LGNS​ = -Ez​[log D(G(z))], which shares the same fixed points but provides stronger gradients early in training. Alternatives include hinge losses and Wasserstein losses, which modify the critic’s role and regularization to improve stability. Optimization proceeds via stochastic gradient methods, alternating updates of D and G on minibatches. Convergence aims for a Nash equilibrium where neither player can unilaterally improve its objective.

04When to Use

Use GANs when you need high-fidelity sample generation from complex, high-dimensional data where explicit likelihoods are hard to model: photorealistic images, style transfer, super-resolution, and data augmentation. They shine when sample quality and sharpness matter, especially in vision tasks. GANs are also useful for domain mapping without paired data (CycleGAN) and for structured output refinement (inpainting, deblurring). If your goal is density estimation, likelihood-based evaluation, or stable training with guaranteed convergence, consider alternatives (VAEs, diffusion models). GANs can be brittle; they require careful hyperparameter tuning and strong regularization. Choose the classic minimax/non-saturating objective when you have moderate data, robust architectures (ResNet/UNet), and want sharp outputs. Prefer hinge or Wasserstein objectives with spectral normalization or gradient penalty when you notice instability, gradient saturation, or mode collapse. Use Two Time-Scale Update Rule (TTUR) with different learning rates for D and G when the discriminator learns faster and risks overpowering the generator. For small tabular or 1D problems, simpler MLPs suffice; for images, use convolutional generators/discriminators (DCGAN-style). Always monitor sample diversity and adopt early stopping if the discriminator overfits.

āš ļøCommon Mistakes

  • Using sigmoid outputs with plain BCE on probabilities rather than logits, leading to numerical issues with log(1 - D). Prefer BCEWithLogits losses that combine sigmoid and log-sum-exp stably.
  • Training the discriminator too much or too fast early on, causing vanishing gradients for the generator. Balance with TTUR, limit discriminator steps, or add regularization (spectral norm, gradient penalty).
  • Confusing minimax vs non-saturating loss and accidentally training with the saturating generator objective, which can stall learning. Implement loss functions explicitly and test on toy data first.
  • Ignoring evaluation of diversity (mode coverage). Low loss doesn’t guarantee variety; use metrics (e.g., FID for images) and qualitative inspections.
  • Noisy batch statistics or batch norm in the discriminator can leak information or destabilize training. Prefer spectral norm or layer norm in D; use batch norm mostly in G.
  • Not shuffling data or reusing fake batches for multiple updates, which biases gradients and can cause overfitting.
  • Single learning rate for both players when dynamics are unbalanced. Use different lrs and possibly more D steps per G step; monitor gradient norms.
  • Forgetting to detach fake samples when updating D (leads to unnecessary generator gradients and memory bloat). Detach G(z) for D updates, but not for G updates.

Key Formulas

Original GAN Minimax

Gmin​Dmax​Ex∼pdata​​[logD(x)]+Ez∼pz​​[log(1āˆ’D(G(z)))]

Explanation: This is the classic adversarial game. The discriminator maximizes correct classification of real vs fake; the generator minimizes the same objective to fool the discriminator.

Optimal Discriminator

Dāˆ—(x)=pdata​(x)+pg​(x)pdata​(x)​

Explanation: For a fixed generator, the discriminator that maximizes the value function outputs the probability that x came from the data distribution. It provides the theoretical link to JS divergence.

JS Divergence Connection

V(Dāˆ—,G)=āˆ’log4+2JS(pdataā€‹āˆ„pg​)

Explanation: Plugging the optimal discriminator into the objective shows that minimizing the GAN value is equivalent to minimizing the JS divergence between real and generated distributions.

Non-saturating Generator Loss

LGNS​=āˆ’Ez∼pz​​[logD(G(z))]

Explanation: This alternative generator loss shares fixed points with the minimax game but provides stronger gradients when the discriminator is confident, improving early training.

Hinge GAN Loss

LDhinge​=E[max(0,1āˆ’D(x))]+E[max(0,1+D(G(z)))]LGhinge​=āˆ’E[D(G(z))]

Explanation: Margin-based losses often stabilize training when paired with spectral normalization in the discriminator. The generator maximizes the critic score of fake samples.

Wasserstein Objective

D∈F1-Lip​max​Ex∼pdata​​[D(x)]āˆ’Ez∼pz​​[D(G(z))]

Explanation: WGAN replaces the probabilistic discriminator with a 1-Lipschitz critic and maximizes the difference of expected scores. This approximates the Earth Mover’s distance for better gradients.

Gradient Penalty

LGP​=Ī»Ex^​(āˆ„āˆ‡x^​D(x^)∄2ā€‹āˆ’1)2

Explanation: A regularizer encouraging the critic’s gradient norm to be near 1, helping enforce the Lipschitz constraint without weight clipping.

Stable BCE on Logits

BCEWithLogits(s,y)=n1​i=1āˆ‘n​(max(si​,0)āˆ’si​yi​+log(1+eāˆ’āˆ£siā€‹āˆ£))

Explanation: This numerically stable form of binary cross-entropy combines the sigmoid and log in one expression, avoiding overflow/underflow that occurs with probabilities.

Two Time-Scale Step Sizes

TTURĀ conditions:Ā tāˆ‘ā€‹at​=āˆž,tāˆ‘ā€‹at2​<āˆž,tāˆ‘ā€‹bt​=āˆž,tāˆ‘ā€‹bt2​<āˆž

Explanation: For stochastic approximation with two players, diminishing step sizes that are square-summable but not summable can guarantee convergence under assumptions. In practice, constant lrs with Adam approximate this behavior.

Complexity Analysis

Per training iteration with batch size B, discriminator parameters PD​, and generator parameters PG​, the dominant cost is forward and backward passes. A single discriminator update on real and fake batches costs roughly O(B Ā· (FD​ + BD​)) where FD​ and BD​ denote per-sample forward and backward flops through D; similarly, a generator update costs O(B Ā· (FG​ + BG​) + B Ā· FD​) because G’s loss backpropagates through D(G(z)). Memory usage is O(PD​ + PG​ + A), where A is activation memory scaling with batch size and network depth; storing activations for backprop typically dominates parameter memory for deep models. If the discriminator is updated k times per generator step (a common choice), total per-iteration compute grows by about a factor of k for D. Regularizers add overhead: spectral normalization requires a few power iterations per layer per step (small constant factor), while gradient penalty (WGAN-GP) requires computing gradients with respect to inputs, adding an extra backward-through-D per batch, roughly doubling D’s compute for that step. Hinge, non-saturating, and standard BCE losses have similar compute, differing mainly in numerical stability. In LibTorch C++, data movement between CPU and GPU can dominate if not handled carefully; keep tensors on the same device and minimize host-device sync. Mixed precision can cut memory and improve throughput but requires loss scaling. Overall training wall-clock scales linearly with dataset size, epochs, and the sum of per-step costs for D and G. Monitoring gradient norms and early stopping can save compute by avoiding divergent runs.

Code Examples

Stable GAN losses in C++ (minimax vs non-saturating, with label smoothing)
1#include <torch/torch.h>
2#include <iostream>
3
4// Helper: create smoothed targets
5torch::Tensor smooth_targets(const torch::Tensor& like, double target, double eps) {
6 // Broadcast a scalar to the same shape as 'like' and apply smoothing
7 auto y = torch::full_like(like, target);
8 if (eps > 0.0) {
9 // One-sided smoothing: move 1 -> 1 - eps, keep 0 as 0 (common for D real labels)
10 y = (target == 1.0) ? (1.0 - eps) * torch::ones_like(like) : torch::zeros_like(like);
11 }
12 return y;
13}
14
15// Discriminator loss using BCEWithLogits on logits
16// real_logits: D(x) before sigmoid; fake_logits: D(G(z)) before sigmoid
17// eps_smooth applies to real labels only (e.g., 0.1 makes real targets 0.9)
18torch::Tensor d_loss_bce_logits(const torch::Tensor& real_logits,
19 const torch::Tensor& fake_logits,
20 double eps_smooth = 0.0) {
21 using namespace torch::nn::functional;
22 auto real_targets = smooth_targets(real_logits, 1.0, eps_smooth);
23 auto fake_targets = smooth_targets(fake_logits, 0.0, 0.0);
24 auto real_loss = binary_cross_entropy_with_logits(real_logits, real_targets,
25 BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
26 auto fake_loss = binary_cross_entropy_with_logits(fake_logits, fake_targets,
27 BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
28 return real_loss + fake_loss;
29}
30
31// Generator loss: choose minimax (saturating) or non-saturating variant
32// If non_saturating=true, L_G = -E[log D(G(z))]; else L_G = E[log(1 - D(G(z)))]
33torch::Tensor g_loss_bce_logits(const torch::Tensor& fake_logits, bool non_saturating = true) {
34 using namespace torch::nn::functional;
35 if (non_saturating) {
36 auto targets = smooth_targets(fake_logits, 1.0, 0.0); // pretend fakes are real
37 return binary_cross_entropy_with_logits(fake_logits, targets,
38 BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
39 } else {
40 auto targets = smooth_targets(fake_logits, 0.0, 0.0);
41 return binary_cross_entropy_with_logits(fake_logits, targets,
42 BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
43 }
44}
45
46// Optional hinge losses (often with spectral normalization)
47torch::Tensor d_loss_hinge(const torch::Tensor& real_scores, const torch::Tensor& fake_scores) {
48 auto zero = torch::zeros_like(real_scores);
49 auto loss_real = torch::relu(1.0 - real_scores).mean();
50 auto loss_fake = torch::relu(1.0 + fake_scores).mean();
51 return loss_real + loss_fake;
52}
53
54torch::Tensor g_loss_hinge(const torch::Tensor& fake_scores) {
55 return (-fake_scores).mean();
56}
57
58int main() {
59 // Demo with random logits (as if coming from a discriminator)
60 torch::manual_seed(42);
61 auto real_logits = torch::randn({64, 1});
62 auto fake_logits = torch::randn({64, 1});
63
64 auto d_bce = d_loss_bce_logits(real_logits, fake_logits, /*eps_smooth=*/0.1);
65 auto g_ns = g_loss_bce_logits(fake_logits, /*non_saturating=*/true);
66 auto g_mm = g_loss_bce_logits(fake_logits, /*non_saturating=*/false);
67
68 // For hinge, treat logits as scores without sigmoid
69 auto d_h = d_loss_hinge(real_logits, fake_logits);
70 auto g_h = g_loss_hinge(fake_logits);
71
72 std::cout << "D BCE loss (smoothed): " << d_bce.item<double>() << "\n";
73 std::cout << "G non-saturating BCE loss: " << g_ns.item<double>() << "\n";
74 std::cout << "G minimax (saturating) BCE loss: " << g_mm.item<double>() << "\n";
75 std::cout << "D hinge loss: " << d_h.item<double>() << ", G hinge loss: " << g_h.item<double>() << "\n";
76 return 0;
77}
78

This example shows numerically stable GAN losses in C++. It uses logits and BCEWithLogits to avoid instability from log(1 āˆ’ D(.)) and demonstrates label smoothing for real labels. It also provides hinge losses, commonly paired with spectral normalization. The minimax and non-saturating variants share equilibria, but the non-saturating loss yields stronger generator gradients early in training.

Time: O(B) per forward loss computation, where B is batch size; dominated by upstream model passes.Space: O(B) for storing logits and targets; dominated by upstream activations if part of a full model.
Toy 1D GAN in LibTorch with TTUR and non-saturating loss
1#include <torch/torch.h>
2#include <iostream>
3#include <cmath>
4
5// Simple MLP Generator: z -> x (1D)
6struct GeneratorImpl : torch::nn::Module {
7 torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
8 GeneratorImpl(int zdim, int h=64) {
9 fc1 = register_module("fc1", torch::nn::Linear(zdim, h));
10 fc2 = register_module("fc2", torch::nn::Linear(h, h));
11 fc3 = register_module("fc3", torch::nn::Linear(h, 1));
12 }
13 torch::Tensor forward(torch::Tensor z) {
14 z = torch::leaky_relu(fc1->forward(z), 0.2);
15 z = torch::leaky_relu(fc2->forward(z), 0.2);
16 // No activation on output; 1D real values
17 return fc3->forward(z);
18 }
19};
20TORCH_MODULE(Generator);
21
22// Simple MLP Discriminator: x (1D) -> logit
23struct DiscriminatorImpl : torch::nn::Module {
24 torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
25 DiscriminatorImpl(int h=64) {
26 fc1 = register_module("fc1", torch::nn::Linear(1, h));
27 fc2 = register_module("fc2", torch::nn::Linear(h, h));
28 fc3 = register_module("fc3", torch::nn::Linear(h, 1));
29 }
30 torch::Tensor forward(torch::Tensor x) {
31 x = torch::leaky_relu(fc1->forward(x), 0.2);
32 x = torch::leaky_relu(fc2->forward(x), 0.2);
33 // Output is a logit; use BCEWithLogits for stability
34 return fc3->forward(x);
35 }
36};
37TORCH_MODULE(Discriminator);
38
39// Stable BCEWithLogits wrappers
40namespace F = torch::nn::functional;
41
42torch::Tensor d_loss_bce_logits(const torch::Tensor& real_logits,
43 const torch::Tensor& fake_logits,
44 double eps_smooth = 0.1) {
45 auto real_targets = torch::full_like(real_logits, 1.0 - eps_smooth);
46 auto fake_targets = torch::zeros_like(fake_logits);
47 auto real_loss = F::binary_cross_entropy_with_logits(real_logits, real_targets,
48 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
49 auto fake_loss = F::binary_cross_entropy_with_logits(fake_logits, fake_targets,
50 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
51 return real_loss + fake_loss;
52}
53
54torch::Tensor g_loss_non_saturating(const torch::Tensor& fake_logits) {
55 auto targets = torch::ones_like(fake_logits); // pretend fakes are real
56 return F::binary_cross_entropy_with_logits(fake_logits, targets,
57 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kMean));
58}
59
60int main() {
61 torch::manual_seed(123);
62 torch::Device device = torch::kCPU; // change to torch::kCUDA if available and linked
63
64 const int zdim = 16;
65 const int batch = 128;
66 const int steps = 2000;
67 const double lrD = 2e-4; // TTUR: D often needs a larger lr
68 const double lrG = 1e-4;
69
70 Generator G(zdim);
71 Discriminator D;
72 G->to(device);
73 D->to(device);
74
75 torch::optim::Adam optimD(D->parameters(), torch::optim::AdamOptions(lrD).betas({0.5, 0.999}));
76 torch::optim::Adam optimG(G->parameters(), torch::optim::AdamOptions(lrG).betas({0.5, 0.999}));
77
78 auto sample_real = [&](int n) {
79 // Real data: 1D Gaussian with mean 4.0, std 1.0
80 return (torch::randn({n, 1}, device) * 1.0 + 4.0);
81 };
82
83 auto sample_noise = [&](int n) {
84 return torch::randn({n, zdim}, device);
85 };
86
87 for (int t = 1; t <= steps; ++t) {
88 // ========== Discriminator step ==========
89 D->zero_grad();
90 // Real batch
91 auto x_real = sample_real(batch);
92 auto real_logits = D->forward(x_real);
93 // Fake batch (detach so G isn't updated here)
94 auto z = sample_noise(batch);
95 auto x_fake = G->forward(z).detach();
96 auto fake_logits = D->forward(x_fake);
97
98 auto lossD = d_loss_bce_logits(real_logits, fake_logits, /*eps_smooth=*/0.1);
99 lossD.backward();
100 // Optional gradient clipping for stability
101 torch::nn::utils::clip_grad_norm_(D->parameters(), 10.0);
102 optimD.step();
103
104 // ========== Generator step ==========
105 G->zero_grad();
106 z = sample_noise(batch);
107 auto fake_logits_for_G = D->forward(G->forward(z)); // do not detach here
108 auto lossG = g_loss_non_saturating(fake_logits_for_G);
109 lossG.backward();
110 torch::nn::utils::clip_grad_norm_(G->parameters(), 10.0);
111 optimG.step();
112
113 if (t % 200 == 0) {
114 // Monitor: losses and sample statistics
115 auto with_no_grad = torch::NoGradGuard();
116 auto zvis = sample_noise(1000);
117 auto xg = G->forward(zvis).cpu();
118 double mean_g = xg.mean().item<double>();
119 double std_g = xg.std().item<double>();
120 std::cout << "Step " << t
121 << ": D_loss=" << lossD.item<double>()
122 << ", G_loss=" << lossG.item<double>()
123 << ", G(x) mean=" << mean_g
124 << ", std=" << std_g << "\n";
125 }
126 }
127
128 // Generate a few samples at the end
129 auto z = torch::randn({10, zdim}, device);
130 auto samples = G->forward(z).cpu();
131 std::cout << "Final samples:\n" << samples.squeeze() << "\n";
132 return 0;
133}
134

This complete LibTorch program trains a tiny 1D GAN to learn a Gaussian centered at 4. It uses logits with BCEWithLogits, non-saturating generator loss, label smoothing for real labels, Adam with TTUR, and gradient clipping. The logs report discriminator and generator losses plus the generator’s output mean and std to diagnose convergence (they should approach 4 and 1). The pattern scales to higher dimensions with architectural changes (e.g., CNNs).

Time: Per step: O(B Ā· (cost(D fwd+back on real+fake) + cost(G fwd+back) + cost(D fwd on G(z)))) ā‰ˆ O(B Ā· (P_D + P_G)) where P_* reflect parameterized compute.Space: O(P_D + P_G + A) where A is activation memory ~ O(B Ā· layers). The program uses additional O(B) buffers for targets and noise.
#gan#generator#discriminator#minimax#non-saturating loss#hinge loss#wasserstein#gradient penalty#ttur#libtorch#binary cross-entropy#logits#mode collapse#spectral normalization#training dynamics