mlwiz.model
model.interface
Abstract model interface expected by the training engine.
Defines ModelInterface, a torch.nn.Module with a standardized forward signature.
- class mlwiz.model.interface.ModelInterface(*args: Any, **kwargs: Any)
Bases:
ModuleProvides the signature for any main model to be trained under MLWiz
- Parameters:
dim_input_features (Union[int, Tuple[int]]) – dimension of node features (according to the
DatasetInterfaceproperty)dim_target (int) – dimension of the target (according to the
DatasetInterfaceproperty)config (dict) – config dictionary containing all the necessary hyper-parameters plus additional information (if needed)
- forward(data: torch.Tensor | torch_geometric.data.Batch) Tuple[torch.Tensor, torch.Tensor | None] | Tuple[torch.Tensor, torch.Tensor | None, List[object] | None]
Perform a forward pass over a batch of samples.
- Parameters:
data – a batch of samples
- Returns:
(output, embeddings)(output, embeddings, additional_outputs)
where
embeddingsandadditional_outputsare optional and may be omitted by models that do not produce them.- Return type:
Either