atomgen.data.data_collator module#

Data collator for atom modeling.

class DataCollatorForAtomModeling(tokenizer, mam=True, autoregressive=False, coords_perturb=0.0, return_lap_pe=False, return_edge_indices=False, k=16, max_radius=12.0, max_neighbors=20, pad=True, pad_to_multiple_of=None, return_tensors='pt')[source]#

Bases: DataCollatorMixin

Data collator used for atom modeling tasks in molecular representations.

This collator prepares input data for various atom modeling tasks, including masked atom modeling (MAM), autoregressive modeling, and coordinate perturbation. It supports both padding and flattening of input data.

Args:

tokenizer (PreTrainedTokenizer): Tokenizer used for encoding the data. mam (Union[bool, float]): If True, uses original masked atom modeling.

If float, masks a constant fraction of atoms/tokens.

autoregressive (bool): Whether to use autoregressive modeling. coords_perturb (float): Standard deviation for coordinate perturbation. return_lap_pe (bool): Whether to return Laplacian positional encoding. return_edge_indices (bool): Whether to return edge indices. k (int): Number of eigenvectors to use for Laplacian positional encoding. max_radius (float): Maximum distance for edge cutoff. max_neighbors (int): Maximum number of neighbors. pad (bool): Whether to pad the input data. pad_to_multiple_of (Optional[int]): Pad to multiple of this value. return_tensors (str): Return tensors as “pt” or “tf”.

tokenizer(PreTrainedTokenizer)#
Type:

The tokenizer used for encoding.

mam(Union[bool, float])#
Type:

The masked atom modeling setting.

autoregressive(bool)#
Type:

The autoregressive modeling setting.

coords_perturb(float)#
Type:

The coordinate perturbation standard deviation.

return_lap_pe(bool)#
Type:

The Laplacian positional encoding setting.

return_edge_indices(bool)#
Type:

The edge indices return setting.

k(int)#
Type:

The number of eigenvectors for Laplacian PE.

max_radius(float)#
Type:

The maximum distance for edge cutoff.

max_neighbors(int)#
Type:

The maximum number of neighbors.

pad(bool)#
Type:

The padding setting.

pad_to_multiple_of(Optional[int])#
Type:

The multiple for padding.

return_tensors(str)#
Type:

The tensor return format.

apply_mask(inputs, mask, special_tokens_mask)[source]#

Apply the mask to the input tokens.

Return type:

Tuple[Tensor, Tensor]

autoregressive: bool = False#
coords_perturb: float = 0.0#
flatten_batch(examples)[source]#

Flatten all lists in examples and concatenate with batch indicator.

Return type:

Dict[str, Any]

k: int = 16#
mam: Union[bool, float] = True#
max_neighbors: int = 20#
max_radius: float = 12.0#
pad: bool = True#
pad_to_multiple_of: Optional[int] = None#
return_edge_indices: bool = False#
return_lap_pe: bool = False#
return_tensors: str = 'pt'#
tokenizer: PreTrainedTokenizer#
torch_call(examples)[source]#

Collate a batch of samples.

Args:

examples: List of samples to collate.

Return type:

Dict[str, Any]

Returns:

Dict[str, Any]: Dictionary of batched data.

torch_compute_edges(coords, attention_mask)[source]#

Compute edge indices and distances for each batch.

Return type:

Any

torch_compute_lap_pe(coords, attention_mask)[source]#

Compute Laplacian positional encoding for each batch.

Return type:

Any

torch_mask_tokens(inputs, t, special_tokens_mask=None)[source]#

Prepare masked tokens inputs/labels for masked atom modeling.

Return type:

Tuple[Any, Any]

torch_perturb_coords(inputs, fixed, perturb_std)[source]#

Prepare perturbed coords inputs/labels for coordinate denoising.

Return type:

Tuple[Tensor, Tensor]