mlwiz.training
training.callback
- callback.early_stopping
- callback.engine_callback
- callback.gradient_clipping
- callback.metric
AdditiveLossAdditiveLoss._instantiate_loss()AdditiveLoss.accumulate_predictions_and_targets()AdditiveLoss.compute_metric()AdditiveLoss.forward()AdditiveLoss.on_eval_batch_end()AdditiveLoss.on_eval_epoch_end()AdditiveLoss.on_eval_epoch_start()AdditiveLoss.on_training_batch_end()AdditiveLoss.on_training_epoch_end()AdditiveLoss.on_training_epoch_start()
AllocatedGPUMemoryClassificationMeanAbsoluteErrorMeanSquareErrorMetricMetric.accumulate_predictions_and_targets()Metric.compute_metric()Metric.forward()Metric.get_main_metric_name()Metric.get_predictions_and_targets()Metric.nameMetric.on_backward()Metric.on_compute_metrics()Metric.on_eval_batch_end()Metric.on_eval_batch_start()Metric.on_eval_epoch_end()Metric.on_eval_epoch_start()Metric.on_training_batch_end()Metric.on_training_batch_start()Metric.on_training_epoch_end()Metric.on_training_epoch_start()
MultiScoreMultiScore._istantiate_scorer()MultiScore.accumulate_predictions_and_targets()MultiScore.forward()MultiScore.get_main_metric_name()MultiScore.on_eval_batch_end()MultiScore.on_eval_epoch_end()MultiScore.on_eval_epoch_start()MultiScore.on_training_batch_end()MultiScore.on_training_epoch_end()MultiScore.on_training_epoch_start()
MulticlassAccuracyMulticlassClassificationRegressionSingleGraphMulticlassAccuracySingleGraphMulticlassClassificationToyMetric
- callback.optimizer
- callback.plotter
- callback.scheduler
training.event
- event.dispatcher
- event.handler
EventHandlerEventHandler.ON_BACKWARDEventHandler.ON_COMPUTE_METRICSEventHandler.ON_EPOCH_ENDEventHandler.ON_EPOCH_STARTEventHandler.ON_EVAL_BATCH_ENDEventHandler.ON_EVAL_BATCH_STARTEventHandler.ON_EVAL_EPOCH_ENDEventHandler.ON_EVAL_EPOCH_STARTEventHandler.ON_FETCH_DATAEventHandler.ON_FIT_ENDEventHandler.ON_FIT_STARTEventHandler.ON_FORWARDEventHandler.ON_TRAINING_BATCH_ENDEventHandler.ON_TRAINING_BATCH_STARTEventHandler.ON_TRAINING_EPOCH_ENDEventHandler.ON_TRAINING_EPOCH_STARTEventHandler.on_backward()EventHandler.on_compute_metrics()EventHandler.on_epoch_end()EventHandler.on_epoch_start()EventHandler.on_eval_batch_end()EventHandler.on_eval_batch_start()EventHandler.on_eval_epoch_end()EventHandler.on_eval_epoch_start()EventHandler.on_fetch_data()EventHandler.on_fit_end()EventHandler.on_fit_start()EventHandler.on_forward()EventHandler.on_training_batch_end()EventHandler.on_training_batch_start()EventHandler.on_training_epoch_end()EventHandler.on_training_epoch_start()
- event.state
training.engine
Training and evaluation loop implementation.
Defines TrainingEngine and helpers for checkpointing, metric reporting, and batching.
- class mlwiz.training.engine.DataStreamTrainingEngine(engine_callback: Callable[[...], EngineCallback], model: ModelInterface, loss: Metric, optimizer: Optimizer, scorer: Metric, scheduler: Scheduler | None = None, early_stopper: EarlyStopper | None = None, gradient_clipper: GradientClipper | None = None, device: str = 'cpu', plotter: Plotter | None = None, exp_path: str | None = None, evaluate_every: int = 1, eval_training: bool = False, eval_test_every_epoch: bool = False, store_last_checkpoint: bool = False)
Bases:
TrainingEngineClass that handles a stream of samples that could end at any moment.
- _loop(loader: torch_geometric.loader.DataLoader, _notify_progress: Callable[[str, dict], None])
Compared to superclass version, handles the issue of a stream of data that could end at any moment. This is done using an additional boolean state variable.
- class mlwiz.training.engine.TrainingEngine(engine_callback: Callable[[...], EngineCallback], model: ModelInterface, loss: Metric, optimizer: Optimizer, scorer: Metric, scheduler: Scheduler | None = None, early_stopper: EarlyStopper | None = None, gradient_clipper: GradientClipper | None = None, device: str = 'cpu', plotter: Plotter | None = None, exp_path: str | None = None, evaluate_every: int = 1, eval_training: bool = False, eval_test_every_epoch: bool = False, store_last_checkpoint: bool = False)
Bases:
EventDispatcherThis is the most important class when it comes to training a model. It implements the
EventDispatcherinterface, which means that after registering some callbacks in a given order, it will proceed to trigger specific events that will result in the sharedStateobject being updated by the callbacks. Callbacks implement the EventHandler interface, and they receive the shared State object when any event is triggered. Knowing the order in which callbacks are called is important. The order is:loss function
score function
gradient clipper
optimizer
early stopper
scheduler
plotter
- Parameters:
engine_callback – (Callable[…,
EngineCallback]): the engine callback object to be used for data fetching and checkpoints (or even other purposes if necessary)model (
ModelInterface) – the model to be trainedloss (
Metric) – the loss to be usedoptimizer (
Optimizer) – the optimizer to be usedscorer (
Metric) – the score to be usedscheduler (
Scheduler) – the scheduler to be used Default isNone.early_stopper –
- (
EarlyStopper): the early stopper to be used. Default is
None.
- (
gradient_clipper – (
GradientClipper): the gradient clipper to be used. Default isNone.device (str) – the device on which to train. Default is
cpu.plotter (
Plotter) – the plotter to be used. Default isNone.exp_path (str) – the path of the experiment folder. Default is
Nonebut it is always instantiated.evaluate_every (int) – the frequency of logging epoch results. Default is
1.eval_training (bool) – whether to re-evaluate loss and scores on the training set after a training epoch. Defaults to False.
eval_test_every_epoch (bool) – whether to evaluate loss and scores on the test set (if available) after each training epoch. Defaults to False because one should not care about test performance until the very end of risk assessment. However, set this to True if you want to log test metrics during training.
store_last_checkpoint (bool) – whether to store a checkpoint at the end of each epoch. Allows to resume training from last epoch. Default is
False.
- _check_termination()
Raises if an external termination has been requested.
- _loop(loader: torch_geometric.loader.DataLoader, _notify_progress: Callable[[str, dict], None])
Main method that computes a pass over the dataset using the data loader provided.
- Parameters:
loader (
torch_geometric.loader.DataLoader) – the loader to be used
- _loop_helper()
Helper function that loops over the data.
- _restore_checkpoint_and_best_results(ckpt_filename, best_ckpt_filename, zero_epoch)
Restores the (best or last) checkpoint from a given file, and loads the best results so far into the state if any.
- _to_data_list(x: torch.Tensor, batch: torch.Tensor, y: torch.Tensor | None) List[torch_geometric.data.Data]
Converts model outputs back to a list of Data elements. Used for graph data.
- Parameters:
x (
torch.Tensor) – tensor holding embedding information of different nodes/graphs embeddingsbatch (
torch.Tensor) – the usual PyG batch tensor. Used to split node/graph embeddings graph-wise.y (
torch.Tensor) – target labels, used to determine whether the task is graph prediction or node prediction. Can beNone.
- Returns:
a list of PyG Data objects (with only
xandyattributes)
- _to_list(data_list: List[torch_geometric.data.Data], embeddings: Tuple[torch.Tensor] | torch.Tensor, batch: torch.Tensor, y: torch.Tensor | None) List[torch_geometric.data.Data]
Extends the
data_listlist of PyG Data objects with new samples.- Parameters:
data_list – a list of PyG Data objects (with only
xandyattributes)embeddings (
torch.Tensor) – tensor holding information of different nodes/graphs embeddingsbatch (
torch.Tensor) – the usual PyG batch tensor. Used to split node/graph embeddings graph-wise.y (
torch.Tensor) – target labels, used to determine whether the task is graph prediction or node prediction. Can beNone.
- Returns:
a list of PyG Data objects (with only
xandyattributes)
- _train(loader, _notify_progress: Callable[[str, dict], None])
Implements a loop over the data in training mode
- cumulative_batch_unsent_time: float = 0.0
- cumulative_epoch_unsent_time: float = 0.0
- infer(loader: torch_geometric.loader.DataLoader, set: str, _notify_progress: Callable[[str, dict], None]) Tuple[dict, dict, List[torch.Tensor | torch_geometric.data.Data]]
Performs an evaluation step on the data.
- Parameters:
loader (
torch_geometric.loader.DataLoader) – the loader to be usedset (str) – the type of dataset being used, can be
TRAINING,VALIDATIONorTEST(as defined inmlwiz.static)
- Returns:
a tuple (loss dict, score dict, list of
torch_geometric.data.Dataobjects withxandyattributes only). The data list can be used, for instance, in semi-supervised experiments or in incremental architectures
- set_device()
Moves the model and the loss metric to the proper device.
- set_eval_mode()
Sets the model and the internal state in
EVALUATIONmode
- set_training_mode()
Sets the model and the internal state in
TRAININGmode
- train(train_loader: torch_geometric.loader.DataLoader, validation_loader: torch_geometric.loader.DataLoader | None = None, test_loader: torch_geometric.loader.DataLoader | None = None, max_epochs: int = 100, zero_epoch: bool = False, logger: Logger | None = None, training_timeout_seconds: int = -1, progress_callback: Callable[[dict], None] | None = None, should_terminate: Callable[[], bool] | None = None) Tuple[dict, dict, List[torch.Tensor | torch_geometric.data.Data], dict, dict, List[torch.Tensor | torch_geometric.data.Data], dict, dict, List[torch.Tensor | torch_geometric.data.Data]]
Trains the model and regularly evaluates on validation and test data (if given). May perform early stopping and checkpointing.
- Parameters:
train_loader (
torch_geometric.loader.DataLoader) – the DataLoader associated with training datavalidation_loader (
torch_geometric.loader.DataLoader) – the DataLoader associated with validation data, if anytest_loader (
torch_geometric.loader.DataLoader) – the DataLoader associated with test data, if anymax_epochs (int) – maximum number of training epochs. Default is
100zero_epoch – if
True, starts again from epoch 0 and resets optimizer and scheduler states. Default isFalselogger – the logger
progress_callback – optional callable that receives dictionaries with progress information for external consumers
should_terminate – optional callable returning
Truewhen a graceful termination has been requested
- Returns:
a tuple (train_loss, train_score, train_embeddings, validation_loss, validation_score, validation_embeddings, test_loss, test_score, test_embeddings)
- mlwiz.training.engine.fmt(x, decimals=2, sci_decimals=2)
Format number with fixed-point unless it’s small, then scientific.
- mlwiz.training.engine.reorder(obj: List[object], perm: List[int])
Reorders a list of objects in ascending order according to the indices defined in permutation argument.
- Parameters:
obj (List[object]) – the list of objects
perm (List[int]) – the permutation
- Returns:
The reordered list of objects
training.profiler
Callback execution-time profiler.
Provides Profiler, a decorator that records per-callback timings and generates a report.
- class mlwiz.training.profiler.Profiler(threshold: float)
Bases:
objectA decorator class that is applied to a
EventHandlerobject implementing a set of callback functions. For each callback, the Profiler stores the average and total running time across epochs. When the experiment terminates (either correctly or abruptly) the Profiler can produce a report to be stored in the experiment’s log file.The Profiler is used as a singleton, and it produces wrappers that update its own state.
- Parameters:
threshold (float) – used to filter out callback functions that consume a negligible amount of time from the report
- Usage:
Istantiate a profiler, and then register an event_handler with the syntax profiler(event_handler), which returns another object implementing the
EventHandlerinterface
- report() str
Builds a report string containing the statistics of the experiment accumulated so far.
- Returns:
a string containing the report
- property total_elapsed_time: timedelta
Return the accumulated elapsed time across all profiled callbacks.
- Returns:
Total elapsed time (sum of callback runtimes).
- Return type:
datetime.timedelta
training.util
Training utility helpers.
Includes atomic_torch_save() for safely writing checkpoint dictionaries.
- mlwiz.training.util.atomic_torch_save(data: dict, filepath: str)
Atomically stores a dictionary that can be serialized by
torch.save(), exploiting the atomicos.replace().- Parameters:
data (dict) – the dictionary to be stored
filepath (str) – the absolute filepath where to store the dictionary