Skip to content

API Reference

Data module for the AtomGen library.

This module contains the data classes and functions for pre-processing and collating data for training/inference.

data_collator

Data collator for atom modeling.

DataCollatorForAtomModeling dataclass

Bases: DataCollatorMixin

Data collator used for atom modeling tasks in molecular representations.

This collator prepares input data for various atom modeling tasks, including masked atom modeling (MAM), autoregressive modeling, and coordinate perturbation. It supports both padding and flattening of input data.

Args: tokenizer (PreTrainedTokenizer): Tokenizer used for encoding the data. mam (Union[bool, float]): If True, uses original masked atom modeling. If float, masks a constant fraction of atoms/tokens. autoregressive (bool): Whether to use autoregressive modeling. coords_perturb (float): Standard deviation for coordinate perturbation. return_lap_pe (bool): Whether to return Laplacian positional encoding. return_edge_indices (bool): Whether to return edge indices. k (int): Number of eigenvectors to use for Laplacian positional encoding. max_radius (float): Maximum distance for edge cutoff. max_neighbors (int): Maximum number of neighbors. pad (bool): Whether to pad the input data. pad_to_multiple_of (Optional[int]): Pad to multiple of this value. return_tensors (str): Return tensors as "pt" or "tf".

Attributes:

Name Type Description
tokenizer (PreTrainedTokenizer) The tokenizer used for encoding.

mam (Union[bool, float]): The masked atom modeling setting. autoregressive (bool): The autoregressive modeling setting. coords_perturb (float): The coordinate perturbation standard deviation. return_lap_pe (bool): The Laplacian positional encoding setting. return_edge_indices (bool): The edge indices return setting. k (int): The number of eigenvectors for Laplacian PE. max_radius (float): The maximum distance for edge cutoff. max_neighbors (int): The maximum number of neighbors. pad (bool): The padding setting. pad_to_multiple_of (Optional[int]): The multiple for padding. return_tensors (str): The tensor return format.

Source code in atomgen/data/data_collator.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
@dataclass
class DataCollatorForAtomModeling(DataCollatorMixin):
    """
    Data collator used for atom modeling tasks in molecular representations.

    This collator prepares input data for various atom modeling tasks, including
    masked atom modeling (MAM), autoregressive modeling, and coordinate perturbation.
    It supports both padding and flattening of input data.

    Args:
        tokenizer (PreTrainedTokenizer): Tokenizer used for encoding the data.
        mam (Union[bool, float]): If True, uses original masked atom modeling.
                                  If float, masks a constant fraction of atoms/tokens.
        autoregressive (bool): Whether to use autoregressive modeling.
        coords_perturb (float): Standard deviation for coordinate perturbation.
        return_lap_pe (bool): Whether to return Laplacian positional encoding.
        return_edge_indices (bool): Whether to return edge indices.
        k (int): Number of eigenvectors to use for Laplacian positional encoding.
        max_radius (float): Maximum distance for edge cutoff.
        max_neighbors (int): Maximum number of neighbors.
        pad (bool): Whether to pad the input data.
        pad_to_multiple_of (Optional[int]): Pad to multiple of this value.
        return_tensors (str): Return tensors as "pt" or "tf".

    Attributes
    ----------
        tokenizer (PreTrainedTokenizer): The tokenizer used for encoding.
        mam (Union[bool, float]): The masked atom modeling setting.
        autoregressive (bool): The autoregressive modeling setting.
        coords_perturb (float): The coordinate perturbation standard deviation.
        return_lap_pe (bool): The Laplacian positional encoding setting.
        return_edge_indices (bool): The edge indices return setting.
        k (int): The number of eigenvectors for Laplacian PE.
        max_radius (float): The maximum distance for edge cutoff.
        max_neighbors (int): The maximum number of neighbors.
        pad (bool): The padding setting.
        pad_to_multiple_of (Optional[int]): The multiple for padding.
        return_tensors (str): The tensor return format.
    """

    tokenizer: PreTrainedTokenizer
    mam: Union[bool, float] = True
    autoregressive: bool = False
    coords_perturb: float = 0.0
    return_lap_pe: bool = False
    return_edge_indices: bool = False
    k: int = 16
    max_radius: float = 12.0
    max_neighbors: int = 20
    pad: bool = True
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    # ruff: noqa: PLR0912
    def torch_call(
        self, examples: List[Union[List[int], Any, Dict[str, Any]]]
    ) -> Dict[str, Any]:
        """Collate a batch of samples.

        Args:
            examples: List of samples to collate.

        Returns
        -------
            Dict[str, Any]: Dictionary of batched data.

        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if self.pad:
            if isinstance(examples[0], Mapping):
                batch: Dict[str, Any] = self.tokenizer.pad(  # type: ignore[assignment]
                    examples,  # type: ignore[arg-type]
                    return_tensors="pt",
                    pad_to_multiple_of=self.pad_to_multiple_of,
                )
            else:
                batch = {
                    "input_ids": _torch_collate_batch(
                        examples,
                        self.tokenizer,
                        pad_to_multiple_of=self.pad_to_multiple_of,
                    )
                }

            if self.return_lap_pe:
                # Compute Laplacian and positional encoding
                (
                    batch["node_pe"],
                    batch["edge_pe"],
                    batch["attention_mask"],
                ) = self.torch_compute_lap_pe(batch["coords"], batch["attention_mask"])
            if self.return_edge_indices:
                # Compute edge indices and distances
                (
                    batch["edge_indices"],
                    batch["edge_distances"],
                    batch["attention_mask"],
                ) = self.torch_compute_edges(batch["coords"], batch["attention_mask"])
        else:
            # flatten all lists in examples and concatenate
            batch = self.flatten_batch(examples)

        t = torch.zeros(batch["input_ids"].shape[0]).float().uniform_(0, 1)
        t = torch.cos(t * math.pi * 0.5)

        if self.mam:
            special_tokens_mask = batch.pop("special_tokens_mask", None)
            if special_tokens_mask is None:
                special_tokens_mask = [
                    self.tokenizer.get_special_tokens_mask(
                        val, already_has_special_tokens=True
                    )
                    for val in batch["input_ids"].tolist()
                ]
                special_tokens_mask = torch.tensor(
                    special_tokens_mask, dtype=torch.bool
                )
            else:
                special_tokens_mask = special_tokens_mask.bool()

            if isinstance(self.mam, float):
                # Constant masking of a float fraction of the atoms/tokens
                mask = torch.bernoulli(
                    torch.full(batch["input_ids"].shape, self.mam)
                ).bool()
                batch["input_ids"], batch["labels"] = self.apply_mask(
                    batch["input_ids"], mask, special_tokens_mask
                )
            else:
                # Original MaskGIT functionality
                batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                    batch["input_ids"], t, special_tokens_mask=special_tokens_mask
                )

        if self.autoregressive:
            # extend coords
            batch["coords"] = torch.cat(
                [
                    torch.zeros_like(batch["coords"][:, :1]),
                    batch["coords"],
                    torch.zeros_like(batch["coords"][:, :1]),
                ],
                dim=1,
            )
            if "labels" not in batch:
                batch["labels"] = batch["input_ids"].clone()
                batch["labels_coords"] = batch["coords"].clone()

            # create mask of ~special_tokens_mask and exclude bos and eos tokens
            special_tokens_mask[batch["labels"] == self.tokenizer.bos_token_id] = False
            special_tokens_mask[batch["labels"] == self.tokenizer.eos_token_id] = False
            batch["labels"] = torch.where(~special_tokens_mask, batch["labels"], -100)

        if self.coords_perturb > 0:
            batch["coords"], batch["labels_coords"] = self.torch_perturb_coords(
                batch["coords"],
                batch.get("fixed", None),
                self.coords_perturb,
            )

        return batch

    def torch_mask_tokens(
        self, inputs: Any, t: Any, special_tokens_mask: Optional[Any] = None
    ) -> Tuple[Any, Any]:
        """Prepare masked tokens inputs/labels for masked atom modeling."""
        labels = inputs.clone()

        batch, seq_len = inputs.shape
        num_token_masked = (seq_len * t).round().clamp(min=1)
        batch_randperm = torch.rand((batch, seq_len)).argsort(dim=-1)
        mask = batch_randperm < num_token_masked.unsqueeze(1)
        inputs = torch.where(
            ~mask,
            inputs,
            self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token),  # type: ignore[arg-type]
        )
        labels = torch.where(mask, labels, -100)
        if special_tokens_mask is not None:
            labels = torch.where(~special_tokens_mask, labels, -100)

        return inputs, labels

    def apply_mask(
        self,
        inputs: torch.Tensor,
        mask: torch.Tensor,
        special_tokens_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply the mask to the input tokens."""
        labels = inputs.clone()
        inputs = torch.where(
            mask,
            torch.tensor(self.tokenizer.mask_token_id, device=inputs.device),
            inputs,
        )
        labels = torch.where(
            ~mask | special_tokens_mask,
            torch.tensor(-100, device=labels.device),
            labels,
        )
        return inputs, labels

    def torch_perturb_coords(
        self, inputs: torch.Tensor, fixed: Optional[torch.Tensor], perturb_std: float
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Prepare perturbed coords inputs/labels for coordinate denoising."""
        if fixed is None:
            fixed = torch.zeros_like(inputs).bool()
        labels = inputs.clone()
        noise = torch.empty_like(inputs).normal_(0, perturb_std)
        inputs[~fixed.bool()] += noise[~fixed.bool()]
        return inputs, labels

    def flatten_batch(self, examples: Any) -> Dict[str, Any]:
        """Flatten all lists in examples and concatenate with batch indicator."""
        batch = {}
        for key in examples[0]:
            if key == "input_ids":
                lengths = []
                for sample in examples:
                    lengths.append(len(sample[key]))
                batch["batch"] = torch.arange(len(examples)).repeat_interleave(
                    torch.tensor(lengths)
                )
                batch[key] = torch.cat(
                    [torch.tensor(sample[key]) for sample in examples], dim=0
                )
            elif (
                key.startswith("coords")
                or key.endswith("coords")
                or (key.startswith("fixed") or key.endswith("fixed"))
            ):
                batch[key] = torch.cat(
                    [torch.tensor(sample[key]) for sample in examples], dim=0
                )
            elif key.startswith("energy") or key.endswith("energy"):
                batch[key] = torch.tensor([sample[key] for sample in examples])
            elif key.startswith("forces") or key.endswith("forces"):
                batch[key] = torch.cat(
                    [torch.tensor(sample[key]) for sample in examples], dim=0
                )
        return batch

    def torch_compute_edges(self, coords: Any, attention_mask: Any) -> Any:
        """Compute edge indices and distances for each batch."""
        dist_matrix = torch.cdist(coords, coords, p=2)
        b, n, _ = dist_matrix.shape

        # ignore distance in padded coords by setting to large number
        attention_mask_mult = (1.0 - attention_mask) * 1e6
        dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(1)
        dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(2)

        # to avoid self-loop, set diagonal to a large number
        dist_matrix = dist_matrix + torch.eye(n) * 1e6

        # get adjacency matrix using cutoff
        adjacency_matrix = torch.where(dist_matrix <= self.max_radius, 1, 0).float()

        # set max_num_neighbors to 20 to get closest 20 neighbors and set rest to zero
        _, topk_indices = torch.topk(
            dist_matrix,
            k=min(self.max_neighbors, dist_matrix.size(2)),
            dim=2,
            largest=False,
        )
        mask = torch.zeros_like(dist_matrix)
        mask.scatter_(2, topk_indices, 1)
        adjacency_matrix *= mask

        # get distances for each batch in for loop
        distance_list = []
        for bi in range(b):
            distance = dist_matrix[bi][adjacency_matrix[bi] != 0]
            distance_list.append(distance)

        # get edge_indices for each batch in for loop
        edge_indices_list = []
        lengths = []
        for bi in range(b):
            edge_indices = torch.column_stack(torch.where(adjacency_matrix[bi] != 0))
            lengths.append(edge_indices.size(0))
            edge_indices_list.append(edge_indices)

        edge_indices = pad_sequence(
            edge_indices_list, batch_first=True, padding_value=0
        )
        edge_distances = pad_sequence(distance_list, batch_first=True, padding_value=-1)
        edge_attention_mask = torch.cat(
            [
                torch.cat(
                    [
                        torch.ones(1, length),
                        torch.zeros(1, edge_indices.size(1) - length),
                    ],
                    dim=1,
                )
                for length in lengths
            ],
            dim=0,
        )
        attention_mask = torch.cat([attention_mask, edge_attention_mask], dim=1)

        return edge_indices, edge_distances, attention_mask

    def torch_compute_lap_pe(self, coords: Any, attention_mask: Any) -> Any:
        """Compute Laplacian positional encoding for each batch."""
        dist_matrix = torch.cdist(coords, coords, p=2)
        b, n, _ = dist_matrix.shape

        # ignore distance in padded coords by setting to large number
        attention_mask_mult = (1.0 - attention_mask) * 1e6
        dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(1)
        dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(2)

        # to avoid self-loop, set diagonal to a large number
        dist_matrix = dist_matrix + torch.eye(n) * 1e6

        # get adjacency matrix using cutoff
        adjacency_matrix = torch.where(dist_matrix <= self.max_radius, 1, 0).float()

        # set max_num_neighbors to 20 to get closest 20 neighbors and set rest to zero
        _, topk_indices = torch.topk(
            dist_matrix,
            k=min(self.max_neighbors, dist_matrix.size(2)),
            dim=2,
            largest=False,
        )
        mask = torch.zeros_like(dist_matrix)
        mask.scatter_(2, topk_indices, 1)
        adjacency_matrix *= mask

        # get distances for each batch in for loop
        distance_list = []
        for bi in range(b):
            distance = dist_matrix[bi][adjacency_matrix[bi] != 0]
            distance_list.append(distance)

        # get edge_indices for each batch in for loop
        edge_indices_list = []
        for bi in range(b):
            edge_indices = torch.column_stack(torch.where(adjacency_matrix[bi] != 0))
            edge_indices_list.append(edge_indices)

        # Construct graph Laplacian for each batch
        degree_matrix = torch.diag_embed(adjacency_matrix.sum(dim=2).clip(1) ** -0.5)
        laplacian_matrix = (
            torch.eye(n) - degree_matrix @ adjacency_matrix @ degree_matrix
        )

        # Eigenvalue decomposition for each batch
        eigval, eigvec = torch.linalg.eigh(laplacian_matrix)

        eigvec = eigvec.float()  # [N, N (channels)]
        eigval = torch.sort(torch.abs(torch.real(eigval)))[0].float()  # [N (channels),]

        if eigvec.size(1) < self.k:
            node_pe = f.pad(eigvec, (0, self.k - eigvec.size(2), 0, 0))
        else:
            # use smallest eigenvalues
            node_pe = eigvec[:, :, : self.k]

        all_edges_pe_list = []
        lengths = []
        for i, edge_indices in enumerate(edge_indices_list):
            e = edge_indices.shape[0]
            lengths.append(e)
            all_edges_pe = torch.zeros([e, 2 * self.k])
            all_edges_pe[:, : self.k] = torch.index_select(
                node_pe[i], 0, edge_indices[:, 0]
            )
            all_edges_pe[:, self.k :] = torch.index_select(
                node_pe[i], 0, edge_indices[:, 1]
            )
            all_edges_pe_list.append(all_edges_pe)

        # get attention mask for edge_pe based on all_edges_pe_list

        edge_pe = pad_sequence(all_edges_pe_list, batch_first=True, padding_value=0)
        edge_attention_mask = torch.cat(
            [
                torch.cat(
                    [torch.ones(1, length), torch.zeros(1, edge_pe.size(1) - length)],
                    dim=1,
                )
                for length in lengths
            ],
            dim=0,
        )
        attention_mask = torch.cat([attention_mask, edge_attention_mask], dim=1)

        edge_distances = pad_sequence(distance_list, batch_first=True, padding_value=-1)
        edge_pe = torch.cat([edge_pe, edge_distances.unsqueeze(-1)], dim=2)

        node_pe = torch.cat([node_pe, node_pe], dim=2)

        return node_pe, edge_pe, attention_mask

torch_call

torch_call(examples)

Collate a batch of samples.

Args: examples: List of samples to collate.

Returns:

Type Description
Dict[str, Any]: Dictionary of batched data.
Source code in atomgen/data/data_collator.py
def torch_call(
    self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
    """Collate a batch of samples.

    Args:
        examples: List of samples to collate.

    Returns
    -------
        Dict[str, Any]: Dictionary of batched data.

    """
    # Handle dict or lists with proper padding and conversion to tensor.
    if self.pad:
        if isinstance(examples[0], Mapping):
            batch: Dict[str, Any] = self.tokenizer.pad(  # type: ignore[assignment]
                examples,  # type: ignore[arg-type]
                return_tensors="pt",
                pad_to_multiple_of=self.pad_to_multiple_of,
            )
        else:
            batch = {
                "input_ids": _torch_collate_batch(
                    examples,
                    self.tokenizer,
                    pad_to_multiple_of=self.pad_to_multiple_of,
                )
            }

        if self.return_lap_pe:
            # Compute Laplacian and positional encoding
            (
                batch["node_pe"],
                batch["edge_pe"],
                batch["attention_mask"],
            ) = self.torch_compute_lap_pe(batch["coords"], batch["attention_mask"])
        if self.return_edge_indices:
            # Compute edge indices and distances
            (
                batch["edge_indices"],
                batch["edge_distances"],
                batch["attention_mask"],
            ) = self.torch_compute_edges(batch["coords"], batch["attention_mask"])
    else:
        # flatten all lists in examples and concatenate
        batch = self.flatten_batch(examples)

    t = torch.zeros(batch["input_ids"].shape[0]).float().uniform_(0, 1)
    t = torch.cos(t * math.pi * 0.5)

    if self.mam:
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True
                )
                for val in batch["input_ids"].tolist()
            ]
            special_tokens_mask = torch.tensor(
                special_tokens_mask, dtype=torch.bool
            )
        else:
            special_tokens_mask = special_tokens_mask.bool()

        if isinstance(self.mam, float):
            # Constant masking of a float fraction of the atoms/tokens
            mask = torch.bernoulli(
                torch.full(batch["input_ids"].shape, self.mam)
            ).bool()
            batch["input_ids"], batch["labels"] = self.apply_mask(
                batch["input_ids"], mask, special_tokens_mask
            )
        else:
            # Original MaskGIT functionality
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                batch["input_ids"], t, special_tokens_mask=special_tokens_mask
            )

    if self.autoregressive:
        # extend coords
        batch["coords"] = torch.cat(
            [
                torch.zeros_like(batch["coords"][:, :1]),
                batch["coords"],
                torch.zeros_like(batch["coords"][:, :1]),
            ],
            dim=1,
        )
        if "labels" not in batch:
            batch["labels"] = batch["input_ids"].clone()
            batch["labels_coords"] = batch["coords"].clone()

        # create mask of ~special_tokens_mask and exclude bos and eos tokens
        special_tokens_mask[batch["labels"] == self.tokenizer.bos_token_id] = False
        special_tokens_mask[batch["labels"] == self.tokenizer.eos_token_id] = False
        batch["labels"] = torch.where(~special_tokens_mask, batch["labels"], -100)

    if self.coords_perturb > 0:
        batch["coords"], batch["labels_coords"] = self.torch_perturb_coords(
            batch["coords"],
            batch.get("fixed", None),
            self.coords_perturb,
        )

    return batch

torch_mask_tokens

torch_mask_tokens(inputs, t, special_tokens_mask=None)

Prepare masked tokens inputs/labels for masked atom modeling.

Source code in atomgen/data/data_collator.py
def torch_mask_tokens(
    self, inputs: Any, t: Any, special_tokens_mask: Optional[Any] = None
) -> Tuple[Any, Any]:
    """Prepare masked tokens inputs/labels for masked atom modeling."""
    labels = inputs.clone()

    batch, seq_len = inputs.shape
    num_token_masked = (seq_len * t).round().clamp(min=1)
    batch_randperm = torch.rand((batch, seq_len)).argsort(dim=-1)
    mask = batch_randperm < num_token_masked.unsqueeze(1)
    inputs = torch.where(
        ~mask,
        inputs,
        self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token),  # type: ignore[arg-type]
    )
    labels = torch.where(mask, labels, -100)
    if special_tokens_mask is not None:
        labels = torch.where(~special_tokens_mask, labels, -100)

    return inputs, labels

apply_mask

apply_mask(inputs, mask, special_tokens_mask)

Apply the mask to the input tokens.

Source code in atomgen/data/data_collator.py
def apply_mask(
    self,
    inputs: torch.Tensor,
    mask: torch.Tensor,
    special_tokens_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply the mask to the input tokens."""
    labels = inputs.clone()
    inputs = torch.where(
        mask,
        torch.tensor(self.tokenizer.mask_token_id, device=inputs.device),
        inputs,
    )
    labels = torch.where(
        ~mask | special_tokens_mask,
        torch.tensor(-100, device=labels.device),
        labels,
    )
    return inputs, labels

torch_perturb_coords

torch_perturb_coords(inputs, fixed, perturb_std)

Prepare perturbed coords inputs/labels for coordinate denoising.

Source code in atomgen/data/data_collator.py
def torch_perturb_coords(
    self, inputs: torch.Tensor, fixed: Optional[torch.Tensor], perturb_std: float
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Prepare perturbed coords inputs/labels for coordinate denoising."""
    if fixed is None:
        fixed = torch.zeros_like(inputs).bool()
    labels = inputs.clone()
    noise = torch.empty_like(inputs).normal_(0, perturb_std)
    inputs[~fixed.bool()] += noise[~fixed.bool()]
    return inputs, labels

flatten_batch

flatten_batch(examples)

Flatten all lists in examples and concatenate with batch indicator.

Source code in atomgen/data/data_collator.py
def flatten_batch(self, examples: Any) -> Dict[str, Any]:
    """Flatten all lists in examples and concatenate with batch indicator."""
    batch = {}
    for key in examples[0]:
        if key == "input_ids":
            lengths = []
            for sample in examples:
                lengths.append(len(sample[key]))
            batch["batch"] = torch.arange(len(examples)).repeat_interleave(
                torch.tensor(lengths)
            )
            batch[key] = torch.cat(
                [torch.tensor(sample[key]) for sample in examples], dim=0
            )
        elif (
            key.startswith("coords")
            or key.endswith("coords")
            or (key.startswith("fixed") or key.endswith("fixed"))
        ):
            batch[key] = torch.cat(
                [torch.tensor(sample[key]) for sample in examples], dim=0
            )
        elif key.startswith("energy") or key.endswith("energy"):
            batch[key] = torch.tensor([sample[key] for sample in examples])
        elif key.startswith("forces") or key.endswith("forces"):
            batch[key] = torch.cat(
                [torch.tensor(sample[key]) for sample in examples], dim=0
            )
    return batch

torch_compute_edges

torch_compute_edges(coords, attention_mask)

Compute edge indices and distances for each batch.

Source code in atomgen/data/data_collator.py
def torch_compute_edges(self, coords: Any, attention_mask: Any) -> Any:
    """Compute edge indices and distances for each batch."""
    dist_matrix = torch.cdist(coords, coords, p=2)
    b, n, _ = dist_matrix.shape

    # ignore distance in padded coords by setting to large number
    attention_mask_mult = (1.0 - attention_mask) * 1e6
    dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(1)
    dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(2)

    # to avoid self-loop, set diagonal to a large number
    dist_matrix = dist_matrix + torch.eye(n) * 1e6

    # get adjacency matrix using cutoff
    adjacency_matrix = torch.where(dist_matrix <= self.max_radius, 1, 0).float()

    # set max_num_neighbors to 20 to get closest 20 neighbors and set rest to zero
    _, topk_indices = torch.topk(
        dist_matrix,
        k=min(self.max_neighbors, dist_matrix.size(2)),
        dim=2,
        largest=False,
    )
    mask = torch.zeros_like(dist_matrix)
    mask.scatter_(2, topk_indices, 1)
    adjacency_matrix *= mask

    # get distances for each batch in for loop
    distance_list = []
    for bi in range(b):
        distance = dist_matrix[bi][adjacency_matrix[bi] != 0]
        distance_list.append(distance)

    # get edge_indices for each batch in for loop
    edge_indices_list = []
    lengths = []
    for bi in range(b):
        edge_indices = torch.column_stack(torch.where(adjacency_matrix[bi] != 0))
        lengths.append(edge_indices.size(0))
        edge_indices_list.append(edge_indices)

    edge_indices = pad_sequence(
        edge_indices_list, batch_first=True, padding_value=0
    )
    edge_distances = pad_sequence(distance_list, batch_first=True, padding_value=-1)
    edge_attention_mask = torch.cat(
        [
            torch.cat(
                [
                    torch.ones(1, length),
                    torch.zeros(1, edge_indices.size(1) - length),
                ],
                dim=1,
            )
            for length in lengths
        ],
        dim=0,
    )
    attention_mask = torch.cat([attention_mask, edge_attention_mask], dim=1)

    return edge_indices, edge_distances, attention_mask

torch_compute_lap_pe

torch_compute_lap_pe(coords, attention_mask)

Compute Laplacian positional encoding for each batch.

Source code in atomgen/data/data_collator.py
def torch_compute_lap_pe(self, coords: Any, attention_mask: Any) -> Any:
    """Compute Laplacian positional encoding for each batch."""
    dist_matrix = torch.cdist(coords, coords, p=2)
    b, n, _ = dist_matrix.shape

    # ignore distance in padded coords by setting to large number
    attention_mask_mult = (1.0 - attention_mask) * 1e6
    dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(1)
    dist_matrix = dist_matrix + attention_mask_mult.unsqueeze(2)

    # to avoid self-loop, set diagonal to a large number
    dist_matrix = dist_matrix + torch.eye(n) * 1e6

    # get adjacency matrix using cutoff
    adjacency_matrix = torch.where(dist_matrix <= self.max_radius, 1, 0).float()

    # set max_num_neighbors to 20 to get closest 20 neighbors and set rest to zero
    _, topk_indices = torch.topk(
        dist_matrix,
        k=min(self.max_neighbors, dist_matrix.size(2)),
        dim=2,
        largest=False,
    )
    mask = torch.zeros_like(dist_matrix)
    mask.scatter_(2, topk_indices, 1)
    adjacency_matrix *= mask

    # get distances for each batch in for loop
    distance_list = []
    for bi in range(b):
        distance = dist_matrix[bi][adjacency_matrix[bi] != 0]
        distance_list.append(distance)

    # get edge_indices for each batch in for loop
    edge_indices_list = []
    for bi in range(b):
        edge_indices = torch.column_stack(torch.where(adjacency_matrix[bi] != 0))
        edge_indices_list.append(edge_indices)

    # Construct graph Laplacian for each batch
    degree_matrix = torch.diag_embed(adjacency_matrix.sum(dim=2).clip(1) ** -0.5)
    laplacian_matrix = (
        torch.eye(n) - degree_matrix @ adjacency_matrix @ degree_matrix
    )

    # Eigenvalue decomposition for each batch
    eigval, eigvec = torch.linalg.eigh(laplacian_matrix)

    eigvec = eigvec.float()  # [N, N (channels)]
    eigval = torch.sort(torch.abs(torch.real(eigval)))[0].float()  # [N (channels),]

    if eigvec.size(1) < self.k:
        node_pe = f.pad(eigvec, (0, self.k - eigvec.size(2), 0, 0))
    else:
        # use smallest eigenvalues
        node_pe = eigvec[:, :, : self.k]

    all_edges_pe_list = []
    lengths = []
    for i, edge_indices in enumerate(edge_indices_list):
        e = edge_indices.shape[0]
        lengths.append(e)
        all_edges_pe = torch.zeros([e, 2 * self.k])
        all_edges_pe[:, : self.k] = torch.index_select(
            node_pe[i], 0, edge_indices[:, 0]
        )
        all_edges_pe[:, self.k :] = torch.index_select(
            node_pe[i], 0, edge_indices[:, 1]
        )
        all_edges_pe_list.append(all_edges_pe)

    # get attention mask for edge_pe based on all_edges_pe_list

    edge_pe = pad_sequence(all_edges_pe_list, batch_first=True, padding_value=0)
    edge_attention_mask = torch.cat(
        [
            torch.cat(
                [torch.ones(1, length), torch.zeros(1, edge_pe.size(1) - length)],
                dim=1,
            )
            for length in lengths
        ],
        dim=0,
    )
    attention_mask = torch.cat([attention_mask, edge_attention_mask], dim=1)

    edge_distances = pad_sequence(distance_list, batch_first=True, padding_value=-1)
    edge_pe = torch.cat([edge_pe, edge_distances.unsqueeze(-1)], dim=2)

    node_pe = torch.cat([node_pe, node_pe], dim=2)

    return node_pe, edge_pe, attention_mask

tokenizer

tokenization module for atom modeling.

AtomTokenizer

Bases: PreTrainedTokenizer

Tokenizer for atomistic data.

Args: vocab_file: The path to the vocabulary file. pad_token: The padding token. mask_token: The mask token. bos_token: The beginning of system token. eos_token: The end of system token. cls_token: The classification token. kwargs: Additional keyword arguments.

Source code in atomgen/data/tokenizer.py
class AtomTokenizer(PreTrainedTokenizer):
    """
    Tokenizer for atomistic data.

    Args:
        vocab_file: The path to the vocabulary file.
        pad_token: The padding token.
        mask_token: The mask token.
        bos_token: The beginning of system token.
        eos_token: The end of system token.
        cls_token: The classification token.
        kwargs: Additional keyword arguments.

    """

    def __init__(
        self,
        vocab_file: str,
        pad_token: str = "<pad>",
        mask_token: str = "<mask>",
        bos_token: str = "<bos>",
        eos_token: str = "<eos>",
        cls_token: str = "<graph>",
        **kwargs: Dict[str, Union[bool, str, PaddingStrategy]],
    ) -> None:
        self.vocab: Dict[str, int] = self.load_vocab(vocab_file)
        self.ids_to_tokens: Dict[int, str] = collections.OrderedDict(
            [(ids, tok) for tok, ids in self.vocab.items()]
        )

        super().__init__(  # type: ignore[no-untyped-call]
            pad_token=pad_token,
            mask_token=mask_token,
            bos_token=bos_token,
            eos_token=eos_token,
            cls_token=cls_token,
            **kwargs,
        )

    @staticmethod
    def load_vocab(vocab_file: str) -> Dict[str, int]:
        """Load the vocabulary from a json file."""
        with open(vocab_file, "r") as f:
            vocab = json.load(f)
            if not isinstance(vocab, dict):
                raise ValueError(
                    "The vocabulary file is not a json file or is not formatted correctly."
                )
        return vocab

    def _tokenize(self, text: str) -> List[str]:  # type: ignore[override]
        """Tokenize the text."""
        tokens = []
        i = 0
        while i < len(text):
            if i + 1 < len(text) and text[i : i + 2] in self.vocab:
                tokens.append(text[i : i + 2])
                i += 2
            else:
                tokens.append(text[i])
                i += 1
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        """Convert the chemical symbols to atomic numbers."""
        return self.vocab[token]

    def _convert_id_to_token(self, index: int) -> str:
        return self.ids_to_tokens[index]

    def get_vocab(self) -> Dict[str, int]:
        """Get the vocabulary."""
        return self.vocab

    def get_vocab_size(self) -> int:
        """Get the size of the vocabulary."""
        return len(self.vocab)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Convert the list of chemical symbol tokens to a concatenated string."""
        return "".join(tokens)

    def pad(  # type: ignore[override]
        self,
        encoded_inputs: Union[
            BatchEncoding,
            List[BatchEncoding],
            Dict[str, EncodedInput],
            Dict[str, List[EncodedInput]],
            List[Dict[str, EncodedInput]],
        ],
        padding: Union[bool, str, PaddingStrategy] = True,
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        verbose: bool = True,
    ) -> BatchEncoding:
        """Pad the input data."""
        if isinstance(encoded_inputs, list):
            if isinstance(encoded_inputs[0], Mapping):
                if any(
                    key.startswith("coords") or key.endswith("coords")
                    for key in encoded_inputs[0]
                ):
                    encoded_inputs = self.pad_coords(  # type: ignore[assignment]
                        encoded_inputs,  # type: ignore[arg-type]
                        max_length=max_length,
                        pad_to_multiple_of=pad_to_multiple_of,
                    )
                if any(
                    key.startswith("forces") or key.endswith("forces")
                    for key in encoded_inputs[0]
                ):
                    encoded_inputs = self.pad_forces(  # type: ignore[assignment]
                        encoded_inputs,  # type: ignore[arg-type]
                        max_length=max_length,
                        pad_to_multiple_of=pad_to_multiple_of,
                    )
                if any(
                    key.startswith("fixed") or key.endswith("fixed")
                    for key in encoded_inputs[0]
                ):
                    encoded_inputs = self.pad_fixed(  # type: ignore[assignment]
                        encoded_inputs,  # type: ignore[arg-type]
                        max_length=max_length,
                        pad_to_multiple_of=pad_to_multiple_of,
                    )
        elif isinstance(encoded_inputs, Mapping):
            if any("coords" in key for key in encoded_inputs):
                encoded_inputs = self.pad_coords(  # type: ignore[assignment]
                    encoded_inputs,
                    max_length=max_length,
                    pad_to_multiple_of=pad_to_multiple_of,
                )
            if any("fixed" in key for key in encoded_inputs):
                encoded_inputs = self.pad_fixed(  # type: ignore[assignment]
                    encoded_inputs,
                    max_length=max_length,
                    pad_to_multiple_of=pad_to_multiple_of,
                )

        return super().pad(
            encoded_inputs=encoded_inputs,
            padding=padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_attention_mask=return_attention_mask,
            return_tensors=return_tensors,
            verbose=verbose,
        )

    def pad_coords(
        self,
        batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
        """Pad the coordinates to the same length."""
        if isinstance(batch, Mapping):
            coord_keys = [
                key
                for key in batch
                if key.startswith("coords") or key.endswith("coords")
            ]
        elif isinstance(batch, list):
            coord_keys = [
                key
                for key in batch[0]
                if key.startswith("coords") or key.endswith("coords")
            ]
        for key in coord_keys:
            if isinstance(batch, Mapping):
                coords = batch[key]
            elif isinstance(batch, list):
                coords = [sample[key] for sample in batch]
            max_length = (
                max([len(c) for c in coords]) if max_length is None else max_length
            )
            if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
                max_length = (
                    (max_length // pad_to_multiple_of) + 1
                ) * pad_to_multiple_of
            for c in coords:
                c.extend([[0.0, 0.0, 0.0]] * (max_length - len(c)))
            if isinstance(batch, list):
                for i, sample in enumerate(batch):
                    sample[key] = coords[i]  # type: ignore[index]
        return batch

    def pad_forces(
        self,
        batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
        """Pad the forces to the same length."""
        if isinstance(batch, Mapping):
            force_keys = [
                key
                for key in batch
                if key.startswith("forces") or key.endswith("forces")
            ]
        elif isinstance(batch, list):
            force_keys = [
                key
                for key in batch[0]
                if key.startswith("forces") or key.endswith("forces")
            ]
        for key in force_keys:
            if isinstance(batch, Mapping):
                forces = batch[key]
            elif isinstance(batch, list):
                forces = [sample[key] for sample in batch]
            max_length = (
                max([len(c) for c in forces]) if max_length is None else max_length
            )
            if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
                max_length = (
                    (max_length // pad_to_multiple_of) + 1
                ) * pad_to_multiple_of
            for f in forces:
                f.extend([[0.0, 0.0, 0.0]] * (max_length - len(f)))
            if isinstance(batch, list):
                for i, sample in enumerate(batch):
                    sample[key] = forces[i]  # type: ignore[index]
        return batch

    def pad_fixed(
        self,
        batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
        """Pad the fixed mask to the same length."""
        if isinstance(batch, Mapping):
            fixed_keys = [
                key for key in batch if key.startswith("fixed") or key.endswith("fixed")
            ]
        elif isinstance(batch, list):
            fixed_keys = [
                key
                for key in batch[0]
                if key.startswith("fixed") or key.endswith("fixed")
            ]
        for key in fixed_keys:
            if isinstance(batch, Mapping):
                fixed = batch[key]
            elif isinstance(batch, list):
                fixed = [sample[key] for sample in batch]
            max_length = (
                max([len(c) for c in fixed]) if max_length is None else max_length
            )
            if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
                max_length = (
                    (max_length // pad_to_multiple_of) + 1
                ) * pad_to_multiple_of
            for f in fixed:
                f.extend([True] * (max_length - len(f)))
            if isinstance(batch, list):
                for i, sample in enumerate(batch):
                    sample[key] = fixed[i]  # type: ignore[index]
        return batch

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ) -> Tuple[str]:
        """Save the vocabulary to a json file."""
        vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "")
            + VOCAB_FILES_NAMES["vocab_file"],
        )
        with open(vocab_file, "w") as f:
            json.dump(self.vocab, f)
        return (vocab_file,)

    @classmethod
    def from_pretrained(cls, *inputs: Any, **kwargs: Any) -> Any:
        """Load the tokenizer from a pretrained model."""
        return super().from_pretrained(*inputs, **kwargs)

    # add special tokens <bos> and <eos>
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """Build the input with special tokens."""
        bos = [self.bos_token_id]
        eos = [self.eos_token_id]

        if token_ids_1 is None:
            return bos + token_ids_0 + eos
        return bos + token_ids_0 + eos + token_ids_1 + eos

load_vocab staticmethod

load_vocab(vocab_file)

Load the vocabulary from a json file.

Source code in atomgen/data/tokenizer.py
@staticmethod
def load_vocab(vocab_file: str) -> Dict[str, int]:
    """Load the vocabulary from a json file."""
    with open(vocab_file, "r") as f:
        vocab = json.load(f)
        if not isinstance(vocab, dict):
            raise ValueError(
                "The vocabulary file is not a json file or is not formatted correctly."
            )
    return vocab

get_vocab

get_vocab()

Get the vocabulary.

Source code in atomgen/data/tokenizer.py
def get_vocab(self) -> Dict[str, int]:
    """Get the vocabulary."""
    return self.vocab

get_vocab_size

get_vocab_size()

Get the size of the vocabulary.

Source code in atomgen/data/tokenizer.py
def get_vocab_size(self) -> int:
    """Get the size of the vocabulary."""
    return len(self.vocab)

convert_tokens_to_string

convert_tokens_to_string(tokens)

Convert the list of chemical symbol tokens to a concatenated string.

Source code in atomgen/data/tokenizer.py
def convert_tokens_to_string(self, tokens: List[str]) -> str:
    """Convert the list of chemical symbol tokens to a concatenated string."""
    return "".join(tokens)

pad

pad(
    encoded_inputs,
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_attention_mask=None,
    return_tensors=None,
    verbose=True,
)

Pad the input data.

Source code in atomgen/data/tokenizer.py
def pad(  # type: ignore[override]
    self,
    encoded_inputs: Union[
        BatchEncoding,
        List[BatchEncoding],
        Dict[str, EncodedInput],
        Dict[str, List[EncodedInput]],
        List[Dict[str, EncodedInput]],
    ],
    padding: Union[bool, str, PaddingStrategy] = True,
    max_length: Optional[int] = None,
    pad_to_multiple_of: Optional[int] = None,
    return_attention_mask: Optional[bool] = None,
    return_tensors: Optional[Union[str, TensorType]] = None,
    verbose: bool = True,
) -> BatchEncoding:
    """Pad the input data."""
    if isinstance(encoded_inputs, list):
        if isinstance(encoded_inputs[0], Mapping):
            if any(
                key.startswith("coords") or key.endswith("coords")
                for key in encoded_inputs[0]
            ):
                encoded_inputs = self.pad_coords(  # type: ignore[assignment]
                    encoded_inputs,  # type: ignore[arg-type]
                    max_length=max_length,
                    pad_to_multiple_of=pad_to_multiple_of,
                )
            if any(
                key.startswith("forces") or key.endswith("forces")
                for key in encoded_inputs[0]
            ):
                encoded_inputs = self.pad_forces(  # type: ignore[assignment]
                    encoded_inputs,  # type: ignore[arg-type]
                    max_length=max_length,
                    pad_to_multiple_of=pad_to_multiple_of,
                )
            if any(
                key.startswith("fixed") or key.endswith("fixed")
                for key in encoded_inputs[0]
            ):
                encoded_inputs = self.pad_fixed(  # type: ignore[assignment]
                    encoded_inputs,  # type: ignore[arg-type]
                    max_length=max_length,
                    pad_to_multiple_of=pad_to_multiple_of,
                )
    elif isinstance(encoded_inputs, Mapping):
        if any("coords" in key for key in encoded_inputs):
            encoded_inputs = self.pad_coords(  # type: ignore[assignment]
                encoded_inputs,
                max_length=max_length,
                pad_to_multiple_of=pad_to_multiple_of,
            )
        if any("fixed" in key for key in encoded_inputs):
            encoded_inputs = self.pad_fixed(  # type: ignore[assignment]
                encoded_inputs,
                max_length=max_length,
                pad_to_multiple_of=pad_to_multiple_of,
            )

    return super().pad(
        encoded_inputs=encoded_inputs,
        padding=padding,
        max_length=max_length,
        pad_to_multiple_of=pad_to_multiple_of,
        return_attention_mask=return_attention_mask,
        return_tensors=return_tensors,
        verbose=verbose,
    )

pad_coords

pad_coords(batch, max_length=None, pad_to_multiple_of=None)

Pad the coordinates to the same length.

Source code in atomgen/data/tokenizer.py
def pad_coords(
    self,
    batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
    max_length: Optional[int] = None,
    pad_to_multiple_of: Optional[int] = None,
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
    """Pad the coordinates to the same length."""
    if isinstance(batch, Mapping):
        coord_keys = [
            key
            for key in batch
            if key.startswith("coords") or key.endswith("coords")
        ]
    elif isinstance(batch, list):
        coord_keys = [
            key
            for key in batch[0]
            if key.startswith("coords") or key.endswith("coords")
        ]
    for key in coord_keys:
        if isinstance(batch, Mapping):
            coords = batch[key]
        elif isinstance(batch, list):
            coords = [sample[key] for sample in batch]
        max_length = (
            max([len(c) for c in coords]) if max_length is None else max_length
        )
        if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
            max_length = (
                (max_length // pad_to_multiple_of) + 1
            ) * pad_to_multiple_of
        for c in coords:
            c.extend([[0.0, 0.0, 0.0]] * (max_length - len(c)))
        if isinstance(batch, list):
            for i, sample in enumerate(batch):
                sample[key] = coords[i]  # type: ignore[index]
    return batch

pad_forces

pad_forces(batch, max_length=None, pad_to_multiple_of=None)

Pad the forces to the same length.

Source code in atomgen/data/tokenizer.py
def pad_forces(
    self,
    batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
    max_length: Optional[int] = None,
    pad_to_multiple_of: Optional[int] = None,
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
    """Pad the forces to the same length."""
    if isinstance(batch, Mapping):
        force_keys = [
            key
            for key in batch
            if key.startswith("forces") or key.endswith("forces")
        ]
    elif isinstance(batch, list):
        force_keys = [
            key
            for key in batch[0]
            if key.startswith("forces") or key.endswith("forces")
        ]
    for key in force_keys:
        if isinstance(batch, Mapping):
            forces = batch[key]
        elif isinstance(batch, list):
            forces = [sample[key] for sample in batch]
        max_length = (
            max([len(c) for c in forces]) if max_length is None else max_length
        )
        if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
            max_length = (
                (max_length // pad_to_multiple_of) + 1
            ) * pad_to_multiple_of
        for f in forces:
            f.extend([[0.0, 0.0, 0.0]] * (max_length - len(f)))
        if isinstance(batch, list):
            for i, sample in enumerate(batch):
                sample[key] = forces[i]  # type: ignore[index]
    return batch

pad_fixed

pad_fixed(batch, max_length=None, pad_to_multiple_of=None)

Pad the fixed mask to the same length.

Source code in atomgen/data/tokenizer.py
def pad_fixed(
    self,
    batch: Union[Mapping[str, Any], List[Mapping[str, Any]]],
    max_length: Optional[int] = None,
    pad_to_multiple_of: Optional[int] = None,
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
    """Pad the fixed mask to the same length."""
    if isinstance(batch, Mapping):
        fixed_keys = [
            key for key in batch if key.startswith("fixed") or key.endswith("fixed")
        ]
    elif isinstance(batch, list):
        fixed_keys = [
            key
            for key in batch[0]
            if key.startswith("fixed") or key.endswith("fixed")
        ]
    for key in fixed_keys:
        if isinstance(batch, Mapping):
            fixed = batch[key]
        elif isinstance(batch, list):
            fixed = [sample[key] for sample in batch]
        max_length = (
            max([len(c) for c in fixed]) if max_length is None else max_length
        )
        if pad_to_multiple_of is not None and max_length % pad_to_multiple_of != 0:
            max_length = (
                (max_length // pad_to_multiple_of) + 1
            ) * pad_to_multiple_of
        for f in fixed:
            f.extend([True] * (max_length - len(f)))
        if isinstance(batch, list):
            for i, sample in enumerate(batch):
                sample[key] = fixed[i]  # type: ignore[index]
    return batch

save_vocabulary

save_vocabulary(save_directory, filename_prefix=None)

Save the vocabulary to a json file.

Source code in atomgen/data/tokenizer.py
def save_vocabulary(
    self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
    """Save the vocabulary to a json file."""
    vocab_file = os.path.join(
        save_directory,
        (filename_prefix + "-" if filename_prefix else "")
        + VOCAB_FILES_NAMES["vocab_file"],
    )
    with open(vocab_file, "w") as f:
        json.dump(self.vocab, f)
    return (vocab_file,)

from_pretrained classmethod

from_pretrained(*inputs, **kwargs)

Load the tokenizer from a pretrained model.

Source code in atomgen/data/tokenizer.py
@classmethod
def from_pretrained(cls, *inputs: Any, **kwargs: Any) -> Any:
    """Load the tokenizer from a pretrained model."""
    return super().from_pretrained(*inputs, **kwargs)

build_inputs_with_special_tokens

build_inputs_with_special_tokens(
    token_ids_0, token_ids_1=None
)

Build the input with special tokens.

Source code in atomgen/data/tokenizer.py
def build_inputs_with_special_tokens(
    self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
    """Build the input with special tokens."""
    bos = [self.bos_token_id]
    eos = [self.eos_token_id]

    if token_ids_1 is None:
        return bos + token_ids_0 + eos
    return bos + token_ids_0 + eos + token_ids_1 + eos

utils

Utilities for data processing and evaluation.

compute_metrics_smp

compute_metrics_smp(eval_pred)

Compute MAE for 20 regression labels for the SMP task.

Source code in atomgen/data/utils.py
def compute_metrics_smp(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute MAE for 20 regression labels for the SMP task."""
    pred = eval_pred.predictions
    label = eval_pred.label_ids

    # Ensure predictions and labels are arrays, not tuples
    if isinstance(pred, tuple):
        pred = pred[0]
    if isinstance(label, tuple):
        label = label[0]

    # get Mean absolute error for each 20 pred and labels
    maes: Dict[str, Optional[float]] = {
        "rot_const_A": None,
        "rot_const_B": None,
        "rot_const_C": None,
        "dipole_moment": None,
        "isotropic_polarizability": None,
        "HOMO": None,
        "LUMO": None,
        "gap": None,
        "electronic_spatial_extent": None,
        "zero_point_vib_energy": None,
        "internal_energy_0K": None,
        "internal_energy_298.15K": None,
        "enthalpy_298.15K": None,
        "free_energy_298.15K": None,
        "heat_capacity_298.15K": None,
        "thermochem_internal_energy_0K": None,
        "thermochem_internal_energy_298.15K": None,
        "thermochem_enthalpy_298.15K": None,
        "thermochem_free_energy_298.15K": None,
        "thermochem_heat_capacity_298.15K": None,
    }
    for i in range(20):
        value = float(np.mean(np.abs(pred[:, i] - label[:, i])))
        maes[list(maes.keys())[i]] = value

    return maes

compute_metrics_ppi

compute_metrics_ppi(eval_pred)

Compute AUROC for the PIP task.

Source code in atomgen/data/utils.py
def compute_metrics_ppi(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute AUROC for the PIP task."""
    predictions = eval_pred.predictions
    label = eval_pred.label_ids

    # Ensure predictions and labels are arrays, not tuples
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if isinstance(label, tuple):
        label = label[0]

    pred = expit(predictions > 0.5)

    # compute AUROC for each label
    for i in range(20):
        auroc = roc_auc_score(label[:, i], pred[:, i])

    return {"auroc": auroc}

compute_metrics_res

compute_metrics_res(eval_pred)

Compute accuracy for the RES task.

Source code in atomgen/data/utils.py
def compute_metrics_res(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute accuracy for the RES task."""
    pred = softmax(eval_pred.predictions).argmax(axis=1)
    label = eval_pred.label_ids

    # compute accuracy

    acc = accuracy_score(label, pred)

    return {"accuracy": acc}

compute_metrics_msp

compute_metrics_msp(eval_pred)

Compute AUROC for the MSP task.

Source code in atomgen/data/utils.py
def compute_metrics_msp(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute AUROC for the MSP task."""
    pred = eval_pred.predictions
    label = eval_pred.label_ids

    # compute AUROC for each label
    auroc = roc_auc_score(label, pred)

    return {"auroc": auroc}

compute_metrics_lba

compute_metrics_lba(eval_pred)

Compute RMSE for the LBA task.

Source code in atomgen/data/utils.py
def compute_metrics_lba(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute RMSE for the LBA task."""
    pred = eval_pred.predictions
    label = eval_pred.label_ids

    # Ensure predictions and labels are arrays, not tuples
    if isinstance(pred, tuple):
        pred = pred[0]
    if isinstance(label, tuple):
        label = label[0]

    # compute RMSE for each label
    rmse = float(np.sqrt(np.mean((pred - label) ** 2)))
    global_pearson = float(pearsonr(pred.flatten(), label.flatten())[0])
    global_spearman = float(spearmanr(pred.flatten(), label.flatten())[0])

    return {
        "rmse": rmse,
        "global_pearson": global_pearson,
        "global_spearman": global_spearman,
    }

compute_metrics_lep

compute_metrics_lep(eval_pred)

Compute AUROC for the LEP task.

Source code in atomgen/data/utils.py
def compute_metrics_lep(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute AUROC for the LEP task."""
    pred = expit(eval_pred.predictions) > 0.5
    label = eval_pred.label_ids

    # compute AUROC for each label
    auroc = roc_auc_score(label, pred)

    return {"auroc": auroc}

compute_metrics_psr

compute_metrics_psr(eval_pred)

Compute global spearman correlation for the PSR task.

Source code in atomgen/data/utils.py
def compute_metrics_psr(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute global spearman correlation for the PSR task."""
    pred = eval_pred.predictions
    label = eval_pred.label_ids

    # Ensure predictions and labels are arrays, not tuples
    if isinstance(pred, tuple):
        pred = pred[0]
    if isinstance(label, tuple):
        label = label[0]

    # compute global spearman correlation
    global_spearman = float(spearmanr(pred.flatten(), label.flatten())[0])

    return {"global_spearman": global_spearman}

compute_metrics_rsr

compute_metrics_rsr(eval_pred)

Compute global spearman correlation for the RSR task.

Source code in atomgen/data/utils.py
def compute_metrics_rsr(eval_pred: EvalPrediction) -> Dict[str, Any]:
    """Compute global spearman correlation for the RSR task."""
    pred = eval_pred.predictions
    label = eval_pred.label_ids

    # Ensure predictions and labels are arrays, not tuples
    if isinstance(pred, tuple):
        pred = pred[0]
    if isinstance(label, tuple):
        label = label[0]

    # compute global spearman correlation
    global_spearman = float(spearmanr(pred.flatten(), label.flatten())[0])

    return {"global_spearman": global_spearman}

Models module for the AtomGen library.

This module contains the model classes and functions for training and inference.

configuration_atomformer

Configuration class for Atomformer.

AtomformerConfig

Bases: PretrainedConfig

Configuration of a :class:~transform:class:~transformers.AtomformerModel`.

It is used to instantiate an Atomformer model according to the specified arguments.

Source code in atomgen/models/configuration_atomformer.py
class AtomformerConfig(PretrainedConfig):
    r"""
    Configuration of a :class:`~transform:class:`~transformers.AtomformerModel`.

    It is used to instantiate an Atomformer model according to the specified arguments.
    """

    model_type = "atomformer"

    def __init__(
        self,
        vocab_size: int = 123,
        dim: int = 768,
        num_heads: int = 32,
        depth: int = 12,
        mlp_ratio: int = 1,
        k: int = 128,
        dropout: float = 0.0,
        mask_token_id: int = 0,
        pad_token_id: int = 119,
        bos_token_id: int = 120,
        eos_token_id: int = 121,
        cls_token_id: int = 122,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)  # type: ignore[no-untyped-call]
        self.vocab_size = vocab_size
        self.dim = dim
        self.num_heads = num_heads
        self.depth = depth
        self.mlp_ratio = mlp_ratio
        self.k = k

        self.dropout = dropout
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.cls_token_id = cls_token_id

modeling_atomformer

Implementation of the Atomformer model.

GaussianLayer

Bases: Module

Gaussian pairwise positional embedding layer.

This layer computes the Gaussian positional embeddings for the pairwise distances between atoms in a molecule.

Taken from: https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py

Source code in atomgen/models/modeling_atomformer.py
class GaussianLayer(nn.Module):
    """Gaussian pairwise positional embedding layer.

    This layer computes the Gaussian positional embeddings for the pairwise distances
    between atoms in a molecule.

    Taken from: https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py
    """

    def __init__(self, k: int = 128, edge_types: int = 1024):
        super().__init__()
        self.k = k
        self.means = nn.Embedding(1, k)
        self.stds = nn.Embedding(1, k)
        self.mul = nn.Embedding(edge_types, 1)
        self.bias = nn.Embedding(edge_types, 1)
        nn.init.uniform_(self.means.weight, 0, 3)
        nn.init.uniform_(self.stds.weight, 0, 3)
        nn.init.constant_(self.bias.weight, 0)
        nn.init.constant_(self.mul.weight, 1)

    def forward(self, x: torch.Tensor, edge_types: int) -> torch.Tensor:
        """Forward pass to compute the Gaussian pos. embeddings."""
        mul = self.mul(edge_types)
        bias = self.bias(edge_types)
        x = mul * x.unsqueeze(-1) + bias
        x = x.expand(-1, -1, -1, self.k)
        mean = self.means.weight.float().view(-1)
        std = self.stds.weight.float().view(-1).abs() + 1e-5
        output: torch.Tensor = gaussian(x.float(), mean, std).type_as(self.means.weight)
        return output

forward

forward(x, edge_types)

Forward pass to compute the Gaussian pos. embeddings.

Source code in atomgen/models/modeling_atomformer.py
def forward(self, x: torch.Tensor, edge_types: int) -> torch.Tensor:
    """Forward pass to compute the Gaussian pos. embeddings."""
    mul = self.mul(edge_types)
    bias = self.bias(edge_types)
    x = mul * x.unsqueeze(-1) + bias
    x = x.expand(-1, -1, -1, self.k)
    mean = self.means.weight.float().view(-1)
    std = self.stds.weight.float().view(-1).abs() + 1e-5
    output: torch.Tensor = gaussian(x.float(), mean, std).type_as(self.means.weight)
    return output

ParallelBlock

Bases: Module

Parallel transformer block (MLP & Attention in parallel).

Based on: 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442

Adapted from TIMM implementation.

Source code in atomgen/models/modeling_atomformer.py
class ParallelBlock(nn.Module):
    """Parallel transformer block (MLP & Attention in parallel).

    Based on:
      'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442

    Adapted from TIMM implementation.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: int = 4,
        dropout: float = 0.0,
        k: int = 128,
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        assert dim % num_heads == 0, (
            f"dim {dim} should be divisible by num_heads {num_heads}"
        )
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.mlp_hidden_dim = int(mlp_ratio * dim)
        self.proj_drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)
        self.gradient_checkpointing = gradient_checkpointing

        self.in_proj_in_dim = dim
        self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim
        self.out_proj_in_dim = self.mlp_hidden_dim + dim
        self.out_proj_out_dim = 2 * dim

        self.in_split = [self.mlp_hidden_dim] + [dim] * 3
        self.out_split = [dim] * 2

        self.in_norm = nn.LayerNorm(dim)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False)
        self.act_fn = nn.GELU()
        self.out_proj = nn.Linear(
            self.out_proj_in_dim, self.out_proj_out_dim, bias=False
        )
        self.gaussian_proj = nn.Linear(k, 1)
        self.pos_embed_ff_norm = nn.LayerNorm(k)

    def forward(
        self,
        x: torch.Tensor,
        pos_embed: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for the parallel block."""
        b, n, c = x.shape
        res = x

        # Combined MLP fc1 & qkv projections
        x = self.in_proj(self.in_norm(x))
        x, q, k, v = torch.split(x, self.in_split, dim=-1)
        x = self.act_fn(x)
        x = self.proj_drop(x)

        # Dot product attention
        q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
        k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
        v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

        x_attn = (
            f.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=attention_mask
                + self.gaussian_proj(self.pos_embed_ff_norm(pos_embed)).permute(
                    0, 3, 1, 2
                ),
                is_causal=False,
            )
            .transpose(1, 2)
            .reshape(b, n, c)
        )

        # Combined MLP fc2 & attn_output projection
        x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split(
            self.out_split, dim=-1
        )
        # Residual connections
        x = x_mlp + x_attn + res
        del x_mlp, x_attn, res

        return x, pos_embed

forward

forward(x, pos_embed, attention_mask=None)

Forward pass for the parallel block.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    x: torch.Tensor,
    pos_embed: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Forward pass for the parallel block."""
    b, n, c = x.shape
    res = x

    # Combined MLP fc1 & qkv projections
    x = self.in_proj(self.in_norm(x))
    x, q, k, v = torch.split(x, self.in_split, dim=-1)
    x = self.act_fn(x)
    x = self.proj_drop(x)

    # Dot product attention
    q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
    k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
    v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

    x_attn = (
        f.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attention_mask
            + self.gaussian_proj(self.pos_embed_ff_norm(pos_embed)).permute(
                0, 3, 1, 2
            ),
            is_causal=False,
        )
        .transpose(1, 2)
        .reshape(b, n, c)
    )

    # Combined MLP fc2 & attn_output projection
    x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split(
        self.out_split, dim=-1
    )
    # Residual connections
    x = x_mlp + x_attn + res
    del x_mlp, x_attn, res

    return x, pos_embed

AtomformerEncoder

Bases: Module

Atomformer encoder.

The transformer encoder consists of a series of parallel blocks, each containing a multi-head self-attention mechanism and a feed-forward network.

Source code in atomgen/models/modeling_atomformer.py
class AtomformerEncoder(nn.Module):
    """Atomformer encoder.

    The transformer encoder consists of a series of parallel blocks,
    each containing a multi-head self-attention mechanism and a feed-forward network.
    """

    def __init__(self, config: AtomformerConfig):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.dim = config.dim
        self.num_heads = config.num_heads
        self.depth = config.depth
        self.mlp_ratio = config.mlp_ratio
        self.dropout = config.dropout
        self.k = config.k
        self.gradient_checkpointing = config.gradient_checkpointing

        self.metadata_vocab = nn.Embedding(self.vocab_size, 17)
        self.metadata_vocab.weight.requires_grad = False
        self.metadata_vocab.weight.fill_(-1)
        self.metadata_vocab.weight[1:-4] = torch.tensor(
            ATOM_METADATA, dtype=torch.float32
        )
        self.embed_metadata = nn.Linear(17, self.dim)

        self.gaussian_embed = GaussianLayer(
            k=self.k, edge_types=(self.vocab_size + 1) ** 2
        )

        self.token_type_embedding = nn.Embedding(4, self.dim)
        nn.init.zeros_(self.token_type_embedding.weight)

        self.embed_tokens = nn.Embedding(config.vocab_size, config.dim)
        nn.init.normal_(self.embed_tokens.weight, std=0.02)

        self.blocks = nn.ModuleList()
        for _ in range(self.depth):
            self.blocks.append(
                ParallelBlock(
                    self.dim,
                    self.num_heads,
                    self.mlp_ratio,
                    self.dropout,
                    self.k,
                    self.gradient_checkpointing,
                )
            )

    def _expand_mask(
        self,
        mask: torch.Tensor,
        dtype: torch.dtype,
        device: torch.device,
        tgt_len: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Expand attention mask.

        Expands attention_mask from `[bsz, seq_len]` to
        `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        expanded_mask = (
            mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        )

        inverted_mask: torch.Tensor = 1.0 - expanded_mask

        return inverted_mask.masked_fill(
            inverted_mask.to(torch.bool), torch.finfo(dtype).min
        ).to(device)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for the transformer encoder."""
        # pad coords by zeros for graph token
        coords_center = torch.sum(coords, dim=1, keepdim=True) / coords.shape[1]
        coords = torch.cat([coords_center, coords], dim=1)

        r_ij = torch.cdist(coords, coords, p=2)  # [B, N, N]
        # pad input_ids by graph token
        input_ids = torch.cat(
            [
                torch.zeros(
                    input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
                ).fill_(122),
                input_ids,
            ],
            dim=1,
        )
        edge_type = input_ids.unsqueeze(-1) * self.vocab_size + input_ids.unsqueeze(
            -2
        )  # [B, N, N]
        pos_embeds = self.gaussian_embed(r_ij, edge_type)  # [B, N, N, K]

        input_embeds = self.embed_tokens(input_ids)
        if token_type_ids is not None:
            token_type_ids = torch.cat(
                [
                    torch.empty(
                        input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
                    ).fill_(3),
                    token_type_ids,
                ],
                dim=1,
            )
            token_type_embeddings = self.token_type_embedding(token_type_ids)
            input_embeds = input_embeds + token_type_embeddings

        atom_metadata = self.metadata_vocab(input_ids)
        input_embeds = input_embeds + self.embed_metadata(atom_metadata)  # [B, N, C]

        attention_mask = (
            torch.cat(
                [
                    torch.ones(
                        attention_mask.size(0),
                        1,
                        dtype=torch.bool,
                        device=attention_mask.device,
                    ),
                    attention_mask.bool(),
                ],
                dim=1,
            )
            if attention_mask is not None
            else None
        )

        attention_mask = (
            self._expand_mask(attention_mask, input_embeds.dtype, input_embeds.device)
            if attention_mask is not None
            else None
        )

        for blk in self.blocks:
            input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)

        return input_embeds, pos_embeds

forward

forward(
    input_ids,
    coords,
    attention_mask=None,
    token_type_ids=None,
)

Forward pass for the transformer encoder.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Forward pass for the transformer encoder."""
    # pad coords by zeros for graph token
    coords_center = torch.sum(coords, dim=1, keepdim=True) / coords.shape[1]
    coords = torch.cat([coords_center, coords], dim=1)

    r_ij = torch.cdist(coords, coords, p=2)  # [B, N, N]
    # pad input_ids by graph token
    input_ids = torch.cat(
        [
            torch.zeros(
                input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
            ).fill_(122),
            input_ids,
        ],
        dim=1,
    )
    edge_type = input_ids.unsqueeze(-1) * self.vocab_size + input_ids.unsqueeze(
        -2
    )  # [B, N, N]
    pos_embeds = self.gaussian_embed(r_ij, edge_type)  # [B, N, N, K]

    input_embeds = self.embed_tokens(input_ids)
    if token_type_ids is not None:
        token_type_ids = torch.cat(
            [
                torch.empty(
                    input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
                ).fill_(3),
                token_type_ids,
            ],
            dim=1,
        )
        token_type_embeddings = self.token_type_embedding(token_type_ids)
        input_embeds = input_embeds + token_type_embeddings

    atom_metadata = self.metadata_vocab(input_ids)
    input_embeds = input_embeds + self.embed_metadata(atom_metadata)  # [B, N, C]

    attention_mask = (
        torch.cat(
            [
                torch.ones(
                    attention_mask.size(0),
                    1,
                    dtype=torch.bool,
                    device=attention_mask.device,
                ),
                attention_mask.bool(),
            ],
            dim=1,
        )
        if attention_mask is not None
        else None
    )

    attention_mask = (
        self._expand_mask(attention_mask, input_embeds.dtype, input_embeds.device)
        if attention_mask is not None
        else None
    )

    for blk in self.blocks:
        input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)

    return input_embeds, pos_embeds

AtomformerPreTrainedModel

Bases: PreTrainedModel

Base class for all transformer models.

Source code in atomgen/models/modeling_atomformer.py
class AtomformerPreTrainedModel(PreTrainedModel):
    """Base class for all transformer models."""

    config_class = AtomformerConfig  # type: ignore[assignment]
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["ParallelBlock"]

    def _set_gradient_checkpointing(  # type: ignore[override]
        self, module: nn.Module, value: bool = False
    ) -> None:
        if isinstance(module, (AtomformerEncoder)):
            module.gradient_checkpointing = value

AtomformerModel

Bases: AtomformerPreTrainedModel

Atomformer model for atom modeling.

Source code in atomgen/models/modeling_atomformer.py
class AtomformerModel(AtomformerPreTrainedModel):
    """Atomformer model for atom modeling."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward function call for the transformer model."""
        output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
        return output

forward

forward(input_ids, coords, attention_mask=None)

Forward function call for the transformer model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward function call for the transformer model."""
    output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
    return output

AtomformerForMaskedAM

Bases: AtomformerPreTrainedModel

Atomformer with an atom modeling head on top for masked atom modeling.

Source code in atomgen/models/modeling_atomformer.py
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
    """Atomformer with an atom modeling head on top for masked atom modeling."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the masked atom modeling model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        logits = self.am_head(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1)
            loss = loss_fct(logits, labels)

        return loss, logits

forward

forward(
    input_ids,
    coords,
    labels=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the masked atom modeling model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the masked atom modeling model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    logits = self.am_head(hidden_states)

    loss = None
    if labels is not None:
        loss_fct = nn.CrossEntropyLoss()
        logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1)
        loss = loss_fct(logits, labels)

    return loss, logits

AtomformerForCoordinateAM

Bases: AtomformerPreTrainedModel

Atomformer with an atom coordinate head on top for coordinate denoising.

Source code in atomgen/models/modeling_atomformer.py
class AtomformerForCoordinateAM(AtomformerPreTrainedModel):
    """Atomformer with an atom coordinate head on top for coordinate denoising."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.coords_head = nn.Linear(config.dim, 3)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the coordinate atom modeling model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        coords_pred = self.coords_head(hidden_states)

        loss = None
        if labels_coords is not None:
            labels_coords = labels_coords.to(coords_pred.device)
            loss_fct = nn.L1Loss()
            loss = loss_fct(coords_pred, labels_coords)

        return loss, coords_pred

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the coordinate atom modeling model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the coordinate atom modeling model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    coords_pred = self.coords_head(hidden_states)

    loss = None
    if labels_coords is not None:
        labels_coords = labels_coords.to(coords_pred.device)
        loss_fct = nn.L1Loss()
        loss = loss_fct(coords_pred, labels_coords)

    return loss, coords_pred

InitialStructure2RelaxedStructure

Bases: AtomformerPreTrainedModel

Atomformer with an coordinate head on top for relaxed structure prediction.

Source code in atomgen/models/modeling_atomformer.py
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel):
    """Atomformer with an coordinate head on top for relaxed structure prediction."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.coords_head = nn.Linear(config.dim, 3)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call.

        Initial structure to relaxed structure model.
        """
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        coords_pred = self.coords_head(hidden_states)

        loss = None
        if labels_coords is not None:
            labels_coords = labels_coords.to(coords_pred.device)
            loss_fct = nn.L1Loss()
            loss = loss_fct(coords_pred, labels_coords)

        return loss, coords_pred

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    fixed=None,
    attention_mask=None,
)

Forward function call.

Initial structure to relaxed structure model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call.

    Initial structure to relaxed structure model.
    """
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    coords_pred = self.coords_head(hidden_states)

    loss = None
    if labels_coords is not None:
        labels_coords = labels_coords.to(coords_pred.device)
        loss_fct = nn.L1Loss()
        loss = loss_fct(coords_pred, labels_coords)

    return loss, coords_pred

InitialStructure2RelaxedEnergy

Bases: AtomformerPreTrainedModel

Atomformer with an energy head on top for relaxed energy prediction.

Source code in atomgen/models/modeling_atomformer.py
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel):
    """Atomformer with an energy head on top for relaxed energy prediction."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_energy: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the relaxed energy prediction model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1)

        loss = None
        if labels_energy is not None:
            loss_fct = nn.L1Loss()
            loss = loss_fct(energy, labels_energy)

        return loss, energy

forward

forward(
    input_ids,
    coords,
    labels_energy=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the relaxed energy prediction model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_energy: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the relaxed energy prediction model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1)

    loss = None
    if labels_energy is not None:
        loss_fct = nn.L1Loss()
        loss = loss_fct(energy, labels_energy)

    return loss, energy

InitialStructure2RelaxedStructureAndEnergy

Bases: AtomformerPreTrainedModel

Atomformer with an coordinate and energy head.

Source code in atomgen/models/modeling_atomformer.py
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel):
    """Atomformer with an coordinate and energy head."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.energy_head = nn.Linear(config.dim, 1, bias=False)
        self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)
        self.coords_head = nn.Linear(config.dim, 3)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward function call for the relaxed structure and energy model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask
        )

        formation_energy_pred = self.formation_energy_head(
            self.energy_norm(atom_hidden_states[:, 0])
        ).squeeze(-1)
        loss_formation_energy = None
        if formation_energy is not None:
            loss_fct = nn.L1Loss()
            loss_formation_energy = loss_fct(
                formation_energy_pred[has_formation_energy],
                formation_energy[has_formation_energy],
            )
        coords_pred = self.coords_head(atom_hidden_states[:, 1:])
        loss_coords = None
        if labels_coords is not None:
            loss_fct = nn.L1Loss()
            loss_coords = loss_fct(coords_pred, labels_coords)

        loss = torch.tensor(0.0).to(coords.device)
        loss = (
            loss + loss_formation_energy if loss_formation_energy is not None else loss
        )
        loss = loss + loss_coords if loss_coords is not None else loss

        return loss, (formation_energy_pred, coords_pred)

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
)

Forward function call for the relaxed structure and energy model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """Forward function call for the relaxed structure and energy model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask
    )

    formation_energy_pred = self.formation_energy_head(
        self.energy_norm(atom_hidden_states[:, 0])
    ).squeeze(-1)
    loss_formation_energy = None
    if formation_energy is not None:
        loss_fct = nn.L1Loss()
        loss_formation_energy = loss_fct(
            formation_energy_pred[has_formation_energy],
            formation_energy[has_formation_energy],
        )
    coords_pred = self.coords_head(atom_hidden_states[:, 1:])
    loss_coords = None
    if labels_coords is not None:
        loss_fct = nn.L1Loss()
        loss_coords = loss_fct(coords_pred, labels_coords)

    loss = torch.tensor(0.0).to(coords.device)
    loss = (
        loss + loss_formation_energy if loss_formation_energy is not None else loss
    )
    loss = loss + loss_coords if loss_coords is not None else loss

    return loss, (formation_energy_pred, coords_pred)

Structure2Energy

Bases: AtomformerPreTrainedModel

Atomformer with an atom modeling head on top for masked atom modeling.

Source code in atomgen/models/modeling_atomformer.py
class Structure2Energy(AtomformerPreTrainedModel):
    """Atomformer with an atom modeling head on top for masked atom modeling."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]:
        """Forward function call for the structure to energy model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask
        )

        formation_energy_pred: torch.Tensor = self.formation_energy_head(
            self.energy_norm(atom_hidden_states[:, 0])
        ).squeeze(-1)
        loss = torch.tensor(0.0).to(coords.device)
        if formation_energy is not None:
            loss_fct = nn.L1Loss()
            loss = loss_fct(
                formation_energy_pred[has_formation_energy],
                formation_energy[has_formation_energy],
            )

        return loss, (
            formation_energy_pred,
            attention_mask.bool() if attention_mask is not None else None,
        )

forward

forward(
    input_ids,
    coords,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
)

Forward function call for the structure to energy model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]:
    """Forward function call for the structure to energy model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask
    )

    formation_energy_pred: torch.Tensor = self.formation_energy_head(
        self.energy_norm(atom_hidden_states[:, 0])
    ).squeeze(-1)
    loss = torch.tensor(0.0).to(coords.device)
    if formation_energy is not None:
        loss_fct = nn.L1Loss()
        loss = loss_fct(
            formation_energy_pred[has_formation_energy],
            formation_energy[has_formation_energy],
        )

    return loss, (
        formation_energy_pred,
        attention_mask.bool() if attention_mask is not None else None,
    )

Structure2Forces

Bases: AtomformerPreTrainedModel

Atomformer with a forces head on top for forces prediction.

Source code in atomgen/models/modeling_atomformer.py
class Structure2Forces(AtomformerPreTrainedModel):
    """Atomformer with a forces head on top for forces prediction."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.force_norm = nn.LayerNorm(config.dim)
        self.force_head = nn.Linear(config.dim, 3)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
        """Forward function call for the structure to forces model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask
        )
        attention_mask = attention_mask.bool() if attention_mask is not None else None

        forces_pred: torch.Tensor = self.force_head(
            self.force_norm(atom_hidden_states[:, 1:])
        )
        loss = torch.tensor(0.0).to(coords.device)
        if forces is not None:
            loss_fct = nn.L1Loss()
            loss = loss_fct(forces_pred[attention_mask], forces[attention_mask])

        return loss, (
            forces_pred,
            attention_mask if attention_mask is not None else None,
        )

forward

forward(
    input_ids,
    coords,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
)

Forward function call for the structure to forces model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
    """Forward function call for the structure to forces model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask
    )
    attention_mask = attention_mask.bool() if attention_mask is not None else None

    forces_pred: torch.Tensor = self.force_head(
        self.force_norm(atom_hidden_states[:, 1:])
    )
    loss = torch.tensor(0.0).to(coords.device)
    if forces is not None:
        loss_fct = nn.L1Loss()
        loss = loss_fct(forces_pred[attention_mask], forces[attention_mask])

    return loss, (
        forces_pred,
        attention_mask if attention_mask is not None else None,
    )

Structure2EnergyAndForces

Bases: AtomformerPreTrainedModel

Atomformer with an energy and forces head for energy and forces prediction.

Source code in atomgen/models/modeling_atomformer.py
class Structure2EnergyAndForces(AtomformerPreTrainedModel):
    """Atomformer with an energy and forces head for energy and forces prediction."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.force_norm = nn.LayerNorm(config.dim)
        self.force_head = nn.Linear(config.dim, 3)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
        """Forward function call for the structure to energy and forces model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask
        )

        formation_energy_pred: torch.Tensor = self.formation_energy_head(
            self.energy_norm(atom_hidden_states[:, 0])
        ).squeeze(-1)
        loss_formation_energy = None
        if formation_energy is not None:
            loss_fct = nn.L1Loss()
            loss_formation_energy = loss_fct(
                formation_energy_pred[has_formation_energy],
                formation_energy[has_formation_energy],
            )
            loss = loss_formation_energy
        attention_mask = attention_mask.bool() if attention_mask is not None else None
        forces_pred: torch.Tensor = self.force_head(
            self.force_norm(atom_hidden_states[:, 1:])
        )
        loss_forces = None
        if forces is not None:
            loss_fct = nn.L1Loss()
            loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask])
            loss = loss + loss_forces if loss is not None else loss_forces

        return loss, (formation_energy_pred, forces_pred, attention_mask)

forward

forward(
    input_ids,
    coords,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
)

Forward function call for the structure to energy and forces model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
    """Forward function call for the structure to energy and forces model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask
    )

    formation_energy_pred: torch.Tensor = self.formation_energy_head(
        self.energy_norm(atom_hidden_states[:, 0])
    ).squeeze(-1)
    loss_formation_energy = None
    if formation_energy is not None:
        loss_fct = nn.L1Loss()
        loss_formation_energy = loss_fct(
            formation_energy_pred[has_formation_energy],
            formation_energy[has_formation_energy],
        )
        loss = loss_formation_energy
    attention_mask = attention_mask.bool() if attention_mask is not None else None
    forces_pred: torch.Tensor = self.force_head(
        self.force_norm(atom_hidden_states[:, 1:])
    )
    loss_forces = None
    if forces is not None:
        loss_fct = nn.L1Loss()
        loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask])
        loss = loss + loss_forces if loss is not None else loss_forces

    return loss, (formation_energy_pred, forces_pred, attention_mask)

Structure2TotalEnergyAndForces

Bases: AtomformerPreTrainedModel

Atomformer with an energy and forces head for energy and forces prediction.

Source code in atomgen/models/modeling_atomformer.py
class Structure2TotalEnergyAndForces(AtomformerPreTrainedModel):
    """Atomformer with an energy and forces head for energy and forces prediction."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.config = config
        self.encoder = AtomformerEncoder(config)
        self.force_norm = nn.LayerNorm(config.dim)
        self.force_head = nn.Linear(config.dim, 3, bias=False)
        nn.init.zeros_(self.force_head.weight)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.total_energy_head = nn.Linear(config.dim, 1, bias=False)
        nn.init.zeros_(self.total_energy_head.weight)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[
        Optional[torch.Tensor],
        Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]],
    ]:
        """Forward function call for the structure to energy and forces model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask
        )

        loss = None
        total_energy_pred: torch.Tensor = self.total_energy_head(
            self.energy_norm(atom_hidden_states[:, 0])
        ).squeeze(-1)
        loss_total_energy = None
        if formation_energy is not None:
            loss_fct = nn.L1Loss()
            loss_total_energy = loss_fct(
                total_energy_pred,
                total_energy,
            )
            loss = loss_total_energy
        attention_mask = attention_mask.bool() if attention_mask is not None else None
        forces_pred: torch.Tensor = self.force_head(
            self.force_norm(atom_hidden_states[:, 1:])
        )
        loss_forces = None
        if forces is not None:
            loss_fct = nn.L1Loss()
            loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask])
            loss = loss + loss_forces if loss is not None else loss_forces

        return loss, (total_energy_pred, forces_pred, attention_mask)

forward

forward(
    input_ids,
    coords,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
)

Forward function call for the structure to energy and forces model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[
    Optional[torch.Tensor],
    Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]],
]:
    """Forward function call for the structure to energy and forces model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask
    )

    loss = None
    total_energy_pred: torch.Tensor = self.total_energy_head(
        self.energy_norm(atom_hidden_states[:, 0])
    ).squeeze(-1)
    loss_total_energy = None
    if formation_energy is not None:
        loss_fct = nn.L1Loss()
        loss_total_energy = loss_fct(
            total_energy_pred,
            total_energy,
        )
        loss = loss_total_energy
    attention_mask = attention_mask.bool() if attention_mask is not None else None
    forces_pred: torch.Tensor = self.force_head(
        self.force_norm(atom_hidden_states[:, 1:])
    )
    loss_forces = None
    if forces is not None:
        loss_fct = nn.L1Loss()
        loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask])
        loss = loss + loss_forces if loss is not None else loss_forces

    return loss, (total_energy_pred, forces_pred, attention_mask)

AtomFormerForSystemClassification

Bases: AtomformerPreTrainedModel

Atomformer with a classification head for system classification.

Source code in atomgen/models/modeling_atomformer.py
class AtomFormerForSystemClassification(AtomformerPreTrainedModel):
    """Atomformer with a classification head for system classification."""

    def __init__(self, config: AtomformerConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.problem_type = config.problem_type
        self.config = config

        self.encoder = AtomformerEncoder(config)

        self.cls_norm = nn.LayerNorm(config.dim)
        self.classification_head = nn.Linear(config.dim, self.num_labels, bias=False)
        nn.init.zeros_(self.classification_head.weight)

        self.loss_fct: Union[nn.L1Loss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss]

        if self.problem_type == "regression":
            self.loss_fct = nn.L1Loss()
        elif self.problem_type == "classification":
            self.loss_fct = nn.BCEWithLogitsLoss()
        elif self.problem_type == "multiclass_classification":
            self.loss_fct = nn.CrossEntropyLoss()

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the structure to energy and forces model."""
        atom_hidden_states, pos_hidden_states = self.encoder(
            input_ids, coords, attention_mask, token_type_ids
        )
        pred = self.classification_head(self.cls_norm(atom_hidden_states[:, 0]))

        loss = None
        if labels is not None:
            if self.problem_type == "multiclass_classification":
                labels = labels.long()
            elif self.problem_type == "classification":
                labels = labels.float()

            loss = self.loss_fct(pred.squeeze(), labels.squeeze())

        return loss, pred

forward

forward(
    input_ids,
    coords,
    labels=None,
    attention_mask=None,
    token_type_ids=None,
)

Forward function call for the structure to energy and forces model.

Source code in atomgen/models/modeling_atomformer.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the structure to energy and forces model."""
    atom_hidden_states, pos_hidden_states = self.encoder(
        input_ids, coords, attention_mask, token_type_ids
    )
    pred = self.classification_head(self.cls_norm(atom_hidden_states[:, 0]))

    loss = None
    if labels is not None:
        if self.problem_type == "multiclass_classification":
            labels = labels.long()
        elif self.problem_type == "classification":
            labels = labels.float()

        loss = self.loss_fct(pred.squeeze(), labels.squeeze())

    return loss, pred

gaussian

gaussian(x, mean, std)

Compute the Gaussian distribution probability density.

Taken from: https://https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py

Source code in atomgen/models/modeling_atomformer.py
@torch.jit.script
def gaussian(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
    """Compute the Gaussian distribution probability density.

    Taken from: https://https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py

    """
    pi = 3.14159
    a = (2 * pi) ** 0.5
    output: torch.Tensor = torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
    return output

schnet

SchNet model for energy prediction.

SchNetConfig

Bases: PretrainedConfig

Stores the configuration of a :class:~transformers.SchNetModel.

It is used to instantiate an SchNet model according to the specified arguments, defining the model architecture.

Args: vocab_size (:obj:int, optional, defaults to 122): The size of the vocabulary, used to define the size of the output embeddings.

hidden_channels (:obj:`int`, `optional`, defaults to 128):
    The hidden size of the model.

model_type = "transformer"

Attributes:

Name Type Description
vocab_size ( obj:`int`):
The size of the vocabulary, used to define
the size of the output embeddings.

hidden_channels (:obj:int): The hidden size of the model.

num_filters (:obj:int): The number of filters.

num_interactions (:obj:int): The number of interactions.

num_gaussians (:obj:int): The number of gaussians.

cutoff (:obj:float): The cutoff value.

interaction_graph (:obj:str, optional): The interaction graph.

max_num_neighbors (:obj:int): The maximum number of neighbors.

readout (:obj:str, optional): The readout method.

dipole (:obj:bool, optional): Whether to include dipole.

mean (:obj:float, optional): The mean value.

std (:obj:float, optional): The standard deviation value.

atomref (:obj:float, optional): The atom reference value.

mask_token_id (:obj:int, optional): The token ID for masking.

pad_token_id (:obj:int, optional): The token ID for padding.

bos_token_id (:obj:int, optional): The token ID for the beginning of sequence.

eos_token_id (:obj:int, optional): The token ID for the end of sequence.

Source code in atomgen/models/schnet.py
class SchNetConfig(PretrainedConfig):
    r"""
    Stores the configuration of a :class:`~transformers.SchNetModel`.

    It is used to instantiate an SchNet model according to the specified arguments,
    defining the model architecture.

    Args:
        vocab_size (:obj:`int`, `optional`, defaults to 122):
            The size of the vocabulary, used to define the size
            of the output embeddings.

        hidden_channels (:obj:`int`, `optional`, defaults to 128):
            The hidden size of the model.

    model_type = "transformer"

    Attributes
    ----------
        vocab_size (:obj:`int`):
            The size of the vocabulary, used to define
            the size of the output embeddings.

        hidden_channels (:obj:`int`):
            The hidden size of the model.

        num_filters (:obj:`int`):
            The number of filters.

        num_interactions (:obj:`int`):
            The number of interactions.

        num_gaussians (:obj:`int`):
            The number of gaussians.

        cutoff (:obj:`float`):
            The cutoff value.

        interaction_graph (:obj:`str`, `optional`):
            The interaction graph.

        max_num_neighbors (:obj:`int`):
            The maximum number of neighbors.

        readout (:obj:`str`, `optional`):
            The readout method.

        dipole (:obj:`bool`, `optional`):
            Whether to include dipole.

        mean (:obj:`float`, `optional`):
            The mean value.

        std (:obj:`float`, `optional`):
            The standard deviation value.

        atomref (:obj:`float`, `optional`):
            The atom reference value.

        mask_token_id (:obj:`int`, `optional`):
            The token ID for masking.

        pad_token_id (:obj:`int`, `optional`):
            The token ID for padding.

        bos_token_id (:obj:`int`, `optional`):
            The token ID for the beginning of sequence.

        eos_token_id (:obj:`int`, `optional`):
            The token ID for the end of sequence.

    """

    def __init__(
        self,
        vocab_size: int = 123,
        hidden_channels: int = 128,
        num_filters: int = 128,
        num_interactions: int = 6,
        num_gaussians: int = 50,
        cutoff: float = 10.0,
        interaction_graph: Optional[Callable[..., Any]] = None,
        max_num_neighbors: int = 32,
        readout: str = "add",
        dipole: bool = False,
        mean: Optional[float] = None,
        std: Optional[float] = None,
        atomref: Optional[OptTensor] = None,
        mask_token_id: int = 0,
        pad_token_id: int = 119,
        bos_token_id: int = 120,
        eos_token_id: int = 121,
        cls_token_id: int = 122,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)  # type: ignore[no-untyped-call]
        self.vocab_size = vocab_size
        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.num_gaussians = num_gaussians
        self.cutoff = cutoff
        self.interaction_graph = interaction_graph
        self.max_num_neighbors = max_num_neighbors
        self.readout = readout
        self.dipole = dipole
        self.mean = mean
        self.std = std
        self.atomref = atomref
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.cls_token_id = cls_token_id

SchNetPreTrainedModel

Bases: PreTrainedModel

A base class for all SchNet models.

An abstract class to handle weights initialization and a simple interface for loading and exporting models.

Source code in atomgen/models/schnet.py
class SchNetPreTrainedModel(PreTrainedModel):
    """
    A base class for all SchNet models.

    An abstract class to handle weights initialization and a
    simple interface for loading and exporting models.
    """

    config_class = SchNetConfig  # type: ignore[assignment]
    base_model_prefix = "model"
    supports_gradient_checkpointing = False

SchNetModel

Bases: SchNetPreTrainedModel

SchNet model for energy prediction.

Args: config (:class:~transformers.SchNetConfig): Configuration class to store the configuration of a model.

Source code in atomgen/models/schnet.py
class SchNetModel(SchNetPreTrainedModel):
    """
    SchNet model for energy prediction.

    Args:
        config (:class:`~transformers.SchNetConfig`):
            Configuration class to store the configuration of a model.
    """

    def __init__(self, config: SchNetConfig):
        super().__init__(config)
        self.config = config
        self.model = SchNet(
            hidden_channels=config.hidden_channels,
            num_filters=config.num_filters,
            num_interactions=config.num_interactions,
            num_gaussians=config.num_gaussians,
            cutoff=config.cutoff,
            interaction_graph=config.interaction_graph,
            max_num_neighbors=config.max_num_neighbors,
            readout=config.readout,
            dipole=config.dipole,
            mean=config.mean,
            std=config.std,
            atomref=config.atomref,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        batch: torch.Tensor,
        labels_energy: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """
        Forward pass of the SchNet model.

        Args:
            input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, num_atoms)`):
                The input tensor containing the atom indices.

            coords (:obj:`torch.Tensor` of shape :obj:`(num_atoms, 3)`):
                The input tensor containing the atom coordinates.

            batch (:obj:`torch.Tensor` of shape :obj:`(num_atoms)`):
                The input tensor containing the batch indices.

            labels_energy (:obj:`torch.Tensor`, `optional`):
                The input tensor containing the energy labels.

            fixed (:obj:`torch.Tensor`, `optional`):
                The input tensor containing the fixed mask.

            attention_mask (:obj:`torch.Tensor`, `optional`):
                The attention mask for the transformer.

        Returns
        -------
            :obj:`tuple`:
                A tuple of the loss and the energy prediction.
        """
        energy_pred: torch.Tensor = self.model(z=input_ids, pos=coords, batch=batch)

        loss = None
        if labels_energy is not None:
            labels_energy = labels_energy.to(energy_pred.device)
            loss_fct = nn.L1Loss()
            loss = loss_fct(energy_pred.squeeze(-1), labels_energy)
        return loss, energy_pred

forward

forward(
    input_ids,
    coords,
    batch,
    labels_energy=None,
    fixed=None,
    attention_mask=None,
)

Forward pass of the SchNet model.

Args: input_ids (:obj:torch.Tensor of shape :obj:(batch_size, num_atoms)): The input tensor containing the atom indices.

coords (:obj:`torch.Tensor` of shape :obj:`(num_atoms, 3)`):
    The input tensor containing the atom coordinates.

batch (:obj:`torch.Tensor` of shape :obj:`(num_atoms)`):
    The input tensor containing the batch indices.

labels_energy (:obj:`torch.Tensor`, `optional`):
    The input tensor containing the energy labels.

fixed (:obj:`torch.Tensor`, `optional`):
    The input tensor containing the fixed mask.

attention_mask (:obj:`torch.Tensor`, `optional`):
    The attention mask for the transformer.

Returns:

Type Description
:obj:`tuple`:

A tuple of the loss and the energy prediction.

Source code in atomgen/models/schnet.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    batch: torch.Tensor,
    labels_energy: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """
    Forward pass of the SchNet model.

    Args:
        input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, num_atoms)`):
            The input tensor containing the atom indices.

        coords (:obj:`torch.Tensor` of shape :obj:`(num_atoms, 3)`):
            The input tensor containing the atom coordinates.

        batch (:obj:`torch.Tensor` of shape :obj:`(num_atoms)`):
            The input tensor containing the batch indices.

        labels_energy (:obj:`torch.Tensor`, `optional`):
            The input tensor containing the energy labels.

        fixed (:obj:`torch.Tensor`, `optional`):
            The input tensor containing the fixed mask.

        attention_mask (:obj:`torch.Tensor`, `optional`):
            The attention mask for the transformer.

    Returns
    -------
        :obj:`tuple`:
            A tuple of the loss and the energy prediction.
    """
    energy_pred: torch.Tensor = self.model(z=input_ids, pos=coords, batch=batch)

    loss = None
    if labels_energy is not None:
        labels_energy = labels_energy.to(energy_pred.device)
        loss_fct = nn.L1Loss()
        loss = loss_fct(energy_pred.squeeze(-1), labels_energy)
    return loss, energy_pred

tokengt

Implementation of the TokenGT model.

ParallelBlock

Bases: Module

Parallel transformer block.

Source code in atomgen/models/tokengt.py
class ParallelBlock(nn.Module):
    """Parallel transformer block."""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: int = 4,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert dim % num_heads == 0, (
            f"dim {dim} should be divisible by num_heads {num_heads}"
        )
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.mlp_hidden_dim = int(mlp_ratio * dim)
        self.dropout = dropout
        self.proj_drop = nn.Dropout(self.dropout)

        self.in_proj_in_dim = dim
        self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim
        self.out_proj_in_dim = self.mlp_hidden_dim + dim
        self.out_proj_out_dim = 2 * dim

        self.in_split = [self.mlp_hidden_dim] + [dim] * 3
        self.out_split = [dim] * 2

        self.in_norm = nn.LayerNorm(dim)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False)
        self.in_proj = nn.Linear(dim, dim * mlp_ratio)
        self.act_fn = nn.GELU()
        self.out_proj = nn.Linear(
            self.out_proj_in_dim, self.out_proj_out_dim, bias=False
        )

    def forward(
        self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward function call for the parallel transformer block."""
        b, n, c = x.shape
        res = x
        x = self.in_norm(x)

        x = self.in_proj(self.in_norm(x))

        x, q, k, v = torch.split(x, self.in_split, dim=-1)
        x = self.act_fn(x)
        x = self.proj_drop(x)

        q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
        k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
        v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

        x_attn = (
            f.scaled_dot_product_attention(
                q, k, v, attn_mask=attention_mask, dropout_p=self.dropout
            )
            .transpose(1, 2)
            .reshape(b, n, c)
        )

        x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split(
            self.out_split, dim=-1
        )
        out: torch.Tensor = x_mlp + x_attn + res

        return out

forward

forward(x, attention_mask=None)

Forward function call for the parallel transformer block.

Source code in atomgen/models/tokengt.py
def forward(
    self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Forward function call for the parallel transformer block."""
    b, n, c = x.shape
    res = x
    x = self.in_norm(x)

    x = self.in_proj(self.in_norm(x))

    x, q, k, v = torch.split(x, self.in_split, dim=-1)
    x = self.act_fn(x)
    x = self.proj_drop(x)

    q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
    k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
    v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

    x_attn = (
        f.scaled_dot_product_attention(
            q, k, v, attn_mask=attention_mask, dropout_p=self.dropout
        )
        .transpose(1, 2)
        .reshape(b, n, c)
    )

    x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split(
        self.out_split, dim=-1
    )
    out: torch.Tensor = x_mlp + x_attn + res

    return out

TransformerConfig

Bases: PretrainedConfig

Configuration class to store the configuration of a TokenGT model.

Source code in atomgen/models/tokengt.py
class TransformerConfig(PretrainedConfig):
    """Configuration class to store the configuration of a TokenGT model."""

    def __init__(
        self,
        vocab_size: int = 123,
        dim: int = 768,
        num_heads: int = 12,
        depth: int = 12,
        mlp_ratio: int = 4,
        k: int = 16,
        sigma: float = 0.03,
        type_id_dim: int = 64,
        dropout: float = 0.0,
        mask_token_id: int = 0,
        pad_token_id: int = 119,
        bos_token_id: int = 120,
        eos_token_id: int = 121,
        cls_token_id: int = 122,
        gradient_checkpointing: bool = False,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)  # type: ignore[no-untyped-call]
        self.vocab_size = vocab_size
        self.dim = dim
        self.num_heads = num_heads
        self.depth = depth
        self.mlp_ratio = mlp_ratio
        self.k = k
        self.sigma = sigma
        self.type_id_dim = type_id_dim
        self.dropout = dropout
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.cls_token_id = cls_token_id
        self.gradient_checkpointing = gradient_checkpointing

TransformerEncoder

Bases: Module

Transformer encoder for atom modeling.

Source code in atomgen/models/tokengt.py
class TransformerEncoder(nn.Module):
    """Transformer encoder for atom modeling."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.dim = config.dim
        self.num_heads = config.num_heads
        self.depth = config.depth
        self.mlp_ratio = config.mlp_ratio
        self.k = config.k
        self.sigma = config.sigma
        self.type_id_dim = config.type_id_dim
        self.gradient_checkpointing = config.gradient_checkpointing
        self.dropout = config.dropout
        self.metadata_vocab = nn.Embedding(122, 17)
        vocab_weight = torch.empty(122, 17).fill_(-1.0)
        vocab_weight[2:-2] = torch.tensor(ATOM_METADATA, dtype=torch.float32)
        self.metadata_vocab.weight = nn.Parameter(vocab_weight, requires_grad=False)
        self.node_id = nn.Embedding(1, self.type_id_dim)
        self.edge_id = nn.Embedding(1, self.type_id_dim)
        self.embed_proj = nn.Linear(17 + 2 * self.k + self.type_id_dim, self.dim)
        self.graph = nn.Embedding(1, self.dim)
        self.distance = nn.Embedding(1, 17)
        self.distance_norm = nn.LayerNorm(17)

        self.blocks = nn.ModuleList()
        for _ in range(self.depth):
            self.blocks.append(
                ParallelBlock(self.dim, self.num_heads, self.mlp_ratio, self.dropout)
            )

    def _expand_mask(
        self,
        mask: torch.Tensor,
        dtype: torch.dtype,
        device: torch.device,
        tgt_len: Optional[int] = None,
    ) -> torch.Tensor:
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        expanded_mask = (
            mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        )

        inverted_mask: torch.Tensor = 1.0 - expanded_mask

        return inverted_mask.masked_fill(
            inverted_mask.to(torch.bool), torch.finfo(dtype).min
        ).to(device)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        node_pe: torch.Tensor,
        edge_pe: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward function call for the transformer encoder."""
        atom_metadata = self.metadata_vocab(input_ids)  # (B, N, 17)

        node_ids = self.node_id(
            torch.zeros(
                node_pe.size(0),
                node_pe.size(1),
                dtype=torch.long,
                device=node_pe.device,
            )
        )
        edge_ids = self.edge_id(
            torch.zeros(
                edge_pe.size(0),
                edge_pe.size(1),
                dtype=torch.long,
                device=edge_pe.device,
            )
        )

        graph_tokens = self.graph(
            torch.zeros(node_pe.size(0), 1, dtype=torch.long, device=node_pe.device)
        )

        nodes = torch.cat([atom_metadata, node_pe, node_ids], dim=-1)
        distance_embed = self.distance_norm(
            self.distance(
                torch.zeros(
                    edge_pe.size(0),
                    edge_pe.size(1),
                    dtype=torch.long,
                    device=edge_pe.device,
                )
            )
            * edge_pe[:, :, -1:]
        )
        edges = torch.cat([distance_embed, edge_pe[:, :, :-1], edge_ids], dim=-1)

        input_embeds: torch.Tensor = self.embed_proj(torch.cat([nodes, edges], dim=1))
        input_embeds = torch.cat([graph_tokens, input_embeds], dim=1)

        # convert attention mask from long into Boolean and add ones for graph token
        attention_mask = (
            torch.cat(
                [
                    torch.ones(
                        attention_mask.size(0),
                        1,
                        dtype=torch.bool,
                        device=attention_mask.device,
                    ),
                    attention_mask.bool(),
                ],
                dim=1,
            )
            if attention_mask is not None
            else None
        )

        for blk in self.blocks:
            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module: Any) -> Callable[..., Any]:
                    def custom_forward(*inputs: Any) -> Any:
                        return module(*inputs)

                    return custom_forward

                input_embeds = checkpoint(
                    create_custom_forward(blk),
                    input_embeds,
                    attention_mask,
                )
            else:
                input_embeds = blk(input_embeds, attention_mask)
        return input_embeds

forward

forward(
    input_ids, coords, node_pe, edge_pe, attention_mask=None
)

Forward function call for the transformer encoder.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    node_pe: torch.Tensor,
    edge_pe: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward function call for the transformer encoder."""
    atom_metadata = self.metadata_vocab(input_ids)  # (B, N, 17)

    node_ids = self.node_id(
        torch.zeros(
            node_pe.size(0),
            node_pe.size(1),
            dtype=torch.long,
            device=node_pe.device,
        )
    )
    edge_ids = self.edge_id(
        torch.zeros(
            edge_pe.size(0),
            edge_pe.size(1),
            dtype=torch.long,
            device=edge_pe.device,
        )
    )

    graph_tokens = self.graph(
        torch.zeros(node_pe.size(0), 1, dtype=torch.long, device=node_pe.device)
    )

    nodes = torch.cat([atom_metadata, node_pe, node_ids], dim=-1)
    distance_embed = self.distance_norm(
        self.distance(
            torch.zeros(
                edge_pe.size(0),
                edge_pe.size(1),
                dtype=torch.long,
                device=edge_pe.device,
            )
        )
        * edge_pe[:, :, -1:]
    )
    edges = torch.cat([distance_embed, edge_pe[:, :, :-1], edge_ids], dim=-1)

    input_embeds: torch.Tensor = self.embed_proj(torch.cat([nodes, edges], dim=1))
    input_embeds = torch.cat([graph_tokens, input_embeds], dim=1)

    # convert attention mask from long into Boolean and add ones for graph token
    attention_mask = (
        torch.cat(
            [
                torch.ones(
                    attention_mask.size(0),
                    1,
                    dtype=torch.bool,
                    device=attention_mask.device,
                ),
                attention_mask.bool(),
            ],
            dim=1,
        )
        if attention_mask is not None
        else None
    )

    for blk in self.blocks:
        if self.gradient_checkpointing and self.training:

            def create_custom_forward(module: Any) -> Callable[..., Any]:
                def custom_forward(*inputs: Any) -> Any:
                    return module(*inputs)

                return custom_forward

            input_embeds = checkpoint(
                create_custom_forward(blk),
                input_embeds,
                attention_mask,
            )
        else:
            input_embeds = blk(input_embeds, attention_mask)
    return input_embeds

TransformerPreTrainedModel

Bases: PreTrainedModel

Base class for all transformer models.

Source code in atomgen/models/tokengt.py
class TransformerPreTrainedModel(PreTrainedModel):
    """Base class for all transformer models."""

    config_class = TransformerConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["ParallelBlock"]

    def _set_gradient_checkpointing(
        self, module: nn.Module, value: bool = False
    ) -> None:
        if isinstance(module, (TransformerEncoder)):
            module.gradient_checkpointing = value

TransformerModel

Bases: TransformerPreTrainedModel

Transformer model for atom modeling.

Source code in atomgen/models/tokengt.py
class TransformerModel(TransformerPreTrainedModel):
    """Transformer model for atom modeling."""

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward function call for the transformer model."""
        out: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
        return out

forward

forward(input_ids, coords, attention_mask=None)

Forward function call for the transformer model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward function call for the transformer model."""
    out: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
    return out

TransformerForMaskedAM

Bases: TransformerPreTrainedModel

Transformer with an atom modeling head on top for masked atom modeling.

Source code in atomgen/models/tokengt.py
class TransformerForMaskedAM(TransformerPreTrainedModel):
    """Transformer with an atom modeling head on top for masked atom modeling."""

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the masked atom modeling model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        logits = self.am_head(hidden_states[:, 1 : input_ids.size(1)])

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1)
            loss = loss_fct(logits, labels)

        return loss, logits

forward

forward(
    input_ids,
    coords,
    labels=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the masked atom modeling model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the masked atom modeling model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    logits = self.am_head(hidden_states[:, 1 : input_ids.size(1)])

    loss = None
    if labels is not None:
        loss_fct = nn.CrossEntropyLoss()
        logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1)
        loss = loss_fct(logits, labels)

    return loss, logits

TransformerForCoordinateAM

Bases: TransformerPreTrainedModel

Transformer with an atom coordinate head on top for coordinate denoising.

Source code in atomgen/models/tokengt.py
class TransformerForCoordinateAM(TransformerPreTrainedModel):
    """Transformer with an atom coordinate head on top for coordinate denoising."""

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.coords_head = nn.Linear(config.dim, 3)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the coordinate atom modeling model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)])

        loss = None
        if labels_coords is not None:
            labels_coords = labels_coords.to(coords_pred.device)
            loss_fct = nn.L1Loss()
            loss = loss_fct(coords_pred, labels_coords)

        return loss, coords_pred

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the coordinate atom modeling model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the coordinate atom modeling model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)])

    loss = None
    if labels_coords is not None:
        labels_coords = labels_coords.to(coords_pred.device)
        loss_fct = nn.L1Loss()
        loss = loss_fct(coords_pred, labels_coords)

    return loss, coords_pred

InitialStructure2RelaxedStructure

Bases: TransformerPreTrainedModel

Transformer with an coordinate head on top for relaxed structure prediction.

Source code in atomgen/models/tokengt.py
class InitialStructure2RelaxedStructure(TransformerPreTrainedModel):
    """Transformer with an coordinate head on top for relaxed structure prediction."""

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.coords_head = nn.Linear(config.dim, 3)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call.

        Initial structure to relaxed structure model.
        """
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)])

        loss = None
        if labels_coords is not None:
            labels_coords = labels_coords.to(coords_pred.device)
            loss_fct = nn.L1Loss()
            loss = loss_fct(coords_pred, labels_coords)

        return loss, coords_pred

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    fixed=None,
    attention_mask=None,
)

Forward function call.

Initial structure to relaxed structure model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call.

    Initial structure to relaxed structure model.
    """
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)])

    loss = None
    if labels_coords is not None:
        labels_coords = labels_coords.to(coords_pred.device)
        loss_fct = nn.L1Loss()
        loss = loss_fct(coords_pred, labels_coords)

    return loss, coords_pred

InitialStructure2RelaxedEnergy

Bases: TransformerPreTrainedModel

Transformer with an energy head on top for relaxed energy prediction.

Source code in atomgen/models/tokengt.py
class InitialStructure2RelaxedEnergy(TransformerPreTrainedModel):
    """Transformer with an energy head on top for relaxed energy prediction."""

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_energy: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Forward function call for the initial structure to relaxed energy model."""
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1)

        loss = None
        if labels_energy is not None:
            loss_fct = nn.L1Loss()
            loss = loss_fct(energy, labels_energy)

        return loss, energy

forward

forward(
    input_ids,
    coords,
    labels_energy=None,
    fixed=None,
    attention_mask=None,
)

Forward function call for the initial structure to relaxed energy model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_energy: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
    """Forward function call for the initial structure to relaxed energy model."""
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1)

    loss = None
    if labels_energy is not None:
        loss_fct = nn.L1Loss()
        loss = loss_fct(energy, labels_energy)

    return loss, energy

InitialStructure2RelaxedStructureAndEnergy

Bases: TransformerPreTrainedModel

Initial structure to relaxed structure and energy prediction model.

Transformer with an coordinate and energy head on top for relaxed structure and energy prediction.

Source code in atomgen/models/tokengt.py
class InitialStructure2RelaxedStructureAndEnergy(TransformerPreTrainedModel):
    """Initial structure to relaxed structure and energy prediction model.

    Transformer with an coordinate and energy head on top for
    relaxed structure and energy prediction.
    """

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.coords_head = nn.Linear(config.dim, 3)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        labels_coords: Optional[torch.Tensor] = None,
        labels_energy: Optional[torch.Tensor] = None,
        fixed: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward function call.

        Initial structure to relaxed structure and energy model.
        """
        hidden_states = self.encoder(input_ids, coords, attention_mask)
        coords_pred: torch.Tensor = self.coords_head(
            hidden_states[:, 1 : input_ids.size(1)]
        )
        energy: torch.Tensor = self.energy_head(
            self.energy_norm(hidden_states[:, 0])
        ).squeeze(-1)

        loss_coords = torch.tensor(0.0, device=input_ids.device)
        if labels_coords is not None:
            labels_coords = labels_coords.to(coords_pred.device)
            loss_fct = nn.L1Loss()
            loss_coords = loss_fct(coords_pred, labels_coords)

        loss_energy = torch.tensor(0.0, device=input_ids.device)
        if labels_energy is not None:
            loss_fct = nn.L1Loss()
            loss_energy = loss_fct(energy, labels_energy)

        loss = loss_coords + loss_energy

        return loss, (coords_pred, energy)

forward

forward(
    input_ids,
    coords,
    labels_coords=None,
    labels_energy=None,
    fixed=None,
    attention_mask=None,
)

Forward function call.

Initial structure to relaxed structure and energy model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    labels_coords: Optional[torch.Tensor] = None,
    labels_energy: Optional[torch.Tensor] = None,
    fixed: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """Forward function call.

    Initial structure to relaxed structure and energy model.
    """
    hidden_states = self.encoder(input_ids, coords, attention_mask)
    coords_pred: torch.Tensor = self.coords_head(
        hidden_states[:, 1 : input_ids.size(1)]
    )
    energy: torch.Tensor = self.energy_head(
        self.energy_norm(hidden_states[:, 0])
    ).squeeze(-1)

    loss_coords = torch.tensor(0.0, device=input_ids.device)
    if labels_coords is not None:
        labels_coords = labels_coords.to(coords_pred.device)
        loss_fct = nn.L1Loss()
        loss_coords = loss_fct(coords_pred, labels_coords)

    loss_energy = torch.tensor(0.0, device=input_ids.device)
    if labels_energy is not None:
        loss_fct = nn.L1Loss()
        loss_energy = loss_fct(energy, labels_energy)

    loss = loss_coords + loss_energy

    return loss, (coords_pred, energy)

Structure2EnergyAndForces

Bases: TransformerPreTrainedModel

Structure to energy and forces prediction model.

Transformer with an energy and forces head on top for energy and forces prediction.

Source code in atomgen/models/tokengt.py
class Structure2EnergyAndForces(TransformerPreTrainedModel):
    """Structure to energy and forces prediction model.

    Transformer with an energy and forces head on top for energy and forces prediction.
    """

    def __init__(self, config: TransformerConfig):
        super().__init__(config)  # type: ignore[no-untyped-call]
        self.config = config
        self.encoder = TransformerEncoder(config)
        self.force_norm = nn.LayerNorm(config.dim)
        self.force_head = nn.Linear(config.dim, 3)
        self.energy_norm = nn.LayerNorm(config.dim)
        self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        coords: torch.Tensor,
        forces: Optional[torch.Tensor] = None,
        total_energy: Optional[torch.Tensor] = None,
        formation_energy: Optional[torch.Tensor] = None,
        has_formation_energy: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        node_pe: Optional[torch.Tensor] = None,
        edge_pe: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
        """Forward function call for the structure to energy and forces model."""
        hidden_states = self.encoder(
            input_ids, coords, attention_mask, node_pe, edge_pe
        )

        formation_energy_pred: torch.Tensor = self.formation_energy_head(
            self.energy_norm(hidden_states[:, 0])
        ).squeeze(-1)

        loss_formation_energy = torch.Tensor(0.0, device=input_ids.device)
        if formation_energy is not None:
            loss_fct = nn.L1Loss()
            loss_formation_energy = loss_fct(
                formation_energy_pred[has_formation_energy],
                formation_energy[has_formation_energy],
            )

        forces_pred: torch.Tensor = self.force_head(
            self.force_norm(hidden_states[:, 1 : input_ids.size(1)])
        )
        loss_forces = torch.Tensor(0.0, device=input_ids.device)
        if forces is not None and attention_mask is not None:
            loss_fct = nn.L1Loss()
            loss_forces = loss_fct(
                forces_pred[attention_mask[:, 1 : input_ids.size(1)].bool()],
                forces[attention_mask[:, 1 : input_ids.size(1)].bool()],
            )

        loss = loss_formation_energy + loss_forces

        return loss, (
            formation_energy_pred,
            forces_pred,
            attention_mask.bool() if attention_mask is not None else attention_mask,
        )

forward

forward(
    input_ids,
    coords,
    forces=None,
    total_energy=None,
    formation_energy=None,
    has_formation_energy=None,
    attention_mask=None,
    node_pe=None,
    edge_pe=None,
)

Forward function call for the structure to energy and forces model.

Source code in atomgen/models/tokengt.py
def forward(
    self,
    input_ids: torch.Tensor,
    coords: torch.Tensor,
    forces: Optional[torch.Tensor] = None,
    total_energy: Optional[torch.Tensor] = None,
    formation_energy: Optional[torch.Tensor] = None,
    has_formation_energy: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    node_pe: Optional[torch.Tensor] = None,
    edge_pe: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
    """Forward function call for the structure to energy and forces model."""
    hidden_states = self.encoder(
        input_ids, coords, attention_mask, node_pe, edge_pe
    )

    formation_energy_pred: torch.Tensor = self.formation_energy_head(
        self.energy_norm(hidden_states[:, 0])
    ).squeeze(-1)

    loss_formation_energy = torch.Tensor(0.0, device=input_ids.device)
    if formation_energy is not None:
        loss_fct = nn.L1Loss()
        loss_formation_energy = loss_fct(
            formation_energy_pred[has_formation_energy],
            formation_energy[has_formation_energy],
        )

    forces_pred: torch.Tensor = self.force_head(
        self.force_norm(hidden_states[:, 1 : input_ids.size(1)])
    )
    loss_forces = torch.Tensor(0.0, device=input_ids.device)
    if forces is not None and attention_mask is not None:
        loss_fct = nn.L1Loss()
        loss_forces = loss_fct(
            forces_pred[attention_mask[:, 1 : input_ids.size(1)].bool()],
            forces[attention_mask[:, 1 : input_ids.size(1)].bool()],
        )

    loss = loss_formation_energy + loss_forces

    return loss, (
        formation_energy_pred,
        forces_pred,
        attention_mask.bool() if attention_mask is not None else attention_mask,
    )