Source code for fl4health.feature_alignment.string_columns_transformer

from __future__ import annotations

import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

from fl4health.feature_alignment.constants import TextFeatureTransformer


[docs] class TextMulticolumnTransformer(BaseEstimator, TransformerMixin):
[docs] def __init__(self, transformer: TextFeatureTransformer): """ The purpose of this class is to enable the application of text feature transformers from sklearn to multiple string columns, which is not supported in the first place. Args: transformer (TextFeatureTransformer): Transformer to be applied """ self.transformer = transformer
[docs] def fit(self, X: pd.DataFrame, y: pd.DataFrame | None = None) -> TextMulticolumnTransformer: """ Fit the transformer to the provided dataframe. The dataframe should have multiple string columns The transformer is fit on the appended text from all columns in the ``X`` dataframe Args: X (pd.DataFrame): Columns on which to fit the transformer y (pd.DataFrame | None, optional): Not used. Defaults to None. Returns: TextMulticolumnTransformer: The fit transformer """ joined_X = X.apply(lambda x: " ".join(x), axis=1) self.transformer.fit(joined_X) return self
[docs] def transform(self, X: pd.DataFrame) -> pd.DataFrame: """ Transforms the concatenation of all columns of text in the ``X`` dataframe Args: X (pd.DataFrame): Dataframe of text-based columns to be transformed Returns: pd.DataFrame: Transformed dataframe. """ joined_X = X.apply(lambda x: " ".join(x), axis=1) return self.transformer.transform(joined_X)
[docs] class TextColumnTransformer(BaseEstimator, TransformerMixin):
[docs] def __init__(self, transformer: TextFeatureTransformer): """ The purpose of this class is to enable the application of text feature transformers from sklearn to a single-column pandas dataframe, which is not supported in the first place. Args: transformer (TextFeatureTransformer): Transformer to be applied """ self.transformer = transformer
[docs] def fit(self, X: pd.DataFrame, y: pd.DataFrame | None = None) -> TextColumnTransformer: """ Fit the transformer to the provided dataframe. The dataframe should have a single string column The transformer is fit on the text from the single columns in the ``X`` dataframe Args: X (pd.DataFrame): Column on which to fit the transformer y (pd.DataFrame | None, optional): Not used. Defaults to None. Returns: TextColumnTransformer: The fit transformer """ assert isinstance(X, pd.DataFrame) and X.shape[1] == 1 self.transformer.fit(X[X.columns[0]]) return self
[docs] def transform(self, X: pd.DataFrame) -> pd.DataFrame: """ Transforms the concatenation of a single column of text in the ``X`` dataframe Args: X (pd.DataFrame): Dataframe of text-based column to be transformed Returns: pd.DataFrame: Transformed dataframe. """ assert isinstance(X, pd.DataFrame) and X.shape[1] == 1 return self.transformer.transform(X[X.columns[0]])