"""Visualization of distance functions."""
from typing import Any
import matplotlib as mpl
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator
from ..storage import load_dict_from_json
from .util import get_labels, to_lists
[docs]
def plot_distance_weights(
log_files: list[str] | str,
ts: list[int] | list[str] | int | str = 'last',
labels: list[str] | str = None,
colors: list[Any] | Any = None,
linestyles: list[str] | str = None,
keys_as_labels: bool = True,
keys: list[str] = None,
xticklabel_rotation: float = 0,
normalize: bool = True,
size: tuple[float, float] = None,
xlabel: str = 'Summary statistic',
ylabel: str = 'Weight',
title: str = None,
ax: mpl.axes.Axes = None,
**kwargs,
) -> mpl.axes.Axes:
"""Plot distance weights, one curve per argument.
Assumes that the weights to be plotted from each file and timepoint have
the same keys.
Parameters
----------
log_files:
The weight log files, as passed to the distance functions, e.g. as
"scale_log_file" or "info_log_file".
ts:
Time points to plot. Defaults to the last entry in each file.
labels:
A label for each log file.
colors:
A color for each log file.
linestyles:
Linestyles to apply.
keys_as_labels:
Whether to use the summary statistic keys as x tick labels.
keys:
Data keys to plot.
xticklabel_rotation:
Angle by which to rotate x tick labels.
normalize:
Whether to normalize the weights to sum 1.
size:
Figure size in inches, (width, height).
xlabel:
x-axis label.
ylabel:
y-axis label.
title:
Plot title.
ax:
Axis object to use.
**kwargs:
Additional keyword arguments are passed on to `plt.plot()`
when plotting lines.
Returns
-------
The used axis object.
"""
log_files, ts, colors, linestyles = to_lists(
log_files, ts, colors, linestyles
)
labels = get_labels(labels, len(log_files))
# default keyword arguments
if 'marker' not in kwargs:
kwargs['marker'] = 'x'
n_run = len(log_files)
# create figure
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# add a line per file
for log_file, t, label, color, linestyle in zip(
log_files, ts, labels, colors, linestyles
):
weights = load_dict_from_json(log_file)
if t == 'last':
t = max(weights.keys())
weights = weights[t]
if keys is None:
keys = list(weights.keys())
weights = np.array([weights[key] for key in keys])
if normalize:
weights /= weights.sum()
ax.plot(
weights,
label=label,
color=color,
linestyle=linestyle,
**kwargs,
)
# add labels
if n_run > 1:
ax.legend()
# x axis ticks
if keys_as_labels:
ax.set_xticks(np.arange(len(keys)))
ax.set_xticklabels(keys, rotation=xticklabel_rotation)
else:
# enforce integer labels
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
if size is not None:
fig.set_size_inches(size)
return ax