"""Transform dictionaries to arrays."""

from numbers import Number
from typing import List, Union

import numpy as np
import pandas as pd

[docs] def dict2arr(dct: Union[dict, np.ndarray], keys: List) -> np.ndarray: """Convert dictionary to 1d array, in specified key order. Parameters ---------- dct: If dict-similar, values of all keys are extracted into a 1d array. Entries can be data frames, ndarrays, or single numbers. keys: Keys of interest, also defines the order. Returns ------- arr: 1d array of all concatenated values. """ if isinstance(dct, np.ndarray): return dct arr = [] for key in keys: val = dct[key] if isinstance(val, (pd.DataFrame, pd.Series)): arr.append(val.to_numpy().flatten()) elif isinstance(val, np.ndarray): arr.append(val.flatten()) elif isinstance(val, Number): arr.append([val]) else: raise TypeError( f"Cannot parse variable {key}={val} of type {type(val)} " "to numeric." ) # for efficiency, directly parse single entries if len(arr) == 1: return np.asarray(arr[0]) # flatten arr = [val for sub_arr in arr for val in sub_arr] return np.asarray(arr)
[docs] def dict2arrlabels(dct: dict, keys: List) -> List[str]: """Get label array consistent with the output of `dict2arr`. Can be called e.g. once on the observed data and used for logging. Parameters ---------- dct: Model output or observed data. keys: Keys of interest, also defines the order. Returns ------- labels: List of labels consistent with the output of `dict2arr`. """ labels = [] for key in keys: val = dct[key] if isinstance(val, (pd.DataFrame, pd.Series)): # default flattening mode is 'C', i.e. row-major, i.e. row-by-row for row in range(len(val.index)): for col in val.columns: labels.append(f"{key}:{col}:{row}") elif isinstance(val, np.ndarray): # array can have any dimension, thus just flat indices for ix in range(val.size): labels.append(f"{key}:{ix}") elif isinstance(val, Number): labels.append(key) else: raise TypeError( f"Cannot parse variable {key}={val} of type {type(val)} " "to numeric." ) return labels
[docs] def io_dict2arr(fun): """Wrapper parsing inputs dicts to ndarrays. Assumes the array is the first argument, and `self` holds a `keys` variable. """ def wrapped_fun(self, data: Union[dict, np.ndarray], *args, **kwargs): # convert input to array data = dict2arr(data, self.x_keys) # call the actual function ret: np.ndarray = fun(self, data, *args, **kwargs) # flatten output return ret.flatten() return wrapped_fun