Source code for botorch.settings
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
BoTorch settings.
"""
from __future__ import annotations
import typing  # noqa F401
import warnings
from botorch.exceptions import BotorchWarning
from botorch.logging import LOG_LEVEL_DEFAULT, logger
class _Flag:
    r"""Base class for context managers for a binary setting."""
    _state: bool = False
    @classmethod
    def on(cls) -> bool:
        return cls._state
    @classmethod
    def off(cls) -> bool:
        return not cls._state
    @classmethod
    def _set_state(cls, state: bool) -> None:
        cls._state = state
    def __init__(self, state: bool = True) -> None:
        self.prev = self.__class__.on()
        self.state = state
    def __enter__(self) -> None:
        self.__class__._set_state(self.state)
    def __exit__(self, *args) -> None:
        self.__class__._set_state(self.prev)
[docs]class propagate_grads(_Flag):
    r"""Flag for propagating gradients to model training inputs / training data.
    When set to `True`, gradients will be propagated to the training inputs.
    This is useful in particular for propating gradients through fantasy models.
    """
    _state: bool = False 
[docs]def suppress_botorch_warnings(suppress: bool) -> None:
    r"""Set botorch warning filter.
    Args:
        state: A boolean indicating whether warnings should be prints
    """
    warnings.simplefilter("ignore" if suppress else "default", BotorchWarning) 
[docs]class debug(_Flag):
    r"""Flag for printing verbose BotorchWarnings.
    When set to `True`, verbose `BotorchWarning`s will be printed for debuggability.
    Warnings that are not subclasses of `BotorchWarning` will not be affected by
    this context_manager.
    """
    _state: bool = False
    suppress_botorch_warnings(suppress=not _state)
    @classmethod
    def _set_state(cls, state: bool) -> None:
        cls._state = state
        suppress_botorch_warnings(suppress=not cls._state) 
[docs]class log_level:
    r"""Flag for printing verbose logging statements.
    Applies the given level to logging.getLogger('botorch') calls. For
    instance, when set to logging.INFO, all logger calls of level INFO or
    above will be printed to STDERR
    """
    level: int = LOG_LEVEL_DEFAULT
    @classmethod
    def _set_level(cls, level: int) -> None:
        cls.level = level
        logger.setLevel(level)
    def __init__(self, level: int = LOG_LEVEL_DEFAULT) -> None:
        self.prev = self.__class__.level
        self.level = level
    def __enter__(self) -> None:
        self.__class__._set_level(self.level)
    def __exit__(self, *args) -> None:
        self.__class__._set_level(self.prev)