Source code for tensorquantlib.tt.decompose

"""
Tensor-Train decomposition algorithms.

Implements:
    - tt_svd:   TT-SVD decomposition (Oseledets, 2011)
    - tt_round: TT-rounding via orthogonalization + truncated SVD
    - tt_cross: Black-box TT-Cross approximation (Oseledets & Tyrtyshnikov, 2010)
              Builds a TT decomposition without forming the full tensor,
              making 6+ asset problems feasible.
"""

from __future__ import annotations

from collections.abc import Callable

import numpy as np
from scipy.linalg import qr


[docs] def tt_svd( tensor: np.ndarray, eps: float = 1e-6, max_rank: int | None = None, ) -> list[np.ndarray]: """Tensor-Train SVD decomposition. Decomposes a d-dimensional tensor A of shape (n1, n2, ..., nd) into a list of TT-cores [G1, G2, ..., Gd] where G_k has shape (r_{k-1}, n_k, r_k) and r_0 = r_d = 1. The reconstruction satisfies ``||A - A_TT||_F <= eps * ||A||_F``. Algorithm: Sequential left-to-right unfolding with truncated SVD. Per-step truncation threshold: ``delta = eps * ||A||_F / sqrt(d-1)``. Args: tensor: Input tensor, shape (n1, n2, ..., nd). eps: Relative truncation tolerance. max_rank: Maximum TT-rank (optional safety cap). Returns: List of TT-cores. cores[k].shape = (r_{k-1}, n_k, r_k). """ d = tensor.ndim if d < 2: raise ValueError(f"Tensor must have at least 2 dimensions, got {d}") if eps < 0: raise ValueError(f"eps must be non-negative, got {eps}") if max_rank is not None and max_rank < 1: raise ValueError(f"max_rank must be >= 1, got {max_rank}") shape = tensor.shape norm_A = np.linalg.norm(tensor) # Handle zero tensor if norm_A < 1e-15: cores = [] for k in range(d): cores.append(np.zeros((1, shape[k], 1))) return cores # Per-step truncation threshold (guarantees total error <= eps * ||A||_F) delta = eps * norm_A / np.sqrt(d - 1) cores = [] C = tensor.copy().astype(np.float64) r_prev = 1 for k in range(d - 1): n_k = shape[k] # Reshape C into 2D matrix: (r_prev * n_k) x (remaining dimensions) C = C.reshape(r_prev * n_k, -1) # Economy SVD U, S, Vt = np.linalg.svd(C, full_matrices=False) # Rank selection: find smallest r_k such that # sqrt(sum(S[r_k:]^2)) <= delta # Use reverse cumsum for numerical stability S_sq = S**2 tail_norms_sq = np.cumsum(S_sq[::-1])[::-1] # tail_norms_sq[i] = sum(S[i:]^2) # Find rank: smallest r such that tail_norms_sq[r] <= delta^2 # tail_norms_sq has length len(S), and we want the smallest r >= 1 # such that tail_norms_sq[r] <= delta^2 (where tail_norms_sq[len(S)] = 0) delta_sq = delta**2 r_k = len(S) # default: keep all for i in range(1, len(S)): if tail_norms_sq[i] <= delta_sq: r_k = i break # Apply max_rank cap if max_rank is not None: r_k = min(r_k, max_rank) # Ensure at least rank 1 r_k = max(r_k, 1) # Truncate U_trunc = U[:, :r_k] S_trunc = S[:r_k] Vt_trunc = Vt[:r_k, :] # Store core: reshape U into (r_prev, n_k, r_k) cores.append(U_trunc.reshape(r_prev, n_k, r_k)) # Prepare next iteration: C = diag(S) @ Vt C = np.diag(S_trunc) @ Vt_trunc r_prev = r_k # Last core: reshape remaining matrix into (r_prev, n_d, 1) n_d = shape[-1] cores.append(C.reshape(r_prev, n_d, 1)) return cores
[docs] def tt_round( cores: list[np.ndarray], eps: float = 1e-6, max_rank: int | None = None, ) -> list[np.ndarray]: """Reduce TT-ranks via orthogonalization + truncated SVD sweep. Two-pass algorithm: 1. Right-to-left QR sweep (right-orthogonalize) 2. Left-to-right SVD sweep with truncation This is used after TT arithmetic (e.g., tt_add) which inflates ranks. Args: cores: List of TT-cores. eps: Relative truncation tolerance. max_rank: Maximum allowed rank. Returns: New list of TT-cores with reduced ranks. """ d = len(cores) if d < 2: return [c.copy() for c in cores] # Work with copies cores = [c.copy() for c in cores] # Compute norm for truncation threshold # Reconstruct isn't feasible for large tensors, so estimate from cores # We use the Frobenius norm through the TT structure # For simplicity, do full reconstruction if small, otherwise use core norms norm_est = _tt_norm(cores) if norm_est < 1e-15: return cores delta = eps * norm_est / np.sqrt(d - 1) # ---- Pass 1: Right-to-left QR sweep ---- for k in range(d - 1, 0, -1): r_left, n_k, r_right = cores[k].shape # Reshape core to (r_left, n_k * r_right) then transpose → (n_k * r_right, r_left) M = cores[k].reshape(r_left, n_k * r_right).T Q, R = np.linalg.qr(M) # Q: (n_k * r_right, new_r), R: (new_r, r_left) new_r = Q.shape[1] cores[k] = Q.T.reshape(new_r, n_k, r_right) # Absorb R into the previous core: contract on right bond dimension r_left_prev, n_prev, _ = cores[k - 1].shape # cores[k-1]: (r_left_prev, n_prev, r_left), R: (new_r, r_left) # new_core[k-1][i,j,l] = sum_m cores[k-1][i,j,m] * R[l,m] new_prev = cores[k - 1].reshape(r_left_prev * n_prev, r_left) @ R.T cores[k - 1] = new_prev.reshape(r_left_prev, n_prev, new_r) # ---- Pass 2: Left-to-right SVD sweep with truncation ---- for k in range(d - 1): r_left, n_k, r_right = cores[k].shape M = cores[k].reshape(r_left * n_k, r_right) U, S, Vt = np.linalg.svd(M, full_matrices=False) # Rank truncation S_sq = S**2 tail_norms_sq = np.cumsum(S_sq[::-1])[::-1] delta_sq = delta**2 r_new = len(S) for i in range(1, len(S)): if tail_norms_sq[i] <= delta_sq: r_new = i break if max_rank is not None: r_new = min(r_new, max_rank) r_new = max(r_new, 1) U_trunc = U[:, :r_new] S_trunc = S[:r_new] Vt_trunc = Vt[:r_new, :] cores[k] = U_trunc.reshape(r_left, n_k, r_new) # Absorb S*Vt into next core SV = np.diag(S_trunc) @ Vt_trunc # (r_new, r_right) _r_left_next, _n_next, _r_right_next = cores[k + 1].shape # cores[k+1] was (r_right, n_next, r_right_next), multiply from left cores[k + 1] = np.einsum("ij,jkl->ikl", SV, cores[k + 1]) return cores
def _tt_norm(cores: list[np.ndarray]) -> float: """Compute the Frobenius norm of a tensor in TT format. Uses the transfer matrix approach: ||A||_F^2 = <A, A>_TT. Complexity: O(d * n * r^4) where r is the max rank. """ d = len(cores) # Initialize: contract first core with itself # cores[0] shape: (1, n_0, r_0) G = cores[0] # <G, G> along mode n_0: sum over n_0 of G[:, i, :] ⊗ G[:, i, :] # Result shape: (r_0, r_0) — but since r_left=1 for first core, it's (r_0, r_0) r_0 = G.shape[2] Z = np.zeros((r_0, r_0)) for i in range(G.shape[1]): Z += G[0, i, :].reshape(-1, 1) @ G[0, i, :].reshape(1, -1) for k in range(1, d): G = cores[k] _r_left, n_k, r_right = G.shape Z_new = np.zeros((r_right, r_right)) for i in range(n_k): # G[:, i, :] is (r_left, r_right) slice_k = G[:, i, :] # (r_left, r_right) # Z is (r_left, r_left) from previous step # Contribution: slice_k^T @ Z @ slice_k → (r_right, r_right) Z_new += slice_k.T @ Z @ slice_k Z = Z_new # Z is now (1, 1) — the squared norm return float(np.sqrt(float(Z.item()))) # ====================================================================== # TT-Cross (black-box approximation — no full tensor needed) # ====================================================================== def _maxvol_greedy(A: np.ndarray, r: int, rng: np.random.Generator) -> np.ndarray: """Approximate maximum-volume row subset of A (n × k, n ≥ k). Returns r row indices forming an approximate maximum-volume (r × k) submatrix of A. Uses greedy pivoting based on QR. Algorithm --------- 1. Find first r pivots via QR with column pivoting on A^T. 2. Iteratively swap rows to increase the determinant of the selected submatrix until convergence (maxvol criterion). """ n, k_cols = A.shape r = min(r, n, k_cols) if r == 0: return np.array([], dtype=int) # Initial pivot rows from QR _, _, piv = qr(A.T, pivoting=True, mode="economic") idx = piv[:r].copy() # Iterative improvement: swap rows to increase abs(det) # B = A @ inv(A[idx, :]) — each row B[i] represents how much # row i is "outside" the current selection sub = A[idx, :] # (r, k_cols) try: B = np.linalg.lstsq(sub.T, A.T, rcond=None)[0].T # (n, r) except np.linalg.LinAlgError: return idx max_iter = min(100, n) tol = 1.0 + 1e-4 for _ in range(max_iter): i_best, j_best = np.unravel_index(np.argmax(np.abs(B)), B.shape) if abs(B[i_best, j_best]) <= tol: break # Swap row i_best into position j_best idx[j_best] = i_best sub = A[idx, :] try: B = np.linalg.lstsq(sub.T, A.T, rcond=None)[0].T except np.linalg.LinAlgError: break return idx def _eval_fiber( fn: Callable[..., float], left_idx: np.ndarray, # shape (r_l, k) — left multi-indices k: int, # current mode position (0-based) n_k: int, # size of mode k right_idx: np.ndarray, # shape (r_r, d-k-1) — right multi-indices d: int, ) -> np.ndarray: """Evaluate fn on all (left × {0..n_k-1} × right) index combinations. Returns ------- np.ndarray of shape ``(r_l * n_k, r_r)`` C[il * n_k + ik, ir] = fn(*left_idx[il], ik, *right_idx[ir]) """ r_l = left_idx.shape[0] r_r = right_idx.shape[0] if right_idx.ndim > 0 and right_idx.size > 0 else 1 C = np.zeros((r_l * n_k, r_r)) for il in range(r_l): left_part = left_idx[il].tolist() if k > 0 else [] for ik in range(n_k): row = il * n_k + ik if k == d - 1: # Last mode: no right indices C[row, 0] = fn(*left_part, ik) else: for ir in range(r_r): right_part = right_idx[ir].tolist() if (d - k - 1) > 0 else [] C[row, ir] = fn(*left_part, ik, *right_part) return C def _eval_interface( fn: Callable[..., float], left_idx: np.ndarray, # shape (r_l, k+1) — left pivots at next boundary right_idx: np.ndarray, # shape (r_r, d-k-1) — right pivots at current boundary d: int, ) -> np.ndarray: """Evaluate fn on all (left × right) combinations. Returns ------- np.ndarray of shape ``(r_l, r_r)`` Z[il, ir] = fn(*left_idx[il], *right_idx[ir]) """ r_l = left_idx.shape[0] r_r = right_idx.shape[0] if right_idx.ndim > 0 and right_idx.size > 0 else 1 n_right_dims = right_idx.shape[1] if right_idx.ndim > 1 else 0 Z = np.zeros((r_l, r_r)) for il in range(r_l): for ir in range(r_r): idx = list(left_idx[il]) + (list(right_idx[ir]) if n_right_dims > 0 else []) Z[il, ir] = fn(*idx) return Z
[docs] def tt_cross( fn: Callable[..., float], shape: tuple[int, ...], eps: float = 1e-4, max_rank: int = 20, n_sweeps: int = 8, seed: int = 42, ) -> list[np.ndarray]: """TT-Cross black-box approximation (Oseledets & Tyrtyshnikov, 2010). Constructs a Tensor-Train decomposition of a *d*-dimensional function **without forming the full tensor**. Only queries ``fn`` at a carefully selected set of index combinations — O(d · r² · n) evaluations instead of O(n^d) for TT-SVD. This makes 6+ asset problems feasible: * 6 assets, 15 pts/axis, rank 10 → ~54,000 evaluations * vs. 15^6 = 11,390,625 for full-grid TT-SVD Algorithm --------- 1. **Initialise** right index sets J_k randomly. 2. **Left-to-right sweep**: for each core k, evaluate the cross C_k = f(I_k × {0..n_k-1} × J_k) and select new left pivots I_{k+1} via greedy maxvol on the QR factor of C_k. 3. **Build TT-cores** using the cross-interpolation formula: Core_k = C_k @ pinv(Z_k) where Z_k = f(I_{k+1} ++ J_k) is the (r_k × r_k) interface matrix. 4. **Alternating sweeps** refine accuracy. Parameters ---------- fn : callable Function accepting ``d`` integer arguments (grid indices) and returning a float:: fn(i_0, i_1, ..., i_{d-1}) -> float Use ``functools.partial`` or a lambda to curry other parameters. shape : tuple of int Mode sizes ``(n_0, n_1, ..., n_{d-1})``. These are *index* sizes. To convert continuous axes to indices, wrap ``fn`` accordingly. eps : float Target relative accuracy. Controls rank selection via the tolerance passed to the SVD truncation after each cross. max_rank : int Hard upper bound on TT-ranks. n_sweeps : int Number of left-to-right + right-to-left alternating sweeps. ``n_sweeps=1`` gives a single L→R pass (fast, lower accuracy). ``n_sweeps=4`` is sufficient for smooth pricing surfaces. seed : int Random seed for initialising right index sets. Returns ------- list of np.ndarray TT-cores[k].shape = ``(r_{k-1}, n_k, r_k)`` with ``r_0 = r_d = 1``. Examples -------- Compress a 6-asset basket payoff without forming the 15^6 grid:: import numpy as np from tensorquantlib.tt.decompose import tt_cross # Suppose price_lookup(i0, i1, i2, i3, i4, i5) evaluates the # basket option price at the i-th point on each asset's price axis. axes = [np.linspace(80, 120, 15)] * 6 def price_lookup(*indices): spots = [axes[k][i] for k, i in enumerate(indices)] return basket_mc(spots, ...) # your existing pricer cores = tt_cross(price_lookup, shape=(15,)*6, max_rank=15, n_sweeps=6) Notes ----- After calling ``tt_cross``, wrap the result in a ``TTSurrogate``:: from tensorquantlib.tt.surrogate import TTSurrogate surr = TTSurrogate(cores=cores, axes=axes, eps=eps) """ d = len(shape) if d < 2: raise ValueError(f"TT-Cross requires at least 2 dimensions, got {d}") if eps < 0: raise ValueError(f"eps must be non-negative, got {eps}") if max_rank < 1: raise ValueError(f"max_rank must be >= 1, got {max_rank}") rng = np.random.default_rng(seed) # ------------------------------------------------------------------ # Step 1: Initialise right index sets J[k], shape (r_init, d-k-1) # J[k] stores right multi-indices used when building core k. # ------------------------------------------------------------------ r_init = min(2, max_rank) # J[k] — right pivots at interface k → k+1 # Each row of J[k] is a (d-k-1)-dimensional multi-index. J: list[np.ndarray] = [] for k in range(d - 1): n_right = d - k - 1 if n_right > 0: rows = np.stack( [rng.integers(0, shape[k + 1 + j], size=r_init) for j in range(n_right)], axis=1, ) else: rows = np.zeros((r_init, 0), dtype=int) J.append(rows) # ------------------------------------------------------------------ # Step 2: Left-to-right sweep to build left pivot sets left_pivots[k] # left_pivots[k] — left pivots at interface k-1 → k, shape (r_k, k) # left_pivots[0] is a single "empty" index — the left boundary has rank 1. # ------------------------------------------------------------------ left_pivots: list[np.ndarray] = [np.zeros((1, 0), dtype=int)] for sweep in range(n_sweeps): # ---- Left-to-right ---- for k in range(d - 1): r_l = left_pivots[k].shape[0] r_r = J[k].shape[0] n_k = shape[k] # Evaluate cross C: shape (r_l * n_k, r_r) C = _eval_fiber(fn, left_pivots[k], k, n_k, J[k], d) # QR + maxvol to select r_new pivot rows r_candidate = min(max_rank, r_l * n_k, max(r_r, 1)) Q_mat, _ = qr(C, mode="economic") Q_r = Q_mat[:, :r_candidate] pivot_rows = _maxvol_greedy(Q_r, r_candidate, rng) # Decode rows back to (il, ik) pairs r_new = len(pivot_rows) new_I = np.zeros((r_new, k + 1), dtype=int) for j, row in enumerate(pivot_rows): il_dec = int(row) // n_k ik_dec = int(row) % n_k if k > 0 and il_dec < left_pivots[k].shape[0]: new_I[j, :k] = left_pivots[k][il_dec, :] new_I[j, k] = ik_dec if sweep == 0: left_pivots.append(new_I) else: left_pivots[k + 1] = new_I # ---- Right-to-left (refine J) ---- for k in range(d - 2, -1, -1): r_l = left_pivots[k].shape[0] r_r = J[k].shape[0] n_k = shape[k] C = _eval_fiber(fn, left_pivots[k], k, n_k, J[k], d) # Select new right pivots from column pivoting of C^T r_candidate = min(max_rank, r_l * n_k, max(r_r, 1)) _, _, piv_col = qr(C.T, pivoting=True, mode="economic") pivot_rows = piv_col[:r_candidate] r_new = len(pivot_rows) n_right = d - k - 1 new_J = np.zeros((r_new, n_right), dtype=int) for j, row in enumerate(pivot_rows): il_dec = int(row) // n_k ik_dec = int(row) % n_k # Right multi-index = (ik_dec, J[k][il_dec]) if n_right == 1: new_J[j, 0] = ik_dec elif n_right > 1 and il_dec < left_pivots[k + 1].shape[0]: # current right pivot is the k+1 index combined with J[k+1] new_J[j, 0] = ik_dec if k + 1 < len(J) and il_dec < J[k + 1].shape[0]: new_J[j, 1:] = J[k + 1][il_dec % J[k + 1].shape[0], :] J[k] = new_J # ------------------------------------------------------------------ # Step 3: Build final TT-cores using the cross-interpolation formula. # Core_k = C_k @ pinv(Z_k) reshaped to (r_{k-1}, n_k, r_k) # ------------------------------------------------------------------ cores: list[np.ndarray] = [] for k in range(d): n_k = shape[k] r_l = left_pivots[k].shape[0] if k < d - 1: r_r = J[k].shape[0] # Fiber: (r_l * n_k, r_r) C = _eval_fiber(fn, left_pivots[k], k, n_k, J[k], d) # Interface matrix Z: (|left_pivots[k+1]|, r_r) Z = _eval_interface(fn, left_pivots[k + 1], J[k], d) # Core = C @ pinv(Z): shape (r_l * n_k, r_next) # pinv handles rank-deficient Z gracefully Z_pinv = np.linalg.pinv(Z) # (r_r, r_next) core_mat = C @ Z_pinv # (r_l * n_k, r_next) # Truncate numerical noise via SVD U, s, Vt = np.linalg.svd(core_mat, full_matrices=False) # Keep singular values above eps * max thresh = eps * s[0] if s[0] > 0 else eps r_trunc = max(1, int(np.sum(s > thresh))) r_trunc = min(r_trunc, max_rank) core_mat = (U[:, :r_trunc] * s[:r_trunc]) @ Vt[:r_trunc, :] # Adjust r_next to r_trunc r_out = core_mat.shape[1] cores.append(core_mat.reshape(r_l, n_k, r_out)) else: # Last core: fiber only, no interface right_dummy = np.zeros((1, 0), dtype=int) C = _eval_fiber(fn, left_pivots[k], k, n_k, right_dummy, d) # (r_l * n_k, 1) cores.append(C.reshape(r_l, n_k, 1)) return cores