Skip to content

llmcompressor.modifiers.obcq.sgpt_base

SparsityModifierBase

Bases: Modifier

Abstract base class which implements functionality related to oneshot sparsity. Inheriters must implement calibrate_module and compress_modules

Source code in llmcompressor/modifiers/obcq/sgpt_base.py
class SparsityModifierBase(Modifier):
    """
    Abstract base class which implements functionality related to oneshot sparsity.
    Inheriters must implement `calibrate_module` and `compress_modules`
    """

    # modifier arguments
    sparsity: Optional[Union[float, List[float]]]
    sparsity_profile: Optional[str] = None
    mask_structure: str = "0:0"
    owl_m: Optional[int] = None
    owl_lmbda: Optional[float] = None

    # data pipeline arguments
    sequential_update: Optional[bool] = False  # deprecated
    sequential_targets: Union[str, List[str], None] = None
    targets: Union[str, List[str]] = ["Linear"]
    ignore: List[str] = Field(default_factory=list)

    # private variables
    _prune_n: Optional[int] = PrivateAttr(default=None)
    _prune_m: Optional[int] = PrivateAttr(default=None)
    _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
    _target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
    _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

    @field_validator("sequential_update", mode="before")
    def validate_sequential_update(cls, value: bool) -> bool:
        if not value:
            warnings.warn(
                "`sequential_update=False` is no longer supported, setting "
                "sequential_update=True",
                DeprecationWarning,
            )

        return True

    @field_validator("sparsity_profile", mode="before")
    def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
        if value is None:
            return value

        value = value.lower()

        profile_options = ["owl"]
        if value not in profile_options:
            raise ValueError(f"Please choose profile from {profile_options}")

        return value

    @model_validator(mode="after")
    def validate_model_after(model: "SparsityModifierBase") -> "SparsityModifierBase":
        profile = model.sparsity_profile
        owl_m = model.owl_m
        owl_lmbda = model.owl_lmbda
        mask_structure = model.mask_structure

        has_owl_m = owl_m is not None
        has_owl_lmbda = owl_lmbda is not None
        has_owl = profile == "owl"
        owl_args = (has_owl_m, has_owl_lmbda, has_owl)
        if any(owl_args) and not all(owl_args):
            raise ValueError(
                'Must provide all of `profile="owl"`, `owl_m` and `owl_lmbda` or none'
            )

        model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

        return model

    @abstractmethod
    def calibrate_module(
        self,
        module: torch.nn.Module,
        args: Tuple[torch.Tensor, ...],
        _output: torch.Tensor,
    ):
        raise NotImplementedError()

    @abstractmethod
    def compress_modules(self):
        raise NotImplementedError()

    def on_initialize(self, state: "State", **kwargs) -> bool:
        """
        Initialize and run the OBCQ algorithm on the current state

        :param state: session state storing input model and calibration data
        """
        model: torch.nn.Module = state.model
        dataloader: torch.utils.data.DataLoader = state.data.calib

        # infer module and sequential targets
        self.sequential_targets = self._infer_sequential_targets(model)
        layers = get_layers(self.sequential_targets, model)
        self._target_layers = get_layers(
            self.targets, model
        )  # layers containing targets

        # infer layer sparsities
        if self.sparsity_profile == "owl":
            logger.info(
                "Using OWL to infer target layer-wise sparsities from "
                f"{len(dataloader) if dataloader else 0} calibration samples..."
            )
            self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

        # get layers and validate sparsity
        if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
            self.sparsity
        ):
            raise ValueError(
                f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
                f"sparsities values, but model has {len(layers)} target layers"
            )

        return True

    def on_start(self, state: State, event: Event, **kwargs):
        self.started_ = True

        # register hooks
        for index, (layer_name, layer) in enumerate(self._target_layers.items()):
            if isinstance(self.sparsity, dict):
                layer_sparsity = self.sparsity[layer_name]
            elif isinstance(self.sparsity, list):
                layer_sparsity = self.sparsity[index]
            else:
                layer_sparsity = self.sparsity

            for name, module in get_prunable_layers(layer).items():
                name = f"{layer_name}.{name}"

                if match_targets(name, self.ignore)[0]:
                    continue

                # HACK: previously, embeddings were not quantized because they were not
                # accessible by the layer compressor. For now, we manually ignore it,
                # but in the FUTURE this should be ignored by the user
                if isinstance(module, torch.nn.Embedding):
                    continue

                if name.endswith("lm_head"):
                    logger.warning(
                        "`lm_head` was previously auto-ignored by SparseGPT and Wanda "
                        "modifiers and is not advised. Please add `re:.*lm_head` to "
                        "your ignore list if this was unintentional"
                    )

                self._module_names[module] = name
                self._module_sparsities[module] = layer_sparsity
                self.register_hook(module, self.calibrate_module, "forward")

    def on_event(self, state: State, event: Event, **kwargs):
        if event.type_ == EventType.CALIBRATION_EPOCH_START:
            if not self.started_:
                self.on_start(state, None)

        if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
            self.compress_modules()

        if event.type_ == EventType.CALIBRATION_EPOCH_END:
            self.compress_modules()

            if not self.ended_:
                self.on_end(state, None)

    def on_end(self, state: State, event: Event, **kwargs):
        self.ended_ = True
        self.remove_hooks()

    def _infer_sequential_targets(
        self, model: torch.nn.Module
    ) -> Union[str, List[str]]:
        if self.sequential_targets is None:
            return get_no_split_params(model)
        if isinstance(self.sequential_targets, str):
            return [self.sequential_targets]
        return self.sequential_targets

    def _infer_owl_layer_sparsity(
        self,
        model: torch.nn.Module,
        layers: Dict[str, torch.nn.Module],
        dataloader: torch.utils.data.DataLoader,
    ) -> Dict[str, float]:
        activations = self._get_activations(model, dataloader)

        groups = {}
        for name, layer in layers.items():
            prunable_layers = get_prunable_layers(layer)
            z = [
                m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
                for n, m in prunable_layers.items()
            ]
            groups[name] = torch.cat([item.flatten().cpu() for item in z])

        del activations

        outlier_ratios = {}
        for group in groups:
            threshold = torch.mean(groups[group]) * self.owl_m
            outlier_ratios[group] = (
                100 * (groups[group] > threshold).sum().item() / groups[group].numel()
            )
        outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
        for k in outlier_ratios:
            outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
                1
                / (outlier_ratios_arr.max() - outlier_ratios_arr.min())
                * self.owl_lmbda
                * 2
            )
        outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
        sparsities = {
            k: 1
            - (
                outlier_ratios[k]
                - numpy.mean(outlier_ratios_arr)
                + (1 - float(self.sparsity))
            )
            for k in outlier_ratios
        }
        logger.info(f"OWL sparsities for sp={self.sparsity} are:")
        for k in sparsities:
            logger.info(f"Sparsity for {k}: {sparsities[k]}")
        return sparsities

    def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
        from llmcompressor.pipelines.basic import run_calibration

        acts = defaultdict(int)

        def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
            nonlocal acts
            if isinstance(input, tuple):
                input = input[0]
            acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

        hooks = set(
            self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
            for name, mod in model.named_modules()
            if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
        )
        with HooksMixin.disable_hooks(keep=hooks):
            run_calibration(model, dataloader)
        self.remove_hooks(hooks)

        return acts

    def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
        n, m = mask_structure.split(":")
        return int(n), int(m)

on_initialize(state, **kwargs)

Initialize and run the OBCQ algorithm on the current state

Parameters:

Name Type Description Default
state State

session state storing input model and calibration data

required
Source code in llmcompressor/modifiers/obcq/sgpt_base.py
def on_initialize(self, state: "State", **kwargs) -> bool:
    """
    Initialize and run the OBCQ algorithm on the current state

    :param state: session state storing input model and calibration data
    """
    model: torch.nn.Module = state.model
    dataloader: torch.utils.data.DataLoader = state.data.calib

    # infer module and sequential targets
    self.sequential_targets = self._infer_sequential_targets(model)
    layers = get_layers(self.sequential_targets, model)
    self._target_layers = get_layers(
        self.targets, model
    )  # layers containing targets

    # infer layer sparsities
    if self.sparsity_profile == "owl":
        logger.info(
            "Using OWL to infer target layer-wise sparsities from "
            f"{len(dataloader) if dataloader else 0} calibration samples..."
        )
        self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

    # get layers and validate sparsity
    if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
        self.sparsity
    ):
        raise ValueError(
            f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
            f"sparsities values, but model has {len(layers)} target layers"
        )

    return True