Source code for tensorquantlib.viz.plots

"""Plotting functions for TensorQuantLib.

All functions return ``(fig, ax)`` tuples so callers can customise further.
Matplotlib is imported lazily — the rest of the library works without it.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:  # pragma: no cover
    import matplotlib.axes
    import matplotlib.figure


def _import_mpl() -> tuple[Any, Any]:
    """Lazy-import matplotlib and return (plt, mpl) or raise ImportError."""
    try:
        import matplotlib
        import matplotlib.pyplot as plt

        return plt, matplotlib
    except ImportError as exc:
        raise ImportError(
            "matplotlib is required for plotting.  "
            "Install it with:  pip install 'tensorquantlib[dev]'"
        ) from exc


# ====================================================================== #
# Pricing Surface
# ====================================================================== #


[docs] def plot_pricing_surface( grid: np.ndarray, axis_values: Sequence[np.ndarray], dims: tuple[int, int] = (0, 1), fixed_indices: dict[int, int] | None = None, title: str = "Pricing Surface", xlabel: str | None = None, ylabel: str | None = None, cmap: str = "viridis", figsize: tuple[float, float] = (8, 6), mode: str = "heatmap", ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Plot a 2D slice of a pricing grid as heatmap or 3D surface. Args: grid: N-dimensional pricing grid (NumPy array). axis_values: List of 1D arrays giving tick values along each axis. dims: Which two axes to plot. The remaining axes are sliced at ``fixed_indices`` (default: midpoints). fixed_indices: ``{axis: index}`` overrides for slicing. title: Plot title. xlabel: Label for x-axis (default: "Axis {dims[0]}"). ylabel: Label for y-axis (default: "Axis {dims[1]}"). cmap: Matplotlib colour-map name. figsize: Figure size in inches ``(width, height)``. mode: ``"heatmap"`` (default) or ``"surface"`` (3D). Returns: ``(fig, ax)`` tuple. """ plt, _ = _import_mpl() fixed_indices = fixed_indices or {} # Build slicer for all dimensions slicer: list[Any] = [] for i in range(grid.ndim): if i in dims: slicer.append(slice(None)) else: idx = fixed_indices.get(i, grid.shape[i] // 2) slicer.append(idx) Z = grid[tuple(slicer)] # Ensure dims[0] is rows, dims[1] is cols — transpose if needed if dims[0] > dims[1]: Z = Z.T X = axis_values[dims[0]] Y = axis_values[dims[1]] xlabel = xlabel or f"Axis {dims[0]}" ylabel = ylabel or f"Axis {dims[1]}" if mode == "surface": from mpl_toolkits.mplot3d import Axes3D # noqa: F401 fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="3d") Xm, Ym = np.meshgrid(X, Y, indexing="ij") ax.plot_surface(Xm, Ym, Z, cmap=cmap, edgecolor="none", alpha=0.9) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_zlabel("Price") ax.set_title(title) else: fig, ax = plt.subplots(figsize=figsize) im = ax.imshow( Z.T, origin="lower", aspect="auto", extent=[X[0], X[-1], Y[0], Y[-1]], cmap=cmap, ) fig.colorbar(im, ax=ax, label="Price") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) return fig, ax
# ====================================================================== # # Greeks Surface # ====================================================================== #
[docs] def plot_greeks_surface( greek_grids: dict[str, np.ndarray], axis_values: Sequence[np.ndarray], dims: tuple[int, int] = (0, 1), fixed_indices: dict[int, int] | None = None, cmap: str = "RdBu_r", figsize: tuple[float, float] = (14, 4), ) -> tuple[matplotlib.figure.Figure, list[matplotlib.axes.Axes]]: """Plot multiple Greeks as side-by-side heatmaps. Args: greek_grids: ``{"Delta": array, "Gamma": array, ...}``. axis_values: Tick values per axis (same as ``plot_pricing_surface``). dims: Pair of axes to plot. fixed_indices: Override slice indices for remaining axes. cmap: Colour map. figsize: Figure size for the whole row. Returns: ``(fig, axes)`` tuple. """ plt, _ = _import_mpl() fixed_indices = fixed_indices or {} n = len(greek_grids) fig, axes = plt.subplots(1, n, figsize=figsize) if n == 1: axes = [axes] X = axis_values[dims[0]] Y = axis_values[dims[1]] for ax, (name, grid) in zip(axes, greek_grids.items()): slicer: list[Any] = [] for i in range(grid.ndim): if i in dims: slicer.append(slice(None)) else: idx = fixed_indices.get(i, grid.shape[i] // 2) slicer.append(idx) Z = grid[tuple(slicer)] if dims[0] > dims[1]: Z = Z.T im = ax.imshow( Z.T, origin="lower", aspect="auto", extent=[X[0], X[-1], Y[0], Y[-1]], cmap=cmap, ) fig.colorbar(im, ax=ax) ax.set_title(name) ax.set_xlabel(f"Axis {dims[0]}") ax.set_ylabel(f"Axis {dims[1]}") fig.tight_layout() return fig, list(axes)
# ====================================================================== # # TT Rank Profile # ====================================================================== #
[docs] def plot_tt_ranks( cores: list[np.ndarray], title: str = "TT-Rank Profile", figsize: tuple[float, float] = (6, 4), color: str = "#2563eb", ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Bar chart of TT-ranks across bonds. Args: cores: List of TT-cores. title: Plot title. figsize: Figure size. color: Bar colour. Returns: ``(fig, ax)`` tuple. """ plt, _ = _import_mpl() from tensorquantlib.tt.ops import tt_ranks ranks = tt_ranks(cores) bonds = list(range(len(ranks))) fig, ax = plt.subplots(figsize=figsize) ax.bar(bonds, ranks, color=color, edgecolor="white", linewidth=0.5) ax.set_xlabel("Bond index") ax.set_ylabel("Rank") ax.set_title(title) ax.set_xticks(bonds) return fig, ax
[docs] def plot_rank_profile( rank_lists: dict[str, list[int]], title: str = "Rank Profiles", figsize: tuple[float, float] = (7, 4), ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Overlay multiple rank profiles for comparison. Args: rank_lists: ``{"eps=1e-4": [r0, r1, ...], ...}``. title: Plot title. figsize: Figure size. Returns: ``(fig, ax)`` tuple. """ plt, _ = _import_mpl() fig, ax = plt.subplots(figsize=figsize) for label, ranks in rank_lists.items(): ax.plot(range(len(ranks)), ranks, "o-", label=label, markersize=5) ax.set_xlabel("Bond index") ax.set_ylabel("Rank") ax.set_title(title) ax.legend() return fig, ax
# ====================================================================== # # Compression vs Tolerance # ====================================================================== #
[docs] def plot_compression_vs_tolerance( epsilons: Sequence[float], compression_ratios: Sequence[float], errors: Sequence[float] | None = None, title: str = "Compression vs Tolerance", figsize: tuple[float, float] = (7, 4), ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Plot compression ratio and (optionally) error vs SVD tolerance. Args: epsilons: SVD tolerance values. compression_ratios: Matching compression ratios. errors: Optional matching relative errors. title: Plot title. figsize: Figure size. Returns: ``(fig, ax)`` — second y-axis is ``ax.twinx()`` if errors given. """ plt, _ = _import_mpl() fig, ax1 = plt.subplots(figsize=figsize) color1, color2 = "#2563eb", "#dc2626" ax1.semilogx(epsilons, compression_ratios, "o-", color=color1, label="Compression ratio") ax1.set_xlabel("SVD tolerance (ε)") ax1.set_ylabel("Compression ratio", color=color1) ax1.tick_params(axis="y", labelcolor=color1) ax1.set_title(title) ax_out = ax1 if errors is not None: ax2 = ax1.twinx() ax2.loglog(epsilons, errors, "s--", color=color2, label="Relative error") ax2.set_ylabel("Relative error", color=color2) ax2.tick_params(axis="y", labelcolor=color2) # Combined legend lines1, labels1 = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax1.legend(lines1 + lines2, labels1 + labels2, loc="center right") ax_out = ax1 fig.tight_layout() return fig, ax_out
# ====================================================================== # # Convergence # ====================================================================== #
[docs] def plot_convergence( iterations: Sequence[int], values: Sequence[float], ylabel: str = "Error", title: str = "Convergence", log_y: bool = True, figsize: tuple[float, float] = (7, 4), ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Line plot of a convergence metric over iterations. Args: iterations: Iteration indices. values: Corresponding metric values. ylabel: Y-axis label. title: Plot title. log_y: Use log scale for y-axis. figsize: Figure size. Returns: ``(fig, ax)`` tuple. """ plt, _ = _import_mpl() fig, ax = plt.subplots(figsize=figsize) ax.plot(iterations, values, "o-", color="#2563eb", markersize=4) ax.set_xlabel("Iteration") ax.set_ylabel(ylabel) ax.set_title(title) if log_y: ax.set_yscale("log") ax.grid(True, alpha=0.3) fig.tight_layout() return fig, ax