atomgen.models.schnet module#

SchNet model for energy prediction.

class SchNetConfig(vocab_size=123, hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, interaction_graph=None, max_num_neighbors=32, readout='add', dipole=False, mean=None, std=None, atomref=None, mask_token_id=0, pad_token_id=119, bos_token_id=120, eos_token_id=121, cls_token_id=122, **kwargs)[source]#

Bases: PretrainedConfig

Stores the configuration of a SchNetModel.

It is used to instantiate an SchNet model according to the specified arguments, defining the model architecture.

Args:
vocab_size (int, optional, defaults to 122):

The size of the vocabulary, used to define the size of the output embeddings.

hidden_channels (int, optional, defaults to 128):

The hidden size of the model.

model_type = “transformer”

vocab_size(:obj:`int`)#

The size of the vocabulary, used to define the size of the output embeddings.

hidden_channels(:obj:`int`)#

The hidden size of the model.

num_filters(:obj:`int`)#

The number of filters.

num_interactions(:obj:`int`)#

The number of interactions.

num_gaussians(:obj:`int`)#

The number of gaussians.

cutoff(:obj:`float`)#

The cutoff value.

interaction_graph(:obj:`str`, `optional`)#

The interaction graph.

max_num_neighbors(:obj:`int`)#

The maximum number of neighbors.

readout(:obj:`str`, `optional`)#

The readout method.

dipole(:obj:`bool`, `optional`)#

Whether to include dipole.

mean(:obj:`float`, `optional`)#

The mean value.

std(:obj:`float`, `optional`)#

The standard deviation value.

atomref(:obj:`float`, `optional`)#

The atom reference value.

mask_token_id(:obj:`int`, `optional`)#

The token ID for masking.

pad_token_id(:obj:`int`, `optional`)#

The token ID for padding.

bos_token_id(:obj:`int`, `optional`)#

The token ID for the beginning of sequence.

eos_token_id(:obj:`int`, `optional`)#

The token ID for the end of sequence.

class SchNetModel(config)[source]#

Bases: SchNetPreTrainedModel

SchNet model for energy prediction.

Args:
config (SchNetConfig):

Configuration class to store the configuration of a model.

forward(input_ids, coords, batch, labels_energy=None, fixed=None, attention_mask=None)[source]#

Forward pass of the SchNet model.

Args:
input_ids (torch.Tensor of shape (batch_size, num_atoms)):

The input tensor containing the atom indices.

coords (torch.Tensor of shape (num_atoms, 3)):

The input tensor containing the atom coordinates.

batch (torch.Tensor of shape (num_atoms)):

The input tensor containing the batch indices.

labels_energy (torch.Tensor, optional):

The input tensor containing the energy labels.

fixed (torch.Tensor, optional):

The input tensor containing the fixed mask.

attention_mask (torch.Tensor, optional):

The attention mask for the transformer.

Return type:

Tuple[Optional[Tensor], Tensor]

Returns:

tuple:

A tuple of the loss and the energy prediction.

class SchNetPreTrainedModel(config, *inputs, **kwargs)[source]#

Bases: PreTrainedModel

A base class for all SchNet models.

An abstract class to handle weights initialization and a simple interface for loading and exporting models.

base_model_prefix = 'model'#
config_class#

alias of SchNetConfig

supports_gradient_checkpointing = False#