mlwiz.model

model.interface

Abstract model interface expected by the training engine.

Defines ModelInterface, a torch.nn.Module with a standardized forward signature.

class mlwiz.model.interface.ModelInterface(*args: Any, **kwargs: Any)

Bases: Module

Provides the signature for any main model to be trained under MLWiz

Parameters:
  • dim_input_features (Union[int, Tuple[int]]) – dimension of node features (according to the DatasetInterface property)

  • dim_target (int) – dimension of the target (according to the DatasetInterface property)

  • config (dict) – config dictionary containing all the necessary hyper-parameters plus additional information (if needed)

forward(data: torch.Tensor | torch_geometric.data.Batch) Tuple[torch.Tensor, torch.Tensor | None] | Tuple[torch.Tensor, torch.Tensor | None, List[object] | None]

Perform a forward pass over a batch of samples.

Parameters:

data – a batch of samples

Returns:

  • (output, embeddings)

  • (output, embeddings, additional_outputs)

where embeddings and additional_outputs are optional and may be omitted by models that do not produce them.

Return type:

Either