Source code for botorch.utils.rounding
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Discretization (rounding) functions for acquisition optimization.
References
.. [Daulton2022bopr]
    S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
    Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
    Reparameterization. Advances in Neural Information Processing Systems
    35, 2022.
"""
from __future__ import annotations
import torch
from torch import Tensor
from torch.autograd import Function
from torch.nn.functional import one_hot
[docs]def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
    r"""Diffentiable approximate rounding function.
    This method is a piecewise approximation of a rounding function where
    each piece is a hyperbolic tangent function.
    Args:
        X: The tensor to round to the nearest integer (element-wise).
        tau: A temperature hyperparameter.
    Returns:
        The approximately rounded input tensor.
    """
    offset = X.floor()
    scaled_remainder = (X - offset - 0.5) / tau
    rounding_component = (torch.tanh(scaled_remainder) + 1) / 2
    return offset + rounding_component 
[docs]class IdentitySTEFunction(Function):
    """Base class for functions using straight through gradient estimators.
    This class approximates the gradient with the identity function.
    """
[docs]    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tensor:
        r"""Use a straight-through estimator the gradient.
        This uses the identity function.
        Args:
            grad_output: A tensor of gradients.
        Returns:
            The provided tensor.
        """
        return grad_output  
[docs]class RoundSTE(IdentitySTEFunction):
    r"""Round the input tensor and use a straight-through gradient estimator.
    [Daulton2022bopr]_ proposes using this in acquisition optimization.
    """
[docs]    @staticmethod
    def forward(ctx, X: Tensor) -> Tensor:
        r"""Round the input tensor element-wise.
        Args:
            X: The tensor to be rounded.
        Returns:
            A tensor where each element is rounded to the nearest integer.
        """
        return X.round()  
[docs]class OneHotArgmaxSTE(IdentitySTEFunction):
    r"""Discretize a continuous relaxation of a one-hot encoded categorical.
    This returns a one-hot encoded categorical and use a straight-through
    gradient estimator via an identity function.
    [Daulton2022bopr]_ proposes using this in acquisition optimization.
    """
[docs]    @staticmethod
    def forward(ctx, X: Tensor) -> Tensor:
        r"""Discretize the input tensor.
        This applies a argmax along the last dimensions of the input tensor
        and one-hot encodes the result.
        Args:
            X: The tensor to be rounded.
        Returns:
            A tensor where each element is rounded to the nearest integer.
        """
        return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)