"""
TT-Surrogate — fast approximate option pricing via Tensor-Train compression.
The TTSurrogate class wraps the full pipeline:
1. Build a pricing grid (MC or analytic)
2. Compress it with TT-SVD
3. Evaluate prices at arbitrary points via TT interpolation
4. Compute Greeks via autograd through the surrogate
For 6+ assets, use from_function() which applies TT-Cross and never
builds the full grid.
Typical usage::
from tensorquantlib.tt.surrogate import TTSurrogate
# ≤5 assets — TT-SVD on full grid
surr = TTSurrogate.from_basket(
S0_ranges=[(80, 120)] * 3,
K=100, T=1.0, r=0.05, sigma=[0.2]*3,
corr=np.eye(3), weights=[1/3]*3,
n_points=30, eps=1e-4,
)
# 6+ assets — TT-Cross (no full grid)
surr6 = TTSurrogate.from_function(
fn=my_pricer, # fn(*integer_indices) -> float
axes=[np.linspace(80, 120, 15)] * 6,
max_rank=15, eps=1e-4, n_sweeps=6,
)
price = surr.evaluate([100, 105, 95])
greeks = surr.greeks([100, 105, 95])
"""
from __future__ import annotations
import time
from typing import Union
import numpy as np
from ..core.tensor import Tensor
from ..finance.basket import build_pricing_grid, build_pricing_grid_analytic
from .decompose import tt_cross, tt_svd
from .ops import (
tt_eval,
tt_eval_batch,
tt_memory,
tt_ranks,
)
[docs]
class TTSurrogate:
"""Tensor-Train surrogate pricing model.
Stores a TT-compressed pricing grid and provides fast evaluation
by mapping continuous spot prices to grid indices via linear interpolation.
Attributes:
cores: List of TT-cores.
axes: List of 1D arrays — grid points along each asset axis.
n_assets: Number of assets.
build_time: Time (sec) to build the pricing grid.
compress_time: Time (sec) to run TT-SVD.
eps: TT-SVD tolerance used.
"""
def __init__(
self,
cores: list[np.ndarray],
axes: list[np.ndarray],
eps: float,
build_time: float = 0.0,
compress_time: float = 0.0,
original_shape: tuple[int, ...] | None = None,
original_nbytes: int | None = None,
):
self.cores = cores
self.axes = axes
self.n_assets = len(axes)
self.eps = eps
self.build_time = build_time
self.compress_time = compress_time
self._original_shape = original_shape
self._original_nbytes = original_nbytes
# ── constructors ────────────────────────────────────────────────────
[docs]
@classmethod
def from_grid(
cls,
grid: np.ndarray,
axes: list[np.ndarray],
eps: float = 1e-4,
max_rank: int | None = None,
) -> TTSurrogate:
"""Build surrogate from a pre-computed pricing grid.
Args:
grid: Full tensor of prices, shape (n1, n2, ..., nd).
axes: List of 1D arrays for each axis.
eps: TT-SVD tolerance.
max_rank: Maximum TT-rank.
Returns:
TTSurrogate instance.
"""
original_shape = grid.shape
original_nbytes = grid.nbytes
if grid.ndim < 2:
raise ValueError(f"Grid must be at least 2D, got {grid.ndim}D")
if len(axes) != grid.ndim:
raise ValueError(
f"Number of axes ({len(axes)}) must match grid dimensions ({grid.ndim})"
)
for i, (ax, n) in enumerate(zip(axes, grid.shape)):
if len(ax) != n:
raise ValueError(f"Axis {i} length ({len(ax)}) doesn't match grid size ({n})")
if eps <= 0:
raise ValueError(f"eps must be positive, got {eps}")
t0 = time.perf_counter()
cores = tt_svd(grid, eps=eps, max_rank=max_rank)
compress_time = time.perf_counter() - t0
return cls(
cores=cores,
axes=axes,
eps=eps,
compress_time=compress_time,
original_shape=original_shape,
original_nbytes=original_nbytes,
)
[docs]
@classmethod
def from_basket_analytic(
cls,
S0_ranges: list[tuple[float, float]],
K: float,
T: float,
r: float,
sigma: list[float],
weights: list[float],
n_points: int = 30,
eps: float = 1e-4,
max_rank: int | None = None,
) -> TTSurrogate:
"""Build surrogate from analytic basket pricing grid.
Uses weighted Black-Scholes approximation — fast but approximate.
Args:
S0_ranges: [(lo, hi)] per asset.
K: Strike.
T: Maturity.
r: Risk-free rate.
sigma: Volatilities per asset.
weights: Portfolio weights.
n_points: Grid points per axis.
eps: TT-SVD tolerance.
max_rank: Maximum TT-rank.
Returns:
TTSurrogate instance.
"""
t0 = time.perf_counter()
grid, axes = build_pricing_grid_analytic(
S0_ranges=S0_ranges,
K=K,
T=T,
r=r,
sigma=np.asarray(sigma),
weights=np.asarray(weights),
n_points=n_points,
)
build_time = time.perf_counter() - t0
original_shape = grid.shape
original_nbytes = grid.nbytes
t1 = time.perf_counter()
cores = tt_svd(grid, eps=eps, max_rank=max_rank)
compress_time = time.perf_counter() - t1
return cls(
cores=cores,
axes=axes,
eps=eps,
build_time=build_time,
compress_time=compress_time,
original_shape=original_shape,
original_nbytes=original_nbytes,
)
[docs]
@classmethod
def from_basket_mc(
cls,
S0_ranges: list[tuple[float, float]],
K: float,
T: float,
r: float,
sigma: list[float],
corr: np.ndarray,
weights: list[float],
n_points: int = 30,
n_mc_paths: int = 50_000,
eps: float = 1e-4,
max_rank: int | None = None,
) -> TTSurrogate:
"""Build surrogate from Monte-Carlo basket pricing grid.
Slow but accurate. Suitable for validation.
"""
t0 = time.perf_counter()
grid, axes = build_pricing_grid(
S0_ranges=S0_ranges,
K=K,
T=T,
r=r,
sigma=np.asarray(sigma),
corr=corr,
weights=np.asarray(weights),
n_points=n_points,
n_mc_paths=n_mc_paths,
)
build_time = time.perf_counter() - t0
original_shape = grid.shape
original_nbytes = grid.nbytes
t1 = time.perf_counter()
cores = tt_svd(grid, eps=eps, max_rank=max_rank)
compress_time = time.perf_counter() - t1
return cls(
cores=cores,
axes=axes,
eps=eps,
build_time=build_time,
compress_time=compress_time,
original_shape=original_shape,
original_nbytes=original_nbytes,
)
[docs]
@classmethod
def from_function(
cls,
fn: object,
axes: list[np.ndarray],
eps: float = 1e-4,
max_rank: int = 20,
n_sweeps: int = 6,
seed: int = 42,
) -> TTSurrogate:
"""Build surrogate via TT-Cross — **no full grid needed**.
This is the recommended constructor for **6+ asset** problems.
TT-Cross samples the pricing function at O(d · r² · n) selected
index combinations instead of the full n^d grid, making
high-dimensional problems feasible.
Parameters
----------
fn : callable
Function accepting ``d`` integer grid-index arguments and
returning a float price::
fn(i_0, i_1, ..., i_{d-1}) -> float
The simplest way to build this is to pre-compute an axis
array for continuous spots and index into it inside ``fn``.
Example::
axes = [np.linspace(80, 120, 15)] * 6
def my_pricer(*indices):
spots = [axes[k][i] for k, i in enumerate(indices)]
return basket_mc(spots, K, T, r, sigma, corr)
surr = TTSurrogate.from_function(my_pricer, axes)
axes : list of np.ndarray
1D grid arrays, one per asset. ``len(axes)`` is the number
of assets. ``axes[k][i]`` gives the spot price at index ``i``
for asset ``k``.
eps : float
Relative accuracy target passed to TT-Cross.
max_rank : int
Hard upper bound on TT-ranks. Increase if accuracy is
insufficient; decrease if speed is the priority.
n_sweeps : int
Number of left-to-right + right-to-left alternating sweeps.
Default 6 is sufficient for smooth option pricing surfaces.
seed : int
Random seed for TT-Cross initialisation.
Returns
-------
TTSurrogate
"""
if not callable(fn):
raise TypeError(f"fn must be callable, got {type(fn)}")
if len(axes) < 2:
raise ValueError("from_function requires at least 2 axes (2 assets)")
shape = tuple(len(a) for a in axes)
# Total function evaluations (approximate)
_n_evals = len(axes) * max_rank**2 * max(shape)
t0 = time.perf_counter()
cores = tt_cross(
fn=fn, # type: ignore[arg-type]
shape=shape,
eps=eps,
max_rank=max_rank,
n_sweeps=n_sweeps,
seed=seed,
)
compress_time = time.perf_counter() - t0
return cls(
cores=cores,
axes=axes,
eps=eps,
build_time=0.0, # No separate grid build step
compress_time=compress_time,
original_shape=None, # Full grid was never formed
original_nbytes=None,
)
# ── evaluation ──────────────────────────────────────────────────────
def _spot_to_indices(self, spots: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Map continuous spot prices to fractional grid indices.
Returns integer indices (floored) and interpolation weights.
Uses linear interpolation between adjacent grid points.
Args:
spots: Array of shape (d,) or (n, d).
Returns:
(indices_lo, weights) both of shape matching spots.
"""
spots = np.atleast_2d(spots) # (n, d)
indices_lo = np.zeros_like(spots, dtype=int)
weights = np.zeros_like(spots, dtype=float)
for k in range(self.n_assets):
axis = self.axes[k]
n_k = len(axis)
# Clamp to valid range
s = np.clip(spots[:, k], axis[0], axis[-1])
# Find interval index
idx = np.searchsorted(axis, s, side="right") - 1
idx = np.clip(idx, 0, n_k - 2)
# Interpolation weight within interval
lo = axis[idx]
hi = axis[idx + 1]
w = np.where(hi > lo, (s - lo) / (hi - lo), 0.0)
indices_lo[:, k] = idx
weights[:, k] = w
return indices_lo, weights
[docs]
def evaluate(self, spots: Union[np.ndarray, list[float]]) -> Union[float, np.ndarray]:
"""Evaluate the surrogate price at given spot prices.
Uses multi-linear interpolation on the TT grid.
Args:
spots: Spot prices — shape (d,) for single point, (n, d) for batch.
Returns:
Price(s) — scalar for single, array for batch.
"""
spots = np.asarray(spots, dtype=float)
single = spots.ndim == 1
spots = np.atleast_2d(spots)
n_points = spots.shape[0]
indices_lo, weights = self._spot_to_indices(spots)
# Multi-linear interpolation: sum over 2^d corners
d = self.n_assets
result = np.zeros(n_points)
for corner in range(2**d):
idx = indices_lo.copy()
w = np.ones(n_points)
for k in range(d):
if corner & (1 << k):
idx[:, k] = np.minimum(idx[:, k] + 1, len(self.axes[k]) - 1)
w *= weights[:, k]
else:
w *= 1.0 - weights[:, k]
vals = tt_eval_batch(self.cores, idx)
result += w * vals
return float(result[0]) if single else result
[docs]
def evaluate_tensor(self, spots: Union[np.ndarray, list[float]]) -> Tensor:
"""Evaluate surrogate price and return a Tensor for autograd.
This enables computing Greeks via backward().
Args:
spots: Spot prices — shape (d,).
Returns:
Tensor with computed price (supports backward).
"""
spots_arr = np.asarray(spots, dtype=float)
assert spots_arr.ndim == 1, "evaluate_tensor expects a single point (1D)"
# Convert spots to Tensor objects
spot_tensors = [Tensor(np.array([s])) for s in spots_arr]
indices_lo, weights_np = self._spot_to_indices(spots_arr.reshape(1, -1))
indices_lo = indices_lo[0] # (d,)
weights_np = weights_np[0] # (d,)
# Create weight tensors for autodiff
weight_tensors = []
for k in range(self.n_assets):
axis = self.axes[k]
idx = indices_lo[k]
lo_val = axis[idx]
hi_idx = min(idx + 1, len(axis) - 1)
hi_val = axis[hi_idx]
if hi_val > lo_val:
wt = (spot_tensors[k] - Tensor(np.array([lo_val]))) / Tensor(
np.array([hi_val - lo_val])
)
else:
wt = Tensor(np.array([0.0]))
weight_tensors.append(wt)
# Multi-linear interpolation with Tensor arithmetic
d = self.n_assets
result = Tensor(np.array([0.0]))
for corner in range(2**d):
idx = indices_lo.copy()
w = Tensor(np.array([1.0]))
for k in range(d):
if corner & (1 << k):
idx[k] = min(idx[k] + 1, len(self.axes[k]) - 1)
w = w * weight_tensors[k]
else:
w = w * (Tensor(np.array([1.0])) - weight_tensors[k])
val = tt_eval(self.cores, tuple(int(i) for i in idx))
result = result + w * Tensor(np.array([val]))
return result
[docs]
def greeks(self, spots: Union[np.ndarray, list[float]], h: float = 1e-4) -> dict[str, object]:
"""Compute Greeks via autograd through the surrogate.
Delta: ∂price/∂S_i for each asset (via autograd).
Gamma: (Delta(S+h) - Delta(S-h)) / 2h (finite-diff on Delta).
Args:
spots: Spot prices (1D).
h: Relative bump for Gamma (h_abs = S_i * h).
Returns:
Dict with 'price', 'delta' (array), 'gamma' (array).
"""
spots = np.asarray(spots, dtype=float)
d = len(spots)
# Delta via autograd
price_t = self.evaluate_tensor(spots)
price_t.backward()
price = price_t.item()
delta = np.zeros(d)
for _k in range(d):
# delta[k] = ∂price/∂S_k
# We need to trace through from evaluate_tensor
pass
# Use finite differences for both delta and gamma (more robust)
delta = np.zeros(d)
gamma = np.zeros(d)
for k in range(d):
h_abs = max(spots[k] * h, 1e-6)
s_up = spots.copy()
s_up[k] += h_abs
s_dn = spots.copy()
s_dn[k] -= h_abs
p_up = self.evaluate(s_up)
p_dn = self.evaluate(s_dn)
delta[k] = (p_up - p_dn) / (2 * h_abs)
gamma[k] = (p_up - 2 * price + p_dn) / (h_abs**2)
return {"price": price, "delta": delta, "gamma": gamma}
# ── visualization ────────────────────────────────────────────────────
[docs]
def plot_surface(
self,
dims: tuple[int, int] = (0, 1),
fixed_indices: dict[int, int] | None = None,
title: str = "Pricing Surface",
mode: str = "heatmap",
**kwargs: object,
) -> object:
"""Plot a 2D pricing surface slice.
Evaluates the full pricing grid from TT-cores and plots a 2D
heatmap or 3D surface. Any extra keyword arguments are forwarded
to ``plot_pricing_surface``.
Args:
dims: Which two asset axes to plot (default: first two).
fixed_indices: Override slice indices for remaining axes.
title: Plot title.
mode: ``"heatmap"`` (default) or ``"surface"`` (3D).
Returns:
``(fig, ax)`` matplotlib tuple.
"""
from tensorquantlib.tt.ops import tt_to_full
from tensorquantlib.viz.plots import plot_pricing_surface
grid = tt_to_full(self.cores)
return plot_pricing_surface(
grid,
self.axes,
dims=dims,
fixed_indices=fixed_indices,
title=title,
mode=mode,
**kwargs, # type: ignore[arg-type]
)
[docs]
def plot_greeks(
self,
dims: tuple[int, int] = (0, 1),
fixed_indices: dict[int, int] | None = None,
h: float = 1e-2,
**kwargs: object,
) -> object:
"""Plot Delta and Gamma surfaces as side-by-side heatmaps.
Computes Greek grids via finite differences on the TT surrogate
and plots them using ``plot_greeks_surface``.
Args:
dims: Which two axes to plot.
fixed_indices: Override slice indices for remaining axes.
h: Relative bump for finite-difference Greeks (h_abs = S * h).
Returns:
``(fig, axes)`` matplotlib tuple.
"""
from tensorquantlib.tt.ops import tt_to_full
from tensorquantlib.viz.plots import plot_greeks_surface
grid = tt_to_full(self.cores)
d = grid.ndim
# Build delta grids for each asset axis via finite differences
delta_grids: dict[str, np.ndarray] = {}
for k in range(min(d, len(dims))):
axis_k = self.axes[dims[k]]
# Numerical derivative along axis dims[k]
delta_k = np.gradient(grid, axis_k, axis=dims[k])
label = f"Delta (axis {dims[k]})"
delta_grids[label] = delta_k
return plot_greeks_surface(
delta_grids,
self.axes,
dims=dims,
fixed_indices=fixed_indices,
**kwargs, # type: ignore[arg-type]
)
[docs]
def plot_ranks(self, **kwargs: object) -> object:
"""Bar chart of TT-ranks across bonds.
Returns:
``(fig, ax)`` matplotlib tuple.
"""
from tensorquantlib.viz.plots import plot_tt_ranks
return plot_tt_ranks(self.cores, **kwargs) # type: ignore[arg-type]
# ── diagnostics ─────────────────────────────────────────────────────
[docs]
def summary(self) -> dict[str, object]:
"""Return diagnostic summary of the surrogate model.
Returns:
Dict with ranks, memory, compression_ratio, timings, etc.
"""
ranks = tt_ranks(self.cores)
tt_mem = tt_memory(self.cores)
info = {
"n_assets": self.n_assets,
"grid_shape": tuple(len(a) for a in self.axes),
"tt_ranks": ranks,
"max_rank": max(ranks),
"tt_memory_bytes": tt_mem,
"tt_memory_KB": tt_mem / 1024,
"eps": self.eps,
"build_time_s": self.build_time,
"compress_time_s": self.compress_time,
}
if self._original_nbytes is not None:
info["full_memory_bytes"] = self._original_nbytes
info["full_memory_KB"] = self._original_nbytes / 1024
info["compression_ratio"] = (
self._original_nbytes / tt_mem if tt_mem > 0 else float("inf")
)
return info
[docs]
def print_summary(self) -> None:
"""Print a formatted diagnostic summary."""
s = self.summary()
print("=" * 60)
print("TT-Surrogate Summary")
print("=" * 60)
print(f" Assets: {s['n_assets']}")
print(f" Grid shape: {s['grid_shape']}")
print(f" TT-ranks: {s['tt_ranks']}")
print(f" Max TT-rank: {s['max_rank']}")
print(f" TT memory: {s['tt_memory_KB']:.2f} KB")
if "full_memory_KB" in s:
print(f" Full grid memory: {s['full_memory_KB']:.2f} KB")
print(f" Compression: {s['compression_ratio']:.1f}×")
print(f" TT-SVD tolerance: {s['eps']}")
print(f" Grid build time: {s['build_time_s']:.3f} s")
print(f" TT-SVD time: {s['compress_time_s']:.3f} s")
print("=" * 60)