#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch
from botorch.exceptions.errors import BotorchError
from botorch.posteriors.base_samples import _reshape_base_samples_non_interleaved
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions.multitask_multivariate_normal import (
    MultitaskMultivariateNormal,
)
from linear_operator.operators import BlockDiagLinearOperator, LinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.errors import NanError, NotPSDError
from torch import Tensor
def _reshape_base_samples(
    base_samples: Tensor, sample_shape: torch.Size, posterior: GPyTorchPosterior
) -> Tensor:
    r"""Manipulate shape of base_samples to match `MultivariateNormal.rsample`.
    This ensure that base_samples are used in the same way as in
    gpytorch.distributions.MultivariateNormal. For CBD, it is important to ensure
    that the same base samples are used for the in-sample points here and in the
    cached box decompositions.
    Args:
        base_samples: The base samples.
        sample_shape: The sample shape.
        posterior: The joint posterior is over (X_baseline, X).
    Returns:
        Reshaped and expanded base samples.
    """
    mvn = posterior.mvn
    loc = mvn.loc
    peshape = posterior.event_shape
    base_samples = base_samples.view(
        sample_shape + torch.Size([1 for _ in range(loc.ndim - 1)]) + peshape[-2:]
    ).expand(sample_shape + loc.shape[:-1] + peshape[-2:])
    if posterior._is_mt:
        base_samples = _reshape_base_samples_non_interleaved(
            mvn=posterior.mvn, base_samples=base_samples, sample_shape=sample_shape
        )
    base_samples = base_samples.reshape(
        -1, *loc.shape[:-1], mvn.lazy_covariance_matrix.shape[-1]
    )
    base_samples = base_samples.permute(*range(1, loc.dim() + 1), 0)
    return base_samples.reshape(
        *peshape[:-2],
        peshape[-1],
        peshape[-2],
        *sample_shape,
    )
[docs]def sample_cached_cholesky(
    posterior: GPyTorchPosterior,
    baseline_L: Tensor,
    q: int,
    base_samples: Tensor,
    sample_shape: torch.Size,
    max_tries: int = 6,
) -> Tensor:
    r"""Get posterior samples at the `q` new points from the joint multi-output
    posterior.
    Args:
        posterior: The joint posterior is over (X_baseline, X).
        baseline_L: The baseline lower triangular cholesky factor.
        q: The number of new points in X.
        base_samples: The base samples.
        sample_shape: The sample shape.
        max_tries: The number of tries for computing the Cholesky
            decomposition with increasing jitter.
    Returns:
        A `sample_shape x batch_shape x q x m`-dim tensor of posterior
            samples at the new points.
    """
    # compute bottom left covariance block
    if isinstance(posterior.mvn, MultitaskMultivariateNormal):
        lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn)
    else:
        lazy_covar = posterior.mvn.lazy_covariance_matrix
    # Get the `q` new rows of the batched covariance matrix
    bottom_rows = lazy_covar[..., -q:, :].to_dense()
    # The covariance in block form is:
    # [K(X_baseline, X_baseline), K(X_baseline, X)]
    # [K(X, X_baseline), K(X, X)]
    # bl := K(X, X_baseline)
    # br := K(X, X)
    # Get bottom right block of new covariance
    bl, br = bottom_rows.split([bottom_rows.shape[-1] - q, q], dim=-1)
    # Solve Ax = b
    # where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T
    # and bl_chol := x^T
    # bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
    # of the cholesky decomposition
    # TODO: remove the exception handling, when the pytorch
    # version requirement is bumped to >= 1.10
    try:
        bl_chol = torch.triangular_solve(
            bl.transpose(-2, -1), baseline_L, upper=False
        ).solution.transpose(-2, -1)
    except RuntimeError as e:
        if "singular" in str(e):
            raise NotPSDError(f"triangular_solve failed with RuntimeError: {e}")
        raise e
    # Compute the new bottom right block of the Cholesky
    # decomposition via:
    # Cholesky(K(X, X) - bl_chol @ bl_chol^T)
    br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1)
    # TODO: technically we should make sure that we add a
    # consistent nugget to the cached covariance and the new block
    br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries)
    # Create a `(batch_shape) x q x (n+q)`-dim tensor containing the
    # `q` new bottom rows of the Cholesky decomposition
    new_Lq = torch.cat([bl_chol, br_chol], dim=-1)
    mean = posterior.mvn.mean
    base_samples = _reshape_base_samples(
        base_samples=base_samples,
        sample_shape=sample_shape,
        posterior=posterior,
    )
    if not isinstance(posterior.mvn, MultitaskMultivariateNormal):
        # add output dim
        mean = mean.unsqueeze(-1)
        # add batch dim corresponding to output dim
        new_Lq = new_Lq.unsqueeze(-3)
    new_mean = mean[..., -q:, :]
    res = (
        new_Lq.matmul(base_samples)
        .add(new_mean.transpose(-1, -2).unsqueeze(-1))
        .permute(-1, *range(posterior.mvn.loc.dim() - 1), -2, -3)
        .contiguous()
    )
    contains_nans = torch.isnan(res).any()
    contains_infs = torch.isinf(res).any()
    if contains_nans or contains_infs:
        suffix_args = []
        if contains_nans:
            suffix_args.append("nans")
        if contains_infs:
            suffix_args.append("infs")
        suffix = " and ".join(suffix_args)
        raise NanError(f"Samples contain {suffix}.")
    return res