Source code for pyabc.sampler.redis_eps.sampler

"""Redis based sampler base class and dynamic scheduling samplers."""

import copy
import logging
from datetime import datetime
from time import sleep
from typing import Callable, Dict, List, Tuple

import cloudpickle as pickle
import numpy as np
from jabbar import jabbar
from redis import StrictRedis

from ...acceptor import Acceptor
from ...distance import Distance
from ...epsilon import Epsilon
from ...inference_util import (
    AnalysisVars,
    create_simulate_function,
    evaluate_preliminary_particle,
    termination_criteria_fulfilled,
)
from ...population import Sample
from ...sampler import Sampler
from ...weighted_statistics import effective_sample_size
from .cmd import (
    ALL_ACCEPTED,
    ANALYSIS_ID,
    BATCH_SIZE,
    DONE_IXS,
    DYNAMIC,
    GENERATION,
    IS_LOOK_AHEAD,
    MAX_N_EVAL_LOOK_AHEAD,
    MODE,
    MSG,
    N_ACC,
    N_EVAL,
    N_FAIL,
    N_LOOKAHEAD_EVAL,
    N_REQ,
    N_WORKER,
    QUEUE,
    SLEEP_TIME,
    SSA,
    START,
    idfy,
)
from .redis_logging import RedisSamplerLogger

logger = logging.getLogger("ABC.Sampler")


class RedisSamplerBase(Sampler):
    """Abstract base class for redis based samplers.

    Parameters
    ----------
    host:
        IP address or name of the Redis server.
        Default is "localhost".
    port:
        Port of the Redis server.
        Default is 6379.
    password:
        Password for a protected server. Default is None (no protection).
    log_file:
        A file for a dedicated sampler history. Updated in each iteration.
        This log file is complementary to the logging realized via the
        logging module.
    """

    def __init__(
        self,
        host: str = "localhost",
        port: int = 6379,
        password: str = None,
        log_file: str = None,
    ):
        super().__init__()
        logger.debug(f"Redis sampler: host={host} port={port}")
        # handles the connection to the redis-server
        self.redis: StrictRedis = StrictRedis(
            host=host, port=port, password=password
        )
        self.logger = RedisSamplerLogger(log_file)

    def n_worker(self) -> int:
        """
        Get the number of connected workers.

        Returns
        -------

        Number of workers connected.
        """
        return self.redis.pubsub_numsub(MSG)[0][-1]

    def set_analysis_id(self, analysis_id: str):
        """Set the analysis id and make sure the server is available."""
        super().set_analysis_id(analysis_id)
        if self.redis.get(ANALYSIS_ID):
            raise AssertionError(
                "The server seems busy with an analysis already"
            )
        self.redis.set(ANALYSIS_ID, analysis_id)

    def sample_until_n_accepted(
        self,
        n: int,
        simulate_one: Callable,
        t: int,
        *,
        max_eval: int = np.inf,
        all_accepted: bool = False,
        ana_vars: AnalysisVars = None,
    ) -> Sample:
        raise NotImplementedError()

    def stop(self):
        """Stop potentially still ongoing sampling."""
        # delete ids specifying the current analysis
        self.redis.delete(ANALYSIS_ID)
        self.redis.delete(idfy(GENERATION, self.analysis_id))
        # note: the other ana_id-t-specific variables are not deleted, as these
        #  do not hurt anyway and could potentially make the workers fail


[docs] class RedisEvalParallelSampler(RedisSamplerBase): """Redis based dynamic scheduling low latency sampler. This sampler is well-performing in distributed environments. It is usually faster than the :class:`pyabc.sampler.DaskDistributedSampler` for short model evaluation runtimes. The longer the model evaluation times, the less the advantage becomes. It requires a running Redis server as broker. This sampler requires workers to be started via the command ``abc-redis-worker``. An example call might look like ``abc-redis-worker --host=123.456.789.123 --runtime=2h`` to connect to a Redis server on IP ``123.456.789.123`` and to terminate the worker after finishing the first population which ends after 2 hours since worker start. So the actual runtime might be longer than 2h. See ``abc-redis-worker --help`` for its options. Use the command ``abc-redis-manager`` to retrieve info on and stop the running workers. Start as many workers as you wish. Workers can be dynamically added during the ABC run. Currently, a server (specified via host and port) can only meaningfully handle one ABCSMC analysis at a time. Parameters ---------- host: IP address or name of the Redis server. Default is "localhost". port: Port of the Redis server. Default is 6379. password: Password for a protected server. Default is None (no protection). batch_size: Number of model evaluations the workers perform before contacting the REDIS server. Defaults to 1. Increase this value if model evaluation times are short or the number of workers is large to reduce communication overhead. look_ahead: Whether to start sampling for the next generation already with preliminary results although the current generation has not completely finished yet. This increases parallel efficiency, but can lead to a higher Monte-Carlo error. look_ahead_delay_evaluation: In look-ahead mode, acceptance can be delayed until the final acceptance criteria for generation t have been decided. This is mandatory if the routine has adaptive components (distance, epsilon, ...) besides the transition kernel. If not needed, enabling it may lead to a worse performance, especially if evaluation is costly compared to simulation, because evaluation happens sequentially on the main thread. Only effective if `look_ahead is True`. max_n_eval_look_ahead_factor: In delayed evaluation, only this factor times the previous number of samples are generated, afterwards the workers wait. Does not apply if evaluation is not delayed. This allows to perform a reasonable number of evaluations only, as for short-running models the number of evaluations can otherwise explode unnecessarily. wait_for_all_samples: Whether to wait for all simulations in an iteration to finish. If not, then the sampler only waits for all simulations that were started prior to the last started particle of the first `n` acceptances. Waiting for all should not be needed, this is for studying purposes. adapt_look_ahead_proposal: In look-ahead mode, adapt the preliminary proposal based on previous acceptances. In theory, as long as proposal >> prior, everything is fine. However, in practice, given a finite sample size, in some cases the preliminary proposal may be biased towards earlier-accepted particles, which can induce a similar bias in the next accepted population. Thus, if any parameter dependent simulation time heterogeneity is to be expected, i.e. if different plausible parameter space regions come with different simulation times, then this flag should be set to False. If no such heterogeneity is to be expected, this flag can be set to True, which can result in improved performance due to a more tailored proposal distribution. Only effective if `look_ahead is True`. log_file: A file for a dedicated sampler history. Updated in each iteration. This log file is complementary to the logging realized via the logging module. """
[docs] def __init__( self, host: str = "localhost", port: int = 6379, password: str = None, batch_size: int = 1, look_ahead: bool = False, look_ahead_delay_evaluation: bool = True, max_n_eval_look_ahead_factor: float = 10.0, wait_for_all_samples: bool = False, adapt_look_ahead_proposal: bool = False, log_file: str = None, ): super().__init__( host=host, port=port, password=password, log_file=log_file ) self.batch_size: int = batch_size self.look_ahead: bool = look_ahead self.look_ahead_delay_evaluation: bool = look_ahead_delay_evaluation self.max_n_eval_look_ahead_factor: float = max_n_eval_look_ahead_factor self.wait_for_all_samples: bool = wait_for_all_samples self.adapt_look_ahead_proposal: bool = adapt_look_ahead_proposal
def sample_until_n_accepted( self, n, simulate_one, t, *, max_eval=np.inf, all_accepted=False, ana_vars=None, ) -> Sample: # get the analysis id ana_id = self.analysis_id def get_int(var: str): """Convenience function to read an int variable.""" return int(self.redis.get(idfy(var, ana_id, t)).decode()) if self.generation_t_was_started(t): # update the SSA function self.redis.set( idfy(SSA, ana_id, t), pickle.dumps((simulate_one, self.sample_factory)), ) # update the required population size self.redis.set(idfy(N_REQ, ana_id, t), n) # let the workers know they should update their ssa self.redis.set(idfy(IS_LOOK_AHEAD, ana_id, t), int(False)) # it can happen that the population size increased, but the workers # believe they are done already if get_int(N_WORKER) == 0 and get_int(N_ACC) < get_int(N_REQ): # send the start signal again self.redis.publish(MSG, START) else: # set up all variables for the new generation self.start_generation_t( n=n, t=t, simulate_one=simulate_one, all_accepted=all_accepted, is_look_ahead=False, ) # for the results id_results = [] # reset logging counters self.logger.reset_counters() # wait until n acceptances with jabbar(total=n, enable=self.show_progress, keep=False) as bar: while len(id_results) < n: # pop result from queue, block until one is available dump = self.redis.blpop(idfy(QUEUE, ana_id, t))[1] # extract pickled object sample_with_id = pickle.loads(dump) # check whether the sample is really acceptable sample_with_id, any_particle_accepted = post_check_acceptance( sample_with_id, ana_id=ana_id, t=t, redis=self.redis, ana_vars=ana_vars, logger=self.logger, ) if any_particle_accepted: # append to collected results id_results.append(sample_with_id) bar.update(len(id_results)) # maybe head-start the next generation already self.maybe_start_next_generation( t=t, n=n, id_results=id_results, all_accepted=all_accepted, ana_vars=ana_vars, ) # wait until all relevant simulations done if self.wait_for_all_samples: while get_int(N_WORKER) > 0: sleep(SLEEP_TIME) else: # we only need to wait for simulations that were started # before the last started one among the first n accepted ones # as later once would be discarded anyway max_ix = sorted(id_result[0] for id_result in id_results)[n - 1] # first time index is 1 missing_ixs = set(range(1, max_ix + 1)) while ( # check whether any active evaluation was started earlier missing_ixs # also stop if no worker is active, useful for server resets and get_int(N_WORKER) > 0 ): # extract done indices # use a pipeline for efficient retrieval # transactions are atomic _var = idfy(DONE_IXS, ana_id, t) with self.redis.pipeline(transaction=True) as p: p.lrange(_var, 0, -1).delete(_var) vals = p.execute()[0] # check if missing list can be reduced for val in vals: done_ix = int(val.decode()) # remove done ix from missing ix list if done_ix in missing_ixs: missing_ixs.discard(done_ix) sleep(SLEEP_TIME) # collect all remaining results in queue at this point while self.redis.llen(idfy(QUEUE, ana_id, t)) > 0: # pop result from queue, block until one is available dump = self.redis.blpop(idfy(QUEUE, ana_id, t))[1] # extract pickled object sample_with_id = pickle.loads(dump) # check whether the sample is really acceptable sample_with_id, any_particle_accepted = post_check_acceptance( sample_with_id, ana_id=ana_id, t=t, redis=self.redis, ana_vars=ana_vars, logger=self.logger, ) if any_particle_accepted: # append to collected results id_results.append(sample_with_id) # set total number of evaluations self.nr_evaluations_ = get_int(N_EVAL) n_lookahead_eval = get_int(N_LOOKAHEAD_EVAL) # remove all time-specific variables if no more active workers, # also for previous generations if self.wait_for_all_samples: self.clear_generation_t(t=t) else: for _t in range(-1, t + 1): n_worker_b = self.redis.get(idfy(N_WORKER, ana_id, _t)) if n_worker_b is not None and int(n_worker_b.decode()) == 0: # TODO For fast-running models, communication does not # always work. # Until that is fixed, simply do not clear up. # self.clear_generation_t(t=_t) pass # create a single sample result, with start time correction sample = self.create_sample(id_results, n) # logging self.logger.add_row( t=t, n_evaluated=self.nr_evaluations_, n_lookahead=n_lookahead_eval ) self.logger.write() # weight sub-populations suitably sample = self_normalize_within_subpopulations(sample, n) return sample
[docs] def start_generation_t( self, n: int, t: int, simulate_one: Callable, all_accepted: bool, is_look_ahead: bool, max_n_eval_look_ahead: float = np.inf, ) -> None: """Start generation `t`.""" ana_id = self.analysis_id # write initial values to pipeline ( self.redis.pipeline() # initialize variables for time t .set( idfy(SSA, ana_id, t), pickle.dumps((simulate_one, self.sample_factory)), ) .set(idfy(N_EVAL, ana_id, t), 0) .set(idfy(N_ACC, ana_id, t), 0) .set(idfy(N_REQ, ana_id, t), n) .set(idfy(N_FAIL, ana_id, t), 0) .set(idfy(N_LOOKAHEAD_EVAL, ana_id, t), 0) # encode as int .set(idfy(ALL_ACCEPTED, ana_id, t), int(all_accepted)) .set(idfy(N_WORKER, ana_id, t), 0) .set(idfy(BATCH_SIZE, ana_id, t), self.batch_size) # encode as int .set(idfy(IS_LOOK_AHEAD, ana_id, t), int(is_look_ahead)) .set(idfy(MAX_N_EVAL_LOOK_AHEAD, ana_id, t), max_n_eval_look_ahead) .set(idfy(MODE, ana_id, t), DYNAMIC) # update the current-generation variable .set(idfy(GENERATION, ana_id), t) # execute all commands .execute() ) # publish start message self.redis.publish(MSG, START)
[docs] def generation_t_was_started(self, t: int) -> bool: """Check whether generation `t` was started already. Parameters ---------- t: The time for which to check. """ # just check any of the variables for time t return self.redis.exists(idfy(N_REQ, self.analysis_id, t))
[docs] def clear_generation_t(self, t: int) -> None: """Clean up after generation `t` has finished. Parameters ---------- t: The time for which to clear. """ ana_id = self.analysis_id # delete keys from pipeline ( self.redis.pipeline() .delete(idfy(SSA, ana_id, t)) .delete(idfy(N_EVAL, ana_id, t)) .delete(idfy(N_ACC, ana_id, t)) .delete(idfy(N_REQ, ana_id, t)) .delete(idfy(N_FAIL, ana_id, t)) .delete(idfy(N_LOOKAHEAD_EVAL, ana_id, t)) .delete(idfy(ALL_ACCEPTED, ana_id, t)) .delete(idfy(N_WORKER, ana_id, t)) .delete(idfy(BATCH_SIZE, ana_id, t)) .delete(idfy(IS_LOOK_AHEAD, ana_id, t)) .delete(idfy(MAX_N_EVAL_LOOK_AHEAD, ana_id, t)) .delete(idfy(MODE, ana_id, t)) .delete(idfy(DONE_IXS, ana_id, t)) .delete(idfy(QUEUE, ana_id, t)) .execute() )
[docs] def maybe_start_next_generation( self, t: int, n: int, id_results: List, all_accepted: bool, ana_vars: AnalysisVars, ) -> None: """Start the next generation already, if that looks reasonable. Parameters ---------- t: The current time. n: The current population size. id_results: The so-far returned samples. all_accepted: Whether all particles are accepted. ana_vars: Analysis variables. """ # not in a look-ahead mood if not self.look_ahead: return # all accepted indicates the preliminary iteration, where we don't # want to look ahead yet if all_accepted: return # create a result sample sample = self.create_sample(id_results, n) # copy as we modify the particles sample = copy.deepcopy(sample) # weight sub-populations suitably sample = self_normalize_within_subpopulations(sample, n) # normalize accepted population weight to 1 sample.normalize_weights() # extract population population = sample.get_accepted_population() # acceptance rate nr_evaluations = int( self.redis.get(idfy(N_EVAL, self.analysis_id, t)).decode() ) acceptance_rate = len(population) / nr_evaluations # check if any termination criterion (based on the current data) # is likely to be fulfilled after the current generation total_nr_simulations = ( ana_vars.prev_total_nr_simulations + nr_evaluations ) walltime = datetime.now() - ana_vars.init_walltime if termination_criteria_fulfilled( current_eps=ana_vars.eps(t), min_eps=ana_vars.min_eps, prev_eps=ana_vars.prev_eps, min_eps_diff=ana_vars.min_eps_diff, stop_if_single_model_alive=ana_vars.stop_if_single_model_alive, # noqa: E251 nr_of_models_alive=population.nr_of_models_alive(), acceptance_rate=acceptance_rate, min_acceptance_rate=ana_vars.min_acceptance_rate, total_nr_simulations=total_nr_simulations, max_total_nr_simulations=ana_vars.max_total_nr_simulations, walltime=walltime, max_walltime=ana_vars.max_walltime, t=t, max_t=ana_vars.max_t, ): # do not head-start a new generation as this is likely not needed return # create a preliminary simulate_one function simulate_one_prel = create_preliminary_simulate_one( t=t + 1, population=population, delay_evaluation=self.look_ahead_delay_evaluation, adapt_proposal=self.adapt_look_ahead_proposal, ana_vars=ana_vars, ) # maximum number of look-ahead evaluations if self.look_ahead_delay_evaluation: # set maximum evaluations to previous simulations * const nr_evaluations_ = int( self.redis.get(idfy(N_EVAL, self.analysis_id, t)).decode() ) max_n_eval_look_ahead = ( nr_evaluations_ * self.max_n_eval_look_ahead_factor ) else: # no maximum necessary as samples are directly evaluated max_n_eval_look_ahead = np.inf # head-start the next generation # all_accepted is most certainly False for t>0 self.start_generation_t( n=n, t=t + 1, simulate_one=simulate_one_prel, all_accepted=False, is_look_ahead=True, max_n_eval_look_ahead=max_n_eval_look_ahead, )
[docs] def create_sample(self, id_results: List[Tuple], n: int) -> Sample: """Create a single sample result. Order the results by starting point to avoid a bias towards short-running simulations (dynamic scheduling). """ # sort id_results.sort(key=lambda x: x[0]) # cut off id_results = id_results[:n] # extract simulations results = [res[1] for res in id_results] # create 1 to-be-returned sample from results sample = self._create_empty_sample() for j in range(n): sample += results[j] # check number of acceptances if (n_accepted := sample.n_accepted) != n: raise AssertionError( f"Expected {n} accepted particles but got {n_accepted}" ) return sample
[docs] def check_analysis_variables( self, distance_function: Distance, eps: Epsilon, acceptor: Acceptor, ) -> None: """ "Check analysis variables appropriateness for sampling.""" if self.look_ahead_delay_evaluation: # nothing to be done return def _check_bad(var): """Check whether a component is incompatible.""" # do not check for `requires_calibration()`, because in the first # iteration we do not look ahead if var.is_adaptive(): raise AssertionError( f"{var.__class__.__name__} cannot be used in look-ahead " "mode without delayed acceptance. Consider setting the " "sampler's `look_ahead_delay_evaluation` flag." ) _check_bad(acceptor) _check_bad(distance_function) _check_bad(eps)
def create_preliminary_simulate_one( t, population, delay_evaluation: bool, adapt_proposal: bool, ana_vars: AnalysisVars, ) -> Callable: """Create a preliminary simulate_one function for generation `t`. Based on preliminary results, update transitions, distance function, epsilon threshold etc., and return a function that samples parameters, simulates data and checks their preliminary acceptance. As the actual acceptance criteria may be different, samples generated by this function must be checked anew afterwards. Parameters ---------- t: The time index for which to create the function (i.e. call with t+1). population: The preliminary population. delay_evaluation: Whether to delay evaluation. adapt_proposal: Whether to fit the proposal distribution to the new population. ana_vars: The analysis variables. Returns ------- simulate_one: The preliminary sampling function. """ model_probabilities = population.get_model_probabilities() # set proposal distribution transitions = ana_vars.transitions if adapt_proposal: # create deep copy of the transition function transitions = copy.deepcopy(transitions) # fit transitions for m in population.get_alive_models(): parameters, w = population.get_distribution(m) transitions[m].fit(parameters, w) elif t == 1: # at t=0, the prior is used for sampling # (and the transition not fitted yet) transitions = ana_vars.parameter_priors return create_simulate_function( t=t, model_probabilities=model_probabilities, model_perturbation_kernel=ana_vars.model_perturbation_kernel, transitions=transitions, model_prior=ana_vars.model_prior, parameter_priors=ana_vars.parameter_priors, models=ana_vars.models, summary_statistics=ana_vars.summary_statistics, x_0=ana_vars.x_0, distance_function=ana_vars.distance_function, eps=ana_vars.eps, acceptor=ana_vars.acceptor, evaluate=not delay_evaluation, proposal_id=-1, ) def post_check_acceptance( sample_with_id, ana_id, t, redis, ana_vars, logger: RedisSamplerLogger, ) -> Tuple: """Check whether the sample is really acceptable. This is where evaluation of preliminary samples happens, using the analysis variables from the actual generation `t` and the previously simulated data. The sample is modified in-place. Returns ------- sample_with_id, any_accepted: The (maybe post-evaluated) id-sample tuple, and an indicator whether any particle in the sample was accepted, s.t. the sample should be kept. """ # 0 is relative start time, 1 is the actual sample sample: Sample = sample_with_id[1] # check whether there are preliminary particles if not any(particle.preliminary for particle in sample.all_particles): n_accepted = len(sample.accepted_particles) if n_accepted != 1: # this should never happen raise AssertionError( "Expected exactly one accepted particle in sample." ) # increase general acceptance counter logger.n_accepted += 1 # increase accepted counter if in look-ahead mode if sample.is_look_ahead: logger.n_lookahead_accepted += 1 # nothing else to be done return sample_with_id, True # in preliminary mode, there should only be one particle per sample if len(sample.all_particles) != 1: # this should never happen raise AssertionError( "Expected number of particles in sample: 1. " f"Got: {len(sample.all_particles)}" ) # from here on, we may assume that all particles (#=1) are yet to be judged logger.n_preliminary += 1 # iterate over the 1 particle for particle in sample.all_particles: particle = evaluate_preliminary_particle(particle, t, ana_vars) # react to acceptance if particle.accepted: sample.accepted_particles = [particle] sample.rejected_particles = [] # increase redis shared counter redis.incr(idfy(N_ACC, ana_id, t), 1) # increase general and lookahead counter logger.n_accepted += 1 logger.n_lookahead_accepted += 1 else: sample.accepted_particles = [] if sample.record_rejected: sample.rejected_particles = [particle] else: sample.rejected_particles = [] return sample_with_id, len(sample.accepted_particles) > 0 def self_normalize_within_subpopulations(sample: Sample, n: int) -> Sample: """Applies subpopulation-wise self-normalization of samples, in-place. The weights are adjusted per proposal id, such that all particles belonging to one proposal id have a total weight proportional to the effective sample size of the sub-population. This defines the relative importances of all particles in the accepted population in a reasonabler manner. Conceptually, also hter normalizations are possible. Parameters ---------- sample: The population to be returned by the sampler. n: Population size. Returns ------- sample: The same, weight-adjusted sample. """ prop_ids = {particle.proposal_id for particle in sample.accepted_particles} if len(prop_ids) == 1: # Nothing to be done, as we only have one proposal, and normalization # is applied later when the population is created return sample if len(sample.accepted_particles) != n: # this should not happen raise AssertionError("Unexpected number of acceptances") # get particles per proposal particles_per_prop = { prop_id: [ particle for particle in sample.accepted_particles if particle.proposal_id == prop_id ] for prop_id in prop_ids } # normalize weights by $ESS_l / sum_{i<=N_l} w^l_i$ for proposal id l # this is s.t. $sum_{i<=N_l} w^l_i \propto ESS_l$ normalizations: Dict[int, float] = {} for prop_id, particles_for_prop in particles_per_prop.items(): weights = np.array( [particle.weight for particle in particles_for_prop] ) ess = effective_sample_size(weights) total_weight = weights.sum() normalizations[prop_id] = ess / total_weight # normalize all particles for particle in sample.all_particles: if particle.proposal_id in normalizations: particle.weight *= normalizations[particle.proposal_id] else: # set weight of particles from populations None of which was # accepted to 0 (until we start caring about those for real) particle.weight = 0.0 return sample