Source code for ddg.nonexact

"""Module for global tolerance defaults and nonexact calculations.

Contains ways to set, get and reset the global tolerances and functions that use the
global values as defaults.
"""
import numpy as np

__all__ = [
    "get_tol_defaults",
    "set_tol_defaults",
    "reset_tol_defaults",
    "TolDefaults",
    "isclose",
    "allclose",
    "svd_rank",
    "combine_tols",
]

###########################################################
# Getting, setting and resetting global tolerance defaults
###########################################################

# We could expose these directly to the user, but setting module-level variables feels
# like something you're not usually supposed to. The indirection might make users feel
# more comfortable using it.
# The API and implementation are also inspired by numpy.set_printoptions.
_RTOL_DEFAULT_ORIGINAL = 0.0
_ATOL_DEFAULT_ORIGINAL = 1e-7
_rtol_default = _RTOL_DEFAULT_ORIGINAL
_atol_default = _ATOL_DEFAULT_ORIGINAL


[docs]def get_tol_defaults(): """Get current global tolerance defaults. Returns ------- dict {'atol': float, 'rtol': float} """ return {"atol": _atol_default, "rtol": _rtol_default}
[docs]def set_tol_defaults(atol=None, rtol=None): """Set one or both global tolerance defaults. Parameters ---------- atol, rtol : float (default=None) When None is given, the default will not be changed. """ if atol is not None: global _atol_default _atol_default = atol if rtol is not None: global _rtol_default _rtol_default = rtol
[docs]def reset_tol_defaults(): """Reset global tolerance defaults to their original values.""" set_tol_defaults(atol=_ATOL_DEFAULT_ORIGINAL, rtol=_RTOL_DEFAULT_ORIGINAL)
[docs]class TolDefaults: """Context manager to temporarily set global tolerances. All alterations made to the tolerance defaults inside the context will be reset upon exiting it. Parameters ---------- atol, rtol : float (default=None) Examples -------- >>> from ddg.nonexact import TolDefaults, get_tol_defaults >>> get_tol_defaults()["atol"] 1e-07 >>> with TolDefaults(atol=1e-5) as global_tols: ... get_tol_defaults()["atol"] == global_tols["atol"] == 1e-5 ... True >>> get_tol_defaults()["atol"] # Context has been exited 1e-07 `global_tols` is the same as `get_tol_defaults` called right after entering the context. """ def __init__(self, atol=None, rtol=None): self._atol_backup = get_tol_defaults()["atol"] self._rtol_backup = get_tol_defaults()["rtol"] set_tol_defaults(atol=atol, rtol=rtol) def __enter__(self): return get_tol_defaults() def __exit__(self, exc_type, exc_val, exc_tb): set_tol_defaults(atol=self._atol_backup, rtol=self._rtol_backup) # Exceptions that occur in this context should not be suppressed return False
######################################################## # Functions/wrappers that use global tolerance defaults ########################################################
[docs]def isclose(a, b, rtol=None, atol=None, equal_nan=False): """Wrapper for numpy.isclose that uses global tolerance defaults.""" rtol_ = get_tol_defaults()["rtol"] if rtol is None else rtol atol_ = get_tol_defaults()["atol"] if atol is None else atol return np.isclose(a, b, rtol=rtol_, atol=atol_, equal_nan=equal_nan)
[docs]def allclose(a, b, rtol=None, atol=None, equal_nan=False): """Wrapper for numpy.allclose that uses global tolerance defaults.""" rtol_ = get_tol_defaults()["rtol"] if rtol is None else rtol atol_ = get_tol_defaults()["atol"] if atol is None else atol return np.allclose(a, b, rtol=rtol_, atol=atol_, equal_nan=equal_nan)
[docs]def svd_rank(s, rtol=None, atol=None): """Approximates the rank from singular values using global tolerance defaults. Returns the number of singular values that are greater than or equal to ``max(atol, rtol * max(s))``. Parameters ---------- s : numpy.ndarray of shape (n,) rtol, atol : float Returns ------- int """ rtol_ = get_tol_defaults()["rtol"] if rtol is None else rtol atol_ = get_tol_defaults()["atol"] if atol is None else atol tol = max(atol_, rtol_ * max(s)) return int((s >= tol).sum())
[docs]def combine_tols(*objects): """Combine tolerances of objects by taking the maximum. If none of the objects overrides the tolerance, returns the global default. Parameters ---------- *objects Objects that might have attributes atol, rtol containing float or None. Returns ------- atol, rtol : float """ atols = [o.atol for o in objects if hasattr(o, "atol") and o.atol is not None] rtols = [o.rtol for o in objects if hasattr(o, "rtol") and o.rtol is not None] return ( max(atols, default=get_tol_defaults()["atol"]), max(rtols, default=get_tol_defaults()["rtol"]), )