Source code for botorch.acquisition.penalized
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Modules to add regularization to acquisition functions.
"""
from __future__ import annotations
import math
from typing import List, Optional
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.exceptions import UnsupportedError
from torch import Tensor
[docs]class L2Penalty(torch.nn.Module):
    r"""L2 penalty class to be added to any arbitrary acquisition function."""
    def __init__(self, init_point: Tensor):
        r"""Initializing L2 regularization.
        Args:
            init_point: The "1 x dim" reference point against which
                we want to regularize.
        """
        super().__init__()
        self.init_point = init_point
[docs]    def forward(self, X: Tensor) -> Tensor:
        r"""
        Args:
            X: A "batch_shape x q x dim" representing the points to be evaluated.
        Returns:
            A tensor of size "batch_shape" representing the acqfn for each q-batch.
        """
        regularization_term = (
            torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2
        )
        return regularization_term  
[docs]class GaussianPenalty(torch.nn.Module):
    r"""Gaussian penalty class to be added to any arbitrary acquisition function."""
    def __init__(self, init_point: Tensor, sigma: float):
        r"""Initializing Gaussian regularization.
        Args:
            init_point: The "1 x dim" reference point against which
                we want to regularize.
            sigma: The parameter used in gaussian function.
        """
        super().__init__()
        self.init_point = init_point
        self.sigma = sigma
[docs]    def forward(self, X: Tensor) -> Tensor:
        r"""
        Args:
            X: A "batch_shape x q x dim" representing the points to be evaluated.
        Returns:
            A tensor of size "batch_shape" representing the acqfn for each q-batch.
        """
        sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2
        pdf = torch.exp(sq_diff / 2 / self.sigma ** 2)
        regularization_term = pdf.max(dim=-1).values
        return regularization_term  
[docs]class GroupLassoPenalty(torch.nn.Module):
    r"""Group lasso penalty class to be added to any arbitrary acquisition function."""
    def __init__(self, init_point: Tensor, groups: List[List[int]]):
        r"""Initializing Group-Lasso regularization.
        Args:
            init_point: The "1 x dim" reference point against which we want
                to regularize.
            groups: Groups of indices used in group lasso.
        """
        super().__init__()
        self.init_point = init_point
        self.groups = groups
[docs]    def forward(self, X: Tensor) -> Tensor:
        r"""
        X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not
        implemented yet.
        """
        if X.shape[-2] != 1:
            raise NotImplementedError(
                "group-lasso has not been implemented for q>1 yet."
            )
        regularization_term = group_lasso_regularizer(
            X=X.squeeze(-2) - self.init_point, groups=self.groups
        )
        return regularization_term  
[docs]class PenalizedAcquisitionFunction(AcquisitionFunction):
    r"""Single-outcome acquisition function regularized by the given penalty.
    The usage is similar to:
        raw_acqf = NoisyExpectedImprovement(...)
        penalty = GroupLassoPenalty(...)
        acqf = PenalizedAcquisitionFunction(raw_acqf, penalty)
    """
    def __init__(
        self,
        raw_acqf: AcquisitionFunction,
        penalty_func: torch.nn.Module,
        regularization_parameter: float,
    ) -> None:
        r"""Initializing Group-Lasso regularization.
        Args:
            raw_acqf: The raw acquisition function that is going to be regularized.
            penalty_func: The regularization function.
            regularization_parameter: Regularization parameter used in optimization.
        """
        super().__init__(model=raw_acqf.model)
        self.raw_acqf = raw_acqf
        self.penalty_func = penalty_func
        self.regularization_parameter = regularization_parameter
[docs]    def forward(self, X: Tensor) -> Tensor:
        raw_value = self.raw_acqf(X=X)
        penalty_term = self.penalty_func(X)
        return raw_value - self.regularization_parameter * penalty_term 
    @property
    def X_pending(self) -> Optional[Tensor]:
        return self.raw_acqf.X_pending
[docs]    def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
        if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
            self.raw_acqf.set_X_pending(X_pending=X_pending)
        else:
            raise UnsupportedError(
                "The raw acquisition function is Analytic and does not account "
                "for X_pending yet."
            )  
[docs]def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
    r"""Computes the group lasso regularization function for the given point.
    Args:
        X: A bxd tensor representing the points to evaluate the regularization at.
        groups: List of indices of different groups.
    Returns:
        Computed group lasso norm of at the given points.
    """
    return torch.sum(
        torch.stack(
            [math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups],
            dim=-1,
        ),
        dim=-1,
    )