#!/usr/bin/env python3
r"""
Some basic data transformation helpers.
"""
from functools import wraps
from typing import Any, Callable, Optional
import torch
from torch import Tensor
[docs]def squeeze_last_dim(Y: Tensor) -> Tensor:
    r"""Squeeze the last dimension of a Tensor.
    Args:
        Y: A `... x d`-dim Tensor.
    Returns:
        The input tensor with last dimension squeezed.
    Example:
        >>> Y = torch.rand(4, 3)
        >>> Y_squeezed = squeeze_last_dim(Y)
    """
    return Y.squeeze(-1) 
[docs]def standardize(X: Tensor) -> Tensor:
    r"""Standardize a tensor by dim=0.
    Args:
        X: A `n` or `n x d`-dim tensor
    Returns:
        The standardized `X`.
    Example:
        >>> X = torch.rand(4, 3)
        >>> X_standardized = standardize(X)
    """
    X_std = X.std(dim=0)
    X_std = X_std.where(X_std >= 1e-9, torch.full_like(X_std, 1.0))
    return (X - X.mean(dim=0)) / X_std 
[docs]def normalize(X: Tensor, bounds: Tensor) -> Tensor:
    r"""Min-max normalize X to [0, 1] using the provided bounds.
    Args:
        X: `... x d` tensor of data
        bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
            columns.
    Returns:
        A `... x d`-dim tensor of normalized data.
    Example:
        >>> X = torch.rand(4, 3)
        >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
        >>> X_normalized = unnormalize(X, bounds)
    """
    return (X - bounds[0]) / (bounds[1] - bounds[0]) 
[docs]def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
    r"""Unscale X from [0, 1] to the original scale.
    Args:
        X: `... x d` tensor of data
        bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
            columns.
    Returns:
        A `... x d`-dim tensor of unnormalized data.
    Example:
        >>> X_normalized = torch.rand(4, 3)
        >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
        >>> X = unnormalize(X_normalized, bounds)
    """
    return X * (bounds[1] - bounds[0]) + bounds[0] 
[docs]def match_batch_shape(X: Tensor, Y: Tensor) -> Tensor:
    r"""Matches the batch dimension of a tensor to that of another tensor.
    Args:
        X: A `batch_shape_X x q x d` tensor, whose batch dimensions that
            correspond to batch dimensions of `Y` are to be matched to those
            (if compatible).
        Y: A `batch_shape_Y x q' x d` tensor.
    Returns:
        A `batch_shape_Y x q x d` tensor containing the data of `X` expanded to
        the batch dimensions of `Y` (if compatible). For instance, if `X` is
        `b'' x b' x q x d` and `Y` is `b x q x d`, then the returned tensor is
        `b'' x b x q x d`.
    Example:
        >>> X = torch.rand(2, 1, 5, 3)
        >>> Y = torch.rand(2, 6, 4, 3)
        >>> X_matched = match_batch_shape(X, Y)
        >>> X_matched.shape
        torch.Size([2, 6, 5, 3])
    """
    return X.expand(X.shape[: -Y.dim()] + Y.shape[:-2] + X.shape[-2:]) 
[docs]def convert_to_target_pre_hook(module, *args):
    r"""Pre-hook for automatically calling `.to(X)` on module prior to `forward`"""
    module.to(args[0][0])