Source code for atomgen.models.schnet

"""SchNet model for energy prediction."""

from typing import Any, Callable, Optional, Tuple

import torch
from torch import nn
from torch_geometric.nn import SchNet
from torch_geometric.typing import OptTensor
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel


[docs] class SchNetConfig(PretrainedConfig): # type: ignore r""" Stores the configuration of a :class:`~transformers.SchNetModel`. It is used to instantiate an SchNet model according to the specified arguments, defining the model architecture. Args: vocab_size (:obj:`int`, `optional`, defaults to 122): The size of the vocabulary, used to define the size of the output embeddings. hidden_channels (:obj:`int`, `optional`, defaults to 128): The hidden size of the model. model_type = "transformer" Attributes ---------- vocab_size (:obj:`int`): The size of the vocabulary, used to define the size of the output embeddings. hidden_channels (:obj:`int`): The hidden size of the model. num_filters (:obj:`int`): The number of filters. num_interactions (:obj:`int`): The number of interactions. num_gaussians (:obj:`int`): The number of gaussians. cutoff (:obj:`float`): The cutoff value. interaction_graph (:obj:`str`, `optional`): The interaction graph. max_num_neighbors (:obj:`int`): The maximum number of neighbors. readout (:obj:`str`, `optional`): The readout method. dipole (:obj:`bool`, `optional`): Whether to include dipole. mean (:obj:`float`, `optional`): The mean value. std (:obj:`float`, `optional`): The standard deviation value. atomref (:obj:`float`, `optional`): The atom reference value. mask_token_id (:obj:`int`, `optional`): The token ID for masking. pad_token_id (:obj:`int`, `optional`): The token ID for padding. bos_token_id (:obj:`int`, `optional`): The token ID for the beginning of sequence. eos_token_id (:obj:`int`, `optional`): The token ID for the end of sequence. """ def __init__( self, vocab_size: int = 123, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable[..., Any]] = None, max_num_neighbors: int = 32, readout: str = "add", dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[OptTensor] = None, mask_token_id: int = 0, pad_token_id: int = 119, bos_token_id: int = 120, eos_token_id: int = 121, cls_token_id: int = 122, **kwargs: Any, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.num_gaussians = num_gaussians self.cutoff = cutoff self.interaction_graph = interaction_graph self.max_num_neighbors = max_num_neighbors self.readout = readout self.dipole = dipole self.mean = mean self.std = std self.atomref = atomref self.mask_token_id = mask_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.cls_token_id = cls_token_id
[docs] class SchNetPreTrainedModel(PreTrainedModel): # type: ignore """ A base class for all SchNet models. An abstract class to handle weights initialization and a simple interface for loading and exporting models. """ config_class = SchNetConfig base_model_prefix = "model" supports_gradient_checkpointing = False
[docs] class SchNetModel(SchNetPreTrainedModel): """ SchNet model for energy prediction. Args: config (:class:`~transformers.SchNetConfig`): Configuration class to store the configuration of a model. """ def __init__(self, config: SchNetConfig): super().__init__(config) self.config = config self.model = SchNet( hidden_channels=config.hidden_channels, num_filters=config.num_filters, num_interactions=config.num_interactions, num_gaussians=config.num_gaussians, cutoff=config.cutoff, interaction_graph=config.interaction_graph, max_num_neighbors=config.max_num_neighbors, readout=config.readout, dipole=config.dipole, mean=config.mean, std=config.std, atomref=config.atomref, )
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, batch: torch.Tensor, labels_energy: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """ Forward pass of the SchNet model. Args: input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, num_atoms)`): The input tensor containing the atom indices. coords (:obj:`torch.Tensor` of shape :obj:`(num_atoms, 3)`): The input tensor containing the atom coordinates. batch (:obj:`torch.Tensor` of shape :obj:`(num_atoms)`): The input tensor containing the batch indices. labels_energy (:obj:`torch.Tensor`, `optional`): The input tensor containing the energy labels. fixed (:obj:`torch.Tensor`, `optional`): The input tensor containing the fixed mask. attention_mask (:obj:`torch.Tensor`, `optional`): The attention mask for the transformer. Returns ------- :obj:`tuple`: A tuple of the loss and the energy prediction. """ energy_pred: torch.Tensor = self.model(z=input_ids, pos=coords, batch=batch) loss = None if labels_energy is not None: labels_energy = labels_energy.to(energy_pred.device) loss_fct = nn.L1Loss() loss = loss_fct(energy_pred.squeeze(-1), labels_energy) return loss, energy_pred