Source code for rombus.model

import importlib
import sys
import os
import shutil
import timeit
from abc import ABCMeta, abstractmethod
from collections import Counter
from typing import Any, Dict, Self, Optional, Tuple, TYPE_CHECKING

import numpy as np

import rombus._core.hdf5 as hdf5
import rombus.exceptions as exceptions
from rombus.params import Params
from rombus._core.log import log
from typing import NamedTuple

import warnings

warnings.simplefilter("ignore", np.ComplexWarning)

# Need to put Samples in quotes below and check TYPE_CHECKING here to
# manage circular imports with models.py
if TYPE_CHECKING:
    from rombus.samples import Samples


class _Ordinate(object):
    def __init__(self):
        self.name = None
        self.dtype = float
        self.label = None

    def set(
        self, name: Optional[str], dtype: type | np.dtype = float, label: str = ""
    ) -> None:
        """Set the details of the model's ordinate.

        Parameters
        ----------
        name : str
            Name of ordinate
        dtype : type|np.dtype
            Datatype used to represent ordinate values
        label : str
            A n optional label to use for the ordinate in plots, etc.
        """
        self.name = name
        self.dtype = dtype
        if label == "":
            self.label = self.name
        else:
            self.label = label


class _Coordinate(object):
    def __init__(self):
        self.name = None
        self.dtype = float
        self.min = None
        self.max = None
        self.label = None
        self._values = None
        self._n_values = 0

    def set(
        self,
        name: str,
        min: Any,
        max: Any,
        n_values: int,
        dtype: type | np.dtype = float,
        label: str = "",
    ) -> None:
        self.name = name
        self.dtype = dtype
        self.min = min
        self.max = max
        if label == "":
            self.label = self.name
        else:
            self.label = label
        if self.dtype != type(min):
            raise exceptions.RombusModelCoordinateError(
                f"Coordinate datatype ({self.dtype}) and min-value datatype ({type(self.min)}) don't match."
            )
        if self.dtype != type(max):
            raise exceptions.RombusModelCoordinateError(
                f"Coordinate datatype ({self.dtype}) and max-value datatype ({type(self.max)}) don't match."
            )
        self._n_values = n_values
        self._values = np.linspace(self.min, self.max, self._n_values, self.dtype)

    def get(self):
        return self._values

    def __len__(self):
        return self._n_values


class _RombusModelMeta(type):
    def __prepare__(name, *args, **kwargs):
        """Initialise the dictionary that gets passed to __new___.

        This is needed here because we don't want the user to have to
        initialise the member(s) that we are adding.  This is the only
        method that gets sourced before the class code is executed, so
        it needs to be done here, not in __new__.
        """

        result = dict()

        # Initialise the following members↵
        result["coordinate"] = _Coordinate()
        result["ordinate"] = _Ordinate()
        result["params"] = Params()

        return result

    def __new__(mcs, name, bases, dct):

        # Perform super-metaclass construction↵
        return super(_RombusModelMeta, mcs).__new__(mcs, name, bases, dct)


class _RombusModelABCMeta(_RombusModelMeta, ABCMeta):
    """Turn _RombusModelMeta into an abstract base class"""

    pass


[docs]class RombusModel(metaclass=_RombusModelABCMeta): """Baseclass from which all RombusModels must inherit.""" # These members are instantiated by the metaclass coordinate: _Coordinate """The domain on which this model is defined.""" ordinate: _Ordinate """The ordinate that this model maps its domain to.""" params: Params """The parameters defined for this model""" def __init__(self, model: str): # Keep track of the model string so we can reinstantiate from a saved state self.model_str = model self.basename = self.model_str.split(":")[0].split(".")[-1] # Initialise the domain self.domain = self.coordinate.get() self.n_domain = len(self.coordinate) # Check that the domain has been suitably set assert self.n_domain > 0 # Check that at least one parameter has beed defined if self.params.count <= 0: raise exceptions.RombusModelParamsError( f"Invalid number of parameters ({self.params.count}) specified for Rombus model ({self})." ) def __str__(self): return f"<RombusModel from {self.model_str}>"
[docs] @abstractmethod # make sure this is the inner-most decorator def compute(self, params: NamedTuple, domain: np.ndarray) -> np.ndarray: """Abstract method which computes the user's model. This method does all the work of computing the user's model. It takes a parameter set as a named tuple with N elements given by the names given to the N calls made to params.add() as well as the array set by coordinate.set() and returns a numpy array. Parameters ---------- params : NamedTuple The parameters to be used when computing the model domain : np.ndarray The domain on which the model is to be computed Returns ------- np.ndarray The user's model, computed for the given parameter set and domain """ pass
[docs] @classmethod @log.callable("Instantiating model from file") def from_file(cls, file_in: hdf5.FileOrFilename) -> Self: """Generate a RombusModel instance from a Rombus HDF5 file. Parameters ---------- file_in : hdf5.FileOrFilename The Rombus HDF5 file to read from. Returns ------- Self The generated RombusModel instance. """ try: h5file, close_file = hdf5.ensure_open(file_in) model_str = h5file["model/model_str"].asstr()[()] if close_file: h5file.close() except IOError as e: log.handle_exception(e) return cls.load(model_str)
[docs] @classmethod @log.callable("Loading model from file") def load(cls, model: str | Self) -> Self: """Ensure that a model has been imported for use by Rombus. Parameters ---------- model : str | Self A string of format 'sub,module.name:ClassName' or a RombusModel instance (trivially returned in the later case). Returns ------- Self A RombusModel instance """ if isinstance(model, str): try: model_class = _import_from_string(model) except exceptions.RombusException as e: log.handle_exception(e) else: return model_class(model) elif not isinstance(model, RombusModel): raise exceptions.RombusModelInitError( "Invalid type ({type(model)}) specified when loading model {model}." ) return model # type: ignore
[docs] @log.callable("Writing model to file") def write(self, h5file: hdf5.File) -> None: """Write a RombusModel to a Rombus HDF5 file. Parameters ---------- h5file : hdf5.File An open HDF5 file """ try: h5_group = h5file.create_group("model") h5_group.create_dataset("model_str", data=self.model_str) except IOError as e: log.handle_exception(e)
# Need to put Samples in quotes and check TYPE_CHECKING above to manage circular import with models.py
[docs] def generate_model_set(self, samples: "Samples") -> np.ndarray: """Generate a set of models for a given set of parameter samples. Parameters ---------- samples : "Samples" A set of Samples Returns ------- np.ndarray An array of model results: 1 for each given sample. """ my_ts: np.ndarray = np.zeros( shape=(samples.n_samples, self.n_domain), dtype=self.ordinate.dtype ) with log.progress("Generating training set", samples.n_samples) as progress: for i, params_numpy in enumerate(samples.samples): model_i = self.compute(self.params.np2param(params_numpy), self.domain) my_ts[i] = model_i / np.sqrt(np.vdot(model_i, model_i).real) progress.update(i) return my_ts
[docs] def parse_cli_params(self, args: Tuple[str, ...]) -> Dict[str, Any]: """Parse parameters given as a tuple of form 'param0=val0', 'param1=val1', ... etc to a dictionary of form {'param0':val0, 'param1':val1, ... }. Generally used to parse the optional arguments recieved from Click into a format that can be converted into a Params or Numpy object Parameters ---------- args : Tuple[str, ...] Given tuple of parameters Returns ------- Dict[str, Any] Resulting dict of parameters """ model_params = dict() for param_i in args: if not param_i.startswith("-"): res = param_i.split("=") if len(res) == 2: # NOTE: for now, all parameters are assumed to be floats model_params[res[0]] = float(res[1]) else: raise Exception(f"Don't know what to do with argument '{param_i}'") else: raise Exception(f"Don't know what to do with option '{param_i}'") # Check that all parameters are specified and that they match what is # defined in the model assert Counter(model_params.keys()) == Counter(self.params.names) return model_params
[docs] def sample(self, kwargs: Dict[str, Any]) -> NamedTuple: """Create a Sample from a dictionary of the form {"param0:val0,"param1":val1,...} Parameters ---------- kwargs : Dict[str, Any] Dict specifying parameter values Returns ------- NamedTuple Named tuple specifying paramter values """ return self.params.params_dtype(**kwargs) # type: ignore
[docs] def timing(self, samples: "Samples") -> float: """Generate timing information for the original source model. Particularly useful when compared to similar timing information computed for ROMs derived from it. Parameters ---------- samples : "Samples" A set of parameters to generate timing information for. Should be the same as those used when timiing a ROM, if comparisons are to be made. Returns ------- float Seconds elapsed """ with log.context( f"Computing timing information for model using {samples.n_samples} samples", time_elapsed=False, ): start_time = timeit.default_timer() for i, sample in enumerate(samples.samples): params_numpy = self.params.np2param(sample) _ = self.compute(params_numpy, self.domain) return timeit.default_timer() - start_time
[docs] @classmethod @log.callable("Writing project template") def write_project_template(cls, project_name: str) -> None: """Write a project model to the current working directory to start a new project from. Two files are written to the current working direcory: a Python file and a set of samples. These can then be modified to suit the needs of the user. Parameters ---------- project_name : str Base name to use for the project. """ # Set the model we will template from model_name = "sinc" # Set source file paths pkgdir = sys.modules["rombus"].__path__[0] model_file_source = os.path.join(pkgdir, "models", f"{model_name}.py") samples_file_source = os.path.join( pkgdir, "models", f"{model_name}_samples.csv" ) # Set output file paths model_file_out = os.path.join(os.getcwd(), f"{project_name}.py") samples_file_out = os.path.join(os.getcwd(), f"{project_name}_samples.csv") # Copy files with log.context("Writing files"): shutil.copy(model_file_source, model_file_out) log.comment(f"Written: {os.path.split(model_file_out)[1]}") shutil.copy(samples_file_source, samples_file_out) log.comment(f"Written: {os.path.split(samples_file_out)[1]}")
RombusModelType = RombusModel | str # The code that follows is modified from code copied from the Uvicorn codebase: # https://github.com/encode/uvicorn (commit: d613cbea388bafafb6f642077c035ed137deea61) # Copyright © 2017-present, [Encode OSS Ltd](https://www.encode.io/). # All rights reserved. @log.callable("Importing model") def _import_from_string(import_str: str) -> Any: """Import a RombusModel class from a given string of the form 'python.module.name:ClassName'. Generally, the user model will be defined in a file in the current working directory with filename (for example) of 'model_name.py', with a class inheriting from RombusModel with name 'ClassName'. It should then be referred to in this context as 'my_model:ClassName'. More generally, the model can be anywhere in the user's PYTHONPATH. Parameters ---------- import_str : str Given string of the form 'python.module.name:ClassName' Returns ------- Any An instance of the user-defined model class """ log.append(f"({import_str})...") if not isinstance(import_str, str): raise exceptions.RombusModelImportFromStringError( f'Import string must be a string with format "<module>:<attribute>". It is actually of type {type(import_str)}.' ) # Make sure the CWD is in the import path sys.path.append(os.getcwd()) sys.path = list(dict.fromkeys(sys.path)) # Split the string module_str, _, attrs_str = import_str.partition(":") if not module_str or not attrs_str: raise exceptions.RombusModelImportFromStringError( f'Import string "{import_str}" must be in format "<module>:<attribute>".' ) # Try to import the module try: module = importlib.import_module(module_str) except ImportError as exc: if exc.name != module_str: raise exc from None raise exceptions.RombusModelImportFromStringError( f'Could not import module "{module_str}".\n' ) instance = module # Try to grab the specified class try: for attr_str in attrs_str.split("."): instance = getattr(instance, attr_str) except AttributeError: raise exceptions.RombusModelImportFromStringError( f'Attribute "{attrs_str}" not found in module "{module_str}".' ) return instance