mlwiz.model

model.interface

class mlwiz.model.interface.ModelInterface(*args: Any, **kwargs: Any)

Bases: Module

Provides the signature for any main model to be trained under MLWiz

Parameters:
  • dim_input_features (Union[int, Tuple[int]]) – dimension of node features (according to the DatasetInterface property)

  • dim_target (int) – dimension of the target (according to the DatasetInterface property)

  • config (dict) – config dictionary containing all the necessary hyper-parameters plus additional information (if needed)

forward(data: torch.Tensor | torch_geometric.data.Batch) Tuple[torch.Tensor, torch.Tensor | None, List[object] | None]

Performs a forward pass over a batch of graphs

Parameters:

data – a batch of samples

Returns:

a tuple (model’s output, [optional] node embeddings, [optional] additional outputs