Source code for pyabc.visualization.epsilon

"""Epsilon threshold plots"""

from typing import TYPE_CHECKING, List, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator

from ..storage import History
from .util import get_labels, to_lists

if TYPE_CHECKING:
    import plotly.graph_objs as go


def _prepare(
    histories: Union[List, History],
    labels: Union[List, str],
    colors: List,
):
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))
    if colors is None:
        colors = [None for _ in range(len(histories))]

    # extract epsilons
    eps = []
    for history in histories:
        # note: first entry is from calibration and thus translates to inf,
        # thus must be discarded
        eps.append(np.array(history.get_all_populations()['epsilon'][1:]))

    return labels, colors, eps


[docs] def plot_epsilons( histories: Union[List, History], labels: Union[List, str] = None, colors: List = None, yscale: str = 'log', title: str = "Epsilon values", size: tuple = None, ax: mpl.axes.Axes = None, ) -> mpl.axes.Axes: """ Plot epsilon trajectory. Parameters ---------- histories: The histories to plot from. History ids must be set correctly. labels: Labels corresponding to the histories. If None are provided, indices are used as labels. colors: Colors to use for the lines. If None, then the matplotlib default values are used. yscale: Scaling to apply to the y-axis. Use matplotlib's notation. title: Title for the plot. size: The size of the plot in inches. ax: The axis object to use. A new one is created if None. Returns ------- ax: The axis object used. """ # process inputs labels, colors, eps = _prepare(histories, labels, colors) # create figure if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() # plot for ep, label, color in zip(eps, labels, colors): ax.plot(ep, 'x-', label=label, color=color) # format ax.set_xlabel("Population index") ax.set_ylabel("Epsilon") if any(lab is not None for lab in labels): ax.legend() ax.set_title(title) ax.set_yscale(yscale) # enforce integer ticks ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # set size if size is not None: fig.set_size_inches(size) fig.tight_layout() return ax
[docs] def plot_epsilons_plotly( histories: Union[List, History], labels: Union[List, str] = None, colors: List = None, yscale: str = 'log', title: str = "Epsilon values", size: tuple = None, fig: "go.Figure" = None, ) -> "go.Figure": """Plot epsilon trajectory using plotly.""" import plotly.graph_objects as go # process inputs labels, colors, eps = _prepare(histories, labels, colors) # create figure if fig is None: fig = go.Figure() # plot for ep, label, color in zip(eps, labels, colors): fig.add_trace( go.Scatter( x=np.arange(len(ep)), y=ep, mode='lines+markers', name=label, line_color=color, ) ) # format fig.update_layout( xaxis_title="Population index", yaxis_title="Epsilon", title=title, yaxis_type=yscale, ) # set size if size is not None: fig.update_layout(width=size[0], height=size[1]) return fig