"""Redis based static scheduling sampler."""
import logging
import pickle
from time import sleep
from typing import Callable, List
import cloudpickle
import numpy as np
from jabbar import jabbar
from ...population import Sample
from .cmd import (
GENERATION,
MODE,
MSG,
N_ACC,
N_EVAL,
N_FAIL,
N_JOB,
N_REQ,
N_WORKER,
QUEUE,
SLEEP_TIME,
SSA,
START,
STATIC,
idfy,
)
from .sampler import RedisSamplerBase
logger = logging.getLogger("ABC.Sampler")
[docs]
class RedisStaticSampler(RedisSamplerBase):
"""Redis based static scheduling sampler."""
def sample_until_n_accepted(
self,
n,
simulate_one,
t,
*,
max_eval=np.inf,
all_accepted=False,
ana_vars=None,
):
# get the analysis id
ana_id = self.analysis_id
# tell workers to start
self.start_generation_t(n=n, t=t, simulate_one=simulate_one)
# collect samples
samples = []
with jabbar(total=n, enable=self.show_progress, keep=False) as bar:
while len(samples) < n:
dump = self.redis.blpop(idfy(QUEUE, ana_id, t))[1]
sample = pickle.loads(dump)
if len(sample.accepted_particles) != 1:
# this should never happen
raise AssertionError(
"Expected exactly one accepted particle in sample."
)
samples.append(sample)
bar.inc()
# wait for all workers to join
# this is necessary for clear intermediate states
while int(self.redis.get(idfy(N_WORKER, ana_id, t)).decode()) > 0:
sleep(SLEEP_TIME)
# set total number of evaluations
self.nr_evaluations_ = int(
self.redis.get(idfy(N_EVAL, ana_id, t)).decode()
)
# remove all time-specific variables
self.clear_generation_t(t)
# create a single sample result, with start time correction
sample = self.create_sample(samples, n)
return sample
[docs]
def start_generation_t(
self, n: int, t: int, simulate_one: Callable
) -> None:
"""Start generation `t`."""
ana_id = self.analysis_id
# write initial values to pipeline
(
self.redis.pipeline()
# initialize variables for time t
.set(
idfy(SSA, ana_id, t),
cloudpickle.dumps((simulate_one, self.sample_factory)),
)
.set(idfy(N_EVAL, ana_id, t), 0)
# N_ACC here only serves for in-time debugging
.set(idfy(N_ACC, ana_id, t), 0)
.set(idfy(N_REQ, ana_id, t), n)
.set(idfy(N_FAIL, ana_id, t), 0)
.set(idfy(N_WORKER, ana_id, t), 0)
.set(idfy(N_JOB, ana_id, t), n)
.set(idfy(MODE, ana_id, t), STATIC)
# update the current-generation variable
.set(idfy(GENERATION, ana_id), t)
# execute all commands
.execute()
)
# publish start message
self.redis.publish(MSG, START)
[docs]
def clear_generation_t(self, t: int) -> None:
"""Clean up after generation `t` has finished.
Parameters
----------
t: The time for which to clear.
"""
ana_id = self.analysis_id
# delete keys from pipeline
(
self.redis.pipeline()
.delete(idfy(SSA, ana_id, t))
.delete(idfy(N_EVAL, ana_id, t))
.delete(idfy(N_ACC, ana_id, t))
.delete(idfy(N_REQ, ana_id, t))
.delete(idfy(N_FAIL, ana_id, t))
.delete(idfy(N_WORKER, ana_id, t))
.delete(idfy(N_JOB, ana_id, t))
.delete(idfy(MODE, ana_id, t))
.delete(idfy(QUEUE, ana_id, t))
.execute()
)
[docs]
def create_sample(self, samples: List[Sample], n: int) -> Sample:
"""Create a single sample result.
Order the results by starting point to avoid a bias towards
short-running simulations (dynamic scheduling).
"""
if len(samples) != n:
raise AssertionError(f"Expected {n} samples, got {len(samples)}.")
# create 1 to-be-returned sample from results
sample = self._create_empty_sample()
for single_sample in samples:
sample += single_sample
return sample