[docs]classSslTensorDataset(TensorDataset):def__init__(self,data:torch.Tensor,targets:torch.Tensor|None=None,transform:Callable|None=None,target_transform:Callable|None=None,)->None:asserttargetsisNone,"SslTensorDataset targets must be None"super().__init__(data,targets,transform,target_transform)def__getitem__(self,index:int)->tuple[torch.Tensor,torch.Tensor]:data=self.data[index]assertself.target_transformisnotNone,"Target transform cannot be None."ifself.transformisnotNone:data=self.transform(data)# Perform transform on input to yield target during data loading# More memory efficient than pre-computing transforms which requires# storing multiple copies of each sampletransformed_data=self.target_transform(data)returndata,transformed_data
[docs]def__init__(self,data:dict[str,list[torch.Tensor]],targets:torch.Tensor)->None:""" A torch dataset that supports a dictionary of input data rather than just a ``torch.Tensor``. This kind of dataset is useful when dealing with non-trivial inputs to a model. For example, a language model may require token ids AND attention masks. This dataset supports that functionality. Args: data (dict[str, list[torch.Tensor]]): A set of data for model training/input in the form of a dictionary of tensors. targets (torch.Tensor): Target tensor. """self.data=dataself.targets=targets
[docs]def__init__(self,data:torch.Tensor,targets:torch.Tensor,):""" A dataset for synthetically created data strictly in the form of pytorch tensors. Generally, this dataset is just used for tests. Args: data (torch.Tensor): Data tensor with first dimension corresponding to the number of datapoints targets (torch.Tensor): Target tensor with first dimension corresponding to the number of datapoints """assertdata.shape[0]==targets.shape[0]self.data=dataself.targets=targets
[docs]defselect_by_indices(dataset:D,selected_indices:torch.Tensor)->D:""" This function is used to extract a subset of a dataset sliced by the indices in the tensor ``selected_indices``. The dataset returned should be of the same type as the input but with only data associated with the given indices. Args: dataset (D): Dataset to be "subsampled" using the provided indices. selected_indices (torch.Tensor): Indices within the datasets data and targets (if they exist) to select Raises: TypeError: Will throw an error if the dataset provided is not supported Returns: D: Dataset with only the data associated with the provided indices. Must be of a supported type. """ifisinstance(dataset,TensorDataset):modified_dataset=copy.deepcopy(dataset)modified_dataset.data=dataset.data[selected_indices]ifdataset.targetsisnotNone:modified_dataset.targets=dataset.targets[selected_indices]# cast being used here until the mypy bug mentioned in https://github.com/python/mypy/issues/12800 and the# duplicate ticket https://github.com/python/mypy/issues/10817 are fixedreturncast(D,modified_dataset)elifisinstance(dataset,DictionaryDataset):new_targets=dataset.targets[selected_indices]new_data:dict[str,list[torch.Tensor]]={}forkey,valindataset.data.items():# Since val is a list of tensors, we can't directly index into it# using selected_indices.new_data[key]=[val[i]foriinselected_indices]# cast being used here until the mypy bug mentioned in https://github.com/python/mypy/issues/12800 and the# duplicate ticket https://github.com/python/mypy/issues/10817 are fixedreturncast(D,DictionaryDataset(new_data,new_targets))else:raiseTypeError("Dataset type is not supported by this function.")