# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
from typing import Callable, List, Optional
import torch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import PosteriorList
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from torch import Tensor
MCMC_DIM = -3  # Location of the MCMC batch dimension
TOL = 1e-6  # Bisection tolerance
[docs]def batched_bisect(
    f: Callable, target: float, bounds: Tensor, tol: float = TOL, max_steps: int = 32
):
    r"""Batched bisection with a fixed number of steps.
    Args:
        f: Target function that takes a `(b1 x ... x bk)`-dim tensor and returns a
            `(b1 x ... x bk)`-dim tensor.
        target: Scalar target value of type float.
        bounds: Lower and upper bounds, of size `2 x b1 x ... x bk`.
        tol: We termniate if all elements satisfy are within `tol` of the `target`.
        max_steps: Maximum number of bisection steps.
    Returns:
        Tensor X of size `b1 x ... x bk` such that `f(X) = target`.
    """
    # Make sure target is actually contained in the interval
    f1, f2 = f(bounds[0]), f(bounds[1])
    if not ((f1 <= target) & (target <= f2)).all():
        raise ValueError(
            "The target is not contained in the interval specified by the bounds"
        )
    bounds = bounds.clone()  # Will be modified in-place
    center = bounds.mean(dim=0)
    f_center = f(center)
    for _ in range(max_steps):
        go_left = f_center > target
        bounds[1, go_left] = center[go_left]
        bounds[0, ~go_left] = center[~go_left]
        center = bounds.mean(dim=0)
        f_center = f(center)
        # Check convergence
        if (f_center - target).abs().max() <= tol:
            return center
    return center 
[docs]class FullyBayesianPosterior(GPyTorchPosterior):
    r"""A posterior for a fully Bayesian model.
    The MCMC batch dimension that corresponds to the models in the mixture is located
    at `MCMC_DIM` (defined at the top of this file). Note that while each MCMC sample
    corresponds to a Gaussian posterior, the fully Bayesian posterior is rather a
    mixture of Gaussian distributions. We provide convenience properties/methods for
    computing the mean, variance, median, and quantiles of this mixture.
    """
    def __init__(self, mvn: MultivariateNormal) -> None:
        r"""A posterior for a fully Bayesian model.
        Args:
            mvn: A GPyTorch MultivariateNormal (single-output case)
        """
        super().__init__(mvn=mvn)
        self._mean = mvn.mean if self._is_mt else mvn.mean.unsqueeze(-1)
        self._variance = mvn.variance if self._is_mt else mvn.variance.unsqueeze(-1)
    @property
    @lru_cache(maxsize=None)
    def mixture_mean(self) -> Tensor:
        r"""The posterior mean for the mixture of models."""
        return self._mean.mean(dim=MCMC_DIM)
    @property
    @lru_cache(maxsize=None)
    def mixture_variance(self) -> Tensor:
        r"""The posterior variance for the mixture of models."""
        num_mcmc_samples = self.mean.shape[MCMC_DIM]
        t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples
        t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples
        t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2)
        return t1 + t2 + t3
    @property
    @lru_cache(maxsize=None)
    def mixture_median(self) -> Tensor:
        r"""The posterior median for the mixture of models."""
        return self.mixture_quantile(q=0.5)
[docs]    @lru_cache(maxsize=None)
    def mixture_quantile(self, q: float) -> Tensor:
        r"""The posterior quantiles for the mixture of models."""
        if not isinstance(q, float):
            raise ValueError("q is expected to be a float.")
        if q <= 0 or q >= 1:
            raise ValueError("q is expected to be in the range (0, 1).")
        q_tensor = torch.tensor(q).to(self.mean)
        dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt())
        if self.mean.shape[MCMC_DIM] == 1:  # Analytical solution
            return dist.icdf(q_tensor).squeeze(MCMC_DIM)
        low = dist.icdf(q_tensor).min(dim=MCMC_DIM).values - TOL
        high = dist.icdf(q_tensor).max(dim=MCMC_DIM).values + TOL
        bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0)
        return batched_bisect(
            f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM),
            target=q,
            bounds=bounds,
        )  
[docs]class FullyBayesianPosteriorList(PosteriorList):
    r"""A Posterior represented by a list of independent Posteriors.
    This posterior should only be used when at least one posterior is a
    `FullyBayesianPosterior`. Posteriors that aren't of type `FullyBayesianPosterior`
    are automatically reshaped to match the size of the fully Bayesian posteriors
    to allow mixing, e.g., deterministic and fully Bayesian models.
    Args:
        *posteriors: A variable number of single-outcome posteriors.
    Example:
        >>> p_1 = model_1.posterior(test_X)
        >>> p_2 = model_2.posterior(test_X)
        >>> p_12 = FullyBayesianPosteriorList(p_1, p_2)
    """
    def _get_mcmc_batch_dimension(self) -> int:
        """Return the number of MCMC samples in the corresponding batch dimension."""
        mcmc_samples = [
            p.mean.shape[MCMC_DIM]
            for p in self.posteriors
            if isinstance(p, FullyBayesianPosterior)
        ]
        if len(set(mcmc_samples)) > 1:
            raise NotImplementedError(
                "All MCMC batch dimensions must have the same size, got shapes: "
                f"{mcmc_samples}."
            )
        return mcmc_samples[0]
    @staticmethod
    def _reshape_tensor(X: Tensor, mcmc_samples: int) -> Tensor:
        """Reshape a tensor without an MCMC batch dimension to match the shape."""
        X = X.unsqueeze(MCMC_DIM)
        return X.expand(*X.shape[:MCMC_DIM], mcmc_samples, *X.shape[MCMC_DIM + 1 :])
    def _reshape_and_cat(self, Xs: List[Tensor]):
        r"""Reshape and cat a list of tensors."""
        mcmc_samples = self._get_mcmc_batch_dimension()
        return torch.cat(
            [
                x
                if isinstance(p, FullyBayesianPosterior)
                else self._reshape_tensor(x, mcmc_samples=mcmc_samples)
                for x, p in zip(Xs, self.posteriors)
            ],
            dim=-1,
        )
    @property
    def event_shape(self) -> torch.Size:
        r"""The event shape (i.e. the shape of a single sample)."""
        fully_bayesian_posteriors = [
            p for p in self.posteriors if isinstance(p, FullyBayesianPosterior)
        ]
        event_shape = fully_bayesian_posteriors[0].event_shape
        if not all(event_shape == p.event_shape for p in fully_bayesian_posteriors):
            # Make sure all fully Bayesian posteriors have the same event shape
            raise NotImplementedError(
                f"`{self.__class__.__name__}.event_shape` is only supported if all "
                "constituent posteriors have the same `event_shape`."
            )
        event_shapes = [event_shape for _ in self.posteriors]
        batch_shapes = [es[:-1] for es in event_shapes]
        return batch_shapes[0] + torch.Size([es[-1] for es in event_shapes])
    @property
    def mean(self) -> Tensor:
        r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
        return self._reshape_and_cat(Xs=[p.mean for p in self.posteriors])
    @property
    def variance(self) -> Tensor:
        r"""The variance of the posterior as a `(b) x n x m`-dim Tensor."""
        return self._reshape_and_cat(Xs=[p.variance for p in self.posteriors])
[docs]    def rsample(
        self,
        sample_shape: Optional[torch.Size] = None,
        base_samples: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Sample from the posterior (with gradients).
        Args:
            sample_shape: A `torch.Size` object specifying the sample shape. To
                draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
                of `n` samples each, set to `torch.Size([b, n])`.
            base_samples: An (optional) Tensor of `N(0, I)` base samples of
                appropriate dimension, typically obtained from a `Sampler`.
                This is used for deterministic optimization.
        Returns:
            A `sample_shape x event`-dim Tensor of samples from the posterior.
        """
        samples = super()._rsample(sample_shape=sample_shape, base_samples=base_samples)
        return self._reshape_and_cat(Xs=samples)