Source code for botorch.utils.dispatcher
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from inspect import getsource, getsourcefile
from typing import Any, Callable, Optional, Tuple, Type
from multipledispatch.dispatcher import (
    Dispatcher as MDDispatcher,
    MDNotImplementedError,  # trivial subclass of NotImplementedError
    str_signature,
)
[docs]class Dispatcher(MDDispatcher):
    r"""Clearing house for multiple dispatch functionality. This class extends
    `<multipledispatch.Dispatcher>` by: (i) generalizing the argument encoding
    convention during method lookup, (ii) implementing `__getitem__` as a dedicated
    method lookup function.
    """
    def __init__(
        self,
        name: str,
        doc: Optional[str] = None,
        encoder: Callable[Any, Type] = type,
    ) -> None:
        """
        Args:
            name: A string identifier for the `Dispatcher` instance.
            doc: A docstring for the multiply dispatched method(s).
            encoder: A callable that individually transforms the arguments passed
                at runtime in order to construct the key used for method lookup as
                `tuple(map(encoder, args))`. Defaults to `type`.
        """
        super().__init__(name=name, doc=doc)
        self._encoder = encoder
    def __getitem__(
        self,
        args: Optional[Any] = None,
        types: Optional[Tuple[Type]] = None,
    ) -> Callable:
        r"""Method lookup.
        Args:
            args: A set of arguments that act as identifiers for a stored method.
            types: A tuple of types that encodes `args`.
        Returns:
            A callable corresponding to the given `args` or `types`.
        """
        if types is None:
            if args is None:
                raise RuntimeError("One of `args` or `types` must be provided.")
            types = self.encode_args(args)
        elif args is not None:
            raise RuntimeError("Only one of `args` or `types` may be provided.")
        try:
            func = self._cache[types]
        except KeyError:
            func = self.dispatch(*types)
            if not func:
                msg = f"{self.name}: <{', '.join(cls.__name__ for cls in types)}"
                raise NotImplementedError(f"Could not find signature for {msg}")
            self._cache[types] = func
        return func
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        r"""Multiply dispatches a call to a collection of methods.
        Args:
            args: A set of arguments that act as identifiers for a stored method.
            kwargs: Optional keyword arguments passed to the retrieved method.
        Returns:
            The result of evaluating `func(*args, **kwargs)`, where `func` is
            the function obtained via method lookup.
        """
        types = self.encode_args(args)
        func = self.__getitem__(types=types)
        try:
            return func(*args, **kwargs)
        except MDNotImplementedError:
            # Traverses registered methods in order, yields whenever a match is found
            funcs = self.dispatch_iter(*types)
            next(funcs)  # burn first, same as self.__getitem__(types=types)
            for func in funcs:
                try:
                    return func(*args, **kwargs)
                except MDNotImplementedError:
                    pass
            raise NotImplementedError(
                f"Matching functions for {self.name:s}: {str_signature(types):s} "
                "found, but none completed successfully"
            )
[docs]    def dispatch(self, *types: Type) -> Callable:
        r"""Method lookup strategy. Checks for an exact match before traversing
        the set of registered methods according to the current ordering.
        Args:
            types: A tuple of types that gets compared with the signatures
                of registered methods to determine compatibility.
        Returns:
            The first method encountered with a matching signature.
        """
        if types in self.funcs:
            return self.funcs[types]
        try:
            return next(self.dispatch_iter(*types))
        except StopIteration:
            return None 
[docs]    def encode_args(self, args: Any) -> Tuple[Type]:
        r"""Converts arguments into a tuple of types used during method lookup."""
        return tuple(map(self.encoder, args if isinstance(args, tuple) else (args,))) 
    def _help(self, *args: Any) -> str:
        r"""Returns the retrieved method's docstring."""
        return self.dispatch(*self.encode_args(args)).__doc__
[docs]    def help(self, *args: Any, **kwargs: Any) -> None:
        r"""Prints the retrieved method's docstring."""
        print(self._help(*args)) 
    def _source(self, *args: Any) -> str:
        r"""Returns the retrieved method's source types as a string."""
        func = self.dispatch(*self.encode_args(args))
        if not func:
            raise TypeError("No function found")
        return f"File: {getsourcefile(func)}\n\n{getsource(func)}"
[docs]    def source(self, *args, **kwargs) -> None:
        r"""Prints the retrieved method's source types."""
        print(self._source(*args)) 
    @property
    def encoder(self) -> Callable[Any, Type]:
        return self._encoder