"""Bayesian credible interval plots"""
from typing import List, Union
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from ..storage import History
from ..transition import MultivariateNormalTransition, Transition
from ..weighted_statistics import weighted_quantile
from .util import get_labels, to_lists
def _prepare_credible_intervals(
history: History,
m: int,
ts: Union[List[int], int],
par_names: List,
levels: List,
show_mean: bool,
show_kde_max: bool,
show_kde_max_1d: bool,
kde: Transition,
kde_1d: Transition,
):
if levels is None:
levels = [0.95]
levels = sorted(levels)
if par_names is None:
# extract all parameter names
df, _ = history.get_distribution(m=m)
par_names = list(df.columns.values)
# dimensions
n_par = len(par_names)
n_confidence = len(levels)
if ts is None:
ts = list(range(0, history.max_t + 1))
n_pop = len(ts)
# prepare matrices
cis = np.empty((n_par, n_pop, 2 * n_confidence))
median = np.empty((n_par, n_pop))
mean = np.empty((n_par, n_pop))
kde_max = np.empty((n_par, n_pop))
kde_max_1d = np.empty((n_par, n_pop))
if kde is None and show_kde_max:
kde = MultivariateNormalTransition()
if kde_1d is None and show_kde_max_1d:
kde_1d = MultivariateNormalTransition()
# fill matrices
# iterate over populations
for i_t, t in enumerate(ts):
df, w = history.get_distribution(m=m, t=t)
# normalize weights to be sure
w /= w.sum()
# fit kde
if show_kde_max:
_kde_max_pnt = compute_kde_max(kde, df, w)
# iterate over parameters
for i_par, par in enumerate(par_names):
# as numpy array
vals = np.array(df[par])
# median
median[i_par, i_t] = compute_quantile(vals, w, 0.5)
# mean
if show_mean:
mean[i_par, i_t] = np.sum(w * vals)
# kde max
if show_kde_max:
kde_max[i_par, i_t] = _kde_max_pnt[par]
if show_kde_max_1d:
_kde_max_1d_pnt = compute_kde_max(kde_1d, df[[par]], w)
kde_max_1d[i_par, i_t] = _kde_max_1d_pnt[par]
# levels
for i_c, confidence in enumerate(levels):
lb, ub = compute_credible_interval(vals, w, confidence)
cis[i_par, i_t, i_c] = lb
cis[i_par, i_t, -1 - i_c] = ub
return (
par_names,
levels,
n_par,
n_confidence,
n_pop,
ts,
cis,
median,
mean,
kde_max,
kde_max_1d,
)
[docs]
def plot_credible_intervals(
history: History,
m: int = 0,
ts: Union[List[int], int] = None,
par_names: List = None,
levels: List = None,
colors: List = None,
color_median: str = None,
show_mean: bool = False,
color_mean: str = None,
show_kde_max: bool = False,
color_kde_max: str = None,
show_kde_max_1d: bool = False,
color_kde_max_1d: str = None,
size: tuple = None,
refval: dict = None,
refval_color: str = 'C1',
kde: Transition = None,
kde_1d: Transition = None,
arr_ax: List[matplotlib.axes.Axes] = None,
):
"""Plot credible intervals over time.
Parameters
----------
history:
The history to extract data from.
m:
The id of the model to plot for.
ts:
The time points to plot for.
par_names:
The parameter to plot for. If None, then all parameters are used.
levels:
Confidence intervals to compute. Default is [0.95].
colors:
Colors to use for the errorbars.
color_median:
Color to use for the median line.
show_mean:
Whether to show the mean apart from the median as well.
color_mean:
Color to use for the mean.
show_kde_max:
Whether to show the one of the sampled points that gives the highest
KDE value for the specified KDE.
Note: It is not attemtped to find the overall hightest KDE value, but
rather the sampled point with the highest value is taken as an
approximation (of the MAP-value).
color_kde_max:
Color to use for KDE max value.
show_kde_max_1d:
Same as `show_kde_max`, but here the KDE is applied component-wise.
color_kde_max_1d:
Color to use for the KDE max value.
size:
Size of the plot.
refval:
A dictionary of reference parameter values to plot for each of
`par_names`.
refval_color:
Color to use for the reference value.
kde:
The KDE to use for `show_kde_max`.
Defaults to :class:`pyabc.MultivariateNormalTransition`.
kde_1d:
The KDE to use for `show_kde_max_1d`.
Defaults to :class:`pyabc.MultivariateNormalTransition`.
arr_ax:
Array of axes to use. Assumed to be a 1-dimensional list.
Returns
-------
arr_ax: Array of generated axes.
"""
# prepare data
(
par_names,
levels,
n_par,
n_confidence,
n_pop,
ts,
cis,
median,
mean,
kde_max,
kde_max_1d,
) = _prepare_credible_intervals(
history=history,
m=m,
ts=ts,
par_names=par_names,
levels=levels,
show_mean=show_mean,
show_kde_max=show_kde_max,
show_kde_max_1d=show_kde_max_1d,
kde=kde,
kde_1d=kde_1d,
)
if colors is None:
colors = [None for _ in range(len(levels))]
if color_median is None:
color_median = colors[0]
# prepare axes
if arr_ax is None:
_, arr_ax = plt.subplots(
nrows=n_par, ncols=1, sharex=False, sharey=False, figsize=size
)
if not isinstance(arr_ax, (list, np.ndarray)):
arr_ax = [arr_ax]
fig = arr_ax[0].get_figure()
# plot
for i_par, (par, ax) in enumerate(zip(par_names, arr_ax)):
for i_c, confidence in reversed(list(enumerate(levels))):
ax.errorbar(
x=range(n_pop),
y=median[i_par].flatten(),
yerr=[
median[i_par] - cis[i_par, :, i_c],
cis[i_par, :, -1 - i_c] - median[i_par],
],
color=color_median,
ecolor=colors[i_c],
capsize=(5.0 / n_confidence) * (i_c + 1),
label="{:.2f}".format(confidence),
)
ax.set_title(f"Parameter {par}")
# mean
if show_mean:
ax.plot(
range(n_pop), mean[i_par], 'x-', label="Mean", color=color_mean
)
# kde max
if show_kde_max:
ax.plot(
range(n_pop),
kde_max[i_par],
'x-',
label="Max KDE",
color=color_kde_max,
)
if show_kde_max_1d:
ax.plot(
range(n_pop),
kde_max_1d[i_par],
'x-',
label="Max KDE 1d",
color=color_kde_max_1d,
)
# reference value
if refval is not None:
ax.hlines(
refval[par],
xmin=0,
xmax=n_pop - 1,
color=refval_color,
label="Reference value",
)
ax.set_xticks(range(n_pop))
ax.set_xticklabels(ts)
ax.set_ylabel(par)
ax.legend()
# format
arr_ax[-1].set_xlabel("Population t")
if size is not None:
fig.set_size_inches(size)
fig.tight_layout()
return arr_ax
[docs]
def plot_credible_intervals_plotly(
history: History,
m: int = 0,
ts: Union[List[int], int] = None,
par_names: List = None,
levels: List = None,
colors=None,
size: tuple = None,
refval: dict = None,
refval_color: str = 'gray',
kde: Transition = None,
kde_1d: Transition = None,
):
"""Plot credible intervals over time using plotly."""
import plotly.graph_objects as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
# prepare data
(
par_names,
levels,
n_par,
_,
n_pop,
ts,
cis,
median,
_,
_,
_,
) = _prepare_credible_intervals(
history=history,
m=m,
ts=ts,
par_names=par_names,
levels=levels,
show_mean=False,
show_kde_max=False,
show_kde_max_1d=False,
kde=kde,
kde_1d=kde_1d,
)
# create figure
fig = make_subplots(rows=n_par, cols=1, shared_xaxes=True)
opacities = [1 for _ in range(len(levels))]
# colors
if colors is None:
colors = DEFAULT_PLOTLY_COLORS[0]
colors = [colors for _ in range(len(levels))]
# opacities
opacities = np.linspace(1, 0.3, len(levels))
# plot
for i_par, par in enumerate(par_names):
showlegend = i_par == 0
for i_c, confidence in reversed(list(enumerate(levels))):
fig.add_trace(
go.Scatter(
x=ts,
y=median[i_par].flatten(),
error_y={
'type': 'data',
'symmetric': False,
'array': cis[i_par, :, i_c] - median[i_par],
'arrayminus': median[i_par] - cis[i_par, :, -1 - i_c],
},
mode='lines+markers',
marker={'color': colors[i_c]},
opacity=opacities[i_c],
name="{:.2f}".format(confidence),
showlegend=showlegend,
),
row=i_par + 1,
col=1,
)
# reference value
if refval is not None:
fig.add_trace(
go.Scatter(
x=ts,
y=[refval[par]] * n_pop,
mode='lines',
marker={'color': refval_color},
name="Reference value",
showlegend=showlegend,
),
row=i_par + 1,
col=1,
)
# y axis label
fig.update_yaxes(title_text=par, row=i_par + 1, col=1)
# set size
if size is not None:
fig.update_layout(width=size[0], height=size[1])
return fig
[docs]
def plot_credible_intervals_for_time(
histories: Union[List[History], History],
labels: Union[List[str], str] = None,
ms: Union[List[int], int] = None,
ts: Union[List[int], int] = None,
par_names: List[str] = None,
levels: List[float] = None,
show_mean: bool = False,
show_kde_max: bool = False,
show_kde_max_1d: bool = False,
size: tuple = None,
rotation: int = 0,
refvals: Union[List[dict], dict] = None,
kde: Transition = None,
kde_1d: Transition = None,
):
"""
Plot credible intervals over time.
Parameters
----------
histories:
The histories to extract data from.
labels:
Labels for the histories. If None, they are just numbered.
ms:
List of the ids of the models to plot for. Default is
model id 0 for all histories.
ts:
The time points to plot for, same length as histories.
If None, the last times are taken.
par_names:
The parameter to plot for. If None, then all parameters are used.
Assumes all histories have these parameters.
levels:
Confidence intervals to compute.
show_mean, show_kde_max, show_kde_max_1d:
As in `plot_credible_intervals`.
size:
Size of the plot.
refvals:
A dictionary of reference parameter values to plot for each of
`par_names`, for each history. Same length as histories.
kde:
The KDE to use for `show_kde_max`.
kde_1d:
The KDE to use for `show_kde_max_1d`.
"""
histories = to_lists(histories)
labels = get_labels(labels, len(histories))
n_run = len(histories)
if ms is None:
ms = [0] * n_run
elif not isinstance(ms, list) or len(ms) == 1:
ms = [ms] * n_run
if levels is None:
levels = [0.95]
levels = sorted(levels)
if par_names is None:
# extract all parameter names
df, _ = histories[0].get_distribution(m=ms[0])
par_names = list(df.columns.values)
n_par = len(par_names)
n_confidence = len(levels)
if ts is None:
ts = [history.max_t for history in histories]
if refvals is not None and not isinstance(refvals, list):
refvals = [refvals] * n_run
# prepare axes
fig, arr_ax = plt.subplots(
nrows=n_par, ncols=1, sharex=False, sharey=False
)
if n_par == 1:
arr_ax = [arr_ax]
# prepare matrices
cis = np.empty((n_par, n_run, 2 * n_confidence))
median = np.empty((n_par, n_run))
if show_mean:
mean = np.empty((n_par, n_run))
if show_kde_max:
kde_max = np.empty((n_par, n_run))
if show_kde_max_1d:
kde_max_1d = np.empty((n_par, n_run))
if kde is None and show_kde_max:
kde = MultivariateNormalTransition()
if kde_1d is None and show_kde_max_1d:
kde_1d = MultivariateNormalTransition()
# fill matrices
# iterate over populations
for i_run, (h, t, m) in enumerate(zip(histories, ts, ms)):
df, w = h.get_distribution(m=m, t=t)
# normalize weights to be sure
w /= w.sum()
# fit kde
if show_kde_max:
_kde_max_pnt = compute_kde_max(kde, df, w)
# iterate over parameters
for i_par, par in enumerate(par_names):
# as numpy array
vals = np.array(df[par])
# median
median[i_par, i_run] = compute_quantile(vals, w, 0.5)
# mean
if show_mean:
mean[i_par, i_run] = np.sum(w * vals)
# kde max
if show_kde_max:
kde_max[i_par, i_run] = _kde_max_pnt[par]
if show_kde_max_1d:
_kde_max_1d_pnt = compute_kde_max(kde_1d, df[[par]], w)
kde_max_1d[i_par, i_run] = _kde_max_1d_pnt[par]
# levels
for i_c, confidence in enumerate(levels):
lb, ub = compute_credible_interval(vals, w, confidence)
cis[i_par, i_run, i_c] = lb
cis[i_par, i_run, -1 - i_c] = ub
# plot
for i_par, (par, ax) in enumerate(zip(par_names, arr_ax)):
for i_run in range(len(histories)):
for i_c in reversed(range(len(levels))):
y_err = np.array(
[
median[i_par, i_run] - cis[i_par, i_run, i_c],
cis[i_par, i_run, -1 - i_c] - median[i_par, i_run],
]
)
y_err = y_err.reshape((2, 1))
ax.errorbar(
x=[i_run],
y=median[i_par, i_run],
yerr=y_err,
capsize=(10.0 / n_confidence) * (i_c + 1),
color=f'C{i_c}',
)
# reference value
if refvals[i_run] is not None:
ax.plot([i_run], [refvals[i_run][par]], 'x', color='black')
ax.set_title(f"Parameter {par}")
# mean
if show_mean:
ax.plot(range(n_run), mean[i_par], 'x', color=f'C{n_confidence}')
# kde max
if show_kde_max:
ax.plot(
range(n_run), kde_max[i_par], 'x', color=f'C{n_confidence + 1}'
)
if show_kde_max_1d:
ax.plot(
range(n_run),
kde_max_1d[i_par],
'x',
color=f'C{n_confidence + 2}',
)
ax.set_xticks(range(n_run))
ax.set_xticklabels(labels, rotation=rotation)
leg_colors = [f'C{i_c}' for i_c in reversed(range(n_confidence))]
leg_labels = ['{:.2f}'.format(c) for c in reversed(levels)]
if show_mean:
leg_colors.append(f'C{n_confidence}')
leg_labels.append("Mean")
if show_kde_max:
leg_colors.append(f'C{n_confidence + 1}')
leg_labels.append("Max KDE")
if show_kde_max_1d:
leg_colors.append(f'C{n_confidence + 2}')
leg_labels.append("Max KDE 1d")
if refvals is not None:
leg_colors.append('black')
leg_labels.append("Reference value")
handles = [
Line2D([0], [0], color=c, label=l)
for c, l in zip(leg_colors, leg_labels)
]
ax.legend(handles=handles, bbox_to_anchor=(1.04, 1), loc="upper left")
# format
arr_ax[-1].set_xlabel("Population t")
if size is not None:
fig.set_size_inches(size)
fig.tight_layout()
return arr_ax
def compute_credible_interval(vals, weights, confidence: float = 0.95):
"""
Compute credible interval to confidence level `confidence` for points
`vals` associated to weights `weights`.
Returns
-------
lb, ub: tuple of float
Lower and upper bound of the credible interval.
"""
if confidence <= 0.0 or confidence >= 1.0:
raise ValueError(
f"Confidence {confidence} must be in the interval (0.0, 1.0)."
)
alpha_lb = 0.5 * (1.0 - confidence)
alpha_ub = confidence + alpha_lb
lb = compute_quantile(vals, weights, alpha_lb)
ub = compute_quantile(vals, weights, alpha_ub)
return lb, ub
def compute_kde_max(kde, df, w):
"""
Fit the kde and find the maximal kde value among the points in df.
"""
kde.fit(df, w)
kde_vals = [kde.pdf(p) for _, p in df.iterrows()]
ix = kde_vals.index(max(kde_vals))
kde_max_pnt = df.iloc[ix]
return kde_max_pnt
def compute_quantile(vals, weights, alpha):
"""
Compute `alpha`-quantile for points `vals` associated to weights
`weights`.
"""
return weighted_quantile(vals, weights, alpha=alpha)