"""Epsilon threshold plots"""
from typing import TYPE_CHECKING
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator
from ..storage import History
from .util import get_labels, to_lists
if TYPE_CHECKING:
import plotly.graph_objs as go
def _prepare(
histories: list | History,
labels: list | str,
colors: list,
):
# preprocess input
histories = to_lists(histories)
labels = get_labels(labels, len(histories))
if colors is None:
colors = [None for _ in range(len(histories))]
# extract epsilons
eps = []
for history in histories:
# note: first entry is from calibration and thus translates to inf,
# thus must be discarded
eps.append(np.array(history.get_all_populations()['epsilon'][1:]))
return labels, colors, eps
[docs]
def plot_epsilons(
histories: list | History,
labels: list | str = None,
colors: list = None,
yscale: str = 'log',
title: str = 'Epsilon values',
size: tuple = None,
ax: mpl.axes.Axes = None,
) -> mpl.axes.Axes:
"""
Plot epsilon trajectory.
Parameters
----------
histories:
The histories to plot from. History ids must be set correctly.
labels:
Labels corresponding to the histories. If None are provided,
indices are used as labels.
colors:
Colors to use for the lines. If None, then the matplotlib
default values are used.
yscale:
Scaling to apply to the y-axis. Use matplotlib's notation.
title:
Title for the plot.
size:
The size of the plot in inches.
ax:
The axis object to use. A new one is created if None.
Returns
-------
ax:
The axis object used.
"""
# process inputs
labels, colors, eps = _prepare(histories, labels, colors)
# create figure
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# plot
for ep, label, color in zip(eps, labels, colors):
ax.plot(ep, 'x-', label=label, color=color)
# format
ax.set_xlabel('Population index')
ax.set_ylabel('Epsilon')
if any(lab is not None for lab in labels):
ax.legend()
ax.set_title(title)
ax.set_yscale(yscale)
# enforce integer ticks
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# set size
if size is not None:
fig.set_size_inches(size)
fig.tight_layout()
return ax
[docs]
def plot_epsilons_plotly(
histories: list | History,
labels: list | str = None,
colors: list = None,
yscale: str = 'log',
title: str = 'Epsilon values',
size: tuple = None,
fig: 'go.Figure' = None,
) -> 'go.Figure':
"""Plot epsilon trajectory using plotly."""
import plotly.graph_objects as go
# process inputs
labels, colors, eps = _prepare(histories, labels, colors)
# create figure
if fig is None:
fig = go.Figure()
# plot
for ep, label, color in zip(eps, labels, colors):
fig.add_trace(
go.Scatter(
x=np.arange(len(ep)),
y=ep,
mode='lines+markers',
name=label,
line_color=color,
)
)
# format
fig.update_layout(
xaxis_title='Population index',
yaxis_title='Epsilon',
title=title,
yaxis_type=yscale,
)
# set size
if size is not None:
fig.update_layout(width=size[0], height=size[1])
return fig