Source code for pyabc.visualization.effective_sample_size

"""Effective sample size plots"""

from typing import TYPE_CHECKING, List, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from ..storage import History
from ..weighted_statistics import effective_sample_size
from .util import get_labels, to_lists

if TYPE_CHECKING:
    import plotly.graph_objs as go


def _prepare_plot_effective_sample_sizes(
    histories: Union[List, History],
    labels: Union[List, str],
    colors: List,
    relative: bool,
):
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))
    if colors is None:
        colors = [None for _ in range(len(histories))]

    # extract effective sample sizes
    essss = []  # :)
    for history in histories:
        esss = []
        for t in range(0, history.max_t + 1):
            # we need the weights not normalized to 1 for each model
            w = history.get_weighted_distances(t=t)['w']
            ess = effective_sample_size(w)
            if relative:
                ess /= len(w)
            esss.append(ess)
        essss.append(esss)

    return labels, colors, essss


[docs] def plot_effective_sample_sizes( histories: Union[List, History], labels: Union[List, str] = None, rotation: int = 0, title: str = "Effective sample size", relative: bool = False, colors: List = None, size: tuple = None, ax: mpl.axes.Axes = None, ) -> mpl.axes.Axes: """ Plot effective sample sizes over all iterations. Parameters ---------- histories: The histories to plot from. History ids must be set correctly. labels: Labels corresponding to the histories. If None are provided, indices are used as labels. rotation: Rotation to apply to the plot's x tick labels. For longer labels, a tilting of 45 or even 90 can be preferable. title: Title for the plot. relative: Whether to show relative sizes (to 1) or w.r.t. the real number of particles. colors: Colors to use for the lines. If None, then the matplotlib default values are used. size: The size of the plot in inches. ax: The matplotlib axes object to use. If None, a new figure is created. Returns ------- ax: The matplotlib axes object. """ # prepare data labels, colors, essss = _prepare_plot_effective_sample_sizes( histories, labels, colors, relative ) # create figure if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() # plot for esss, label, color in zip(essss, labels, colors): ax.plot(range(0, len(esss)), esss, 'x-', label=label, color=color) # format ax.set_xlabel("Population index") ax.set_ylabel("ESS") if any(lab is not None for lab in labels): ax.legend() ax.set_title(title) # enforce integer ticks ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # rotate x tick labels plt.setp(ax.get_xticklabels(), rotation=rotation) # set size if size is not None: fig.set_size_inches(size) fig.tight_layout() return ax
[docs] def plot_effective_sample_sizes_plotly( histories: Union[List, History], labels: Union[List, str] = None, rotation: int = 0, title: str = "Effective sample size", relative: bool = False, colors: List = None, size: tuple = None, fig: "go.Figure" = None, ) -> "go.Figure": """Plot effective sample sizes using plotly.""" import plotly.graph_objects as go # prepare data labels, colors, essss = _prepare_plot_effective_sample_sizes( histories, labels, colors, relative ) # create figure if fig is None: fig = go.Figure() # plot for esss, label, color in zip(essss, labels, colors): fig.add_trace( go.Scatter( x=list(range(0, len(esss))), y=esss, mode="lines+markers", name=label, marker={'color': color}, ) ) # format fig.update_layout( xaxis_title="Population index", yaxis_title="ESS", title=title, xaxis={'tickmode': "linear"}, ) # rotate x tick labels fig.update_xaxes(tickangle=rotation) # set size if size is not None: fig.update_layout(width=size[0], height=size[1]) return fig