Source code for pyabc.util.read_sample

"""Read sample to array."""

from typing import Tuple

import numpy as np

from ..population import Sample
from .par_trafo import ParTrafoBase


def _only_finites(*args):
    """Remove samples (rows) where any entry is not finite.

    Parameters
    ----------
    A collection of np.ndarray objects, each of shape (n_sample, n_x) or
    (n_sample,).

    Returns
    -------
    The objects excluding rows where any entry in any object is not finite.
    """
    # create array of rows to keep
    keep = np.ones((args[0].shape[0],), dtype=bool)
    # check each argument whether a row has non-finite entries
    for arg in args:
        if arg.ndim == 1:
            keep = np.logical_and(keep, np.isfinite(arg))
        else:
            keep = np.logical_and(keep, np.all(np.isfinite(arg), axis=1))

    # reduce arrays
    args = [arg[keep] for arg in args]

    return args


[docs] def read_sample( sample: Sample, sumstat, all_particles: bool, par_trafo: ParTrafoBase, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Read in sample. Parameters ---------- sample: Calibration or last generation's sample. sumstat: Up-chain summary statistic, already fitted. all_particles: Whether to use all particles or only accepted ones. par_trafo: Parameter transformation to apply. Returns ------- sumstats, parameters, weights: Arrays of shape (n_sample, n_out). """ if all_particles: particles = sample.all_particles else: particles = sample.accepted_particles # dimensions of sample, summary statistics, and parameters n_sample = len(particles) n_sumstat = len(sumstat(particles[0].sum_stat).flatten()) n_par = len(par_trafo(particles[0].parameter)) # prepare matrices sumstats = np.empty((n_sample, n_sumstat)) parameters = np.empty((n_sample, n_par)) weights = np.empty((n_sample, 1)) # fill by iteration over all particles for i_particle, particle in enumerate(particles): sumstats[i_particle, :] = sumstat(particle.sum_stat).flatten() parameters[i_particle, :] = par_trafo(particle.parameter) weights[i_particle] = particle.weight # remove samples where an entry is not finite sumstats, parameters, weights = _only_finites( sumstats, parameters, weights, ) return sumstats, parameters, weights