atomgen.data.data_collator module#

Data collator for atom modeling.

class DataCollatorForAtomModeling(tokenizer, mam=True, causal=False, coords_perturb=False, return_lap_pe=False, return_edge_indices=False, k=16, max_radius=12.0, max_neighbors=20, pad=True, pad_to_multiple_of=None, return_tensors='pt')[source]#

Bases: DataCollatorMixin

Data collator used for atom modeling.

Args:

tokenizer: The tokenizer used for encoding the data. mam: Whether to use masked atom modeling. causal: Whether to use causal modeling. coords_perturb: Whether to perturb the coordinates. return_lap_pe: Whether to return Laplacian positional encoding. return_edge_indices: Whether to return edge indices. k: Number of eigenvectors to use for Laplacian positional encoding. max_radius: Maximum distance for edge cutoff. max_neighbors: Maximum number of neighbors. pad: Whether to pad the input data, if False, flatten all samples and

concatenates with batch indicator.

pad_to_multiple_of: Pad to multiple of this value. return_tensors: Return tensors as “pt” or “tf”.

Returns:

Dict[str, Any]

Return type:

Dictionary of batched data.

causal: bool = False#
coords_perturb: bool = False#
flatten_batch(examples)[source]#

Flatten all lists in examples and concatenate with batch indicator.

Return type:

Dict[str, Any]

k: int = 16#
mam: bool = True#
max_neighbors: int = 20#
max_radius: float = 12.0#
pad: bool = True#
pad_to_multiple_of: Optional[int] = None#
return_edge_indices: bool = False#
return_lap_pe: bool = False#
return_tensors: str = 'pt'#
tokenizer: PreTrainedTokenizer#
torch_call(examples)[source]#

Collate a batch of samples.

Args:

examples: List of samples to collate.

Return type:

Dict[str, Any]

Returns:

Dict[str, Any]: Dictionary of batched data.

torch_compute_edges(coords, attention_mask)[source]#

Compute edge indices and distances for each batch.

Return type:

Any

torch_compute_lap_pe(coords, attention_mask)[source]#

Compute Laplacian positional encoding for each batch.

Return type:

Any

torch_mask_tokens(inputs, t, special_tokens_mask=None)[source]#

Prepare masked tokens inputs/labels for masked atom modeling.

Return type:

Tuple[Any, Any]

torch_perturb_coords(inputs, fixed, t)[source]#

Prepare perturbed coords inputs/labels for coordinate denoising.

Return type:

Tuple[Any, Any]