Group Theory for Neural Networks
Key Points
- â˘Group theory gives a precise language for symmetries, and neural networks can exploit these symmetries to learn faster and generalize better.
- â˘An operation is equivariant if transforming the input and then applying the operation equals applying the operation and then transforming the output.
- â˘Classic CNNs are built on the translation group; more general groups (permutations, rotations, reflections) lead to powerful equivariant architectures.
- â˘Group convolutions and weight-tying via group averaging (Reynolds operator) are practical tools to enforce equivariance.
- â˘Permutation-equivariant layers for sets and circular convolutions for cyclic groups are easy to implement in C++ and illustrate the core ideas.
- â˘Equivariance reduces parameters by sharing them across symmetry-related parts, improving sample efficiency and stability.
- â˘Testing equivariance is straightforward: check f(g¡x) â g¡f(x) numerically for random group elements.
- â˘Naive group convolution is O(^2), but for abelian groups like we can use FFT to reach O( ).
Prerequisites
- âLinear algebra (vectors, matrices, matrix multiplication) â Group actions on data are implemented as linear maps; equivariance conditions are expressed as commutation relations.
- âBasic abstract algebra (groups, permutations) â Understanding the axioms and examples of groups is essential to reason about symmetry and actions.
- âConvolution and Fourier transform (discrete) â Group convolution generalizes classical convolution; FFT accelerates computations for abelian groups like C_n.
- âNeural network basics (layers, parameters, backpropagation) â You must know how layers are composed and trained to apply symmetry constraints effectively.
- âNumerical computing and floating point â Equivariance tests require tolerances; implementation details can break symmetry via rounding or padding.
- âGraph theory (optional) â Graph neural networks rely on permutation symmetries and automorphisms.
- â3D geometry and rotations (optional) â Equivariant models for molecules/physics use rotation/reflection groups and their representations.
Detailed Explanation
Tap terms for definitions01Overview
Hook â Concept â Example. Hook: Imagine a classifier that recognizes a rotated cat as a cat without ever seeing that exact rotation. Concept: Group theory formalizes such regularities as symmetriesâtransformations (like shifts, rotations, or permutations) that preserve problem structure. In neural networks, we use group actions to build layers that are either invariant (output unchanged) or equivariant (output transforms predictably) under these symmetries. This lets us share weights across symmetric situations, reducing parameters and improving generalization. Example: CNNs use translation symmetry via convolution, making features shift equivariant and ultimately producing shift-invariant predictions via pooling.
Group theory for neural networks studies how to encode a chosen symmetry group G into architectures. The key ingredients are (1) a representation of G that acts on inputs and outputs, and (2) maps f that satisfy f(Ď_X(g)x) = Ď_Y(g)f(x). This principle extends CNNs (translations) to sets (permutation symmetry), molecules (3D rotations/reflections), graphs (automorphisms), and more. Practically, we implement group convolutions, tie or average weights across group orbits, and verify equivariance numerically. By aligning models with known symmetries, we obtain sample efficiency, robustness, and interpretability while often reducing training time and data needs.
02Intuition & Analogies
Hook â Concept â Example. Hook: Think of a jigsaw puzzle. Turning a piece doesnât change whether it fitsâonly its orientation. If your strategy understood that rotations are âallowed moves,â youâd try far fewer options. Concept: A group is just a set of allowed moves (transformations) you can applyâlike rotating, shifting, or permutingâthat can be combined, undone, and that have a âdo nothingâ move. A group action says how those moves apply to your data (images, point sets, graphs). If your modelâs behavior is consistent with these moves, it doesnât have to relearn the same pattern every time it appears in a new pose.
Example 1: Rotating a photo doesnât change whether it contains a dog. A rotation-invariant classifier should output the same label regardless of rotation. Example 2: A bag of marbles has no ordering; swapping positions shouldnât change the summary. A permutation-invariant function like the mean treats all permutations equally. Example 3: Shifting an ECG time series left or right shouldnât change where the model detects a heartbeat, just the position; that is shift equivarianceâexactly what 1D convolution provides.
The big payoff: When we bake symmetries into a network, each learned pattern automatically applies in all symmetric situations. This is like learning âa wheelâ once, not separately for every rotation. Convolutions, orbit-averaging weights, and carefully designed layers are the tools that turn this symmetry intuition into working code.
03Formal Definition
04When to Use
Hook â Concept â Example. Hook: If you can describe how your data should look after a transformation and how the target should change (or not), you likely have a symmetry to exploit. Concept: Use group-theoretic models when the task is structured by known symmetriesâtranslation, rotation, reflection, or permutationsâso that weight sharing and equivariance reduce redundancy.
Use cases:
- Images/audio/time series with shift symmetry â standard CNNs (C_n or Z actions). Example: ECG beat detection should shift with the signal.
- Sets/point clouds with permutation symmetry (S_n) â Deep Sets or permutation-equivariant layers. Example: Predict total mass from unordered particles.
- Molecules/3D physics with rotations/reflections (SO(3), O(3), or discrete subgroups) â steerable/equivariant networks. Example: Predict energy invariant to global rotation.
- Graphs with automorphism symmetry â message passing that is permutation-equivariant at node level and invariant at graph level. Example: Classify molecular graphs.
- Data augmentation alternative â Instead of sampling many transformed inputs, build the symmetry into the architecture via group convolutions or weight-tying (Reynolds operator). Example: Replace exhaustive rotations by a rotation-equivariant layer.
Choose finite/discrete groups for simple, fast implementations (e.g., C_n, D_n, S_n); use continuous groups when physics demands it, often via specialized libraries for steerable filters or spherical harmonics.
â ď¸Common Mistakes
Hook â Concept â Example. Hook: Many bugs come from mixing up what should change and what shouldnât. Concept: Be precise about invariance vs. equivariance, and ensure the action you code is really a group action. Example: Padding an image inconsistently breaks translation equivariance even if the kernel is fine.
Common pitfalls:
- Confusing invariance with equivariance: If labels should move with inputs (segmentation), enforce equivariance, not invariance.
- Implementing a non-action: Your transformation set must be closed with an identity and inverses. Forgetting inverses (e.g., using only forward rotations) invalidates proofs and tests.
- Breaking symmetry in preprocessing: Non-circular padding destroys circular shift equivariance; inconsistent interpolation breaks rotation equivariance.
- Assuming commutativity: Many groups (e.g., D_n, S_n) are non-abelian; averaging or composing assuming abelian properties gives wrong results.
- Ignoring representation choice: Ď_X and Ď_Y must be compatible; otherwise, no linear map W can commute with the action.
- Numerical checks without tolerance: Floating-point roundoff can make exact equality tests fail; use norms with epsilons.
- Over-enumeration: Explicitly listing large groups is exponential; prefer generators or structure (e.g., cyclic shifts) and fast algorithms (FFT) when possible.
- Weight-tying mistakes: Averaging W over the group must use P W P^{-1}, not just P W or W P, to ensure commutation.
Key Formulas
Equivariance
Explanation: This equation states that transforming the input by g and then applying f equals transforming the output by g after applying f. Use it to define and test equivariant layers.
Invariance
Explanation: An invariant function ignores the group transformation entirely. Use this when the taskâs label should not change under the symmetry (e.g., classification).
Representation Homomorphism
Explanation: A representation maps group elements to invertible linear transformations consistently with group multiplication. It guarantees that compositions of actions behave as expected.
Group Convolution
Explanation: Convolution generalized to a group uses the groupâs multiplication and inverse. For cyclic groups, this reduces to circular convolution on indices.
Reynolds Operator (Projection)
Explanation: Averaging a linear map over the group projects it onto the space of equivariant maps. After this operation, W commutes with every group action element.
Orbit-Stabilizer Theorem
Explanation: For a finite group, the size of the orbit of x times the size of its stabilizer equals the groupâs size. It helps count distinct symmetry-related configurations.
S_n-Equivariant Linear Maps on R^n
Explanation: Any linear map commuting with all permutations must be a linear combination of the identity and the all-ones projector. This characterizes permutation-equivariant linear layers for sets.
Shift Equivariance of Circular Convolution
Explanation: Shifting the input of a circular convolution shifts the output by the same amount. This is the core symmetry exploited by CNNs on periodic domains.
Group Convolution Complexity (Abelian case)
Explanation: Naive group convolution sums over all pairs of group elements, while FFT-based methods for abelian groups reduce complexity to near-linear in group size.
Stabilizer Definition
Explanation: The stabilizer of x contains all symmetries that leave x unchanged. It can explain when features collapse under symmetry constraints.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Apply a permutation p to vector x: y[i] = x[p[i]] 5 vector<double> apply_permutation(const vector<int>& p, const vector<double>& x) { 6 int n = (int)x.size(); 7 vector<double> y(n); 8 for (int i = 0; i < n; ++i) y[i] = x[p[i]]; 9 return y; 10 } 11 12 // Equivariant linear map for S_n: y = a x + b (sum x) 1 13 vector<double> sn_equivariant_linear(const vector<double>& x, double a, double b) { 14 double s = accumulate(x.begin(), x.end(), 0.0); 15 vector<double> y(x.size()); 16 for (size_t i = 0; i < x.size(); ++i) y[i] = a * x[i] + b * s; 17 return y; 18 } 19 20 // Helper: max absolute difference between two vectors 21 double max_abs_diff(const vector<double>& a, const vector<double>& b) { 22 double m = 0.0; 23 for (size_t i = 0; i < a.size(); ++i) m = max(m, fabs(a[i] - b[i])); 24 return m; 25 } 26 27 int main() { 28 ios::sync_with_stdio(false); 29 cin.tie(nullptr); 30 31 int n = 8; 32 vector<double> x(n); 33 mt19937 rng(42); 34 uniform_real_distribution<double> dist(-1.0, 1.0); 35 for (int i = 0; i < n; ++i) x[i] = dist(rng); 36 37 // Random permutation p 38 vector<int> p(n); 39 iota(p.begin(), p.end(), 0); 40 shuffle(p.begin(), p.end(), rng); 41 42 // Parameters for the equivariant layer 43 double a = 1.7, b = -0.3; 44 45 // Check equivariance: f(Px) == P f(x) 46 vector<double> lhs = sn_equivariant_linear(apply_permutation(p, x), a, b); // f(Px) 47 vector<double> rhs = apply_permutation(p, sn_equivariant_linear(x, a, b)); // P f(x) 48 49 double err = max_abs_diff(lhs, rhs); 50 cout << fixed << setprecision(6); 51 cout << "Max |f(Px) - P f(x)| = " << err << "\n"; 52 return 0; 53 } 54
For sets, the symmetric group S_n acts by permuting elements. All linear S_n-equivariant maps on R^n have the form y = a x + b (sum x) 1. The code builds such a layer and numerically verifies equivariance for a random permutation p by checking f(Px) â P f(x).
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Rotate vector x by s positions (positive s rotates to the right) 5 vector<double> rotate_cyclic(const vector<double>& x, int s) { 6 int n = (int)x.size(); 7 vector<double> y(n); 8 s = ((s % n) + n) % n; 9 for (int i = 0; i < n; ++i) { 10 y[(i + s) % n] = x[i]; 11 } 12 return y; 13 } 14 15 // Naive circular convolution (group convolution on C_n) 16 // y[g] = sum_h k[h^{-1} g] x[h]; with C_n, h^{-1} g corresponds to (g - h) mod n 17 vector<double> circular_convolution(const vector<double>& k, const vector<double>& x) { 18 int n = (int)x.size(); 19 vector<double> y(n, 0.0); 20 for (int g = 0; g < n; ++g) { 21 double acc = 0.0; 22 for (int h = 0; h < n; ++h) { 23 int idx = (g - h) % n; if (idx < 0) idx += n; // h^{-1} g 24 acc += k[idx] * x[h]; 25 } 26 y[g] = acc; 27 } 28 return y; 29 } 30 31 // Helper: max absolute difference 32 double max_abs_diff(const vector<double>& a, const vector<double>& b) { 33 double m = 0.0; 34 for (size_t i = 0; i < a.size(); ++i) m = max(m, fabs(a[i] - b[i])); 35 return m; 36 } 37 38 int main() { 39 ios::sync_with_stdio(false); 40 cin.tie(nullptr); 41 42 int n = 16; 43 mt19937 rng(123); 44 normal_distribution<double> dist(0.0, 1.0); 45 46 vector<double> x(n), k(n); 47 for (int i = 0; i < n; ++i) { x[i] = dist(rng); k[i] = dist(rng); } 48 49 // Compute y = k * x 50 vector<double> y = circular_convolution(k, x); 51 52 // Pick a random shift s and test equivariance: conv(k, rotate(x,s)) == rotate(conv(k,x), s) 53 int s = uniform_int_distribution<int>(0, n-1)(rng); 54 vector<double> lhs = circular_convolution(k, rotate_cyclic(x, s)); 55 vector<double> rhs = rotate_cyclic(y, s); 56 57 double err = max_abs_diff(lhs, rhs); 58 cout << fixed << setprecision(6); 59 cout << "n = " << n << ", shift s = " << s << "\n"; 60 cout << "Max |conv(k, R_s x) - R_s conv(k, x)| = " << err << "\n"; 61 return 0; 62 } 63
This implements group convolution on the cyclic group C_n, which is classical circular convolution. The test verifies shift equivariance: rotating the input by s rotates the output by s. The naive implementation directly follows the group convolution formula, using modular indexing for h^{-1} g.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Apply left multiplication by permutation p: permute rows of matrix M 5 vector<vector<double>> left_apply_perm(const vector<int>& p, const vector<vector<double>>& M) { 6 int n = (int)p.size(); 7 vector<vector<double>> R(n, vector<double>(n)); 8 for (int i = 0; i < n; ++i) R[i] = M[p[i]]; // row i becomes row p[i] 9 return R; 10 } 11 12 // Apply right multiplication by inverse of p: permute columns by p^{-1} 13 vector<vector<double>> right_apply_perm_inv(const vector<int>& p, const vector<vector<double>>& M) { 14 int n = (int)p.size(); 15 vector<int> pinv(n); 16 for (int i = 0; i < n; ++i) pinv[p[i]] = i; 17 vector<vector<double>> R = M; 18 for (int j = 0; j < n; ++j) { 19 int src = pinv[j]; 20 for (int i = 0; i < n; ++i) R[i][j] = M[i][src]; 21 } 22 return R; 23 } 24 25 // Add two matrices 26 vector<vector<double>> add_mat(const vector<vector<double>>& A, const vector<vector<double>>& B) { 27 int n = (int)A.size(); 28 vector<vector<double>> C(n, vector<double>(n)); 29 for (int i = 0; i < n; ++i) 30 for (int j = 0; j < n; ++j) 31 C[i][j] = A[i][j] + B[i][j]; 32 return C; 33 } 34 35 // Scale matrix 36 void scale_mat(vector<vector<double>>& A, double s) { 37 int n = (int)A.size(); 38 for (int i = 0; i < n; ++i) 39 for (int j = 0; j < n; ++j) 40 A[i][j] *= s; 41 } 42 43 // Reynolds operator: W_sym = (1/|G|) sum_{g in G} P_g W P_g^{-1} 44 vector<vector<double>> symmetrize_by_group(const vector<vector<int>>& group, const vector<vector<double>>& W) { 45 int n = (int)W.size(); 46 vector<vector<double>> Acc(n, vector<double>(n, 0.0)); 47 for (const auto& p : group) { 48 auto L = left_apply_perm(p, W); 49 auto R = right_apply_perm_inv(p, L); 50 Acc = add_mat(Acc, R); 51 } 52 scale_mat(Acc, 1.0 / (double)group.size()); 53 return Acc; 54 } 55 56 // Check commutation: P W P^{-1} == W for all p in group 57 bool check_commutes(const vector<vector<int>>& group, const vector<vector<double>>& W, double tol = 1e-9) { 58 for (const auto& p : group) { 59 auto L = left_apply_perm(p, W); 60 auto R = right_apply_perm_inv(p, L); 61 int n = (int)W.size(); 62 for (int i = 0; i < n; ++i) 63 for (int j = 0; j < n; ++j) 64 if (fabs(R[i][j] - W[i][j]) > tol) return false; 65 } 66 return true; 67 } 68 69 int main() { 70 ios::sync_with_stdio(false); 71 cin.tie(nullptr); 72 73 // Example group: C3 acting by cyclic permutations on indices {0,1,2} 74 vector<vector<int>> group = { 75 {0,1,2}, // identity 76 {2,0,1}, // rotate right by 1 77 {1,2,0} // rotate right by 2 78 }; 79 80 // Random 3x3 weight matrix W 81 int n = 3; 82 mt19937 rng(7); 83 uniform_real_distribution<double> dist(-1.0, 1.0); 84 vector<vector<double>> W(n, vector<double>(n)); 85 for (int i = 0; i < n; ++i) 86 for (int j = 0; j < n; ++j) 87 W[i][j] = dist(rng); 88 89 // Project W to the space of C3-equivariant maps 90 auto Wsym = symmetrize_by_group(group, W); 91 92 cout << fixed << setprecision(6); 93 cout << "Commutes before: " << (check_commutes(group, W) ? "yes" : "no") << "\n"; 94 cout << "Commutes after: " << (check_commutes(group, Wsym) ? "yes" : "no") << "\n"; 95 96 // Print Wsym 97 cout << "W_sym =\n"; 98 for (int i = 0; i < n; ++i) { 99 for (int j = 0; j < n; ++j) cout << setw(10) << Wsym[i][j] << ' '; 100 cout << '\n'; 101 } 102 103 return 0; 104 } 105
Averaging W over the group via W_sym = (1/|G|) â P W P^{-1} projects W onto the subspace of matrices that commute with every group element, ensuring equivariance of the linear map y = W x. We demonstrate this for C_3 permutations represented as index arrays and verify commutation numerically.