PPO & Trust Region Methods
Key Points
- ā¢Proximal Policy Optimization (PPO) stabilizes policy gradient learning by preventing each update from moving the policy too far from the previous one.
- ā¢The key idea is to optimize a clipped surrogate objective that caps how much the probability ratio ( can change, acting like training wheels for policy updates.
- ā¢Trust region methods formalize this idea by bounding the average KL divergence between old and new policies, creating a safe update region.
- ā¢PPO is a practical, first-order approximation to trust region ideas that works well without expensive second-order optimization.
- ā¢Advantages () tell the policy which actions were better or worse than expected and are crucial for low-variance learning.
- ā¢Generalized Advantage Estimation (GAE) trades bias and variance to produce smoother advantage targets for PPO.
- ā¢Monitoring KL divergence and normalizing advantages are simple but essential tricks to keep PPO stable.
- ā¢In practice, you combine the clipped policy loss with a value-function loss (and optionally entropy bonus), trained over several epochs on the same on-policy batch.
Prerequisites
- āMarkov Decision Processes (MDPs) ā Defines states, actions, rewards, and transitions, which PPO optimizes over.
- āPolicy gradient basics ā PPO builds on the REINFORCE surrogate and gradients of log probabilities.
- āAdvantage functions and baselines ā Understanding A_t and variance reduction is central to PPOās objective.
- āKL divergence ā Trust region constraints and PPO diagnostics use KL as a distance between policies.
- āStochastic gradient descent ā PPO uses first-order optimization over minibatches and epochs.
- āGeneralized Advantage Estimation (GAE) ā Commonly used to compute stable advantages for PPO updates.
- āSoftmax and log-softmax ā Discrete-action policies rely on stable probability computations.
- āBasic calculus and chain rule ā Needed to implement gradients of the surrogate and value loss.
Detailed Explanation
Tap terms for definitions01Overview
Proximal Policy Optimization (PPO) and trust region methods are techniques in reinforcement learning (RL) designed to make policy gradient training stable and efficient. In vanilla policy gradient, we directly maximize expected returns by pushing up the probability of actions that turned out well. However, if we push too hard, the policy can change drastically after a single update, causing instability or collapse. Trust region methods address this by limiting how far the new policy can move from the old one on each step, typically measured by KL divergence. PPO is a widely used, practical way to enforce this principle using a clipped surrogate objective: you optimize the policy gradient but cap the probability ratio so the update cannot overshoot. This gives you most of the benefits of trust region methods without complex second-order optimization. Practically, PPO combines three pieces: a policy (actor) trained with a clipped objective, a value function (critic) trained with regression to returns, and (optionally) an entropy bonus to encourage exploration. The method iterates between collecting on-policy rollouts, computing advantages (often with Generalized Advantage Estimation, GAE), and performing several epochs of minibatch gradient updatesācarefully constrained to keep the new policy ācloseā to the old one.
02Intuition & Analogies
Imagine steering a boat along a river toward a goal. If you crank the wheel too hard after every small observation, youāll zigzag wildly and may capsize. Trust region methods give you a speed limit for steering: you can adjust course, but only within a safe lane each time. In RL terms, the policy is your steering strategy, and the trust region is a bound that ensures each update doesnāt veer too far from what was previously working.
PPOās clipping is like installing bumpers on the steering wheel. The update tries to increase the probability of good actions (and decrease for bad ones), but if the change gets too large, the bumpers prevent more turning. The probability ratio r_t(Īø) tells you how much the new policy differs from the old one for the chosen action. If r_t shoots above 1+ε (got too enthusiastic) or below 1āε (over-penalized), the clip stops further change for that sample. This keeps learning stable even when advantages are noisy.
Another helpful metaphor: price negotiation with a contract. You want to adjust prices based on new info (advantages), but the contract limits per-update changes (trust region/KL bound). You still move toward better prices, but within safe increments that avoid breaking the deal. GAE is your accountant smoothing out profits/losses across time so that your updates arenāt swayed by a single lucky or unlucky sale. With these safeguardsāclipping/contract limits and smoothed accountingāPPO tends to make steady, reliable progress without expensive mathematics behind the scenes.
03Formal Definition
04When to Use
Use PPO when you want a robust, on-policy policy gradient method that is easy to implement and tune across many environments. It is particularly effective in continuous control (robotics, locomotion) and also works in discrete-action tasks where stability is a concern. If your environment is non-stationary or highly stochastic, PPOās clipping and GAE help control variance and prevent destructive updates. When batch collection is expensive but you still prefer on-policy updates, PPOās multiple epochs over the same data improve sample efficiency compared to vanilla REINFORCE.
Prefer explicit trust region methods (like TRPO) if you need stricter theoretical guarantees on step size in policy space or if catastrophic policy updates are very costly. If you can afford extra computation for second-order optimization and Fisher-vector products, TRPOās KL-constrained steps can be attractive. If you rely on small models or need extremely fast iteration with simple first-order code, PPOās clipped surrogate (and optional KL early stopping) gives you most benefits at much lower implementation complexity.
If your policies are very large and off-policy data is abundant, consider algorithms like SAC or TD3 instead. If your state abstraction is simple and tabular, classic policy iteration or actor-critic without trust regions may suffice.
ā ļøCommon Mistakes
- Using a large learning rate or a large (\varepsilon) so the clip rarely engages; this defeats the purpose and often destabilizes training. Start with small steps and (\varepsilon\in[0.1,0.3]).
- Forgetting to treat (\pi_{\theta_{\text{old}}}) as a constant baseline during an epoch. You must freeze old log-probabilities when computing ratios; recomputing āoldā with updated parameters breaks the surrogate.
- Not normalizing advantages. Unnormalized, skewed advantages can dominate gradients and cause erratic updates; normalize to zero mean and unit variance per batch.
- Ignoring KL monitoring. Even with clipping, watch the empirical KL; if it spikes, reduce learning rate, reduce (\varepsilon), or stop the epoch early.
- Incorrect sign conventions in the loss. Remember we maximize the clipped surrogate but minimize total loss in code: implement as negative surrogate plus value loss (and minus entropy bonus).
- Value loss overpowering the policy loss (or vice versa). Tune (c_1) and (c_2) so the gradients have comparable magnitudes.
- Leaking gradients into advantage targets or āoldā policy. Detach returns/advantages and store old log-probs/probabilities at rollout time.
- Using too small a batch or too many epochs, which can lead to overfitting the on-policy batch and poor generalization. Balance batch size, epochs, and minibatch size.
Key Formulas
Probability Ratio
Explanation: This measures how much the new policy changes the probability of the taken action compared to the old policy. Ratios far from 1 indicate large updates that PPO aims to limit.
PPO Clipped Objective
Explanation: This objective maximizes the standard surrogate when the ratio is within bounds, but becomes constant outside, preventing overly large policy updates.
Combined PPO Loss
Explanation: In practice, PPO trains both the policy and value function together, adding a value regression term and optionally an entropy bonus for exploration.
Trust Region Constraint
Explanation: Trust region methods constrain the average KL divergence to keep the new policy close to the old one. This makes each update safe and avoids policy collapse.
KL Divergence (Discrete)
Explanation: This computes how different distribution q is from p. In PPO/TRPO, p is the old policy and q is the new policy at a state.
Discounted Return
Explanation: The target return is the discounted sum of future rewards. It acts as a regression target for the value function.
TD Residual
Explanation: The temporal-difference residual is the one-step error signal used inside GAE to compose advantages.
GAE
Explanation: GAE smooths advantages by exponentially weighting multi-step TD residuals. Ī» controls the bias-variance tradeoff.
Policy Gradient Estimator
Explanation: The gradient of expected return can be estimated by weighting log-probability gradients with advantages. PPO modifies this with ratios and clipping.
Policy Entropy
Explanation: Entropy measures randomness in the policy. Adding it to the objective encourages exploration by preventing premature determinism.
KL Penalty Variant
Explanation: An alternative to clipping is to penalize KL directly in the objective. β is tuned (or adapted) to keep KL near a target.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Simple 1D Random Walk environment: states are integer positions in [-GOAL, GOAL]. 5 // Actions: 0=left (-1), 1=right (+1). 6 // Reward: +1 when reaching +GOAL, -1 when reaching -GOAL; small step penalty -0.01; episode ends on goal or max steps. 7 struct RandomWalkEnv { 8 int GOAL = 5; 9 int max_steps = 50; 10 int pos = 0; 11 int steps = 0; 12 mt19937 rng; 13 14 RandomWalkEnv(unsigned seed=42) : rng(seed) { reset(); } 15 16 vector<double> reset() { 17 pos = 0; steps = 0; 18 return get_state(); 19 } 20 21 vector<double> get_state() const { 22 // 2 features: normalized position and bias 1 23 return { (double)pos / (double)GOAL, 1.0 }; 24 } 25 26 tuple<vector<double>, double, bool> step(int action) { 27 steps++; 28 pos += (action==0 ? -1 : +1); 29 double r = -0.01; // step penalty 30 bool done = false; 31 if (pos >= GOAL) { r = +1.0; done = true; } 32 if (pos <= -GOAL) { r = -1.0; done = true; } 33 if (steps >= max_steps) { done = true; } 34 return { get_state(), r, done }; 35 } 36 }; 37 38 // Small helper functions 39 static inline double logsumexp2(double a, double b) { 40 double m = max(a,b); 41 return m + log(exp(a-m) + exp(b-m)); 42 } 43 44 struct PolicyValue { 45 // Linear policy (2 actions) and linear value function over 2-dim features 46 // Policy: logits z = W * phi, W is 2x2 (row per action) 47 // Value: V = v^T phi 48 array<array<double,2>,2> W; // policy weights 49 array<double,2> v; // value weights 50 51 mt19937 rng; 52 uniform_real_distribution<double> unif{-0.01, 0.01}; 53 54 PolicyValue(unsigned seed=123) : rng(seed) { 55 for (int i=0;i<2;i++) for (int j=0;j<2;j++) W[i][j] = unif(rng); 56 for (int j=0;j<2;j++) v[j] = unif(rng); 57 } 58 59 // Compute logits and softmax probabilities for 2 actions 60 void policy_forward(const array<double,2>& phi, array<double,2>& logits, array<double,2>& probs) const { 61 for (int a=0;a<2;a++) logits[a] = W[a][0]*phi[0] + W[a][1]*phi[1]; 62 double lse = logsumexp2(logits[0], logits[1]); 63 probs[0] = exp(logits[0] - lse); 64 probs[1] = exp(logits[1] - lse); 65 } 66 67 double value_forward(const array<double,2>& phi) const { 68 return v[0]*phi[0] + v[1]*phi[1]; 69 } 70 71 int sample_action(const array<double,2>& phi, array<double,2>& probs_out, double& logp_out) { 72 array<double,2> logits, probs; 73 policy_forward(phi, logits, probs); 74 probs_out = probs; 75 // Categorical sample over 2 actions 76 double u = generate_canonical<double, 10>(rng); 77 int a = (u < probs[0]) ? 0 : 1; 78 logp_out = log(probs[a] + 1e-12); 79 return a; 80 } 81 }; 82 83 struct Transition { 84 array<double,2> phi; // features 85 int a; // action 86 double r; // reward 87 bool done; // terminal flag 88 double logp_old; // log prob of action under old policy 89 double vpred; // value prediction under old value net 90 array<double,2> p_old; // full old policy probs (for diagnostics if needed) 91 }; 92 93 int main() { 94 ios::sync_with_stdio(false); 95 cin.tie(nullptr); 96 97 // Hyperparameters 98 const double gamma = 0.99; 99 const double lambda_gae = 0.95; 100 const double clip_eps = 0.2; // PPO clip 101 const double lr = 5e-3; // learning rate (SGD) 102 const double c1_value = 0.5; // value loss weight 103 const int batch_size = 2000; // steps per iteration 104 const int epochs = 5; // PPO epochs per iteration 105 const int minibatch = 250; // minibatch size 106 const int updates = 60; // training iterations 107 108 RandomWalkEnv env(7); 109 PolicyValue net(1234); 110 111 mt19937 rng(2024); 112 113 auto to_phi = [](const vector<double>& s){ return array<double,2>{s[0], s[1]}; }; 114 115 double avg_return = 0.0; 116 117 for (int it=0; it<updates; ++it) { 118 // 1) Collect on-policy rollouts until we have batch_size steps 119 vector<Transition> buf; buf.reserve(batch_size+100); 120 int steps_collected = 0; 121 int episodes = 0; 122 double ret_sum = 0.0; 123 while (steps_collected < batch_size) { 124 vector<double> s = env.reset(); 125 array<double,2> phi = to_phi(s); 126 double ep_ret = 0.0; 127 for (int t=0; t<env.max_steps && steps_collected < batch_size; ++t) { 128 array<double,2> probs; double logp; 129 int a = net.sample_action(phi, probs, logp); 130 double vpred = net.value_forward(phi); 131 auto [s2, r, done] = env.step(a); 132 buf.push_back({phi, a, r, done, logp, vpred, probs}); 133 steps_collected++; 134 ep_ret += r; 135 phi = to_phi(s2); 136 if (done) break; 137 } 138 episodes++; 139 ret_sum += ep_ret; 140 } 141 avg_return = ret_sum / (double)episodes; 142 143 // 2) Compute returns and advantages (GAE) 144 int N = (int)buf.size(); 145 vector<double> returns(N), adv(N); 146 double next_value = 0.0; 147 double gae = 0.0; 148 for (int i=N-1; i>=0; --i) { 149 double mask = buf[i].done ? 0.0 : 1.0; 150 double delta = buf[i].r + gamma * next_value * mask - buf[i].vpred; 151 gae = delta + gamma * lambda_gae * mask * gae; 152 adv[i] = gae; 153 returns[i] = buf[i].vpred + adv[i]; 154 next_value = buf[i].vpred; // bootstrap with old value 155 if (buf[i].done) { gae = 0.0; next_value = 0.0; } 156 } 157 // Normalize advantages 158 double meanA=0, varA=0; for (double x: adv) meanA+=x; meanA/=N; for (double x: adv){ double d=x-meanA; varA+=d*d; } varA=max(varA/N,1e-8); 159 double stdA = sqrt(varA); 160 for (double &x: adv) x = (x - meanA) / (stdA + 1e-8); 161 162 // 3) PPO updates over multiple epochs 163 vector<int> idx(N); iota(idx.begin(), idx.end(), 0); 164 for (int e=0; e<epochs; ++e) { 165 shuffle(idx.begin(), idx.end(), rng); 166 for (int start=0; start<N; start+=minibatch) { 167 int end = min(start+minibatch, N); 168 int m = end - start; 169 // Accumulate gradients 170 array<array<double,2>,2> gW{}; // zero-init 171 array<double,2> gv{}; 172 173 double pg_obj = 0.0; // for monitoring (not required) 174 double v_loss = 0.0; 175 176 for (int ii=start; ii<end; ++ii) { 177 int i = idx[ii]; 178 const auto &tr = buf[i]; 179 array<double,2> logits, pnew; 180 // Forward new policy/value 181 array<double,2> phi = tr.phi; 182 // compute probs and value 183 for (int a=0;a<2;a++) logits[a] = net.W[a][0]*phi[0] + net.W[a][1]*phi[1]; 184 double lse = logsumexp2(logits[0], logits[1]); 185 pnew[0] = exp(logits[0]-lse); pnew[1] = exp(logits[1]-lse); 186 double logp_new = log(pnew[tr.a] + 1e-12); 187 double ratio = exp(logp_new - tr.logp_old); 188 double A = adv[i]; 189 190 // Clipped PPO objective per-sample 191 double unclipped = ratio * A; 192 double clipped = (ratio < 1.0 - clip_eps ? (1.0 - clip_eps) : (ratio > 1.0 + clip_eps ? (1.0 + clip_eps) : ratio)) * A; 193 double surr = min(unclipped, clipped); 194 pg_obj += surr; 195 196 // Policy gradient coefficient with proper clipping (no gradient when clipped is active and is the min) 197 bool clip_active = false; 198 if (A >= 0.0 && ratio > 1.0 + clip_eps) clip_active = true; 199 if (A < 0.0 && ratio < 1.0 - clip_eps) clip_active = true; 200 double pg_coef = clip_active ? 0.0 : ratio * A; // gradient stops when clipped term chosen 201 202 // grad of log pi(a|s) wrt logits is (onehot - p); wrt W is outer with phi 203 for (int a=0;a<2;a++) { 204 double dlogp_dz = (a == tr.a ? 1.0 : 0.0) - pnew[a]; 205 for (int k=0;k<2;k++) { 206 gW[a][k] += -(pg_coef / m) * dlogp_dz * phi[k]; // negative because we minimize -surr 207 } 208 } 209 210 // Value loss grad: 0.5*(V - R)^2 211 double V = net.v[0]*phi[0] + net.v[1]*phi[1]; 212 double R = returns[i]; 213 double dv = (V - R); 214 v_loss += 0.5 * dv * dv; 215 for (int k=0;k<2;k++) gv[k] += (c1_value / m) * dv * phi[k]; // positive in loss 216 } 217 218 // SGD step 219 for (int a=0;a<2;a++) for (int k=0;k<2;k++) net.W[a][k] -= lr * gW[a][k]; 220 for (int k=0;k<2;k++) net.v[k] -= lr * gv[k]; 221 } 222 } 223 224 if ((it+1)%5==0) { 225 cout << "Iter " << (it+1) << ": avg return = " << avg_return << "\n"; 226 } 227 } 228 229 // Quick evaluation after training 230 RandomWalkEnv eval_env(99); 231 double eval_ret = 0.0; int eval_episodes = 50; 232 for (int ep=0; ep<eval_episodes; ++ep) { 233 auto s = eval_env.reset(); 234 array<double,2> phi = {s[0], s[1]}; 235 double R = 0.0; 236 for (int t=0;t<eval_env.max_steps;t++) { 237 array<double,2> logits, p; 238 for (int a=0;a<2;a++) logits[a] = net.W[a][0]*phi[0] + net.W[a][1]*phi[1]; 239 double lse = logsumexp2(logits[0], logits[1]); 240 p[0] = exp(logits[0]-lse); p[1] = exp(logits[1]-lse); 241 int a = (p[0] > p[1]) ? 0 : 1; // greedy 242 auto [s2, r, done] = eval_env.step(a); 243 R += r; phi = {s2[0], s2[1]}; 244 if (done) break; 245 } 246 eval_ret += R; 247 } 248 cout << "Post-training average return (greedy over 50 eps): " << (eval_ret / eval_episodes) << "\n"; 249 return 0; 250 } 251
This program trains a simple PPO agent with a linear policy and value function on a small 1D random-walk environment. It collects on-policy rollouts, computes advantages with GAE, normalizes them, and performs multiple epochs of minibatch updates using the clipped PPO objective. The gradient for the policy uses the probability ratio but turns off when the clipped value is the active minimum, implementing the trust-region-like behavior of PPO. For simplicity, entropy regularization is omitted and the value loss is un-clipped. After training, it evaluates the learned policy greedily.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct RandomWalkEnv { 5 int GOAL=5, max_steps=50, pos=0, steps=0; mt19937 rng; 6 RandomWalkEnv(unsigned s=1):rng(s){reset();} 7 vector<double> reset(){ pos=0; steps=0; return {(double)pos/GOAL, 1.0}; } 8 tuple<vector<double>,double,bool> step(int a){ steps++; pos += (a==0?-1:+1); double r=-0.01; bool d=false; if(pos>=GOAL){r=+1; d=true;} if(pos<=-GOAL){r=-1; d=true;} if(steps>=max_steps) d=true; return {{(double)pos/GOAL,1.0}, r, d}; } 9 }; 10 11 static inline double logsumexp2(double a,double b){ double m=max(a,b); return m+log(exp(a-m)+exp(b-m)); } 12 13 struct Net { // linear policy (2 actions) + value 14 array<array<double,2>,2> W; array<double,2> v; mt19937 rng; uniform_real_distribution<double> unif{-0.01,0.01}; 15 Net(unsigned s=3):rng(s){ for(int i=0;i<2;i++)for(int j=0;j<2;j++)W[i][j]=unif(rng); for(int j=0;j<2;j++)v[j]=unif(rng);} 16 void forward(const array<double,2>& phi, array<double,2>& p, double& V) const { 17 array<double,2> z; for(int a=0;a<2;a++) z[a]=W[a][0]*phi[0]+W[a][1]*phi[1]; 18 double lse=logsumexp2(z[0],z[1]); p[0]=exp(z[0]-lse); p[1]=exp(z[1]-lse); 19 V=v[0]*phi[0]+v[1]*phi[1]; 20 } 21 int sample(const array<double,2>& phi, array<double,2>& p, double& logp){ double V; forward(phi,p,V); double u=generate_canonical<double,10>(rng); int a=(u<p[0]?0:1); logp=log(p[a]+1e-12); return a; } 22 }; 23 24 struct Tr { array<double,2> phi; int a; double r; bool done; double logp_old; array<double,2> p_old; double vpred; }; 25 26 int main(){ 27 const double gamma=0.99, lam=0.95, lr=5e-3, c1=0.5, target_kl=0.01; 28 const int batch_size=2000, epochs=10, minibatch=250, iters=60; 29 RandomWalkEnv env(11); Net net(5); mt19937 rng(2025); 30 31 auto to_phi=[](const vector<double>& s){ return array<double,2>{s[0],s[1]}; }; 32 33 for(int iter=0; iter<iters; ++iter){ 34 // Collect rollouts 35 vector<Tr> D; D.reserve(batch_size+100); 36 int steps=0; while(steps<batch_size){ auto s=env.reset(); auto phi=to_phi(s); for(int t=0;t<env.max_steps && steps<batch_size; ++t){ array<double,2> p; double logp; int a=net.sample(phi,p,logp); double V=net.v[0]*phi[0]+net.v[1]*phi[1]; auto [s2,r,d]=env.step(a); D.push_back({phi,a,r,d,logp,p,V}); phi=to_phi(s2); steps++; if(d)break; } } 37 int N=D.size(); 38 // GAE 39 vector<double> adv(N), ret(N); double nextV=0, gae=0; for(int i=N-1;i>=0;--i){ double mask=D[i].done?0.0:1.0; double delta=D[i].r + gamma*nextV*mask - D[i].vpred; gae=delta + gamma*lam*mask*gae; adv[i]=gae; ret[i]=D[i].vpred+adv[i]; nextV=D[i].vpred; if(D[i].done){gae=0; nextV=0;} } 40 double mA=0,sA=0; for(double x:adv)mA+=x; mA/=N; for(double x:adv){double d=x-mA; sA+=d*d;} sA=sqrt(max(sA/N,1e-8)); for(double&x:adv)x=(x-mA)/(sA+1e-8); 41 42 vector<int> idx(N); iota(idx.begin(),idx.end(),0); 43 for(int e=0;e<epochs;++e){ 44 shuffle(idx.begin(),idx.end(),rng); 45 double mean_kl_epoch=0.0; int kl_count=0; bool stop=false; 46 for(int st=0; st<N && !stop; st+=minibatch){ int en=min(st+minibatch,N); int m=en-st; 47 array<array<double,2>,2> gW{}; array<double,2> gv{}; double mean_kl=0.0; 48 for(int ii=st; ii<en; ++ii){ int i=idx[ii]; auto &tr=D[i]; auto phi=tr.phi; 49 array<double,2> pnew; double V; net.forward(phi,pnew,V); 50 // Unclipped surrogate with ratio 51 double logp_new=log(pnew[tr.a]+1e-12); double ratio=exp(logp_new - tr.logp_old); double A=adv[i]; 52 double pg_coef = ratio * A; // always on (no clipping) 53 for(int a=0;a<2;a++){ double dlogp_dz = (a==tr.a?1.0:0.0) - pnew[a]; for(int k=0;k<2;k++) gW[a][k] += -(pg_coef/m)*dlogp_dz*phi[k]; } 54 // Value grad 55 double dv=(V - ret[i]); for(int k=0;k<2;k++) gv[k] += (c1/m)*dv*phi[k]; 56 // KL(old||new) for early stopping 57 double kl = 0.0; for(int a=0;a<2;a++){ kl += tr.p_old[a] * (log(tr.p_old[a]+1e-12) - log(pnew[a]+1e-12)); } 58 mean_kl += kl; 59 } 60 mean_kl /= m; mean_kl_epoch += mean_kl; kl_count++; 61 // Early stopping if KL too large 62 if(mean_kl > 1.5*target_kl){ stop=true; break; } 63 for(int a=0;a<2;a++) for(int k=0;k<2;k++) net.W[a][k] -= lr*gW[a][k]; 64 for(int k=0;k<2;k++) net.v[k] -= lr*gv[k]; 65 } 66 mean_kl_epoch = (kl_count>0? mean_kl_epoch/kl_count : 0.0); 67 if(mean_kl_epoch > 1.5*target_kl) break; // stop whole epoch 68 } 69 if((iter+1)%5==0) cerr << "Iter "<<(iter+1)<<" done (KL-early-stop).\n"; 70 } 71 cout << "Training finished with KL early stopping variant." << "\n"; 72 return 0; 73 } 74
This variant removes clipping and instead uses the standard surrogate with the probability ratio. It monitors the empirical KL divergence between the stored old policy and the current policy and performs early stopping within an epoch if the KL grows beyond a threshold. This mimics a trust-region constraint without second-order methods, often improving stability compared to unconstrained policy gradients while avoiding TRPOās computational overhead.