PyTorch implementation of the LM-Supervised Retriever Loss.
Given input context x and ground truth continuation y, computes KL divergence
between retrieval likelihood P_R(d|x) and language model likelihood Q_LM(d|x,y),
where d is the retrieved document.
Shi, Weijia, et al. "Replug: Retrieval-augmented black-box language models."
classLSRLoss(nn.Module):"""PyTorch implementation of the LM-Supervised Retriever Loss. Given input context x and ground truth continuation y, computes KL divergence between retrieval likelihood P_R(d|x) and language model likelihood Q_LM(d|x,y), where d is the retrieved document. Source: Shi, Weijia, et al. "Replug: Retrieval-augmented black-box language models." arXiv preprint arXiv:2301.12652 (2023). Arxiv: https://arxiv.org/pdf/2301.12652 """def__init__(self,reduction:ReductionMode=ReductionMode.MEAN):# This line is critical - it initializes all the Module machinerysuper(LSRLoss,self).__init__()ifreductionnotinReductionMode.members_list():msg=(f"Invalid reduction {reduction}. "f"Valid reductions are: {', '.join(ReductionMode.members_list())}")raiseInvalidReductionParam(msg)self.reduction=reductiondefforward(self,retrieval_scores:torch.Tensor,lm_scores:torch.Tensor)->torch.Tensor:retrieval_log_probs=F.log_softmax(retrieval_scores,dim=1)lm_probs=F.softmax(lm_scores,dim=1)kl_div=F.kl_div(retrieval_log_probs,lm_probs,reduction="none").sum(dim=-1)matchself.reduction:caseReductionMode.MEAN:returnkl_div.mean()caseReductionMode.SUM:returnkl_div.sum()case_:# pragma: no coverassert_never(self.reduction)# pragma: no cover