Source code for pyabc.transition.grid_search

import logging

import numpy as np
from sklearn.model_selection import GridSearchCV as GridSearchCVSKL

from .multivariatenormal import MultivariateNormalTransition

logger = logging.getLogger("ABC.Transition")


[docs] class GridSearchCV(GridSearchCVSKL): """ Do a grid search to automatically select the best parameters for transition classes such as the :class:`pyabc.transition.MultivariateNormalTransition`. This is essentially a thin wrapper around 'sklearn.model_selection.GridSearchCV'. It translates the scikit-learn interface to the interface used in pyABC. It implements hence a thin `adapter pattern <https://en.wikipedia.org/wiki/Adapter_pattern>`_. The parameters are just as for sklearn.model_selection.GridSearchCV. Major default values: - estimator = MultivariateNormalTransition() - param_grid = {'scaling': np.linspace(0.05, 1.0, 5)} - cv = 5 """
[docs] def __init__( self, estimator=None, param_grid=None, scoring=None, n_jobs=1, refit=True, cv=5, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True, ): if estimator is None: estimator = MultivariateNormalTransition() if param_grid is None: param_grid = {'scaling': np.linspace(0.05, 1.0, 5)} self.best_estimator_ = None super().__init__( estimator=estimator, param_grid=param_grid, scoring=scoring, n_jobs=n_jobs, pre_dispatch=pre_dispatch, cv=cv, refit=refit, verbose=verbose, error_score=error_score, return_train_score=return_train_score, )
[docs] def fit(self, X, y=None, groups=None): """ Fit the density estimator (perturber) to the sampled data. """ if len(X) == 1: res = self.estimator.fit(X, y) self.best_estimator_ = self.estimator logger.debug( "Single sample Gridsearch. " f"Params: {self.estimator.get_params()}" ) return res if self.cv > len(X): # pylint: disable=E0203 old_cv = self.cv # pylint: disable=E0203 self.cv = len(X) res = super().fit(X, y, groups=groups) self.cv = old_cv logger.info( f"Reduced CV Gridsearch {self.cv} -> {len(X)}. " f"Best params: {self.best_params_}" ) return res res = super().fit(X, y, groups=groups) logger.info(f"Best params: {self.best_params_}") return res
def __getattr__(self, item): if item == "best_estimator_": raise AttributeError return getattr(self.best_estimator_, item)