import numpy as np
#%%
term_dict = {}
#%%
[docs]
class energy_tol_term:
"""
Check if the energy is below a certain tolerance.
Returns:
bool: True if the energy is below the tolerance, False otherwise.
"""
def __init__(self, energy_tol=1e-7):
self.energy_tol = energy_tol
[docs]
def __call__(self, dyn):
return dyn.f_min < self.energy_tol
term_dict.update(
dict.fromkeys(['energy-tol', 'energy tol', 'energy_tol'],
energy_tol_term
)
)
[docs]
class diff_tol_term:
"""
Checks if the update difference is less than the difference tolerance.
Returns:
bool: True if the update difference is less than the difference tolerance, False otherwise.
"""
def __init__(self, diff_tol=1e-7):
self.diff_tol = diff_tol
[docs]
def __call__(self, dyn):
return dyn.update_diff < self.diff_tol
term_dict.update(
dict.fromkeys(['diff-tol', 'diff tol', 'diff_tol'],
diff_tol_term
)
)
[docs]
class max_eval_term:
"""
Check if the number of function evaluations is greater than or equal to the maximum number of evaluations.
Returns:
bool: True if the number of function evaluations is greater than or equal to the maximum number of evaluations, False otherwise.
"""
def __init__(self, max_eval=1000):
self.max_eval = max_eval
[docs]
def __call__(self, dyn):
return dyn.num_f_eval >= self.max_eval
term_dict.update(
dict.fromkeys(['max-eval', 'max eval', 'max_eval'],
max_eval_term
)
)
[docs]
class max_it_term:
"""
Checks if the current value of `dyn.it` is greater than or equal to the value of `dyn.max_it`.
Returns:
bool: True if `dyn.it` is greater than or equal to `dyn.max_it`, False otherwise.
"""
def __init__(self, max_it=1000):
self.max_it = max_it
[docs]
def __call__(self, dyn):
if self.max_it is None:
return np.zeros((dyn.M), dtype=bool)
else:
return (dyn.it >= self.max_it) * np.ones((dyn.M), dtype=bool)
term_dict.update(
dict.fromkeys(['max-it', 'max it', 'max_it'],
max_it_term
)
)
[docs]
class max_time_term:
"""
Checks if the current value of `dyn` is greater than or equal to the value of `dyn.max_time`.
Returns:
bool: True if `dyn.t` is greater than or equal to `dyn.max_time`, False otherwise.
"""
def __init__(self, max_time=10.):
self.max_time = max_time
[docs]
def __call__(self, dyn):
return (dyn.t >= self.max_time) * np.ones((dyn.M), dtype=bool)
term_dict.update(
dict.fromkeys(['max-time', 'max time', 'max_time'],
max_time_term
)
)
#%%
class energy_stagnation_term:
"""
Checks if the loss was moving during the last iterations.
"""
def __init__(self, patience=20, std_thresh=1e-9):
self.patience = patience
self.losses = None
self.std_thresh = std_thresh
def __call__(self, dyn):
if self.losses is None:
self.losses = np.random.uniform(0., 1., size=(self.patience, dyn.M))
if dyn.consensus is None:
return np.zeros((dyn.M), dtype=bool)
# eval loss
E = dyn.f(dyn.consensus[dyn.active_runs_idx, ...])
dyn.num_f_eval[dyn.active_runs_idx] += 1
# update losses
self.losses[dyn.it%self.patience, dyn.active_runs_idx] = E
return np.std(self.losses, axis=0) < self.std_thresh
term_dict.update(
dict.fromkeys(['energy-stagnation', 'energy stagnation', 'energy_stagnation'],
energy_stagnation_term
)
)
#%%
def select_term(term):
if isinstance(term, str):
return term_dict[term]()
elif hasattr(term, 'keys'):
if 'name' in term.keys():
return term_dict[term['name']](
**{k:v for k,v in term.items() if k not in ['name']}
)
else:
raise ValueError('The given term dict: ' + str(term) + '\n ' +
'does not have the necessary key ' +
'"name"')
else:
return term