Source code for pyabc.util.par_trafo

"""Parameter transformations."""

from abc import ABC, abstractmethod
from typing import Callable, List, Union

import numpy as np

from .dict2arr import dict2arr


[docs] class ParTrafoBase(ABC): """Parameter transformations to use as regression targets. It may be useful to use as regression targets not simply the original parameters theta, but transformations thereof, such as moments theta**2. In particular, this can help overcome non-identifiabilities. """
[docs] def initialize(self, keys: List[str]): """Initialize. Called once per analysis."""
[docs] @abstractmethod def __call__(self, par_dict: dict) -> np.ndarray: """Transform parameters from input dict."""
@abstractmethod def __len__(self): """Length of expected parameter transformation."""
[docs] @abstractmethod def get_ids(self) -> List[str]: """Identifiers for the parameter transformations."""
[docs] class ParTrafo(ParTrafoBase): """Simple parameter transformation that accepts a list of transformations. The implementation assumes that each transformation maps n_par -> n_par. Parameters ---------- trafos: Transformations to apply. Defaults to a single identity mapping. """
[docs] def __init__( self, trafos: List[Callable[[np.ndarray], np.ndarray]] = None, trafo_ids: Union[str, List[str]] = "{par_id}_{trafo_ix}", ): self.trafos = trafos if not isinstance(trafo_ids, str) and len(trafos) != len(trafo_ids): raise AssertionError("Lengths of trafos and trafo_ids must match") self.trafo_ids = trafo_ids # to maintain key order self.keys: Union[List[str], None] = None
[docs] def initialize(self, keys: List[str]): # remember key order self.keys = keys
[docs] def __call__(self, par_dict: dict) -> np.ndarray: # remember key order in case it was not set yet # (fallback, uniqueness not guaranteed) if self.keys is None: self.keys = list(par_dict.keys()) # to array out = dict2arr(par_dict, keys=self.keys) # apply transformations if self.trafos is not None: out = np.concatenate( [trafo(out).flatten() for trafo in self.trafos], ) return out
def __len__(self): if self.trafos is None: return len(self.keys) return len(self.keys) * len(self.trafos)
[docs] def get_ids(self) -> List[str]: """ Calculate keys as: {par_id_1}_{trafo_1}, ..., {par_id_n}_{trafo_1}, ..., {par_id_1}_{trafo_m}, ..., {par_id_n}_{trafo_m} """ par_ids = [f"{key}" for key in self.keys] if self.trafos is None: return par_ids ids = [ self.trafo_ids.format(par_id=par_id, trafo_ix=trafo_ix) if isinstance(self.trafo_ids, str) else self.trafo_ids[trafo_ix].format(par_id=par_id) for trafo_ix in range(len(self.trafos)) for par_id in par_ids ] return ids