Skip to content

Generators

Base Generator

BaseGenerator

Bases: BaseModel, ABC

Base Generator Class.

Source code in src/fed_rag/base/generator.py
class BaseGenerator(BaseModel, ABC):
    """Base Generator Class."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @abstractmethod
    def generate(self, query: str, context: str, **kwargs: dict) -> str:
        """Generate an output from a given query and context."""

    @property
    @abstractmethod
    def model(self) -> torch.nn.Module:
        """Model associated with this generator."""

    @property
    @abstractmethod
    def tokenizer(self) -> BaseTokenizer:
        """Tokenizer associated with this generator."""

    @abstractmethod
    def compute_target_sequence_proba(
        self, prompt: str, target: str
    ) -> torch.Tensor:
        """Compute P(target | prompt).

        NOTE: this is used in LM Supervised Retriever fine-tuning.
        """

model abstractmethod property

model

Model associated with this generator.

tokenizer abstractmethod property

tokenizer

Tokenizer associated with this generator.

generate abstractmethod

generate(query, context, **kwargs)

Generate an output from a given query and context.

Source code in src/fed_rag/base/generator.py
@abstractmethod
def generate(self, query: str, context: str, **kwargs: dict) -> str:
    """Generate an output from a given query and context."""

compute_target_sequence_proba abstractmethod

compute_target_sequence_proba(prompt, target)

Compute P(target | prompt).

NOTE: this is used in LM Supervised Retriever fine-tuning.

Source code in src/fed_rag/base/generator.py
@abstractmethod
def compute_target_sequence_proba(
    self, prompt: str, target: str
) -> torch.Tensor:
    """Compute P(target | prompt).

    NOTE: this is used in LM Supervised Retriever fine-tuning.
    """