#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Some basic data transformation helpers.
"""
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from functools import wraps
from typing import Any, Callable, List, Optional
import torch
from gpytorch import settings as gpt_settings
from torch import Tensor
from .. import settings
[docs]def squeeze_last_dim(Y: Tensor) -> Tensor:
    r"""Squeeze the last dimension of a Tensor.
    Args:
        Y: A `... x d`-dim Tensor.
    Returns:
        The input tensor with last dimension squeezed.
    Example:
        >>> Y = torch.rand(4, 3)
        >>> Y_squeezed = squeeze_last_dim(Y)
    """
    return Y.squeeze(-1) 
[docs]def standardize(Y: Tensor) -> Tensor:
    r"""Standardizes (zero mean, unit variance) a tensor by dim=-2.
    If the tensor is single-dimensional, simply standardizes the tensor.
    If for some batch index all elements are equal (of if there is only a single
    data point), this function will return 0 for that batch index.
    Args:
        Y: A `batch_shape x n x m`-dim tensor.
    Returns:
        The standardized `Y`.
    Example:
        >>> Y = torch.rand(4, 3)
        >>> Y_standardized = standardize(Y)
    """
    stddim = -1 if Y.dim() < 2 else -2
    Y_std = Y.std(dim=stddim, keepdim=True)
    Y_std = Y_std.where(Y_std >= 1e-9, torch.full_like(Y_std, 1.0))
    return (Y - Y.mean(dim=stddim, keepdim=True)) / Y_std 
[docs]def normalize(X: Tensor, bounds: Tensor) -> Tensor:
    r"""Min-max normalize X w.r.t. the provided bounds.
    Args:
        X: `... x d` tensor of data
        bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
            columns.
    Returns:
        A `... x d`-dim tensor of normalized data, given by
            `(X - bounds[0]) / (bounds[1] - bounds[0])`. If all elements of `X`
            are contained within `bounds`, the normalized values will be
            contained within `[0, 1]^d`.
    Example:
        >>> X = torch.rand(4, 3)
        >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
        >>> X_normalized = normalize(X, bounds)
    """
    return (X - bounds[0]) / (bounds[1] - bounds[0]) 
[docs]def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
    r"""Un-normalizes X w.r.t. the provided bounds.
    Args:
        X: `... x d` tensor of data
        bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
            columns.
    Returns:
        A `... x d`-dim tensor of unnormalized data, given by
            `X * (bounds[1] - bounds[0]) + bounds[0]`. If all elements of `X`
            are contained in `[0, 1]^d`, the un-normalized values will be
            contained within `bounds`.
    Example:
        >>> X_normalized = torch.rand(4, 3)
        >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
        >>> X = unnormalize(X_normalized, bounds)
    """
    return X * (bounds[1] - bounds[0]) + bounds[0] 
[docs]def normalize_indices(indices: Optional[List[int]], d: int) -> Optional[List[int]]:
    r"""Normalize a list of indices to ensure that they are positive.
    Args:
        indices: A list of indices (may contain negative indices for indexing
            "from the back").
        d: The dimension of the tensor to index.
    Returns:
        A normalized list of indices such that each index is between `0` and
        `d-1`, or None if indices is None.
    """
    if indices is None:
        return indices
    normalized_indices = []
    for i in indices:
        if i < 0:
            i = i + d
        if i < 0 or i > d - 1:
            raise ValueError(f"Index {i} out of bounds for tensor or length {d}.")
        normalized_indices.append(i)
    return normalized_indices 
[docs]def concatenate_pending_points(
    method: Callable[[Any, Tensor], Any]
) -> Callable[[Any, Tensor], Any]:
    r"""Decorator concatenating X_pending into an acquisition function's argument.
    This decorator works on the `forward` method of acquisition functions taking
    a tensor `X` as the argument. If the acquisition function has an `X_pending`
    attribute (that is not `None`), this is concatenated into the input `X`,
    appropriately expanding the pending points to match the batch shape of `X`.
    Example:
        >>> class ExampleAcquisitionFunction:
        >>>     @concatenate_pending_points
        >>>     @t_batch_mode_transform()
        >>>     def forward(self, X):
        >>>         ...
    """
    @wraps(method)
    def decorated(cls: Any, X: Tensor, **kwargs: Any) -> Any:
        if cls.X_pending is not None:
            X = torch.cat([X, match_batch_shape(cls.X_pending, X)], dim=-2)
        return method(cls, X, **kwargs)
    return decorated 
[docs]def match_batch_shape(X: Tensor, Y: Tensor) -> Tensor:
    r"""Matches the batch dimension of a tensor to that of another tensor.
    Args:
        X: A `batch_shape_X x q x d` tensor, whose batch dimensions that
            correspond to batch dimensions of `Y` are to be matched to those
            (if compatible).
        Y: A `batch_shape_Y x q' x d` tensor.
    Returns:
        A `batch_shape_Y x q x d` tensor containing the data of `X` expanded to
        the batch dimensions of `Y` (if compatible). For instance, if `X` is
        `b'' x b' x q x d` and `Y` is `b x q x d`, then the returned tensor is
        `b'' x b x q x d`.
    Example:
        >>> X = torch.rand(2, 1, 5, 3)
        >>> Y = torch.rand(2, 6, 4, 3)
        >>> X_matched = match_batch_shape(X, Y)
        >>> X_matched.shape
        torch.Size([2, 6, 5, 3])
    """
    return X.expand(X.shape[: -(Y.dim())] + Y.shape[:-2] + X.shape[-2:]) 
[docs]def convert_to_target_pre_hook(module, *args):
    r"""Pre-hook for automatically calling `.to(X)` on module prior to `forward`"""
    module.to(args[0][0]) 
[docs]@contextmanager
def gpt_posterior_settings():
    r"""Context manager for settings used for computing model posteriors."""
    with ExitStack() as es:
        es.enter_context(gpt_settings.debug(False))
        es.enter_context(gpt_settings.fast_pred_var())
        es.enter_context(
            gpt_settings.detach_test_caches(settings.propagate_grads.off())
        )
        yield