Mamba & Selective State Spaces
Key Points
- •Mamba uses a state-space model whose parameters are selected (gated) by the current input token, letting the model adapt its memory dynamics at each step.
- •A continuous-time SSM is discretized per step using an input-dependent step size \(\) to produce matrices \(\) and \(\) that update the hidden state.
- •Selective scan computes the sequence recurrence with time-varying matrices; it can be expressed as an associative composition of affine transformations, enabling parallel prefix-scan style evaluation.
- •For diagonal or structured \(A\), the matrix exponential and integral simplify, giving fast formulas: \( = \) and \( = ( - I)\) (with safe handling when eigenvalues are near zero).
- •Unlike attention, SSMs scale linearly with sequence length and support streaming; selectivity recovers context sensitivity by modulating \(, , \) per token.
- •Stability matters: \(A\) must have negative real parts (e.g., negative diagonal) and \(\) should be positive and bounded to avoid exploding states.
- •The selective scan can be implemented as a serial loop or via an associative block-matrix operator that is parallelizable on hardware.
- •In C++, a practical starting point is a diagonal-\(A\) SSM with sigmoid/softplus gates for \(, , \), using numerically safe exponentials.
Prerequisites
- →Linear algebra (vectors, matrices, diagonals) — State updates, matrix exponentials, and affine compositions rely on matrix operations.
- →Ordinary differential equations (basic) — Discretization formulas come from solving linear ODEs over an interval of length Δ.
- →Time complexity and scan algorithms — Selective scan benefits from understanding parallel prefix and associativity.
- →Numerical stability and floating-point basics — Safe use of exp, softplus, and limits is crucial to avoid overflow/underflow.
- →Probability and activations (sigmoid, softplus) — Gating functions use these to constrain parameters like B and Δ.
Detailed Explanation
Tap terms for definitions01Overview
Mamba and selective state spaces are modern sequence-modeling techniques built on state-space models (SSMs). An SSM evolves a hidden state through time according to linear dynamics driven by the input, and it emits an output from that state. Classic SSMs have fixed parameters, but Mamba introduces selectivity: the parameters that control the dynamics (like how fast the state decays and how strongly the input drives the state) depend on the current input token. Concretely, at each time step k, the model picks input-dependent matrices (or vectors in a structured form) A_k, B_k, C_k, and a step size Δ_k, then updates the hidden state x_{k+1} = \bar{A}_k x_k + \bar{B}_k u_k and outputs y_k = C_k x_k. Here, \bar{A}_k and \bar{B}_k come from discretizing a continuous-time SSM using the per-step Δ_k. This input-dependent discretization is the "selective scan." Why is this powerful? Traditional SSMs are fast and memory-efficient but can be too rigid; attention is flexible but quadratic. Selectivity injects flexibility into linear-time SSMs: the model can, for instance, slow down its memory when tokens are important, or gate specific channels of the state when certain patterns appear. The key algorithmic insight is that even with time-varying parameters, the per-step update is an affine map on the state, and compositions of affine maps form an associative operation. Therefore, we can evaluate long sequences either serially (O(n)) or with parallel-prefix (scan) techniques that exploit associativity, enabling high throughput on modern hardware.
02Intuition & Analogies
Imagine your notes app with a smart "memory knob" that you twist as you read: when a sentence matters, you dial up persistence so the memory fades slowly; when it’s filler, you dial it down so the old state decays quickly. That knob is like the input-dependent step size Δ_k and decay A in a selective SSM. Consider a factory conveyor belt (the hidden state) carrying parts along stations. Each new part (token) arrives with an instruction card that says: slow the belt a bit (change A), open gate #3 wider (increase B’s third channel so the input injects more into that state dimension), and read from chamber #1 (C selects a readout). After applying these instructions for the current part, the belt moves, and the next part brings its own instructions. This is the essence of selectivity: the sequence itself controls how the memory updates. Another analogy is a camera with an adaptive exposure. In bright scenes, you shorten the exposure (small Δ_k), preventing washout; in dim scenes, you lengthen it (large Δ_k), capturing more light. Translating to SSMs, Δ_k sets the effective time discretization: larger Δ_k allows more of the input to flow into the state (via (\bar{B}_k)) and increases how much the current state evolves (via (\bar{A}_k)). Finally, think of stacking Lego blocks of transformations. Each step constructs a block that says: “first shrink and rotate your current state a bit ((\bar{A}_k)), then add this input-shaped bump ((\bar{B}_k u_k)).” Snapping blocks together is associative: no matter how you group them, the final transformation is the same. That’s why we can evaluate the whole sequence with a parallel scan that composes these blocks efficiently, even though each block is different and depends on the token.
03Formal Definition
04When to Use
Use selective SSMs when you need long-context modeling with linear time and memory complexity but want context sensitivity approaching attention. They shine in language modeling, audio and speech processing, and multivariate time series where the importance of tokens changes over time. Streaming and low-latency inference are natural fits: the recurrence is causal and can be advanced one token at a time with O(d) work and memory for d-dimensional state. Choose a diagonal or low-rank-plus-diagonal parameterization of A to keep exponentials fast. Let B and C be gated by the token via lightweight linear layers and activations (e.g., sigmoid for B-gates and softplus for (\Delta)). When hardware parallelism is available (GPUs/TPUs), express the scan as an associative composition of affine maps to unlock parallel prefix algorithms. For small to medium state sizes, a serial loop can already be extremely efficient. Avoid selective SSMs if your task relies heavily on arbitrary, non-local pairwise interactions that benefit from explicit attention maps, or if your environment forbids even lightweight matrix operations per token. In those cases, hybrids (e.g., mixing attention heads with SSM channels) can be more robust.
⚠️Common Mistakes
- Ignoring stability: If A_k has eigenvalues with positive real parts or (\Delta_k) becomes very large, states can explode. Parameterize A with negative diagonals (e.g., (-\operatorname{softplus}(\omega))) and bound (\Delta_k) via softplus plus clamping.
- Numerical issues in discretization: Directly computing (\bar{B}_k = A_k^{-1}(\bar{A}_k - I)B_k) is unstable when eigenvalues are near zero. Use the stable limit (\frac{e^{\lambda \Delta} - 1}{\lambda} \approx \Delta) when (|\lambda \Delta|) is small.
- Misplaced gating: Applying C_k before updating the state or mixing time indices leads to off-by-one errors. The common convention is y_k = C_k x_k using the state after k updates (choose and stick to a convention consistently in code).
- Assuming standard parallel prefix applies to raw states: The state update with time-varying parameters is not a simple sum or product. You must lift it to affine maps and compose them associatively.
- Overly dense matrices: Full dense A_k, B_k, C_k per step are expensive. Use diagonal or low-rank-plus-diagonal forms and channel-wise gates.
- Forgetting causality: Using u_{k+1} to produce parameters for step k leaks future information. Gates should depend only on current/past inputs and state.
- Inconsistent units/scales: If inputs u_k have wildly varying scales, Δ_k and B_k gates may saturate. Normalize inputs and consider per-channel scaling.
Key Formulas
Continuous-time SSM
Explanation: This differential equation defines how the hidden state evolves under linear dynamics driven by input u(t), and how output is read from the state.
Per-step discretization
Explanation: Given input-dependent parameters and step size Δ_k, these define the discrete update matrices used in the recurrence.
Selective recurrence
Explanation: This is the core update of a selective SSM: state advances using token-specific dynamics and input drive; output reads from the current state.
Closed form for \bar{B}
Explanation: When is invertible, the integral can be evaluated in closed form. Numerically, use the series limit when eigenvalues are near zero.
Small-eigenvalue limit
Explanation: This limit provides a numerically stable substitute for \(( - 1)/\) when is very small.
Per-step affine map
Explanation: Each step is an affine transformation on the state, where the bias term encodes the input’s effect for that step.
Affine composition (associative)
Explanation: Composing two steps yields another affine map; this operation is associative and underlies parallel scan implementations.
Prefix state
Explanation: The state at time k is the composition of all previous step maps applied to the initial state. This justifies using a prefix-scan algorithm.
Stable diagonal parameterization
Explanation: A negative diagonal ensures contraction per channel, providing stability for the discretized dynamics.
Block-matrix lifting
Explanation: Lifting states to homogeneous coordinates turns each step into a linear map; products of across steps compute prefixes and are associative.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Utility: stable softplus and sigmoid 5 static inline float softplus(float x) { 6 if (x > 20.0f) return x; // exp overflow guard 7 if (x < -20.0f) return expf(x); // when very negative, log(1+e^x) ~ e^x 8 return log1pf(expf(x)); 9 } 10 static inline float sigmoid(float x) { 11 if (x >= 0) { 12 float z = expf(-x); 13 return 1.0f / (1.0f + z); 14 } else { 15 float z = expf(x); 16 return z / (1.0f + z); 17 } 18 } 19 20 // Numerically safe ratio r = (exp(lambda*dt) - 1) / lambda, with limit ~ dt when |lambda*dt| small 21 static inline float expm1_over_lambda(float lambda, float dt) { 22 float z = lambda * dt; 23 if (fabsf(z) < 1e-4f) return dt; // first-order limit 24 // use expm1 for better precision when z is small 25 return expm1f(z) / lambda; 26 } 27 28 // A simple linear layer: y = W x + b (row-major W: out x in) 29 struct Linear { 30 int out_dim, in_dim; 31 vector<float> W; // size out_dim * in_dim 32 vector<float> b; // size out_dim 33 Linear(int out_dim, int in_dim): out_dim(out_dim), in_dim(in_dim), W(out_dim*in_dim), b(out_dim) {} 34 vector<float> operator()(const vector<float>& x) const { 35 vector<float> y(out_dim, 0.0f); 36 for (int o = 0; o < out_dim; ++o) { 37 float acc = b[o]; 38 const float* wrow = &W[o * in_dim]; 39 for (int i = 0; i < in_dim; ++i) acc += wrow[i] * x[i]; 40 y[o] = acc; 41 } 42 return y; 43 } 44 }; 45 46 // Selective SSM forward pass with diagonal A = -softplus(omega) (stable), 47 // gates: B_k = sigmoid(Wb u_k), C_k = Wc u_k, dt_k = softplus(wd^T u_k) + eps 48 struct SelectiveSSM { 49 int d; // state dimension 50 int m; // input dimension 51 Linear Wb; // produces B gate (d) 52 Linear Wc; // produces C readout (out_dim x d), here out_dim == d for simplicity 53 vector<float> wd; // dt projection (size m) 54 vector<float> omega; // base diagonal params for A (size d), A = -softplus(omega) (constant over time) 55 float dt_eps; 56 57 SelectiveSSM(int d, int m) 58 : d(d), m(m), Wb(d, m), Wc(d, d), wd(m), omega(d), dt_eps(1e-3f) {} 59 60 // One full sequence forward. Returns outputs Y[k][o] with o in [0,d) 61 vector<vector<float>> forward(const vector<vector<float>>& U) const { 62 int n = (int)U.size(); 63 vector<vector<float>> Y(n, vector<float>(d, 0.0f)); 64 vector<float> x(d, 0.0f); // initial state x0 = 0 65 66 // Precompute stable diagonal A entries: lambda_i = -softplus(omega_i) 67 vector<float> lambda(d); 68 for (int i = 0; i < d; ++i) lambda[i] = -softplus(omega[i]); 69 70 for (int k = 0; k < n; ++k) { 71 const vector<float>& u = U[k]; // input at step k (size m) 72 73 // Gates 74 vector<float> Bgate = Wb(u); // size d 75 for (int i = 0; i < d; ++i) Bgate[i] = sigmoid(Bgate[i]); 76 77 vector<float> Cread = Wc(vector<float>(x.begin(), x.end())); // we'll multiply C by x after update 78 // Note: Here we treat Wc as producing a linear readout matrix times state; for simplicity, 79 // we use C_k x_k approximated by (Wc x_k) with input-dependent C optional. 80 // To make C depend on u_k as in selective SSM, use Cread = Wc( gate_from_u ); kept simple here. 81 82 // Δ_k 83 float dt_raw = inner_product(wd.begin(), wd.end(), u.begin(), 0.0f); 84 float dt = softplus(dt_raw) + dt_eps; // positive step size 85 86 // Discretization for diagonal A 87 // \bar{A}_k (diagonal): a_i = exp(lambda_i * dt) 88 // \bar{B}_k (diagonal): b_i = ((exp(lambda_i * dt) - 1)/lambda_i) * Bgate_i 89 vector<float> a(d), b(d); 90 for (int i = 0; i < d; ++i) { 91 float li = lambda[i]; 92 float ai = expf(li * dt); 93 float ri = expm1_over_lambda(li, dt); 94 a[i] = ai; 95 b[i] = ri * Bgate[i]; 96 } 97 98 // State update: x = a ⊙ x + b * u_inj, where we inject a scalar from u or a projection. 99 // For simplicity use a single scalar drive s = mean(u); real models use projection per channel. 100 float s = 0.0f; for (float val : u) s += val; s /= max(1, m); 101 for (int i = 0; i < d; ++i) x[i] = a[i] * x[i] + b[i] * s; 102 103 // Output y_k = C_k x_k. Here we use a simple linear readout Wc*x. 104 // (If C is gated by u, replace Wc*x with elementwise-gated readout.) 105 // We already computed Cread = Wc(x) above (uses x before update). Recompute with updated x: 106 vector<float> y = Wc(x); 107 Y[k] = y; 108 } 109 return Y; 110 } 111 }; 112 113 int main() { 114 ios::sync_with_stdio(false); 115 cin.tie(nullptr); 116 117 int n = 8; // sequence length 118 int m = 4; // input dimension 119 int d = 6; // state dimension 120 121 SelectiveSSM model(d, m); 122 123 // Initialize weights randomly but small 124 std::mt19937 rng(42); 125 std::normal_distribution<float> N(0.0f, 0.1f); 126 127 for (auto &w : model.Wb.W) w = N(rng); 128 for (auto &b : model.Wb.b) b = 0.0f; 129 for (auto &w : model.Wc.W) w = N(rng); 130 for (auto &b : model.Wc.b) b = 0.0f; 131 for (auto &w : model.wd) w = N(rng); 132 for (auto &o : model.omega) o = fabsf(N(rng)); // positive to make -softplus(omega) negative 133 134 // Build a toy input sequence 135 vector<vector<float>> U(n, vector<float>(m)); 136 for (int k = 0; k < n; ++k) for (int j = 0; j < m; ++j) U[k][j] = 0.5f * sinf(0.3f * k + 0.2f * j); 137 138 auto Y = model.forward(U); 139 140 // Print outputs 141 cout << fixed << setprecision(4); 142 for (int k = 0; k < n; ++k) { 143 cout << "y[" << k << "]: "; 144 for (int i = 0; i < d; ++i) cout << Y[k][i] << (i+1==d? '\n':' '); 145 } 146 return 0; 147 } 148
This program implements a selective SSM with a diagonal, stable A = -softplus(omega). For each token u_k it computes input-dependent gates: B_k via a sigmoid, Δ_k via softplus, and uses a simple readout. Discretization uses elementwise exponentials to obtain \bar{A}_k and a numerically stable formula for \bar{B}_k. The state is updated serially (causal), and outputs are produced by a linear readout. Although simplified, this captures the core selective scan: input-dependent dynamics per step with O(d) cost.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Helper: stable expm1_over_lambda for diagonal entries 5 static inline float expm1_over_lambda(float lambda, float dt) { 6 float z = lambda * dt; 7 if (fabsf(z) < 1e-4f) return dt; 8 return expm1f(z) / lambda; 9 } 10 11 // Affine map in R^d: x' = A ⊙ x + b, where A is diagonal (stored as vector) 12 struct AffineDiag { 13 vector<float> A; // size d: diagonal entries 14 vector<float> b; // size d: bias term 15 AffineDiag() {} 16 AffineDiag(int d): A(d, 1.0f), b(d, 0.0f) {} 17 }; 18 19 // Associative composition: (A2,b2) ∘ (A1,b1) = (A2*A1, A2⊙b1 + b2) 20 static inline AffineDiag compose(const AffineDiag& g2, const AffineDiag& g1) { 21 int d = (int)g1.A.size(); 22 AffineDiag g(d); 23 for (int i = 0; i < d; ++i) { 24 g.A[i] = g2.A[i] * g1.A[i]; 25 g.b[i] = g2.A[i] * g1.b[i] + g2.b[i]; 26 } 27 return g; 28 } 29 30 // Build per-step affine map from gates for diagonal A 31 AffineDiag step_from_gates(const vector<float>& lambda, float dt, const vector<float>& Bgate, float s) { 32 int d = (int)lambda.size(); 33 AffineDiag g(d); 34 for (int i = 0; i < d; ++i) { 35 float ai = expf(lambda[i] * dt); 36 float ri = expm1_over_lambda(lambda[i], dt); 37 g.A[i] = ai; 38 g.b[i] = ri * Bgate[i] * s; // b_k = \bar{B}_k u_k (scalar drive s) 39 } 40 return g; 41 } 42 43 // Apply affine map to state 44 static inline void apply(const AffineDiag& g, vector<float>& x) { 45 int d = (int)x.size(); 46 for (int i = 0; i < d; ++i) x[i] = g.A[i] * x[i] + g.b[i]; 47 } 48 49 // Divide-and-conquer prefix scan over affine maps (single-threaded demonstration) 50 void prefix_scan_affine(const vector<AffineDiag>& G, vector<AffineDiag>& prefix) { 51 int n = (int)G.size(); 52 prefix.resize(n); 53 if (n == 0) return; 54 // Up-sweep: build segment tree of compositions 55 int p = 1; while (p < n) p <<= 1; 56 vector<AffineDiag> seg(2*p); 57 // Initialize leaves 58 for (int i = 0; i < n; ++i) seg[p+i] = G[i]; 59 for (int i = n; i < p; ++i) seg[p+i] = AffineDiag((int)G[0].A.size()); 60 // Internal nodes 61 for (int i = p-1; i >= 1; --i) seg[i] = compose(seg[2*i], seg[2*i+1]); 62 // Down-sweep to compute prefixes 63 AffineDiag id((int)G[0].A.size()); // identity: A=1, b=0 64 function<void(int,const AffineDiag&)> dfs = [&](int idx, const AffineDiag& left_prefix) { 65 if (idx >= p) { 66 int leaf = idx - p; 67 if (leaf < n) prefix[leaf] = compose(left_prefix, seg[idx]); 68 return; 69 } 70 // left child receives left_prefix 71 dfs(2*idx, left_prefix); 72 // right child receives left_prefix ∘ left_subtree 73 AffineDiag mid = compose(left_prefix, seg[2*idx]); 74 dfs(2*idx+1, mid); 75 }; 76 dfs(1, id); 77 } 78 79 int main() { 80 ios::sync_with_stdio(false); 81 cin.tie(nullptr); 82 83 int n = 8; // sequence length 84 int d = 4; // state dimension 85 86 // Stable diagonal lambdas 87 vector<float> lambda(d); 88 for (int i = 0; i < d; ++i) lambda[i] = -0.2f * (i + 1); // negative for stability 89 90 // Toy gates per step 91 vector<vector<float>> Bgate(n, vector<float>(d)); 92 vector<float> dt(n), s(n); 93 for (int k = 0; k < n; ++k) { 94 for (int i = 0; i < d; ++i) Bgate[k][i] = 1.0f / (1.0f + expf(-(0.5f * sinf(0.3f*k + 0.1f*i)))); 95 dt[k] = log1pf(expf(0.2f + 0.1f * cosf(0.4f*k))) + 1e-3f; // softplus + eps 96 s[k] = 0.5f * sinf(0.5f * k); // scalar drive from input 97 } 98 99 // Build per-step affine maps 100 vector<AffineDiag> G(n); 101 for (int k = 0; k < n; ++k) G[k] = step_from_gates(lambda, dt[k], Bgate[k], s[k]); 102 103 // Serial baseline 104 vector<float> x_serial(d, 0.0f); 105 for (int k = 0; k < n; ++k) apply(G[k], x_serial); 106 107 // Parallel-scan-friendly prefixes: prefix[k] = composition up to step k (inclusive) 108 vector<AffineDiag> P; 109 prefix_scan_affine(G, P); 110 111 // Apply prefixes to initial state x0=0 112 vector<float> x0(d, 0.0f); 113 vector<float> x_scan = x0; 114 if (n > 0) apply(P[n-1], x_scan); 115 116 cout << fixed << setprecision(4); 117 cout << "Serial state after n steps:\n"; 118 for (int i = 0; i < d; ++i) cout << x_serial[i] << (i+1==d?'\n':' '); 119 cout << "Scan-composed state after n steps:\n"; 120 for (int i = 0; i < d; ++i) cout << x_scan[i] << (i+1==d?'\n':' '); 121 122 return 0; 123 } 124
This example shows how to lift per-step selective updates into affine maps over the state with diagonal A. The composition (A2,b2)∘(A1,b1) = (A2*A1, A2⊙b1 + b2) is associative, so a tree-based divide-and-conquer scan computes all prefixes. We verify that composing all steps equals the result of the serial loop. On parallel hardware, the same operator enables an efficient parallel prefix (Blelloch scan) with reduced span.