Source code for cbx.utils.resampling

import numpy as np
from typing import Callable, List

def apply_resampling_default(dyn, idx, sigma_indep=0.1, var_name='x'):
    z = dyn.sampler(size=(len(idx), dyn.N, *dyn.d))
    getattr(dyn, var_name)[idx, ...] += sigma_indep * np.sqrt(dyn.dt) * z

[docs] class resampling: """ Resamplings from a list of callables Parameters ---------- resamplings: list The list of resamplings to apply. Each entry should be a callable that accepts exactly one argument (the dynamic object) and returns a one-dimensional numpy array of indices. apply: Callable - ``dyn``: The dynmaic which the resampling is applied to. - ``idx``: List of indices that are resampled. The function that should be performed on a given dynamic for selected indices. This function has to have the signature apply(dyn,idx). """ def __init__(self, resamplings: List[Callable], apply:Callable = None, sigma_indep:float = 0.1, var_name = 'x' ): self.resamplings = resamplings self.num_resampling = None self.apply = apply if apply is not None else apply_resampling_default self.sigma_indep = sigma_indep self.var_name = var_name
[docs] def __call__(self, dyn): """ Applies the resamplings to a given dynamic Parameters ---------- dyn The dynamic object to apply resamplings to Returns ------- None """ self.check_num_resamplings(dyn.M) idx = np.unique(np.concatenate([r(dyn) for r in self.resamplings])) if len(idx)>0: self.apply(dyn, idx, var_name = self.var_name, sigma_indep = self.sigma_indep) self.num_resampling[idx] += 1 if dyn.verbosity > 0: print('Resampled in runs ' + str(idx))
def check_num_resamplings(self, M): if self.num_resampling is None: self.num_resampling = np.zeros(shape=(M))
[docs] class ensemble_update_resampling: """ Resampling based on ensemble update difference Parameters ---------- update_thresh: float The threshold for ensemble update difference. When the update difference is less than this threshold, the ensemble is resampled. Returns ------- The indices of the runs to resample as a numpy array. """ def __init__(self, update_thresh:float): self.update_thresh = update_thresh
[docs] def __call__(self, dyn): return np.where(dyn.update_diff < self.update_thresh)[0]
class consensus_stagnation: def __init__(self, patience=5, update_thresh=1e-4): self.patience = patience self.update_thresh = update_thresh self.consensus_updates = [] self.it = 0 def __call__(self, dyn): return np.where(self.check_consensus_update(dyn) < self.update_thresh)[0] def check_consensus_update(self, dyn): self.it += 1 wt = dyn.x[:, 0, 0] + 1e10 if hasattr(self, 'consensus_old'): self.consensus_updates.append( dyn.norm( dyn.to_numpy(dyn.consensus - self.consensus_old), axis=-1)[:, 0] ) self.consensus_updates = self.consensus_updates[-self.patience:] wt = np.array(self.consensus_updates).max(axis=0) self.consensus_old = dyn.copy(dyn.consensus) return dyn.to_numpy(wt)
[docs] class loss_update_resampling: """ Resampling based on loss update difference Parameters ---------- M: int The number of runs in the dynamic object the resampling is applied to. wait_thresh: int The number of iterations to wait before resampling. The default is 5. If the best loss is not updated after the specified number of iterations, the ensemble is resampled. Returns ------- The indices of the runs to resample as a numpy array. """ def __init__(self, wait_thresh:int = 5): self.wait_thresh = wait_thresh self.initalized = False
[docs] def __call__(self, dyn): self.check_energy_wait(dyn.M) self.wait += 1 u_idx = self.best_energy > dyn.best_energy self.wait[u_idx] = 0 self.best_energy[u_idx] = dyn.best_energy[u_idx] idx = np.where(self.wait >= self.wait_thresh)[0] self.wait = np.mod(self.wait, self.wait_thresh) return idx
def check_energy_wait(self, M): if not self.initalized: self.best_energy = float('inf') * np.ones((M,)) self.wait = np.zeros((M,), dtype=int) self.initalized = True