Source code for pyabc.visualization.epsilon

"""Epsilon threshold plots"""

from typing import TYPE_CHECKING

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator

from ..storage import History
from .util import get_labels, to_lists

if TYPE_CHECKING:
    import plotly.graph_objs as go


def _prepare(
    histories: list | History,
    labels: list | str,
    colors: list,
):
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))
    if colors is None:
        colors = [None for _ in range(len(histories))]

    # extract epsilons
    eps = []
    for history in histories:
        # note: first entry is from calibration and thus translates to inf,
        # thus must be discarded
        eps.append(np.array(history.get_all_populations()['epsilon'][1:]))

    return labels, colors, eps


[docs] def plot_epsilons( histories: list | History, labels: list | str = None, colors: list = None, yscale: str = 'log', title: str = 'Epsilon values', size: tuple = None, ax: mpl.axes.Axes = None, ) -> mpl.axes.Axes: """ Plot epsilon trajectory. Parameters ---------- histories: The histories to plot from. History ids must be set correctly. labels: Labels corresponding to the histories. If None are provided, indices are used as labels. colors: Colors to use for the lines. If None, then the matplotlib default values are used. yscale: Scaling to apply to the y-axis. Use matplotlib's notation. title: Title for the plot. size: The size of the plot in inches. ax: The axis object to use. A new one is created if None. Returns ------- ax: The axis object used. """ # process inputs labels, colors, eps = _prepare(histories, labels, colors) # create figure if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() # plot for ep, label, color in zip(eps, labels, colors): ax.plot(ep, 'x-', label=label, color=color) # format ax.set_xlabel('Population index') ax.set_ylabel('Epsilon') if any(lab is not None for lab in labels): ax.legend() ax.set_title(title) ax.set_yscale(yscale) # enforce integer ticks ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # set size if size is not None: fig.set_size_inches(size) fig.tight_layout() return ax
[docs] def plot_epsilons_plotly( histories: list | History, labels: list | str = None, colors: list = None, yscale: str = 'log', title: str = 'Epsilon values', size: tuple = None, fig: 'go.Figure' = None, ) -> 'go.Figure': """Plot epsilon trajectory using plotly.""" import plotly.graph_objects as go # process inputs labels, colors, eps = _prepare(histories, labels, colors) # create figure if fig is None: fig = go.Figure() # plot for ep, label, color in zip(eps, labels, colors): fig.add_trace( go.Scatter( x=np.arange(len(ep)), y=ep, mode='lines+markers', name=label, line_color=color, ) ) # format fig.update_layout( xaxis_title='Population index', yaxis_title='Epsilon', title=title, yaxis_type=yscale, ) # set size if size is not None: fig.update_layout(width=size[0], height=size[1]) return fig