Source code for cbx.scheduler

r"""
Scheduler
==========

This module implements the  schedulers employed in conensus based schemes.

"""

import numpy as np
from scipy.special import logsumexp
import warnings

[docs] class param_update(): r"""Base class for parameter updates This class implements the base class for parameter updates. Parameters ---------- name : str The name of the parameter that should be updated. The default is 'alpha'. maximum : float The maximum value of the parameter. The default is 1e5. """ def __init__( self, name: str ='alpha', maximum: float = 1e5, minimum: float = 1e-5 ): self.name = name self.maximum = maximum self.minimum = minimum
[docs] def update(self, dyn) -> None: """ Updates the object with the given `dyn` parameter. Parameters ---------- dyn The dynamic of which the parameter should be updated. Returns ------- None """ pass
[docs] def ensure_max(self, dyn): r"""Ensures that the parameter does not exceed its maximum value.""" setattr(dyn, self.name, np.minimum(self.maximum, getattr(dyn, self.name)))
[docs] class scheduler(): r"""scheduler class This class allows to update multiple parmeters with one update call. Parameters ---------- var_params : list A list of parameter updates, that implement an ``update`` function. """ def __init__(self, var_params): self.var_params = var_params
[docs] def update(self, dyn) -> None: """ Updates the dynamic variables in the object. Parameters ---------- dyn: The dynamic variables to update. Returns ------- None """ for var_param in self.var_params: var_param.update(dyn)
[docs] class multiply(param_update): def __init__(self, factor = 1.0, **kwargs): """ This scheduler updates the parameter as specified by ``'name'``, by multiplying it by a given ``'factor'``. Parameters ---------- factor : float The factor by which the parameter should be multiplied. """ super(multiply, self).__init__(**kwargs) self.factor = factor
[docs] def update(self, dyn) -> None: r"""Update the parameter as specified by ``'name'``, by multiplying it by a given ``'factor'``.""" old_val = getattr(dyn, self.name) new_val = self.factor * old_val setattr(dyn, self.name, new_val) self.ensure_max(dyn)
# class for alpha_eff scheduler
[docs] class effective_sample_size(param_update): r"""effective sample size scheduler class This class implements a scheduler for the :math:`\alpha`-parameter based on the effective sample size as inroduced in [1]_. In every step we try to find :math:`\alpha` such that The :math:`\alpha`-parameter is updated according to the rule .. math:: J_{eff}(\alpha) = \frac{\left(\sum_{i=1}^N w_i(\alpha)\right)^2}{\sum_{i=1}^N w_i(\alpha)^2} = \eta N where :math:`\eta` is a parameter, :math:`N` is the number of particles and :math:`w_i := \exp(-\alpha f(x_i))`. The above equation is solved via bisection. Parameters ---------- eta : float, optional The parameter :math:`\eta` of the scheduler. The default is 0.5. alpha_max : float, optional The maximum value of the :math:`\alpha`-parameter. The default is 100000.0. factor : float, optional The parameter :math:`r` of the scheduler. The default is 1.05. References ---------- .. [1] Carrillo, J. A., Hoffmann, F., Stuart, A. M., & Vaes, U. (2022). Consensus‐based sampling. Studies in Applied Mathematics, 148(3), 1069-1140. """ def __init__(self, name = 'alpha', eta=.5, maximum=1e5, solve_max_it = 15): super().__init__(name = name, maximum=maximum) if self.name != 'alpha': warnings.warn('effective_number scheduler only works for alpha parameter! You specified name = {}!'.format(self.name), stacklevel=2) self.eta = eta self.J_eff = 1.0 self.solve_max_it = solve_max_it
[docs] def update(self, dyn): val = getattr(dyn, self.name) val = bisection_solve( eff_sample_size_gap(dyn.energy, self.eta), self.minimum * np.ones((dyn.M,)), self.maximum * np.ones((dyn.M,)), max_it = self.solve_max_it, thresh=1e-2 ) setattr(dyn, self.name, val[:, None]) self.ensure_max(dyn)
class eff_sample_size_gap: r"""effective sample size gap This class is used for the effective sample size scheduler. Its call is defined as .. math:: \alpha \mapsto J_{eff}(\alpha) - \eta N. Therefore, the root of this non-increasing function solve the effective sampling size equation for :math:`\alpha`. """ def __init__(self, energy, eta): self.eta = eta self.energy = energy self.N = energy.shape[-1] def __call__(self, alpha): nom = logsumexp(-alpha[:, None] * self.energy, axis=-1) denom = logsumexp(-2 * alpha[:, None] * self.energy, axis=-1) return np.exp(2 * nom - denom) - self.eta * self.N def bisection_solve(f, low, high, max_it = 100, thresh = 1e-2, verbosity=0): r"""simple bisection optimization to solve for roots Parameters ---------- f : Callable A non-increasing function of which we want to find roots, it expects inputs of the shape (M,) where M denotes the number of different runs low: Array The low initial value for the bisection, should be an array of size (M,) high: Array The high initial value for the bisection, should be an array of size (M,) Returns ------- roots of the function f """ it = 0 x = high term = False idx = np.arange(len(low)) while not term: x = (low + high)/2 fx = f(x) gtzero = np.where(fx[idx] > 0)[0] ltzero = np.where(fx[idx] < 0)[0] # update low and high low[idx[gtzero]] = x[idx[gtzero]] high[idx[ltzero]] = x[idx[ltzero]] # update running idx and iteration idx = np.where(np.abs(fx) > thresh)[0] it += 1 term = (it > max_it) | (len(idx) == 0) if verbosity > 0: print('Finishing after ' + str(it) + ' Iterations') return x