Source code for rombus.samples

from typing import Any, Dict, List, Optional, Self, TypeAlias

import numpy as np

import rombus._core.mpi as mpi

from rombus.model import RombusModel
from rombus._core import hdf5

DEFAULT_TOLERANCE = 1e-14
DEFAULT_REFINE_N_RANDOM = 100

Sample: TypeAlias = np.ndarray


[docs]class Samples(object): """Class for managing sets of parameter samples for Rombus.""" def __init__( self, model: RombusModel, filename: Optional[str] = None, n_random: int = 0 ): self.model: RombusModel = model """Rombus model for which the samples are computed""" self.n_random: int = n_random """Number of random points generated for this set""" # RNG current and initial state self._random: Optional[np.random._generator.Generator] = None """The numpy random number generator instance used for generating any random numbers in this Sample set""" self._random_starting_state: Optional[Dict[str, Any]] = None """The starting state of the generator""" # Initialise samples self.n_samples: np.int32 = np.int32(0) """Number of samples in this set.""" self.samples: List[Sample] = [] """List of samples in this set.""" if filename: self._add_from_file(filename) if self.n_random > 0: self._add_random_samples(self.n_random)
[docs] @classmethod def from_file(cls, file_in: hdf5.FileOrFilename) -> Self: """Create an instance of a Sample set from a Rombus file on disk. Parameters ---------- file_in : hdf5.FileOrFilename Rombus file (filename or opened file) to read from Returns ------- Self Returns a reference to self so that method calls can be chained """ h5file, close_file = hdf5.ensure_open(file_in) model_str = h5file["samples/model/model_str"].asstr()[()] model = RombusModel.load(model_str) samples = cls(model) samples.samples = [np.array(x) for x in h5file["samples/samples"]] samples.n_samples = np.int32(h5file["samples/n_samples"]) if close_file: h5file.close() return samples
[docs] def extend(self, new_samples: List[Sample]) -> None: """Add additional samples to the set. Parameters ---------- new_samples : List[Sample] A list of new samples """ self.samples.extend(new_samples) self.n_samples = self.n_samples + len(new_samples)
[docs] def write(self, h5file: hdf5.File): """Save samples to an open HDF5 file. Parameters ---------- h5file : hdf5.File An open HDF5 file """ h5_group = h5file.create_group("samples") self.model.write(h5_group) h5_group.create_dataset("samples", data=self.samples) h5_group.create_dataset("n_samples", data=self.n_samples)
def _add_from_file(self, filename_in: str) -> None: """Add samples from file to this set. Accepts Numpy or CSV files. Parameters ---------- filename_in : str Filename of a Numpy or CSV file. """ # dividing greedypoints into chunks if mpi.RANK_IS_MAIN: if filename_in.endswith(".npy"): samples = np.load(filename_in) elif filename_in.endswith(".csv"): samples = [ np.atleast_1d(x) for x in np.genfromtxt(filename_in, delimiter=",", comments="#") ] else: raise Exception else: samples = None new_samples = self._decompose_samples(samples) n_new_samples = len(new_samples) self.samples.extend(new_samples) self.n_samples = self.n_samples + n_new_samples def _add_random_samples(self, n_samples: int) -> None: """Add randomly generated samples to the set. Parameters ---------- n_samples : int Number of random samples to add """ self._random = np.random.default_rng() self._random_starting_state = np.random.get_state() samples = [] for _ in range(n_samples): new_sample = self.model.params.generate_random_sample(self._random) samples.append(new_sample) new_samples = self._decompose_samples(samples) n_new_samples = len(new_samples) self.samples.extend(new_samples) self.n_samples = self.n_samples + n_new_samples def _decompose_samples( self, samples: List[Sample], ) -> List[Sample]: """Split a list of samples accross MPI ranks. Parameters ---------- samples : List[Sample] Set of samples to split Returns ------- List[Sample] Set of samples selected for the local rank """ chunks: List[List[Sample]] = [[]] if mpi.RANK_IS_MAIN: chunks = [[] for _ in range(mpi.SIZE)] for i, chunk in enumerate(samples): chunks[i % mpi.SIZE].append(chunk) return mpi.COMM.scatter(chunks, root=mpi.MAIN_RANK)