Source code for cbx.utils.history
import numpy as np
from warnings import warn
#%%
[docs]
class track:
"""
Base class for tracking of variables in the history dictionary of given dynamics.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
"""
Initializes the value to be tracked in the history dictionary of the given dyn object.
Parameters
----------
dyn : object
The object to track in the history dictionary.
Returns
-------
None
"""
pass
[docs]
@staticmethod
def update(dyn) -> None:
"""
Updates the value to be tracked in the history dictionary of the given dyn object.
Parameters
----------
dyn : object
The object to track in the history dictionary.
Returns
-------
None
"""
pass
#%%
class default_track(track):
def __init__(self, name):
self.name = str(name)
self.tracking = True
def init_history(self, dyn) -> None:
dyn.history[self.name] = []
def update(self, dyn) -> None:
if self.tracking:
if hasattr(dyn, self.name):
dyn.history[self.name].append(
dyn.copy(getattr(dyn, self.name))
)
else:
warn('The tracker tried to track the variable ' +
self.name +
' which is not an attribute of the given dynamic. ' +
'The varibale will not be tracked!',
stacklevel=2)
self.tracking = False
#%%
[docs]
class track_x(track):
"""
Class for tracking of variable 'x' in the history dictionary.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
dyn.history['x'] = []
dyn.history['x'].append(dyn.x)
[docs]
@staticmethod
def update(dyn) -> None:
"""
Update the history of the 'x' variable by copying the current particles to the next time step.
Parameters
----------
dyn : object
The object to track in the history dictionary.
Returns
-------
None
"""
dyn.history['x'].append(dyn.copy(dyn.x))
[docs]
class track_update_norm(track):
"""
Class for tracking the 'update_norm' entry in the history.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
dyn.history['update_norm'] = []
[docs]
@staticmethod
def update(dyn) -> None:
"""
Updates the 'update_norm' entry in the 'history' dictionary with the 'update_diff' value.
Parameters:
None
Returns:
None
"""
dyn.history['update_norm'].append(dyn.update_diff)
[docs]
class track_energy(track):
"""
Class for tracking the 'energy' entry in the history.
"""
[docs]
@staticmethod
def init_history(dyn):
dyn.history['energy'] = []
[docs]
@staticmethod
def update(dyn) -> None:
dyn.history['energy'].append(np.copy(dyn.energy))# always assumes energy is numpy
[docs]
class track_consensus(track):
"""
Class for tracking the 'consensus' entry in the dynamic.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
dyn.history['consensus'] = []
[docs]
@staticmethod
def update(dyn) -> None:
dyn.history['consensus'].append(dyn.copy(dyn.consensus))
[docs]
class track_drift_mean(track):
"""
Class for tracking the 'drift_mean' entry in the history.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
dyn.history['drift_mean'] = []
[docs]
@staticmethod
def update(dyn) -> None:
dyn.history['drift_mean'].append(np.mean(np.abs(dyn.drift), axis=(-2,-1)))
[docs]
class track_drift(track):
"""
Class for tracking the 'drift' entry in the history.
"""
[docs]
@staticmethod
def init_history(dyn) -> None:
dyn.history['drift'] = []
dyn.history['particle_idx'] = []
[docs]
@staticmethod
def update(dyn) -> None:
dyn.history['drift'].append(dyn.drift)
dyn.history['particle_idx'].append(dyn.particle_idx)