atomgen.data.data_collator module#
Data collator for atom modeling.
- class DataCollatorForAtomModeling(tokenizer, mam=True, causal=False, coords_perturb=False, return_lap_pe=False, return_edge_indices=False, k=16, max_radius=12.0, max_neighbors=20, pad=True, pad_to_multiple_of=None, return_tensors='pt')[source]#
Bases:
DataCollatorMixin
Data collator used for atom modeling.
- Args:
tokenizer: The tokenizer used for encoding the data. mam: Whether to use masked atom modeling. causal: Whether to use causal modeling. coords_perturb: Whether to perturb the coordinates. return_lap_pe: Whether to return Laplacian positional encoding. return_edge_indices: Whether to return edge indices. k: Number of eigenvectors to use for Laplacian positional encoding. max_radius: Maximum distance for edge cutoff. max_neighbors: Maximum number of neighbors. pad: Whether to pad the input data, if False, flatten all samples and
concatenates with batch indicator.
pad_to_multiple_of: Pad to multiple of this value. return_tensors: Return tensors as “pt” or “tf”.
- Returns:
Dict[str, Any]
- Return type:
Dictionary of batched data.
- flatten_batch(examples)[source]#
Flatten all lists in examples and concatenate with batch indicator.
-
tokenizer:
PreTrainedTokenizer
#
- torch_call(examples)[source]#
Collate a batch of samples.
- Args:
examples: List of samples to collate.
- torch_compute_edges(coords, attention_mask)[source]#
Compute edge indices and distances for each batch.
- Return type:
- torch_compute_lap_pe(coords, attention_mask)[source]#
Compute Laplacian positional encoding for each batch.
- Return type: