# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.optim.homotopy import Homotopy
from botorch.optim.optimize import optimize_acqf
from torch import Tensor
[docs]
def prune_candidates(
    candidates: Tensor, acq_values: Tensor, prune_tolerance: float
) -> Tensor:
    r"""Prune candidates based on their distance to other candidates.
    Args:
        candidates: An `n x d` tensor of candidates.
        acq_values: An `n` tensor of candidate values.
        prune_tolerance: The minimum distance to prune candidates.
    Returns:
        An `m x d` tensor of pruned candidates.
    """
    if candidates.ndim != 2:
        raise ValueError("`candidates` must be of size `n x d`.")
    if acq_values.ndim != 1 or len(acq_values) != candidates.shape[0]:
        raise ValueError("`acq_values` must be of size `n`.")
    if prune_tolerance < 0:
        raise ValueError("`prune_tolerance` must be >= 0.")
    sorted_inds = acq_values.argsort(descending=True)
    candidates = candidates[sorted_inds]
    candidates_new = candidates[:1, :]
    for i in range(1, candidates.shape[0]):
        if (
            torch.cdist(candidates[i : i + 1, :], candidates_new).min()
            > prune_tolerance
        ):
            candidates_new = torch.cat(
                [candidates_new, candidates[i : i + 1, :]], dim=-2
            )
    return candidates_new 
[docs]
def optimize_acqf_homotopy(
    acq_function: AcquisitionFunction,
    bounds: Tensor,
    q: int,
    homotopy: Homotopy,
    num_restarts: int,
    raw_samples: Optional[int] = None,
    fixed_features: Optional[Dict[int, float]] = None,
    options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
    final_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
    batch_initial_conditions: Optional[Tensor] = None,
    post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
    prune_tolerance: float = 1e-4,
) -> Tuple[Tensor, Tensor]:
    r"""Generate a set of candidates via multi-start optimization.
    Args:
        acq_function: An AcquisitionFunction.
        bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
        q: The number of candidates.
        homotopy: Homotopy object that will make the necessary modifications to the
            problem when calling `step()`.
        num_restarts: The number of starting points for multistart acquisition
            function optimization.
        raw_samples: The number of samples for initialization. This is required
            if `batch_initial_conditions` is not specified.
        fixed_features: A map `{feature_index: value}` for features that
            should be fixed to a particular value during generation.
        options: Options for candidate generation.
        final_options: Options for candidate generation in the last homotopy step.
        batch_initial_conditions: A tensor to specify the initial conditions. Set
            this if you do not want to use default initialization strategy.
        post_processing_func: Post processing function (such as roundingor clamping)
            that is applied before choosing the final candidate.
    """
    candidate_list, acq_value_list = [], []
    if q > 1:
        base_X_pending = acq_function.X_pending
    for _ in range(q):
        candidates = batch_initial_conditions
        homotopy.restart()
        while not homotopy.should_stop:
            candidates, acq_values = optimize_acqf(
                q=1,
                acq_function=acq_function,
                bounds=bounds,
                num_restarts=num_restarts,
                batch_initial_conditions=candidates,
                raw_samples=raw_samples,
                fixed_features=fixed_features,
                return_best_only=False,
                options=options,
            )
            homotopy.step()
            # Prune candidates
            candidates = prune_candidates(
                candidates=candidates.squeeze(1),
                acq_values=acq_values,
                prune_tolerance=prune_tolerance,
            ).unsqueeze(1)
        # Optimize one more time with the final options
        candidates, acq_values = optimize_acqf(
            q=1,
            acq_function=acq_function,
            bounds=bounds,
            num_restarts=num_restarts,
            batch_initial_conditions=candidates,
            return_best_only=False,
            options=final_options,
        )
        # Post-process the candidates and grab the best candidate
        if post_processing_func is not None:
            candidates = post_processing_func(candidates)
            acq_values = acq_function(candidates)
        best = torch.argmax(acq_values.view(-1), dim=0)
        candidate, acq_value = candidates[best], acq_values[best]
        # Keep the new candidate and update the pending points
        candidate_list.append(candidate)
        acq_value_list.append(acq_value)
        selected_candidates = torch.cat(candidate_list, dim=-2)
        if q > 1:
            acq_function.set_X_pending(
                torch.cat([base_X_pending, selected_candidates], dim=-2)
                if base_X_pending is not None
                else selected_candidates
            )
    if q > 1:  # Reset acq_function to previous X_pending state
        acq_function.set_X_pending(base_X_pending)
    homotopy.reset()  # Reset the homotopy parameters
    return selected_candidates, torch.stack(acq_value_list)