Source code for botorch.utils.multi_objective.box_decompositions.box_decomposition
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Box decomposition algorithms."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
import torch
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.utils import (
    _expand_ref_point,
    _pad_batch_pareto_frontier,
)
from torch import Tensor
from torch.nn import Module
[docs]class BoxDecomposition(Module, ABC):
    r"""An abstract class for box decompositions.
    Note: Internally, we store the negative reference point (minimization).
    """
    def __init__(
        self, ref_point: Tensor, sort: bool, Y: Optional[Tensor] = None
    ) -> None:
        """Initialize BoxDecomposition.
        Args:
            ref_point: A `m`-dim tensor containing the reference point.
            sort: A boolean indicating whether to sort the Pareto frontier.
            Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
        """
        super().__init__()
        self.register_buffer("_neg_ref_point", -ref_point)
        self.register_buffer("sort", torch.tensor(sort, dtype=torch.bool))
        self.num_outcomes = ref_point.shape[-1]
        if Y is not None:
            self.update(Y=Y)
    @property
    def pareto_Y(self) -> Tensor:
        r"""This returns the non-dominated set.
        Returns:
            A `n_pareto x m`-dim tensor of outcomes.
        """
        try:
            return -self._neg_pareto_Y
        except AttributeError:
            raise BotorchError("pareto_Y has not been initialized")
    @property
    def ref_point(self) -> Tensor:
        r"""Get the reference point.
        Returns:
            A `m`-dim tensor of outcomes.
        """
        return -self._neg_ref_point
    @property
    def Y(self) -> Tensor:
        r"""Get the raw outcomes.
        Returns:
            A `n x m`-dim tensor of outcomes.
        """
        return -self._neg_Y
    def _update_pareto_Y(self) -> bool:
        r"""Update the non-dominated front.
        Returns:
            A boolean indicating whether the Pareto frontier has changed.
        """
        # is_non_dominated assumes maximization
        if self._neg_Y.shape[-2] == 0:
            pareto_Y = self._neg_Y
        else:
            # assumes maximization
            pareto_Y = -_pad_batch_pareto_frontier(
                Y=self.Y,
                ref_point=_expand_ref_point(
                    ref_point=self.ref_point, batch_shape=self.batch_shape
                ),
            )
            if self.sort:
                # sort by first objective
                if len(self.batch_shape) > 0:
                    pareto_Y = pareto_Y.gather(
                        index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(
                            pareto_Y.shape
                        ),
                        dim=-2,
                    )
                else:
                    pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]
        if not hasattr(self, "_neg_pareto_Y") or not torch.equal(
            pareto_Y, self._neg_pareto_Y
        ):
            self.register_buffer("_neg_pareto_Y", pareto_Y)
            return True
        return False
[docs]    def partition_space(self) -> None:
        r"""Compute box decomposition."""
        try:
            self.partition_space_2d()
        except BotorchTensorDimensionError:
            self._partition_space() 
[docs]    @abstractmethod
    def partition_space_2d(self) -> None:
        r"""Compute box decomposition for 2 objectives."""
        pass  # pragma: no cover 
[docs]    @abstractmethod
    def get_hypercell_bounds(self) -> Tensor:
        r"""Get the bounds of each hypercell in the decomposition.
        Returns:
            A `2 x num_cells x num_outcomes`-dim tensor containing the
                lower and upper vertices bounding each hypercell.
        """
        pass  # pragma: no cover 
[docs]    def update(self, Y: Tensor) -> None:
        r"""Update non-dominated front and decomposition.
        Args:
            Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
        """
        self.batch_shape = Y.shape[:-2]
        if len(self.batch_shape) > 1:
            raise NotImplementedError(
                f"{type(self).__name__} only supports a single "
                f"batch dimension, but got {len(self.batch_shape)} "
                "batch dimensions."
            )
        elif len(self.batch_shape) > 0 and self.num_outcomes > 2:
            raise NotImplementedError(
                f"{type(self).__name__} only supports a batched box "
                f"decompositions in the 2-objective setting."
            )
        # multiply by -1, since internally we minimize.
        self._neg_Y = -Y
        is_new_pareto = self._update_pareto_Y()
        # Update decomposition if the Pareto front changed
        if is_new_pareto:
            self.partition_space()