from typing import Callable, List, NamedTuple, Optional
import numpy as np
import rombus.exceptions as exceptions
from rombus._core.log import log
MAX_N_TRIES = 1000
[docs]class Params(object):
def __init__(self) -> None:
self.params: List[NamedTuple] = []
self.names: List[str] = []
self.count: int = 0
self._idx: int = -1
self.params_dtype: Optional[NamedTuple] = None
self._validate: Callable = lambda param: True
[docs] def add(
self, name: str, range_min: float, range_max: float, dtype: type = float
) -> None:
"""Add a parameter to this parameter set.
Parameters
----------
name : str
Name of the parameter. This will become the reference name when accessing NamedTuple & dict versions of the parameters
range_min : float
The minimum-allowed value for this parameter
range_max : float
The maximum-allowed value for this parameter
dtype : type
The datatype used to represent this parameter
"""
with log.context(f"Adding parameter ({name})"):
# N.B.: mypy struggles with NamedTuples, so typing is turned off for the following
Param = NamedTuple(f"Param{len(self.params)}", [("name", str), ("min", dtype), ("max", dtype), ("dtype", dtype)]) # type: ignore
self.params.append(Param(name, dtype(range_min), dtype(range_max), dtype)) # type: ignore
self.count = self.count + 1
self.names.append(name)
# Update the named tuple that will be used to convert numpy arrays to Param objects
# N.B.: mypy struggles with NamedTuples, so typing is turned off for the following
self.params_dtype = NamedTuple("params_dtype", [(name, dtype) for name in self.names]) # type: ignore
[docs] def set_validation(self, func: Callable[[NamedTuple], bool]) -> None:
"""Specify a callable which accepts a parameter set (in the form of a NamedTuple) and
returns a bool indicating if the parameter set is valid.
By default, no validation is performed and all passed samples are assumed to be valid.
Parameters
----------
func : Callable
Callable which will perform the validation
"""
self._validate = func
[docs] def generate_random_sample(
self, random_generator: np.random._generator.Generator
) -> np.ndarray:
"""Generate a random parameter set, with values uniformly distributed between their specified min and max values.
Parameters
----------
random_generator : np.random._generator.Generator
An initialised numpy random number generator
Returns
-------
np.ndarray
A parameter set in numpy form
"""
new_sample: np.ndarray = np.ndarray(self.count, dtype=np.float64)
n_tries = 0
try:
while True:
for i, param in enumerate(self.params):
# N.B.: mypy struggles with NamedTuples, so typing is turned off for this next line
new_sample[i] = random_generator.uniform(low=param.min, high=param.max) # type: ignore
param = self.np2param(new_sample)
if self._validate(param):
break
else:
n_tries = n_tries + 1
if n_tries >= MAX_N_TRIES:
raise exceptions.RombusModelParamsError(
f"Max number of tries ({MAX_N_TRIES}) reached when trying to generate a valid random parameter set"
)
except exceptions.RombusException as e:
log.handle_exception(e)
return new_sample
[docs] def np2param(self, params_np: np.ndarray) -> NamedTuple:
"""Convert a numpy version of a parameter set to a NamedTuple format.
Parameters
----------
params_np : np.ndarray
Numpy version of parameter set
Returns
-------
NamedTuple
NamedTuple version of parameter set
"""
if not self.params_dtype:
return NamedTuple("params_empty", [])
else:
# N.B.: mypy struggles with NamedTuples, so typing is turned off for this next line
return self.params_dtype(**dict(zip(self.names, np.atleast_1d(params_np)))) # type: ignore
def __iter__(self):
self._idx = -1
return self
def __next__(self):
self._idx = self._idx + 1
if self._idx >= self.count:
raise StopIteration
return self.params[self._idx]
def __len__(self):
return self.count