Source code for botorch.utils.containers
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Containers to standardize inputs into models and acquisition functions.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import torch
from botorch.exceptions.errors import UnsupportedError
from torch import Tensor
[docs]@dataclass
class TrainingData:
    r"""Standardized container of model training data for models.
    Properties:
        Xs: A list of tensors, each of shape `batch_shape x n_i x d`,
            where `n_i` is the number of training inputs for the i-th model.
        Ys: A list of tensors, each of shape `batch_shape x n_i x 1`,
            where `n_i` is the number of training observations for the i-th
            (single-output) model.
        Yvars: A list of tensors, each of shape `batch_shape x n_i x 1`,
            where `n_i` is the number of training observations of the
            observation noise for the i-th  (single-output) model.
            If `None`, the observation noise level is unobserved.
    """
    Xs: List[Tensor]  # `batch_shape x n_i x 1`
    Ys: List[Tensor]  # `batch_shape x n_i x 1`
    Yvars: Optional[List[Tensor]] = None  # `batch_shape x n_i x 1`
    def __post_init__(self):
        self._is_block_design = all(torch.equal(X, self.Xs[0]) for X in self.Xs[1:])
[docs]    @classmethod
    def from_block_design(cls, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None):
        r"""Construct a TrainingData object from a block design description.
        Args:
            X: A `batch_shape x n x d` tensor of training points (shared across
                all outcomes).
            Y: A `batch_shape x n x m` tensor of training observations.
            Yvar: A `batch_shape x n x m` tensor of training noise variance
                observations, or `None`.
        Returns:
            The `TrainingData` object (with `is_block_design=True`).
        """
        return cls(
            Xs=[X for _ in range(Y.shape[-1])],
            Ys=list(torch.split(Y, 1, dim=-1)),
            Yvars=None if Yvar is None else list(torch.split(Yvar, 1, dim=-1)),
        ) 
    @property
    def is_block_design(self) -> bool:
        r"""Indicates whether training data is a "block design".
        Block designs are designs in which all outcomes are observed
        at the same training inputs.
        """
        return self._is_block_design
    @property
    def X(self) -> Tensor:
        r"""The training inputs (block-design only).
        This raises an `UnsupportedError` in the non-block-design case.
        """
        if not self.is_block_design:
            raise UnsupportedError
        return self.Xs[0]
    @property
    def Y(self) -> Tensor:
        r"""The training observations (block-design only).
        This raises an `UnsupportedError` in the non-block-design case.
        """
        if not self.is_block_design:
            raise UnsupportedError
        return torch.cat(self.Ys, dim=-1)
    @property
    def Yvar(self) -> Optional[List[Tensor]]:
        r"""The training observations's noise variance (block-design only).
        This raises an `UnsupportedError` in the non-block-design case.
        """
        if self.Yvars is not None:
            if not self.is_block_design:
                raise UnsupportedError
            return torch.cat(self.Yvars, dim=-1)
    def __eq__(self, other: TrainingData) -> bool:
        # Check for `None` Yvars and unequal attribute lengths upfront.
        if self.Yvars is None or other.Yvars is None:
            if not (self.Yvars is other.Yvars is None):
                return False
        else:
            if len(self.Yvars) != len(other.Yvars):
                return False
        if len(self.Xs) != len(other.Xs) or len(self.Ys) != len(other.Ys):
            return False
        return (  # Deep-check equality of attributes.
            all(
                torch.equal(self_X, other_X)
                for self_X, other_X in zip(self.Xs, other.Xs)
            )
            and all(
                torch.equal(self_Y, other_Y)
                for self_Y, other_Y in zip(self.Ys, other.Ys)
            )
            and (
                self.Yvars is other.Yvars is None
                or all(
                    torch.equal(self_Yvar, other_Yvar)
                    for self_Yvar, other_Yvar in zip(self.Yvars, other.Yvars)
                )
            )
        )