Source code for pyabc.inference_util.inference_util

# Note: Due to cyclic imports, these need to be separated from other modules

import logging
import uuid
from datetime import datetime, timedelta
from typing import Callable, List

import numpy as np
import pandas as pd

from ..acceptor import Acceptor
from ..distance import Distance
from ..epsilon import Epsilon
from ..model import Model
from ..parameters import Parameter
from ..population import Particle
from ..random_choice import fast_random_choice
from ..random_variables import RV, Distribution
from ..storage.history import History
from ..transition import ModelPerturbationKernel, Transition

logger = logging.getLogger("ABC")


[docs] class AnalysisVars: """Contract object class for passing analysis variables. Used e.g. to create new sampling tasks or check early stopping. """
[docs] def __init__( self, model_prior: RV, parameter_priors: List[Distribution], model_perturbation_kernel: ModelPerturbationKernel, transitions: List[Transition], models: List[Model], summary_statistics: Callable, x_0: dict, distance_function: Distance, eps: Epsilon, acceptor: Acceptor, min_acceptance_rate: float, min_eps: float, stop_if_single_model_alive: bool, max_t: int, max_total_nr_simulations: int, prev_total_nr_simulations: int, max_walltime: timedelta, init_walltime: datetime, min_eps_diff: float, prev_eps: float, ): self.model_prior = model_prior self.parameter_priors = parameter_priors self.model_perturbation_kernel = model_perturbation_kernel self.transitions = transitions self.models = models self.summary_statistics = summary_statistics self.x_0 = x_0 self.distance_function = distance_function self.eps = eps self.acceptor = acceptor self.min_acceptance_rate = min_acceptance_rate self.min_eps = min_eps self.stop_if_single_model_alive = stop_if_single_model_alive self.max_t = max_t self.max_total_nr_simulations = max_total_nr_simulations self.prev_total_nr_simulations = prev_total_nr_simulations self.max_walltime = max_walltime self.init_walltime = init_walltime self.min_eps_diff = min_eps_diff self.prev_eps = prev_eps
[docs] def create_simulate_from_prior_function( model_prior: RV, parameter_priors: List[Distribution], models: List[Model], summary_statistics: Callable, ) -> Callable: """Create a function that simulates from the prior. Similar to _create_simulate_function, apart here we sample from the prior and accept all. Parameters ---------- model_prior: The model prior. parameter_priors: The parameter priors. models: List of all models. summary_statistics: Computes summary statistics from model output. Returns ------- simulate_one: A function that returns a sampled particle. """ # simulation function, simplifying some parts compared to later def simulate_one(): # sample model m = int(model_prior.rvs()) # sample parameter theta = parameter_priors[m].rvs() # simulate summary statistics model_result = models[m].summary_statistics( 0, theta, summary_statistics ) # sampled from prior, so all have uniform weight weight = 1.0 # distance will be computed after initialization of the # distance function distance = np.inf # all are happy and accepted accepted = True return Particle( m=m, parameter=theta, weight=weight, sum_stat=model_result.sum_stat, distance=distance, accepted=accepted, proposal_id=0, preliminary=False, ) return simulate_one
[docs] def generate_valid_proposal( t: int, m: np.ndarray, p: np.ndarray, model_prior: RV, parameter_priors: List[Distribution], model_perturbation_kernel: ModelPerturbationKernel, transitions: List[Transition], ): """Sample a parameter for a model. Parameters ---------- t: Population index to generate for. m: Indices of alive models. p: Probabilities of alive models. model_prior: The model prior. parameter_priors: The parameter priors. model_perturbation_kernel: The model perturbation kernel. transitions: The transitions, one per model. Returns ------- (m_ss, theta_ss): Model, parameter. """ # first generation if t == 0: # sample from prior m_ss = int(model_prior.rvs()) theta_ss = parameter_priors[m_ss].rvs() return m_ss, theta_ss # later generation # counter n_sample, n_sample_soft_limit = 0, 1000 # sample until the prior density is positive while True: if len(m) > 1: index = fast_random_choice(p) m_s = m[index] m_ss = model_perturbation_kernel.rvs(m_s) # theta_s is None if the population m_ss has died out. # This can happen since the model_perturbation_kernel # can return a model nr which has died out. if m_ss not in m: continue else: # only one model m_ss = m[0] theta_ss = transitions[m_ss].rvs() # check if positive under prior if model_prior.pmf(m_ss) * parameter_priors[m_ss].pdf(theta_ss) > 0: return m_ss, theta_ss # unhealthy sampling detection n_sample += 1 if n_sample == n_sample_soft_limit: logger.warning( "Unusually many (model, parameter) samples have prior " "density zero. The transition might be inappropriate." )
[docs] def evaluate_proposal( m_ss: int, theta_ss: Parameter, t: int, models: List[Model], summary_statistics: Callable, distance_function: Distance, eps: Epsilon, acceptor: Acceptor, x_0: dict, weight_function: Callable, proposal_id: int, ) -> Particle: """Evaluate a proposed parameter. Parameters ---------- m_ss, theta_ss: The proposed (model, parameter) sample. t: The current time. models: List of all models. summary_statistics: Function to compute summary statistics from model output. distance_function: The distance function. eps: The epsilon threshold. acceptor: The acceptor. x_0: The observed summary statistics. weight_function: Function by which to reweight the sample. proposal_id: Id of the transition kernel. Returns ------- particle: A particle containing all information. Data for the given parameters theta_ss are simulated, summary statistics computed and evaluated. """ # simulate, compute distance, check acceptance model_result = models[m_ss].accept( t, theta_ss, summary_statistics, distance_function, eps, acceptor, x_0 ) # compute acceptance weight if model_result.accepted: weight = weight_function(m_ss, theta_ss, model_result.weight) else: weight = 0 return Particle( m=m_ss, parameter=theta_ss, weight=weight, sum_stat=model_result.sum_stat, distance=model_result.distance, accepted=model_result.accepted, preliminary=False, proposal_id=proposal_id, )
[docs] def create_prior_pdf( model_prior: RV, parameter_priors: List[Distribution] ) -> Callable: """Create a function that calculates a sample's prior density. Parameters ---------- model_prior: The model prior. parameter_priors: The parameter priors, one for each model. Returns ------- prior_pdf: The prior density function. """ def prior_pdf(m_ss, theta_ss): prior_pd = model_prior.pmf(m_ss) * parameter_priors[m_ss].pdf(theta_ss) return prior_pd return prior_pdf
[docs] def create_transition_pdf( transitions: List[Transition], model_probabilities: pd.DataFrame, model_perturbation_kernel: ModelPerturbationKernel, ) -> Callable: """Create the transition probability density function for time `t`. Parameters ---------- transitions: The list of parameter transition functions. model_probabilities: The last generation's model probabilities. model_perturbation_kernel: The kernel perturbing the models. Returns ------- transition_pdf: The transition density function. """ def transition_pdf(m_ss, theta_ss): model_factor = sum( row.p * model_perturbation_kernel.pmf(m_ss, m) for m, row in model_probabilities.iterrows() ) particle_factor = transitions[m_ss].pdf(theta_ss) transition_pd = model_factor * particle_factor if transition_pd == 0: logger.debug("Transition density is zero!") return transition_pd return transition_pdf
[docs] def create_weight_function( prior_pdf: Callable, transition_pdf: Callable, ) -> Callable: """Create a function that calculates a sample's importance weight. The weight is the prior divided by the transition density and the acceptance step weight. Parameters ---------- prior_pdf: The prior density. transition_pdf: The transition density. Returns ------- weight_function: The importance sample weight function. """ def weight_function(m_ss, theta_ss, acceptance_weight: float): """Calculate total weight, from sampling and acceptance weight. Parameters ---------- m_ss: The model sample. theta_ss: The parameter sample. acceptance_weight: The acceptance weight sample. In most cases 1. Returns ------- weight: The total weight. """ # prior and transition density (can be equal) prior_pd = prior_pdf(m_ss, theta_ss) transition_pd = transition_pdf(m_ss, theta_ss) # calculate weight weight = acceptance_weight * prior_pd / transition_pd return weight return weight_function
[docs] def create_simulate_function( t: int, model_probabilities: pd.DataFrame, model_perturbation_kernel: ModelPerturbationKernel, transitions: List[Transition], model_prior: RV, parameter_priors: List[Distribution], models: List[Model], summary_statistics: Callable, x_0: dict, distance_function: Distance, eps: Epsilon, acceptor: Acceptor, evaluate: bool = True, proposal_id: int = 0, ) -> Callable: """ Create a simulation function which performs the sampling of parameters, simulation of data and acceptance checking, and which is then passed to the sampler. Parameters ---------- t: The time index to simulate for. model_probabilities: The last generation's model probabilities. model_perturbation_kernel: The model perturbation kernel. transitions: The parameter transition kernels. model_prior: The model prior. parameter_priors: The parameter priors. models: List of all models. summary_statistics: Function to compute summary statistics from model output. x_0: The observed summary statistics. distance_function: The distance function. eps: The epsilon threshold. acceptor: The acceptor. evaluate: Whether to actually evaluate the sample. Should be True except for certain preliminary settings. proposal_id: Identifier for the proposal distribution. Returns ------- simulate_one: callable Function that samples parameters, simulates data, and checks acceptance. .. note:: For some of the samplers, the sampling function needs to be serialized in order to be transported to where the sampling happens. Therefore, the returned function should be light, and in particular not contain references to the ABCSMC class. """ # cache model_probabilities to not query the database so often m = np.array(model_probabilities.index) p = np.array(model_probabilities.p) # create prior and transition densities for weight function prior_pdf = create_prior_pdf( model_prior=model_prior, parameter_priors=parameter_priors ) if t == 0: transition_pdf = prior_pdf else: transition_pdf = create_transition_pdf( transitions=transitions, model_probabilities=model_probabilities, model_perturbation_kernel=model_perturbation_kernel, ) # create weight function weight_function = create_weight_function( prior_pdf=prior_pdf, transition_pdf=transition_pdf ) # simulation function def simulate_one(): parameter = generate_valid_proposal( t=t, m=m, p=p, model_prior=model_prior, parameter_priors=parameter_priors, model_perturbation_kernel=model_perturbation_kernel, transitions=transitions, ) if evaluate: particle = evaluate_proposal( *parameter, t=t, models=models, summary_statistics=summary_statistics, distance_function=distance_function, eps=eps, acceptor=acceptor, x_0=x_0, weight_function=weight_function, proposal_id=proposal_id, ) else: particle = only_simulate_data_for_proposal( *parameter, t=t, models=models, summary_statistics=summary_statistics, weight_function=weight_function, proposal_id=proposal_id, ) return particle return simulate_one
[docs] def only_simulate_data_for_proposal( m_ss: int, theta_ss: Parameter, t: int, models: List[Model], summary_statistics: Callable, weight_function: Callable, proposal_id: int, ) -> Particle: """Simulate data for parameters. Similar to `evaluate_proposal`, however here for the passed parameters only data are simulated, but no distances calculated or acceptance checked. That needs to be done post-hoc then, not checked here.""" # simulate model_result = models[m_ss].summary_statistics( t, theta_ss, summary_statistics ) # dummies for distance and weight, need to be recomputed later distance = np.inf acceptance_weight = 1.0 # needs to be accepted in order to be forwarded by the sampler, and so # as a single particle accepted = True # compute weight weight = weight_function(m_ss, theta_ss, acceptance_weight) return Particle( m=m_ss, parameter=theta_ss, weight=weight, sum_stat=model_result.sum_stat, distance=distance, accepted=accepted, preliminary=True, proposal_id=proposal_id, )
[docs] def evaluate_preliminary_particle( particle: Particle, t, ana_vars: AnalysisVars ) -> Particle: """Evaluate a preliminary particle. I.e. compute distance and check acceptance. Returns ------- evaluated_particle: The evaluated particle """ if not particle.preliminary: raise AssertionError("Particle is not preliminary") acc_res = ana_vars.acceptor( distance_function=ana_vars.distance_function, eps=ana_vars.eps, x=particle.sum_stat, x_0=ana_vars.x_0, t=t, par=particle.parameter, ) # reconstruct weighting function from `weight_function` sampling_weight = particle.weight # the weight is the sampling weight times the acceptance weight(s) if acc_res.accept: weight = sampling_weight * acc_res.weight else: weight = 0.0 # return the evaluated particle return Particle( m=particle.m, parameter=particle.parameter, weight=weight, sum_stat=particle.sum_stat, distance=acc_res.distance, accepted=acc_res.accept, preliminary=False, proposal_id=particle.proposal_id, )
[docs] def termination_criteria_fulfilled( current_eps: float, min_eps: float, prev_eps: float, min_eps_diff: float, stop_if_single_model_alive: bool, nr_of_models_alive: int, acceptance_rate: float, min_acceptance_rate: float, total_nr_simulations: int, max_total_nr_simulations: int, walltime: timedelta, max_walltime: timedelta, t: int, max_t: int, ) -> bool: """Check termination criteria. Parameters ---------- current_eps: The last generation's epsilon value. min_eps: The minimum allowed epsilon value. stop_if_single_model_alive: Whether to stop with a single model left. nr_of_models_alive: The number of models alive in the last generation. acceptance_rate: The last generation's acceptance rate. min_acceptance_rate: The minimum acceptance rate. total_nr_simulations: The total number of simulations so far. max_total_nr_simulations: Bound on the total number of simulations. walltime: Walltime passed since start of the analysis. max_walltime: Maximum allowed walltime. t: The last generation's time index. max_t: The maximum allowed time index. Returns ------- True if any criterion is met, otherwise False. """ if t >= max_t: logger.info("Stop: Maximum number of generations.") return True if current_eps <= min_eps: logger.info("Stop: Minimum epsilon.") return True if prev_eps is not None and abs(current_eps - prev_eps) < min_eps_diff: logger.info("Stop: Minimum epsilon difference") return True elif stop_if_single_model_alive and nr_of_models_alive <= 1: logger.info("Stop: Single model alive.") return True elif acceptance_rate < min_acceptance_rate: logger.info("Stop: Minimum acceptance rate.") return True elif total_nr_simulations >= max_total_nr_simulations: logger.info("Stop: Total simulations budget.") return True elif max_walltime is not None and walltime > max_walltime: logger.info("Stop: Maximum walltime.") return True return False
[docs] def create_analysis_id(): """Create a universally unique id for a given analysis. Used by the inference routine to uniquely associated results with analyses. """ return str(uuid.uuid4())
[docs] def eps_from_hist(history: History, t: int = None) -> float: """Read epsilon value for time `t` from `history`. Defaults to latest.""" pops = history.get_all_populations() if len(pops) == 0 or (t is not None and t not in pops.t): return None if t is None: return pops.epsilon.to_numpy()[-1] return pops.set_index("t").loc[t].epsilon