callback.early_stopping

Early stopping callbacks for the training engine.

Provides EarlyStopper and PatienceEarlyStopper.

class mlwiz.training.callback.early_stopping.EarlyStopper(monitor: str, mode: str, checkpoint: bool = False)

Bases: EventHandler

EarlyStopper is the main event handler for early stopping. Just create a subclass that implements an early stopping method.

Parameters:
  • monitor (str) – the metric to monitor. The format is [TRAINING|VALIDATION]_[METRIC NAME], where TRAINING and VALIDATION are defined in mlwiz.static

  • mode (str) – can be MIN or MAX (as defined in mlwiz.static)

  • checkpoint (bool) – whether we are interested in the checkpoint of the “best” epoch or not

on_epoch_end(state: State)

At the end of an epoch, check that the validation score improves over the current best validation score. If so, store the necessary info in a dictionary and save it into the “best_epoch_results” property of the state. If it is time to stop, updates the stop_training field of the state.

Parameters:

state (State) – object holding training information

stop(state: State, score_or_loss: str, metric: str) bool

Returns true when the early stopping technique decides it is time to stop.

Parameters:
  • state (State) – object holding training information

  • score_or_loss (str) – whether to monitor scores or losses

  • metric (str) – the metric to consider. The format is [TRAINING|VALIDATION]_[METRIC NAME], where TRAINING and VALIDATION are defined in mlwiz.static

Returns:

a boolean specifying whether training should be stopped or not

class mlwiz.training.callback.early_stopping.PatienceEarlyStopper(monitor, mode, patience=30, checkpoint=False)

Bases: EarlyStopper

Early Stopper that implements patience

Parameters:
  • monitor (str) – the metric to monitor. The format is [TRAINING|VALIDATION]_[METRIC NAME], where TRAINING and VALIDATION are defined in mlwiz.static

  • mode (str) – can be MIN or MAX (as defined in mlwiz.static)

  • patience (int) – the number of epochs of patience

  • checkpoint (bool) – whether we are interested in the checkpoint of the “best” epoch or not

stop(state, score_or_loss, metric)

Returns true when the number of epochs without improvement is greater than our patience parameter.

Parameters:
  • state (State) – object holding training information

  • score_or_loss (str) – whether to monitor scores or losses

  • metric (str) – the metric to consider. The format is [TRAINING|VALIDATION]_[METRIC NAME], where TRAINING and VALIDATION are defined in mlwiz.static

Returns:

a boolean specifying whether training should be stopped or not

callback.engine_callback

Engine callback implementations for data fetching and checkpointing.

Provides EngineCallback and iterable-dataset variants.

class mlwiz.training.callback.engine_callback.EngineCallback(store_last_checkpoint: bool)

Bases: EventHandler

Class responsible for fetching data and handling current-epoch checkpoints

at training time.

Parameters:

store_last_checkpoint (bool) – If True, write a checkpoint file at the end of each epoch (see on_epoch_end()).

on_epoch_end(state: State)

Stores the checkpoint in a dictionary with the following fields:

  • EPOCH (as defined in mlwiz.static)

  • MODEL_STATE (as defined in mlwiz.static)

  • OPTIMIZER_STATE (as defined in mlwiz.static)

  • SCHEDULER_STATE (as defined in mlwiz.static)

  • STOP_TRAINING (as defined in mlwiz.static)

Parameters:

state (State) – object holding training information

on_fetch_data(state: State)

Fetches next batch of data from loader and updates the batch_input field of the state

Parameters:

state (State) – object holding training information

on_forward(state: State)

Calls the forward method of the model and stores the outputs in the batch_outputs field of the state.

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.engine_callback.IterableEngineCallback(store_last_checkpoint: bool)

Bases: EngineCallback

Class that extends mlwiz.training.callback.EngineCallback to the processing of Iterable-style datasets. Needs to be used together with the appropriate engine class (DataStreamTrainingEngine).

on_fetch_data(state: State)

Fetches next batch of data from loader (if any, as data comes from a stream of unknown length) and updates the batch_input field of the state

Parameters:

state (State) – object holding training information

callback.gradient_clipping

Gradient clipping callback for the training engine.

Defines GradientClipper.

class mlwiz.training.callback.gradient_clipping.GradientClipper(clip_value: float, **kwargs: dict)

Bases: EventHandler

GradientClipper is the main event handler for gradient clippers. Configure it in the experiment configuration to enable gradient clipping.

Parameters:
  • clip_value (float) – the gradient will be clipped in [-clip_value, clip_value]

  • kwargs (dict) – additional arguments

on_backward(state: State)

Clips the gradients of the model before the weights are updated.

Parameters:

state (State) – object holding training information

callback.metric

Metric and loss callbacks for the training engine.

Defines Metric and composite helpers such as AdditiveLoss and MultiScore.

class mlwiz.training.callback.metric.AdditiveLoss(*args: Any, **kwargs: Any)

Bases: Metric

AdditiveLoss sums an arbitrary number of losses together.

Parameters:
  • use_as_loss (bool) – whether this metric should act as a loss (i.e., it should act when on_backward() is called). Used by MLWiz, no need to care about this.

  • reduction (str) – the type of reduction to apply across samples of the mini-batch. Supports mean and sum. Default is mean.

  • accumulate_over_epoch (bool) – Whether or not to display the epoch-wise metric rather than an average of per-batch metrics. If true, it keeps a list of predictions and target values across the entire epoch. Use it especially with batch-sensitive metrics, such as micro AP/F1 scores. Default is True.

  • force_cpu (bool) – Whether or not to move all predictions to cpu before computing the epoch-wise loss/score. Default is True.

  • device (bool) – The device used. Default is ‘cpu’.

  • losses_weights – (dict): dictionary of (loss_name, loss_weight) that specifies the weight to apply to each loss to be summed.

  • losses (dict) – dictionary of metrics to add together

_instantiate_loss(loss)

Instantiate a loss with its own arguments (if any are given).

accumulate_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) None

Accumulates predictions and targets of the batch into a list for each loss, so as to compute aggregated statistics at the end of an epoch.

Parameters:

state (State) – object holding training information

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Sums the value of all different losses into one

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

forward(targets: torch.Tensor, *outputs: List[torch.Tensor]) dict

For each scorer, it computes a loss and returns them into a dictionary, alongside the sum of all losses.

Parameters:

state (State) – object holding training information

on_eval_batch_end(state: State)

For each loss, computes the average metric in the batch wrt the number of timesteps (default is 1 for static datasets) unless statistics are accumulated over the entire epoch

Parameters:

state (State) – object holding training information

on_eval_epoch_end(state: State)

Computes an averaged or aggregated loss across the entire epoch, including itself as the main loss. Updates the field epoch_loss in state.

Parameters:
  • state (State)

  • state – object holding training information

on_eval_epoch_start(state: State)

Instantiates a dictionary with one list per loss (including itself, representing the sum of all losses)

Parameters:

state (State) – object holding training information

on_training_batch_end(state: State)

For each loss, computes the average metric in the batch wrt the number of timesteps (default is 1 for static datasets) unless statistics are accumulated over the entire epoch

Parameters:

state (State) – object holding training information

on_training_epoch_end(state: State)

Computes an averaged or aggregated loss across the entire epoch, including itself as the main loss. Updates the field epoch_loss in state.

Parameters:

state (State) – object holding training information

on_training_epoch_start(state: State)

Instantiates a dictionary with one list per loss (including itself, representing the sum of all losses)

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.metric.AllocatedGPUMemory(*args: Any, **kwargs: Any)

Bases: Metric

Metric that reports current allocated CUDA memory (in MB).

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Compute the metric from predicted values.

Parameters:
  • targets (torch.Tensor) – Unused (kept for API compatibility).

  • predictions (torch.Tensor) – Tensor containing memory values (MB).

Returns:

Mean of predictions.

Return type:

torch.Tensor

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Query the current allocated GPU memory.

Parameters:
  • targets (torch.Tensor) – Unused.

  • *outputs (List[torch.Tensor]) – Unused.

Returns:

A pair of identical 1D tensors containing the allocated GPU memory in megabytes (MB).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Raises:

RuntimeError – If CUDA is not available or not initialized.

Side effects:

Queries the CUDA runtime via torch.cuda.memory_allocated().

class mlwiz.training.callback.metric.Classification(*args: Any, **kwargs: Any)

Bases: Metric

Generic metric for classification tasks. Used to maximize code reuse for classical metrics.

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Applies a classification metric (to be subclassed as it is None in this class)

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Returns output[0] as predictions and dataset targets. Squeezes the first dimension of output and targets to get single vector.

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

Returns:

A tuple of tensors (predicted_values, target_values)

class mlwiz.training.callback.metric.MeanAbsoluteError(*args: Any, **kwargs: Any)

Bases: Regression

Wrapper around torch.nn.L1Loss

class mlwiz.training.callback.metric.MeanSquareError(*args: Any, **kwargs: Any)

Bases: Regression

Wrapper around torch.nn.MSELoss

class mlwiz.training.callback.metric.Metric(*args: Any, **kwargs: Any)

Bases: Module, EventHandler

Metric is the main event handler for all metrics. Other metrics can easily subclass by implementing the forward() method, though sometimes more complex implementations are required.

Parameters:
  • use_as_loss (bool) – whether this metric should act as a loss (i.e., it should act when on_backward() is called). Used by MLWiz, no need to care about this.

  • reduction (str) – the type of reduction to apply across samples of the mini-batch. Supports mean and sum. Default is mean.

  • accumulate_over_epoch (bool) – Whether or not to display the epoch-wise metric rather than an average of per-batch metrics. If true, it keep a list of predictions and target values across the entire epoch. Use it especially with batch-sensitive metrics, such as micro AP/F1 scores. Default is True.

  • force_cpu (bool) – Whether or not to move all predictions to cpu before computing the epoch-wise loss/score. Default is True.

  • device (bool) – The device used. Default is ‘cpu’.

  • kwargs (dict) – additional arguments that may depend on the metric

accumulate_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) None

Used to specify how to accumulate predictions and targets. This can be customized by subclasses like AdditiveLoss and MultiScore to accumulate predictions and targets for different losses/scores.

Parameters:
  • targets – target tensor

  • *outputs – outputs of the model

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Computes the metric for a given set of targets and predictions

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

forward(targets: torch.Tensor, *outputs: List[torch.Tensor]) dict

Computes the metric value. Optionally, and only for scores used as losses, some extra information can be also returned.

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

  • batch_loss_extra (dict) – dictionary of information computed by metrics used as losses

Returns:

A dictionary containing associations metric_name - value

get_main_metric_name() str

Return the metric’s main name. Useful when a metric is the combination of many.

Returns:

the metric’s main name

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Returns predictions and target tensors to be accumulated for a given metric

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

Returns:

A tuple of tensors (predicted_values, target_values)

property name: str

The name of the loss to be used in configuration files and displayed on Tensorboard. It is the same as the class name.

on_backward(state: State)

Calls backward on the loss if the metric is a loss.

Parameters:

state (State) – object holding training information

on_compute_metrics(state: State)

Computes the loss/score depending on the metric, updating the batch_loss or batch_score field in the state. In temporal graph learning, this method is computed more than once before the batch ends, so we accumulate the loss or scores across timesteps of a single batch.

Parameters:

state (State) – object holding training information

on_eval_batch_end(state: State)

If we do not computed aggregated metric values over the entire epoch, populate the batch metrics list with the new loss/score. Divide by the number of timesteps in the batch (default is 1 for static datasets)

Parameters:

state (State) – object holding training information

on_eval_batch_start(state: State)

Initializes the number of potential time steps in a batch (for temporal learning)

Parameters:

state (State) – object holding training information

on_eval_epoch_end(state: State)

Computes the mean of batch metrics or an aggregated score over all epoch depending on the accumulate_over_epoch parameter. Updates epoch_loss and epoch_score fields in the state and resets the basic fields used.

Parameters:

state (State) – object holding training information

on_eval_epoch_start(state: State)

Initialize list of batch metrics as well as the list of batch predictions and targets for the metric

Parameters:

state (State) – object holding training information

on_training_batch_end(state: State)

If we do not computed aggregated metric values over the entire epoch, populate the batch metrics list with the new loss/score. Divide by the number of timesteps in the batch (default is 1 for static datasets)

Parameters:

state (State) – object holding training information

on_training_batch_start(state: State)

Initializes the number of potential time steps in a batch (for temporal learning)

Parameters:

state (State) – object holding training information

on_training_epoch_end(state: State)

Computes the mean of batch metrics or an aggregated score over all epoch depending on the accumulate_over_epoch parameter. Updates epoch_loss and epoch_score fields in the state and resets the basic fields used.

Parameters:

state (State) – object holding training information

on_training_epoch_start(state: State)

Initialize list of batch metrics as well as the list of batch predictions and targets for the metric

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.metric.MultiScore(*args: Any, **kwargs: Any)

Bases: Metric

This class is used to keep track of multiple additional metrics used as scores, rather than losses.

Parameters:
  • use_as_loss (bool) – whether this metric should act as a loss (i.e., it should act when on_backward() is called). Used by MLWiz, no need to care about this.

  • reduction (str) – the type of reduction to apply across samples of the mini-batch. Supports mean and sum. Default is mean.

  • accumulate_over_epoch (bool) – Whether or not to display the epoch-wise metric rather than an average of per-batch metrics. If true, it keeps a list of predictions and target values across the entire epoch. Use it especially with batch-sensitive metrics, such as micro AP/F1 scores. Default is True.

  • force_cpu (bool) – Whether or not to move all predictions to cpu before computing the epoch-wise loss/score. Default is True.

  • device (bool) – The device used. Default is ‘cpu’.

  • main_scorer (Metric) – the score on which final results are computed.

  • extra_scorers (dict) – dictionary of other metrics to consider.

_istantiate_scorer(scorer)

Instantiate a scorer with its own arguments (if any are given).

accumulate_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) None

Accumulates predictions and targets of the batch into a list for each scorer, so as to compute aggregated statistics at the end of an epoch.

Parameters:

state (State) – object holding training information

forward(targets: torch.Tensor, *outputs: List[torch.Tensor]) dict | float

For each scorer, it computes a score and returns them into a dictionary

Parameters:

state (State) – object holding training information

get_main_metric_name()

Returns the name of the first scorer that is passed to this class via the __init__ method.

on_eval_batch_end(state: State)

For each scorer, computes the average metric in the batch wrt the number of timesteps (default is 1 for static datasets) unless statistics are accumulated over the entire epoch

Parameters:

state (State) – object holding training information

on_eval_epoch_end(state: State)

For each score, computes the epoch scores using the same logic as the superclass

Parameters:

state (State) – object holding training information

on_eval_epoch_start(state: State)

Compared to superclass version, initializes a dictionary for each score to track rather than single lists

Args:
state (State):

object holding training information

on_training_batch_end(state: State)

For each scorer, computes the average metric in the batch wrt the number of timesteps (default is 1 for static datasets) unless statistics are accumulated over the entire epoch

Parameters:

state (State) – object holding training information

on_training_epoch_end(state: State)

For each score, computes the epoch scores using the same ogic as the superclass

Parameters:

state (State) – object holding training information

on_training_epoch_start(state: State)

Compared to superclass version, initializes a dictionary for each score to track rather than single lists

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.metric.MulticlassAccuracy(*args: Any, **kwargs: Any)

Bases: Metric

Implements multiclass classification accuracy.

static _get_correct(output)

Returns the argmax of the output alongside dimension 1.

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Takes output[0] as predictions and computes a discrete class using argmax. Returns standard dataset targets as well. Squeezes the first dimension of output and targets to get single vector.

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Takes output[0] as predictions and computes a discrete class using argmax. Returns standard dataset targets as well. Squeezes the first dimension of output and targets to get single vector.

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

Returns:

A tuple of tensors (predicted_values, target_values)

class mlwiz.training.callback.metric.MulticlassClassification(*args: Any, **kwargs: Any)

Bases: Classification

Wrapper around torch.nn.CrossEntropyLoss

class mlwiz.training.callback.metric.Regression(*args: Any, **kwargs: Any)

Bases: Metric

Generic metric for regression tasks. Used to maximize code reuse for classical metrics.

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Applies a regression metric (to be subclassed as it is None in this class)

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Returns output[0] as predictions and dataset targets. Squeezes the first dimension of output and targets to get single vector.

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

Returns:

A tuple of tensors (predicted_values, target_values)

class mlwiz.training.callback.metric.SingleGraphMulticlassAccuracy(*args: Any, **kwargs: Any)

Bases: MulticlassAccuracy

Wrapper around torch.nn.CrossEntropyLoss

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Extract predicted classes/targets for single-graph node accuracy.

Parameters:
  • targets (torch.Tensor) – Full target tensor for all nodes.

  • *outputs (List[torch.Tensor]) – Model outputs expected to be a tuple (o, embeddings, idxs) where idxs selects the nodes to evaluate (train or eval indices).

Returns:

(pred, t) where pred are the predicted classes and t are the corresponding targets.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class mlwiz.training.callback.metric.SingleGraphMulticlassClassification(*args: Any, **kwargs: Any)

Bases: MulticlassClassification

Wrapper around torch.nn.CrossEntropyLoss

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Extract predictions/targets for single-graph node classification.

Parameters:
  • targets (torch.Tensor) – Full target tensor for all nodes.

  • *outputs (List[torch.Tensor]) – Model outputs expected to be a tuple (o, embeddings, idxs) where idxs selects the nodes to evaluate (train or eval indices).

Returns:

(o, t) where o are the selected node logits and t are the corresponding targets.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class mlwiz.training.callback.metric.ToyMetric(*args: Any, **kwargs: Any)

Bases: Metric

Implements a toy metric.

static _get_correct(output)

Returns the argmax of the output alongside dimension 1.

compute_metric(targets: torch.Tensor, predictions: torch.Tensor) torch.tensor

Computes a dummy score

Parameters:
  • targets (torch.Tensor) – tensor of ground truth values

  • predictions (torch.Tensor) – tensor of predictions of the model

Returns:

A tensor with the metric value

get_predictions_and_targets(targets: torch.Tensor, *outputs: List[torch.Tensor]) Tuple[torch.Tensor, torch.Tensor]

Returns output[0] and dataset targets

Parameters:
  • targets (torch.Tensor) – ground truth

  • outputs (List[torch.Tensor]) – outputs of the model

Returns:

A tuple of tensors (predicted_values, target_values)

callback.optimizer

Optimizer callback wrapper for the training engine.

Instantiates a PyTorch optimizer from a dotted path and exposes lifecycle hooks.

class mlwiz.training.callback.optimizer.Optimizer(model: ModelInterface, optimizer_class_name: str, accumulate_gradients: bool = False, **kwargs: dict)

Bases: EventHandler

Optimizer is the main event handler for optimizers. Just pass a PyTorch optimizer together with its arguments in the configuration file.

Parameters:
  • model (ModelInterface) – the model that has to be trained

  • optimizer_class_name (str) – dotted path to the optimizer class to use

  • accumulate_gradients (bool) – if True, accumulate mini-batch gradients to perform a batch gradient update without loading the entire batch in memory

  • kwargs (dict) – additional parameters for the specific optimizer

load_state_dict(state_dict)

Loads the state_dict of the optimizer from a checkpoint

Parameters:

state (State) – object holding training information

on_epoch_end(state)

Updates the state of the optimizer into the state at the end of the epoch

Parameters:

state (State) – object holding training information

on_fit_start(state)

If a checkpoint is present, load the state of the optimizer

Parameters:

state (State) – object holding training information

on_training_batch_end(state)

At the end of a batch, if batch updates are in order, performs a weight update

Parameters:

state (State) – object holding training information

on_training_batch_start(state)

At the start of a batch, if batch updates are in order, zeroes the gradient of the optimizer

Parameters:

state (State) – object holding training information

on_training_epoch_end(state)

At the end of a batch, and if the gradient has been accumulated across the entire epoch, performs a weight update

Parameters:

state (State) – object holding training information

on_training_epoch_start(state)

At the start of epoch, and if the gradient has been accumulated across the entire epoch, zeroes the gradient of the optimizer.

Parameters:

state (State) – object holding training information

callback.plotter

TensorBoard logging callback for training runs.

The Plotter writes per-epoch metrics and optional on-disk histories.

class mlwiz.training.callback.plotter.Plotter(exp_path: str, store_on_disk: bool = False, enable_tensorboard: bool = True, **kwargs: dict)

Bases: EventHandler

Plotter is the main event handler for plotting at training time.

Parameters:
  • exp_path (str) – path where to store the Tensorboard logs

  • store_on_disk (bool) – whether to store all metrics on disk. Defaults to False

  • kwargs (dict) – additional arguments that may depend on the plotter

on_epoch_end(state: State)

Writes Training, Validation and (if any) Test metrics to Tensorboard

Parameters:

state (State) – object holding training information

on_fit_end(state: State)

Frees resources by closing the Tensorboard writer

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.plotter._NullSummaryWriter

Bases: object

Drop-in writer used when TensorBoard logging is disabled.

add_scalars(*args, **kwargs)

Ignore scalar logging calls.

close()

No-op close for API compatibility.

callback.scheduler

Learning-rate scheduler callbacks for the training engine.

Provides epoch-based and metric-based scheduler wrappers.

class mlwiz.training.callback.scheduler.EpochScheduler(scheduler_class_name: str, optimizer: torch.optim.optimizer.Optimizer, **kwargs: dict)

Bases: Scheduler

Implements a scheduler which uses epochs to modify the step size

on_training_epoch_end(state: State)

Performs a scheduler’s step at the end of the training epoch.

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.scheduler.MetricScheduler(scheduler_class_name: str, use_loss: bool, monitor: str, optimizer: torch.optim.optimizer.Optimizer, **kwargs: dict)

Bases: Scheduler

Implements a scheduler which uses variations in the metric of interest to modify the step size

Parameters:
  • scheduler_class_name (str) – dotted path to class name of the scheduler

  • use_loss (str) – whether to monitor scores or losses

  • monitor (str) – the metric to monitor. The format is [TRAINING|VALIDATION]_[METRIC NAME], where TRAINING and VALIDATION are defined in mlwiz.static

  • optimizer (torch.optim.optimizer) – the Pytorch optimizer to use. This is automatically recovered by MLWiz when providing an optimizer

  • kwargs – additional parameters for the specific scheduler to be used

on_epoch_end(state: State)

Updates the state of the scheduler according to a metric to monitor at each epoch. Finally, loads the scheduler state if already present in the state_dict of a checkpoint

Parameters:

state (State) – object holding training information

class mlwiz.training.callback.scheduler.Scheduler(scheduler_class_name: str, optimizer: torch.optim.optimizer.Optimizer, **kwargs: dict)

Bases: EventHandler

Scheduler is the main event handler for schedulers. Just pass a PyTorch scheduler together with its arguments in the configuration file.

Parameters:
  • scheduler_class_name (str) – dotted path to class name of the scheduler

  • optimizer (torch.optim.optimizer) – the Pytorch optimizer to use. This is automatically recovered by MLWiz when providing an optimizer

  • kwargs – additional parameters for the specific scheduler to be used

on_epoch_end(state: State)

Updates the scheduler state with the current one for checkpointing

Parameters:

state (State) – object holding training information

on_fit_start(state: State)

Loads the scheduler state if already present in the state_dict of a checkpoint

Parameters:

state (State) – object holding training information