API Reference¶
Data module for the AtomGen library.
This module contains the data classes and functions for pre-processing and collating data for training/inference.
data_collator
¶
Data collator for atom modeling.
DataCollatorForAtomModeling
dataclass
¶
Bases: DataCollatorMixin
Data collator used for atom modeling tasks in molecular representations.
This collator prepares input data for various atom modeling tasks, including masked atom modeling (MAM), autoregressive modeling, and coordinate perturbation. It supports both padding and flattening of input data.
Args: tokenizer (PreTrainedTokenizer): Tokenizer used for encoding the data. mam (Union[bool, float]): If True, uses original masked atom modeling. If float, masks a constant fraction of atoms/tokens. autoregressive (bool): Whether to use autoregressive modeling. coords_perturb (float): Standard deviation for coordinate perturbation. return_lap_pe (bool): Whether to return Laplacian positional encoding. return_edge_indices (bool): Whether to return edge indices. k (int): Number of eigenvectors to use for Laplacian positional encoding. max_radius (float): Maximum distance for edge cutoff. max_neighbors (int): Maximum number of neighbors. pad (bool): Whether to pad the input data. pad_to_multiple_of (Optional[int]): Pad to multiple of this value. return_tensors (str): Return tensors as "pt" or "tf".
Attributes:
Name | Type | Description |
---|---|---|
tokenizer (PreTrainedTokenizer) |
The tokenizer used for encoding.
|
mam (Union[bool, float]): The masked atom modeling setting. autoregressive (bool): The autoregressive modeling setting. coords_perturb (float): The coordinate perturbation standard deviation. return_lap_pe (bool): The Laplacian positional encoding setting. return_edge_indices (bool): The edge indices return setting. k (int): The number of eigenvectors for Laplacian PE. max_radius (float): The maximum distance for edge cutoff. max_neighbors (int): The maximum number of neighbors. pad (bool): The padding setting. pad_to_multiple_of (Optional[int]): The multiple for padding. return_tensors (str): The tensor return format. |
Source code in atomgen/data/data_collator.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 |
|
torch_call
¶
Collate a batch of samples.
Args: examples: List of samples to collate.
Returns:
Type | Description |
---|---|
Dict[str, Any]: Dictionary of batched data.
|
|
Source code in atomgen/data/data_collator.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
|
torch_mask_tokens
¶
Prepare masked tokens inputs/labels for masked atom modeling.
Source code in atomgen/data/data_collator.py
apply_mask
¶
Apply the mask to the input tokens.
Source code in atomgen/data/data_collator.py
torch_perturb_coords
¶
Prepare perturbed coords inputs/labels for coordinate denoising.
Source code in atomgen/data/data_collator.py
flatten_batch
¶
Flatten all lists in examples and concatenate with batch indicator.
Source code in atomgen/data/data_collator.py
torch_compute_edges
¶
Compute edge indices and distances for each batch.
Source code in atomgen/data/data_collator.py
torch_compute_lap_pe
¶
Compute Laplacian positional encoding for each batch.
Source code in atomgen/data/data_collator.py
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 |
|
tokenizer
¶
tokenization module for atom modeling.
AtomTokenizer
¶
Bases: PreTrainedTokenizer
Tokenizer for atomistic data.
Args: vocab_file: The path to the vocabulary file. pad_token: The padding token. mask_token: The mask token. bos_token: The beginning of system token. eos_token: The end of system token. cls_token: The classification token. kwargs: Additional keyword arguments.
Source code in atomgen/data/tokenizer.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
|
load_vocab
staticmethod
¶
Load the vocabulary from a json file.
Source code in atomgen/data/tokenizer.py
get_vocab
¶
get_vocab_size
¶
convert_tokens_to_string
¶
pad
¶
pad(
encoded_inputs,
padding=True,
max_length=None,
pad_to_multiple_of=None,
return_attention_mask=None,
return_tensors=None,
verbose=True,
)
Pad the input data.
Source code in atomgen/data/tokenizer.py
pad_coords
¶
Pad the coordinates to the same length.
Source code in atomgen/data/tokenizer.py
pad_forces
¶
Pad the forces to the same length.
Source code in atomgen/data/tokenizer.py
pad_fixed
¶
Pad the fixed mask to the same length.
Source code in atomgen/data/tokenizer.py
save_vocabulary
¶
Save the vocabulary to a json file.
Source code in atomgen/data/tokenizer.py
from_pretrained
classmethod
¶
build_inputs_with_special_tokens
¶
Build the input with special tokens.
Source code in atomgen/data/tokenizer.py
utils
¶
Utilities for data processing and evaluation.
compute_metrics_smp
¶
Compute MAE for 20 regression labels for the SMP task.
Source code in atomgen/data/utils.py
compute_metrics_ppi
¶
Compute AUROC for the PIP task.
Source code in atomgen/data/utils.py
compute_metrics_res
¶
Compute accuracy for the RES task.
Source code in atomgen/data/utils.py
compute_metrics_msp
¶
Compute AUROC for the MSP task.
Source code in atomgen/data/utils.py
compute_metrics_lba
¶
Compute RMSE for the LBA task.
Source code in atomgen/data/utils.py
compute_metrics_lep
¶
Compute AUROC for the LEP task.
Source code in atomgen/data/utils.py
compute_metrics_psr
¶
Compute global spearman correlation for the PSR task.
Source code in atomgen/data/utils.py
compute_metrics_rsr
¶
Compute global spearman correlation for the RSR task.
Source code in atomgen/data/utils.py
Models module for the AtomGen library.
This module contains the model classes and functions for training and inference.
configuration_atomformer
¶
Configuration class for Atomformer.
AtomformerConfig
¶
Bases: PretrainedConfig
Configuration of a :class:~transform:class:
~transformers.AtomformerModel`.
It is used to instantiate an Atomformer model according to the specified arguments.
Source code in atomgen/models/configuration_atomformer.py
modeling_atomformer
¶
Implementation of the Atomformer model.
GaussianLayer
¶
Bases: Module
Gaussian pairwise positional embedding layer.
This layer computes the Gaussian positional embeddings for the pairwise distances between atoms in a molecule.
Taken from: https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward pass to compute the Gaussian pos. embeddings.
Source code in atomgen/models/modeling_atomformer.py
ParallelBlock
¶
Bases: Module
Parallel transformer block (MLP & Attention in parallel).
Based on: 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
Adapted from TIMM implementation.
Source code in atomgen/models/modeling_atomformer.py
2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 |
|
forward
¶
Forward pass for the parallel block.
Source code in atomgen/models/modeling_atomformer.py
AtomformerEncoder
¶
Bases: Module
Atomformer encoder.
The transformer encoder consists of a series of parallel blocks, each containing a multi-head self-attention mechanism and a feed-forward network.
Source code in atomgen/models/modeling_atomformer.py
2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 |
|
forward
¶
Forward pass for the transformer encoder.
Source code in atomgen/models/modeling_atomformer.py
2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 |
|
AtomformerPreTrainedModel
¶
Bases: PreTrainedModel
Base class for all transformer models.
Source code in atomgen/models/modeling_atomformer.py
AtomformerModel
¶
Bases: AtomformerPreTrainedModel
Atomformer model for atom modeling.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call for the transformer model.
Source code in atomgen/models/modeling_atomformer.py
AtomformerForMaskedAM
¶
Bases: AtomformerPreTrainedModel
Atomformer with an atom modeling head on top for masked atom modeling.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call for the masked atom modeling model.
Source code in atomgen/models/modeling_atomformer.py
AtomformerForCoordinateAM
¶
Bases: AtomformerPreTrainedModel
Atomformer with an atom coordinate head on top for coordinate denoising.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call for the coordinate atom modeling model.
Source code in atomgen/models/modeling_atomformer.py
InitialStructure2RelaxedStructure
¶
Bases: AtomformerPreTrainedModel
Atomformer with an coordinate head on top for relaxed structure prediction.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call.
Initial structure to relaxed structure model.
Source code in atomgen/models/modeling_atomformer.py
InitialStructure2RelaxedEnergy
¶
Bases: AtomformerPreTrainedModel
Atomformer with an energy head on top for relaxed energy prediction.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call for the relaxed energy prediction model.
Source code in atomgen/models/modeling_atomformer.py
InitialStructure2RelaxedStructureAndEnergy
¶
Bases: AtomformerPreTrainedModel
Atomformer with an coordinate and energy head.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
forward(
input_ids,
coords,
labels_coords=None,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
)
Forward function call for the relaxed structure and energy model.
Source code in atomgen/models/modeling_atomformer.py
Structure2Energy
¶
Bases: AtomformerPreTrainedModel
Atomformer with an atom modeling head on top for masked atom modeling.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
forward(
input_ids,
coords,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
)
Forward function call for the structure to energy model.
Source code in atomgen/models/modeling_atomformer.py
Structure2Forces
¶
Bases: AtomformerPreTrainedModel
Atomformer with a forces head on top for forces prediction.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
forward(
input_ids,
coords,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
)
Forward function call for the structure to forces model.
Source code in atomgen/models/modeling_atomformer.py
Structure2EnergyAndForces
¶
Bases: AtomformerPreTrainedModel
Atomformer with an energy and forces head for energy and forces prediction.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
forward(
input_ids,
coords,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
)
Forward function call for the structure to energy and forces model.
Source code in atomgen/models/modeling_atomformer.py
Structure2TotalEnergyAndForces
¶
Bases: AtomformerPreTrainedModel
Atomformer with an energy and forces head for energy and forces prediction.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
forward(
input_ids,
coords,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
)
Forward function call for the structure to energy and forces model.
Source code in atomgen/models/modeling_atomformer.py
AtomFormerForSystemClassification
¶
Bases: AtomformerPreTrainedModel
Atomformer with a classification head for system classification.
Source code in atomgen/models/modeling_atomformer.py
forward
¶
Forward function call for the structure to energy and forces model.
Source code in atomgen/models/modeling_atomformer.py
gaussian
¶
Compute the Gaussian distribution probability density.
Taken from: https://https://github.com/microsoft/Graphormer/blob/main/graphormer/models/graphormer_3d.py
Source code in atomgen/models/modeling_atomformer.py
schnet
¶
SchNet model for energy prediction.
SchNetConfig
¶
Bases: PretrainedConfig
Stores the configuration of a :class:~transformers.SchNetModel
.
It is used to instantiate an SchNet model according to the specified arguments, defining the model architecture.
Args:
vocab_size (:obj:int
, optional
, defaults to 122):
The size of the vocabulary, used to define the size
of the output embeddings.
hidden_channels (:obj:`int`, `optional`, defaults to 128):
The hidden size of the model.
model_type = "transformer"
Attributes:
Name | Type | Description |
---|---|---|
vocab_size ( |
obj:`int`):
|
hidden_channels (:obj: num_filters (:obj: num_interactions (:obj: num_gaussians (:obj: cutoff (:obj: interaction_graph (:obj: max_num_neighbors (:obj: readout (:obj: dipole (:obj: mean (:obj: std (:obj: atomref (:obj: mask_token_id (:obj: pad_token_id (:obj: bos_token_id (:obj: eos_token_id (:obj: |
Source code in atomgen/models/schnet.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|
SchNetPreTrainedModel
¶
Bases: PreTrainedModel
A base class for all SchNet models.
An abstract class to handle weights initialization and a simple interface for loading and exporting models.
Source code in atomgen/models/schnet.py
SchNetModel
¶
Bases: SchNetPreTrainedModel
SchNet model for energy prediction.
Args:
config (:class:~transformers.SchNetConfig
):
Configuration class to store the configuration of a model.
Source code in atomgen/models/schnet.py
forward
¶
Forward pass of the SchNet model.
Args:
input_ids (:obj:torch.Tensor
of shape :obj:(batch_size, num_atoms)
):
The input tensor containing the atom indices.
coords (:obj:`torch.Tensor` of shape :obj:`(num_atoms, 3)`):
The input tensor containing the atom coordinates.
batch (:obj:`torch.Tensor` of shape :obj:`(num_atoms)`):
The input tensor containing the batch indices.
labels_energy (:obj:`torch.Tensor`, `optional`):
The input tensor containing the energy labels.
fixed (:obj:`torch.Tensor`, `optional`):
The input tensor containing the fixed mask.
attention_mask (:obj:`torch.Tensor`, `optional`):
The attention mask for the transformer.
Returns:
Type | Description |
---|---|
:obj:`tuple`:
|
A tuple of the loss and the energy prediction. |
Source code in atomgen/models/schnet.py
tokengt
¶
Implementation of the TokenGT model.
ParallelBlock
¶
Bases: Module
Parallel transformer block.
Source code in atomgen/models/tokengt.py
2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 |
|
forward
¶
Forward function call for the parallel transformer block.
Source code in atomgen/models/tokengt.py
TransformerConfig
¶
Bases: PretrainedConfig
Configuration class to store the configuration of a TokenGT model.
Source code in atomgen/models/tokengt.py
TransformerEncoder
¶
Bases: Module
Transformer encoder for atom modeling.
Source code in atomgen/models/tokengt.py
2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 |
|
forward
¶
Forward function call for the transformer encoder.
Source code in atomgen/models/tokengt.py
2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 |
|
TransformerPreTrainedModel
¶
Bases: PreTrainedModel
Base class for all transformer models.
Source code in atomgen/models/tokengt.py
TransformerModel
¶
Bases: TransformerPreTrainedModel
Transformer model for atom modeling.
Source code in atomgen/models/tokengt.py
forward
¶
Forward function call for the transformer model.
Source code in atomgen/models/tokengt.py
TransformerForMaskedAM
¶
Bases: TransformerPreTrainedModel
Transformer with an atom modeling head on top for masked atom modeling.
Source code in atomgen/models/tokengt.py
forward
¶
Forward function call for the masked atom modeling model.
Source code in atomgen/models/tokengt.py
TransformerForCoordinateAM
¶
Bases: TransformerPreTrainedModel
Transformer with an atom coordinate head on top for coordinate denoising.
Source code in atomgen/models/tokengt.py
forward
¶
Forward function call for the coordinate atom modeling model.
Source code in atomgen/models/tokengt.py
InitialStructure2RelaxedStructure
¶
Bases: TransformerPreTrainedModel
Transformer with an coordinate head on top for relaxed structure prediction.
Source code in atomgen/models/tokengt.py
forward
¶
Forward function call.
Initial structure to relaxed structure model.
Source code in atomgen/models/tokengt.py
InitialStructure2RelaxedEnergy
¶
Bases: TransformerPreTrainedModel
Transformer with an energy head on top for relaxed energy prediction.
Source code in atomgen/models/tokengt.py
forward
¶
Forward function call for the initial structure to relaxed energy model.
Source code in atomgen/models/tokengt.py
InitialStructure2RelaxedStructureAndEnergy
¶
Bases: TransformerPreTrainedModel
Initial structure to relaxed structure and energy prediction model.
Transformer with an coordinate and energy head on top for relaxed structure and energy prediction.
Source code in atomgen/models/tokengt.py
forward
¶
forward(
input_ids,
coords,
labels_coords=None,
labels_energy=None,
fixed=None,
attention_mask=None,
)
Forward function call.
Initial structure to relaxed structure and energy model.
Source code in atomgen/models/tokengt.py
Structure2EnergyAndForces
¶
Bases: TransformerPreTrainedModel
Structure to energy and forces prediction model.
Transformer with an energy and forces head on top for energy and forces prediction.
Source code in atomgen/models/tokengt.py
2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 |
|
forward
¶
forward(
input_ids,
coords,
forces=None,
total_energy=None,
formation_energy=None,
has_formation_energy=None,
attention_mask=None,
node_pe=None,
edge_pe=None,
)
Forward function call for the structure to energy and forces model.