atomgen.models.schnet module#
SchNet model for energy prediction.
- class SchNetConfig(vocab_size=123, hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, interaction_graph=None, max_num_neighbors=32, readout='add', dipole=False, mean=None, std=None, atomref=None, mask_token_id=0, pad_token_id=119, bos_token_id=120, eos_token_id=121, cls_token_id=122, **kwargs)[source]#
Bases:
PretrainedConfig
Stores the configuration of a
SchNetModel
.It is used to instantiate an SchNet model according to the specified arguments, defining the model architecture.
- Args:
model_type = “transformer”
- vocab_size(:obj:`int`)#
The size of the vocabulary, used to define the size of the output embeddings.
The hidden size of the model.
- num_filters(:obj:`int`)#
The number of filters.
- num_interactions(:obj:`int`)#
The number of interactions.
- num_gaussians(:obj:`int`)#
The number of gaussians.
- cutoff(:obj:`float`)#
The cutoff value.
- interaction_graph(:obj:`str`, `optional`)#
The interaction graph.
- max_num_neighbors(:obj:`int`)#
The maximum number of neighbors.
- readout(:obj:`str`, `optional`)#
The readout method.
- dipole(:obj:`bool`, `optional`)#
Whether to include dipole.
- mean(:obj:`float`, `optional`)#
The mean value.
- std(:obj:`float`, `optional`)#
The standard deviation value.
- atomref(:obj:`float`, `optional`)#
The atom reference value.
- mask_token_id(:obj:`int`, `optional`)#
The token ID for masking.
- pad_token_id(:obj:`int`, `optional`)#
The token ID for padding.
- bos_token_id(:obj:`int`, `optional`)#
The token ID for the beginning of sequence.
- eos_token_id(:obj:`int`, `optional`)#
The token ID for the end of sequence.
- class SchNetModel(config)[source]#
Bases:
SchNetPreTrainedModel
SchNet model for energy prediction.
- Args:
- config (
SchNetConfig
): Configuration class to store the configuration of a model.
- config (
- forward(input_ids, coords, batch, labels_energy=None, fixed=None, attention_mask=None)[source]#
Forward pass of the SchNet model.
- Args:
- input_ids (
torch.Tensor
of shape(batch_size, num_atoms)
): The input tensor containing the atom indices.
- coords (
torch.Tensor
of shape(num_atoms, 3)
): The input tensor containing the atom coordinates.
- batch (
torch.Tensor
of shape(num_atoms)
): The input tensor containing the batch indices.
- labels_energy (
torch.Tensor
, optional): The input tensor containing the energy labels.
- fixed (
torch.Tensor
, optional): The input tensor containing the fixed mask.
- attention_mask (
torch.Tensor
, optional): The attention mask for the transformer.
- input_ids (
- class SchNetPreTrainedModel(config, *inputs, **kwargs)[source]#
Bases:
PreTrainedModel
A base class for all SchNet models.
An abstract class to handle weights initialization and a simple interface for loading and exporting models.
- base_model_prefix = 'model'#
- config_class#
alias of
SchNetConfig
- supports_gradient_checkpointing = False#