Skip to content

llmcompressor.transformers.finetune.data.base

TextGenerationDataset

Bases: RegistryMixin

Base class for text datasets. Applies the following transformations to a dataset in order to prepare the dataset to be loaded by a dataloader

  1. Load dataset from huggingface or local cache
  2. Preprocess dataset according to preprocess function or chat/dataset template
  3. Tokenize dataset using model tokenizer/processor
  4. Apply post processing such as grouping text and/or adding labels for finetuning

Parameters:

Name Type Description Default
dataset_args DatasetArguments

configuration settings for dataset loading

required
split str

split from dataset to load, for instance test or train[:5%]

required
processor Processor

processor or tokenizer to use on dataset

required
Source code in llmcompressor/transformers/finetune/data/base.py
class TextGenerationDataset(RegistryMixin):
    """
    Base class for text datasets. Applies the following transformations to a dataset
    in order to prepare the dataset to be loaded by a dataloader

    1. Load dataset from huggingface or local cache
    2. Preprocess dataset according to preprocess function or chat/dataset template
    3. Tokenize dataset using model tokenizer/processor
    4. Apply post processing such as grouping text and/or adding labels for finetuning

    :param dataset_args: configuration settings for dataset loading
    :param split: split from dataset to load, for instance `test` or `train[:5%]`
    :param processor: processor or tokenizer to use on dataset
    """

    # used to mask out the prompt so prompt tokens do not contribute to training loss
    PROMPT_KEY = "prompt"

    def __init__(
        self,
        dataset_args: DatasetArguments,
        split: str,
        processor: Processor,
    ):
        self.dataset_args = dataset_args
        self.split = split
        self.processor = processor

        # get tokenizer
        self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

        if self.tokenizer is not None:
            # fill in pad token
            if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # configure sequence length
            max_seq_length = dataset_args.max_seq_length
            if dataset_args.max_seq_length > self.tokenizer.model_max_length:
                logger.warning(
                    f"The max_seq_length passed ({max_seq_length}) is larger than "
                    f"maximum length for model ({self.tokenizer.model_max_length}). "
                    f"Using max_seq_length={self.tokenizer.model_max_length}."
                )
            self.max_seq_length = min(
                dataset_args.max_seq_length, self.tokenizer.model_max_length
            )

            # configure padding
            self.padding = (
                False
                if self.dataset_args.concatenate_data
                else "max_length"
                if self.dataset_args.pad_to_max_length
                else False
            )

        else:
            self.max_seq_length = None
            self.padding = False

    def __call__(self, add_labels: bool = True) -> DatasetType:
        dataset = self.dataset_args.dataset

        if isinstance(dataset, str):
            # load dataset: load from huggingface or disk
            dataset = self.load_dataset()
        logger.debug(f"Raw dataset: {get_columns(dataset)}")

        if self.preprocess is not None:
            # preprocess: apply template or preprocessing function
            dataset = self.map(
                dataset,
                self.preprocess,
                batched=False,
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Preprocessing",
            )
            logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")

        # rename and remove columns match processor kwargs
        dataset = self.rename_columns(dataset)
        logger.debug(f"Dataset after column renaming: {get_columns(dataset)}")

        # use processor.model_input_names to determine if the ds is already tokenized
        model_input_names = getattr(self.processor, "model_input_names", ["input_ids"])
        if not any(col_name in model_input_names for col_name in get_columns(dataset)):
            # tokenize/ process
            dataset = self.filter_tokenizer_args(dataset)
            logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}")
            dataset = self.map(
                dataset,
                self.tokenize,
                batched=False,  # batching is not well supported for vision processors
                keep_in_memory=True,  # bug occurs when not batched and not in memory,
                # subsequent ds.map calls are always batched,
                # regardless of `batched` argument
                remove_columns=get_columns(dataset),  # assumes that input names
                # and output names are disjoint
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Tokenizing",
            )
            logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")

        if self.dataset_args.concatenate_data:
            # postprocess: group text
            dataset = self.map(
                dataset,
                self.group_text,
                batched=True,
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Concatenating data",
            )
            logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")

        if add_labels:
            # postprocess: add labels
            dataset = self.map(
                dataset,
                self.add_labels,
                batched=False,  # not compatible with batching, need row lengths
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Adding labels",
            )
            logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")

        elif self.PROMPT_KEY in get_columns(dataset):
            dataset = dataset.remove_columns(self.PROMPT_KEY)
            logger.debug("Removed prompt key")

        logger.debug(f"Model kwargs after postprocessing: {get_columns(dataset)}")
        return dataset

    def load_dataset(self):
        """
        Load the raw dataset from Hugging Face, using cached copy if available

        :param cache_dir: disk location to search for cached dataset
        :return: the requested dataset
        """
        if self.dataset_args.dataset_path is not None:
            if self.dataset_args.dvc_data_repository is not None:
                self.dataset_args.raw_kwargs["storage_options"] = {
                    "url": self.dataset_args.dvc_data_repository
                }
                self.dataset_args.raw_kwargs["data_files"] = (
                    self.dataset_args.dataset_path
                )
            else:
                self.dataset_args.raw_kwargs["data_files"] = (
                    get_custom_datasets_from_path(
                        self.dataset_args.dataset_path,
                        self.dataset_args.dataset
                        if hasattr(self.dataset_args, "dataset")
                        else self.dataset_args.dataset_name,
                    )
                )

        logger.debug(f"Loading dataset {self.dataset_args.dataset}")
        return get_raw_dataset(
            self.dataset_args,
            None,
            split=self.split,
            streaming=self.dataset_args.streaming,
            **self.dataset_args.raw_kwargs,
        )

    @cached_property
    def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
        """
        The function must return keys which correspond to processor/tokenizer kwargs,
        optionally including PROMPT_KEY
        """
        preprocessing_func = self.dataset_args.preprocessing_func

        if callable(preprocessing_func):
            return preprocessing_func

        if isinstance(preprocessing_func, str):
            if ":" in preprocessing_func:
                # load func_name from "/path/to/file.py:func_name"
                return import_from_path(preprocessing_func)
            else:
                # load from the registry
                return PreprocessingFunctionRegistry.get_value_from_registry(
                    name=preprocessing_func
                )

        return self.dataset_template

    @property
    def dataset_template(self) -> Union[Callable[[Any], Any], None]:
        return None

    def rename_columns(self, dataset: DatasetType) -> DatasetType:
        # rename columns to match processor/tokenizer kwargs
        column_names = get_columns(dataset)
        if self.dataset_args.text_column in column_names and "text" not in column_names:
            logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`")
            dataset = dataset.rename_column(self.dataset_args.text_column, "text")

        return dataset

    def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
        # assumes that inputs are not passed via self.processor.__call__ args and kwargs
        signature = inspect.signature(self.processor.__call__)
        tokenizer_args = set(
            key
            for key, param in signature.parameters.items()
            if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD)
        )
        logger.debug(
            f"Found processor args `{tokenizer_args}`. Removing all other columns"
        )

        column_names = get_columns(dataset)
        return dataset.remove_columns(
            list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
        )

    def tokenize(self, data: LazyRow) -> Dict[str, Any]:
        # separate prompt
        prompt = data.pop(self.PROMPT_KEY, None)

        # tokenize
        data = self.processor(
            **data,
            padding=self.padding,
            max_length=self.max_seq_length,
            truncation=True,
        )

        # store unpadded prompt so we can mask out correct number of elements in labels
        if prompt is not None:
            data[self.PROMPT_KEY] = self.processor(
                text=prompt,
                max_length=self.max_seq_length,
                truncation=True,
            )["input_ids"]

        return data

    def group_text(self, data: LazyRow) -> Dict[str, Any]:
        concatenated_data = {k: sum(data[k], []) for k in data.keys()}
        total_length = len(concatenated_data[list(data.keys())[0]])
        total_length = (total_length // self.max_seq_length) * self.max_seq_length
        result = {
            k: [
                t[i : i + self.max_seq_length]
                for i in range(0, total_length, self.max_seq_length)
            ]
            for k, t in concatenated_data.items()
        }
        return result

    def add_labels(self, data: LazyRow) -> LazyRow:
        if "pixel_values" in data:
            raise NotImplementedError(
                "Label masking for vision datasets has not been implemented yet"
            )

        # if the dataset uses prompts, mask them out so they don't contribute
        # to the loss calculation
        prompt_len = 0
        if self.PROMPT_KEY in data:
            prompt_len = len(data[self.PROMPT_KEY])
        data["labels"] = data["input_ids"].copy()
        data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

        # mask out padding in the labels as well
        padding = len(data["attention_mask"]) - sum(data["attention_mask"])
        if padding > 0:
            data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding
        return data

    def map(
        self,
        dataset: Union[Dataset, IterableDataset],
        function: Callable[[Any], Any],
        **kwargs,
    ) -> Union[Dataset, IterableDataset]:
        """
        Wrapper function around Dataset.map and IterableDataset.map.

        If the dataset is streaming (in the case of IterableDataset), non-applicable
        arguments are ignored and the dataset features are resolved
        """
        if isinstance(dataset, IterableDataset):
            # remove arguments that don't apply to streaming
            kwargs.pop("num_proc", None)
            kwargs.pop("load_from_cache_file", None)
            kwargs.pop("desc", None)
            kwargs.pop("keep_in_memory", None)

        dataset = dataset.map(function, **kwargs)

        if isinstance(dataset, IterableDataset):
            dataset = dataset._resolve_features()

        return dataset

preprocess cached property

The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY

load_dataset()

Load the raw dataset from Hugging Face, using cached copy if available

Parameters:

Name Type Description Default
cache_dir

disk location to search for cached dataset

required

Returns:

Type Description

the requested dataset

Source code in llmcompressor/transformers/finetune/data/base.py
def load_dataset(self):
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :return: the requested dataset
    """
    if self.dataset_args.dataset_path is not None:
        if self.dataset_args.dvc_data_repository is not None:
            self.dataset_args.raw_kwargs["storage_options"] = {
                "url": self.dataset_args.dvc_data_repository
            }
            self.dataset_args.raw_kwargs["data_files"] = (
                self.dataset_args.dataset_path
            )
        else:
            self.dataset_args.raw_kwargs["data_files"] = (
                get_custom_datasets_from_path(
                    self.dataset_args.dataset_path,
                    self.dataset_args.dataset
                    if hasattr(self.dataset_args, "dataset")
                    else self.dataset_args.dataset_name,
                )
            )

    logger.debug(f"Loading dataset {self.dataset_args.dataset}")
    return get_raw_dataset(
        self.dataset_args,
        None,
        split=self.split,
        streaming=self.dataset_args.streaming,
        **self.dataset_args.raw_kwargs,
    )

map(dataset, function, **kwargs)

Wrapper function around Dataset.map and IterableDataset.map.

If the dataset is streaming (in the case of IterableDataset), non-applicable arguments are ignored and the dataset features are resolved

Source code in llmcompressor/transformers/finetune/data/base.py
def map(
    self,
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs,
) -> Union[Dataset, IterableDataset]:
    """
    Wrapper function around Dataset.map and IterableDataset.map.

    If the dataset is streaming (in the case of IterableDataset), non-applicable
    arguments are ignored and the dataset features are resolved
    """
    if isinstance(dataset, IterableDataset):
        # remove arguments that don't apply to streaming
        kwargs.pop("num_proc", None)
        kwargs.pop("load_from_cache_file", None)
        kwargs.pop("desc", None)
        kwargs.pop("keep_in_memory", None)

    dataset = dataset.map(function, **kwargs)

    if isinstance(dataset, IterableDataset):
        dataset = dataset._resolve_features()

    return dataset