Source code for botorch.posteriors.posterior
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Abstract base module for all botorch posteriors.
"""
from __future__ import annotations
from abc import ABC, abstractmethod, abstractproperty
from typing import Optional
import torch
from torch import Tensor
[docs]class Posterior(ABC):
    r"""Abstract base class for botorch posteriors."""
    @property
    def base_sample_shape(self) -> torch.Size:
        r"""The shape of a base sample used for constructing posterior samples.
        This function may be overwritten by subclasses in case `base_sample_shape`
        and `event_shape` do not agree (e.g. if the posterior is a Multivariate
        Gaussian that is not full rank).
        """
        return self.event_shape
    @abstractproperty
    def device(self) -> torch.device:
        r"""The torch device of the posterior."""
        pass  # pragma: no cover
    @abstractproperty
    def dtype(self) -> torch.dtype:
        r"""The torch dtype of the posterior."""
        pass  # pragma: no cover
    @abstractproperty
    def event_shape(self) -> torch.Size:
        r"""The event shape (i.e. the shape of a single sample)."""
        pass  # pragma: no cover
    @property
    def mean(self) -> Tensor:
        r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
        raise NotImplementedError(
            f"Property `mean` not implemented for {self.__class__.__name__}"
        )
    @property
    def variance(self) -> Tensor:
        r"""The variance of the posterior as a `(b) x n x m`-dim Tensor."""
        raise NotImplementedError(
            f"Property `variance` not implemented for {self.__class__.__name__}"
        )
[docs]    @abstractmethod
    def rsample(
        self,
        sample_shape: Optional[torch.Size] = None,
        base_samples: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Sample from the posterior (with gradients).
        Args:
            sample_shape: A `torch.Size` object specifying the sample shape. To
                draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
                of `n` samples each, set to `torch.Size([b, n])`.
            base_samples: An (optional) Tensor of `N(0, I)` base samples of
                appropriate dimension, typically obtained from a `Sampler`.
                This is used for deterministic optimization.
        Returns:
            A `sample_shape x event`-dim Tensor of samples from the posterior.
        """
        pass  # pragma: no cover 
[docs]    def sample(
        self,
        sample_shape: Optional[torch.Size] = None,
        base_samples: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Sample from the posterior (without gradients).
        This is a simple wrapper calling `rsample` using `with torch.no_grad()`.
        Args:
            sample_shape: A `torch.Size` object specifying the sample shape. To
                draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
                of `n` samples each, set to `torch.Size([b, n])`.
            base_samples: An (optional) Tensor of `N(0, I)` base samples of
                appropriate dimension, typically obtained from a `Sampler` object.
                This is used for deterministic optimization.
        Returns:
            A `sample_shape x event_shape`-dim Tensor of samples from the posterior.
        """
        with torch.no_grad():
            return self.rsample(sample_shape=sample_shape, base_samples=base_samples)