r"""
Scheduler
==========
This module implements the schedulers employed in conensus based schemes.
"""
import numpy as np
from scipy.special import logsumexp
import warnings
[docs]
class param_update():
r"""Base class for parameter updates
This class implements the base class for parameter updates.
Parameters
----------
name : str
The name of the parameter that should be updated. The default is 'alpha'.
maximum : float
The maximum value of the parameter. The default is 1e5.
"""
def __init__(
self,
name: str ='alpha',
maximum: float = 1e5,
minimum: float = 1e-5
):
self.name = name
self.maximum = maximum
self.minimum = minimum
[docs]
def update(self, dyn) -> None:
"""
Updates the object with the given `dyn` parameter.
Parameters
----------
dyn
The dynamic of which the parameter should be updated.
Returns
-------
None
"""
pass
[docs]
def ensure_max(self, dyn):
r"""Ensures that the parameter does not exceed its maximum value."""
setattr(dyn, self.name, np.minimum(self.maximum, getattr(dyn, self.name)))
[docs]
class scheduler():
r"""scheduler class
This class allows to update multiple parmeters with one update call.
Parameters
----------
var_params : list
A list of parameter updates, that implement an ``update`` function.
"""
def __init__(self, var_params):
self.var_params = var_params
[docs]
def update(self, dyn) -> None:
"""
Updates the dynamic variables in the object.
Parameters
----------
dyn: The dynamic variables to update.
Returns
-------
None
"""
for var_param in self.var_params:
var_param.update(dyn)
[docs]
class multiply(param_update):
def __init__(self,
factor = 1.0,
**kwargs):
"""
This scheduler updates the parameter as specified by ``'name'``, by multiplying it by a given ``'factor'``.
Parameters
----------
factor : float
The factor by which the parameter should be multiplied.
"""
super(multiply, self).__init__(**kwargs)
self.factor = factor
[docs]
def update(self, dyn) -> None:
r"""Update the parameter as specified by ``'name'``, by multiplying it by a given ``'factor'``."""
old_val = getattr(dyn, self.name)
new_val = self.factor * old_val
setattr(dyn, self.name, new_val)
self.ensure_max(dyn)
# class for alpha_eff scheduler
[docs]
class effective_sample_size(param_update):
r"""effective sample size scheduler class
This class implements a scheduler for the :math:`\alpha`-parameter based on the effective sample size as inroduced
in [1]_. In every step we try to find :math:`\alpha` such that
The :math:`\alpha`-parameter is updated according to the rule
.. math::
J_{eff}(\alpha) = \frac{\left(\sum_{i=1}^N w_i(\alpha)\right)^2}{\sum_{i=1}^N w_i(\alpha)^2} = \eta N
where :math:`\eta` is a parameter, :math:`N` is the number of particles and :math:`w_i := \exp(-\alpha f(x_i))`. The above equation is solved via bisection.
Parameters
----------
eta : float, optional
The parameter :math:`\eta` of the scheduler. The default is 0.5.
alpha_max : float, optional
The maximum value of the :math:`\alpha`-parameter. The default is 100000.0.
factor : float, optional
The parameter :math:`r` of the scheduler. The default is 1.05.
References
----------
.. [1] Carrillo, J. A., Hoffmann, F., Stuart, A. M., & Vaes, U. (2022). Consensus‐based sampling. Studies in Applied Mathematics, 148(3), 1069-1140.
"""
def __init__(self, name = 'alpha', eta=.5, maximum=1e5, solve_max_it = 15):
super().__init__(name = name, maximum=maximum)
if self.name != 'alpha':
warnings.warn('effective_number scheduler only works for alpha parameter! You specified name = {}!'.format(self.name), stacklevel=2)
self.eta = eta
self.J_eff = 1.0
self.solve_max_it = solve_max_it
[docs]
def update(self, dyn):
val = getattr(dyn, self.name)
val = bisection_solve(
eff_sample_size_gap(dyn.energy, self.eta),
self.minimum * np.ones((dyn.M,)), self.maximum * np.ones((dyn.M,)),
max_it = self.solve_max_it, thresh=1e-2
)
setattr(dyn, self.name, val[:, None])
self.ensure_max(dyn)
class eff_sample_size_gap:
r"""effective sample size gap
This class is used for the effective sample size scheduler. Its call is defined as
.. math::
\alpha \mapsto J_{eff}(\alpha) - \eta N.
Therefore, the root of this non-increasing function solve the effective sampling size equation for :math:`\alpha`.
"""
def __init__(self, energy, eta):
self.eta = eta
self.energy = energy
self.N = energy.shape[-1]
def __call__(self, alpha):
nom = logsumexp(-alpha[:, None] * self.energy, axis=-1)
denom = logsumexp(-2 * alpha[:, None] * self.energy, axis=-1)
return np.exp(2 * nom - denom) - self.eta * self.N
def bisection_solve(f, low, high, max_it = 100, thresh = 1e-2, verbosity=0):
r"""simple bisection optimization to solve for roots
Parameters
----------
f : Callable
A non-increasing function of which we want to find roots, it expects inputs of the shape (M,) where M denotes
the number of different runs
low: Array
The low initial value for the bisection, should be an array of size (M,)
high: Array
The high initial value for the bisection, should be an array of size (M,)
Returns
-------
roots of the function f
"""
it = 0
x = high
term = False
idx = np.arange(len(low))
while not term:
x = (low + high)/2
fx = f(x)
gtzero = np.where(fx[idx] > 0)[0]
ltzero = np.where(fx[idx] < 0)[0]
# update low and high
low[idx[gtzero]] = x[idx[gtzero]]
high[idx[ltzero]] = x[idx[ltzero]]
# update running idx and iteration
idx = np.where(np.abs(fx) > thresh)[0]
it += 1
term = (it > max_it) | (len(idx) == 0)
if verbosity > 0:
print('Finishing after ' + str(it) + ' Iterations')
return x