Skip to content

llmcompressor.modifiers.awq.base

AWQModifier

Bases: Modifier, QuantizationMixin

Implements the AWQ (Activation-Weighted Quantization) algorithm, as described in https://arxiv.org/pdf/2306.00978. The algorithm significantly reduces quantization error by protecting only 1% of the most salient weight channels.

Instead of relying on raw weight values, AWQ identifies important channels by analyzing activation patterns, focusing on the channels in the weight tensor that are most responsive to the input. To reduce quantization error, it scales these channels in a way that preserves the model's original behavior, using scaling factors computed offline from activation statistics.

Because this modifier manipulates the weights of the model, it can only be used in in one-shot and not during training. Activation ranges are determined by running a small set of calibration data through the model.

example recipe:

AWQModifier:
  mappings:
    - smooth_layer: "re:.*self_attn_layer_norm"
      balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
    - smooth_layer: "re:.*final_layer_norm"
      balance_layers: ["re:.*fc1"]
  ]
  ignore: ["lm_head"]
  config_groups:
    group_0:
      targets:
        - "Linear"
      input_activations: null
      output_activations: null
      weights:
        num_bits: 4
        type: int
        symmetric: false
        strategy: group
        group_size: 128

Lifecycle: - on_initialize - resolve mappings - capture kwargs needed for forward passes into modules - on_start - set up activation cache hooks to capture input activations to balance layers - on sequential epoch end - apply smoothing to each smoothing layer - consume cached activations across all batches - clear cached activations as they are used - find best smoothing scale for each smoothing layer - apply to model weights - raise error if any unused activations remain - on_end - re-run logic of sequential epoch end (in case of basic pipeline) - set scales and zero points - remove activation hooks - on_finalize - clear resolved mappings and captured activations

Parameters:

Name Type Description Default
sequential_targets

list of module names to compress in the same calibration pass

required
mappings

list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first entry is a list of layers who share the same input activation (the one to be to smoothed) and the second entry is the layer whose output is scaled to achieve the smoothing. If regex is used, it matches layers with the largest overlap in module name.

required
ignore

list of layers to ignore, even if they match a regex in mappings. It should match the name of layers whose outputs are scaled to achieve smoothing (the second entry of the mappings list).

required
offload_device

offload cached args to this device, which reduces memory requirements but requires more time to move data between cpu and execution device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors

required
duo_scaling

whether to use duo scaling, which uses both input activations and weights to determine the scaling factor

required
Source code in llmcompressor/modifiers/awq/base.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
class AWQModifier(Modifier, QuantizationMixin):
    """
    Implements the AWQ (Activation-Weighted Quantization) algorithm,
    as described in https://arxiv.org/pdf/2306.00978. The algorithm
    significantly reduces quantization error by protecting only 1%
    of the most salient weight channels.

    Instead of relying on raw weight values, AWQ identifies important channels by
    analyzing activation patterns, focusing on the channels in the weight tensor that
    are most responsive to the input. To reduce quantization error, it scales these
    channels in a way that preserves the model's original behavior, using scaling
    factors computed offline from activation statistics.

    Because this modifier manipulates the weights of the model, it can only be used in
    in one-shot and not during training. Activation ranges are determined by running a
    small set of calibration data through the model.

    example recipe:
    ```yaml
    AWQModifier:
      mappings:
        - smooth_layer: "re:.*self_attn_layer_norm"
          balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
        - smooth_layer: "re:.*final_layer_norm"
          balance_layers: ["re:.*fc1"]
      ]
      ignore: ["lm_head"]
      config_groups:
        group_0:
          targets:
            - "Linear"
          input_activations: null
          output_activations: null
          weights:
            num_bits: 4
            type: int
            symmetric: false
            strategy: group
            group_size: 128
    ```

    Lifecycle:
        - on_initialize
            - resolve mappings
            - capture kwargs needed for forward passes into modules
        - on_start
            - set up activation cache hooks to capture input activations
                to balance layers
        - on sequential epoch end
            - apply smoothing to each smoothing layer
                - consume cached activations across all batches
                    - clear cached activations as they are used
                - find best smoothing scale for each smoothing layer
                - apply to model weights
                - raise error if any unused activations remain
        - on_end
            - re-run logic of sequential epoch end (in case of basic pipeline)
            - set scales and zero points
            - remove activation hooks
        - on_finalize
            - clear resolved mappings and captured activations

    :param sequential_targets: list of module names to compress in
        the same calibration pass
    :param mappings: list activation layers to smooth, and which layers to
        scale the output such that activations are smoothed.
        Each entry of the mapping list should be a list itself, in which the first
        entry is a list of layers who share the same input activation (the one to be
        to smoothed) and the second entry is the layer whose output is scaled to
        achieve the smoothing.
        If regex is used, it matches layers with the largest overlap in module name.
    :param ignore: list of layers to ignore, even if they match a regex in mappings.
        It should match the name of layers whose outputs are scaled to achieve
        smoothing (the second entry of the mappings list).
    :param offload_device: offload cached args to this device, which reduces memory
        requirements but requires more time to move data between cpu and execution
        device. Defaults to None, so cached args are not offloaded. Consider setting
        to torch.device("cpu") if you are encountering OOM errors
    :param duo_scaling: whether to use duo scaling, which uses both input activations
        and weights to determine the scaling factor
    """

    # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
    model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

    # User-provided vars (in addition to QuantizationMixin args)
    sequential_targets: Union[str, List[str], None] = None
    mappings: Optional[List[AWQMapping]] = None
    offload_device: Optional[torch.device] = None
    duo_scaling: bool = True

    # Private vars set during validation
    _num_bits: Optional[int] = PrivateAttr(default=None)
    _symmetric: Optional[bool] = PrivateAttr(default=None)
    _group_size: Optional[int] = PrivateAttr(default=None)

    # Private vars set during initialization, cleared during finalization
    _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
    # Cache list of forward input args for each parent module, one dict for each batch
    _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
        default_factory=dict
    )
    # Dict[smooth layer name, (activation means, activation counts)]
    _smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
        default_factory=dict
    )

    @model_validator(mode="after")
    def validate_model_after(model: "AWQModifier") -> "AWQModifier":
        """
        Confirm only one configuration for group_size, symmetric, and num_bits,
        as AWQ algorithm depends on it
        Confirm no activation quantization, as AWQ only works with WNA16
        """
        config = model.resolve_quantization_config()

        num_bits_set = set(
            group.weights.num_bits
            for group in config.config_groups.values()
            if group.weights is not None
        )
        assert (
            len(num_bits_set) == 1
        ), "In AWQ, all config groups must use the same configuration for num_bits"

        model._num_bits = next(iter(num_bits_set))

        symmetric_set = set(
            group.weights.symmetric
            for group in config.config_groups.values()
            if group.weights is not None
        )
        assert (
            len(symmetric_set) == 1
        ), "In AWQ, all config groups must use the same configuration for symmetric"

        model._symmetric = next(iter(symmetric_set))

        group_size_set = set(
            group.weights.group_size
            for group in config.config_groups.values()
            if group.weights is not None
        )
        assert (
            len(group_size_set) == 1
        ), "In AWQ, all config groups must use the same configuration for group_size"

        model._group_size = next(iter(group_size_set))

        in_num_bits_set = set(
            group.input_activations.num_bits
            for group in config.config_groups.values()
            if group.input_activations is not None
        )
        assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, (
            "AWQ activations must be 16-bit precision, "
            f"input activations {in_num_bits_set} not allowed"
        )

        out_num_bits_set = set(
            group.output_activations.num_bits
            for group in config.config_groups.values()
            if group.output_activations is not None
        )
        assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, (
            "AWQ activations must be 16-bit precision, "
            f"output activations {out_num_bits_set} not allowed"
        )

        return model

    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        Initialize AWQ on the given state
        Initialize quantization, resolve mappings, cache module kwargs

        :param state: state to run AWQ on
        :return: True on a successful run, False otherwise
        """

        # apply config to model and prepare calibration hooks
        if QuantizationMixin.has_config(self):
            QuantizationMixin.initialize_quantization(self, state.model)

        if self.mappings is None:
            logger.info("No AWQModifier.mappings provided, inferring from model...")
            self.mappings = get_layer_mappings_from_architecture(
                architecture=state.model.__class__.__name__
            )

        self._set_resolved_mappings(state.model)

        return True

    def on_start(self, state: State, event: Event, **kwargs):
        self.started_ = True

        # register quantization calibration hooks
        # assume quantization has been initialized by this modifier or one before it
        QuantizationMixin.start_calibration(self, state.model)
        # Unlike qmod, do not quantize as we calibrate
        # This choice does not seem to have a meaningful impact on accuracy
        state.model.apply(disable_quantization)

        self._setup_activation_cache_hooks()

    def on_event(self, state: State, event: Event, **kwargs):
        if event.type_ == EventType.CALIBRATION_EPOCH_START:
            if not self.started_:
                self.on_start(state, None)

        elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
            # Run smoothing in case of sequential pipeline
            self._apply_smoothing(state.model)

        elif event.type_ == EventType.CALIBRATION_EPOCH_END:
            # Run smoothing in case of basic pipeline
            self._apply_smoothing(state.model)

            if not self.ended_:
                self.on_end(state, None)

    def on_end(self, state: State, event: Event, **kwargs):
        """
        Finish calibrating by setting scales and zero-points,
         removing observers and calibration hooks
        """
        self._assert_all_activations_consumed()

        self.ended_ = True

        modules = list(state.model.modules())
        for module in tqdm(modules, desc="Calibrating weights"):
            update_weight_zp_scale(module)

        QuantizationMixin.end_calibration(self, state.model)

        # remove activation hooks
        self.remove_hooks()

    def on_finalize(self, state: State, **kwargs) -> bool:
        """
        Clean up by clearing the activations and mapping data

        :param state: unused
        :return: True
        """
        if not self.ended_:
            self.on_end(state, None)

        self._parent_args_cache.clear()
        self._smooth_activation_means.clear()
        self._resolved_mappings.clear()

        return True

    def _set_resolved_mappings(self, model: Module) -> None:
        """
        Transforms the list of activations to smooth and their corresponding weights
        into ResolvedMapping objects, resolving regular expressions.
        Result is stored in _resolved_mappings.

        For each activation in the mapping list, we find the corresponding weight to
        balance by searching for the longest substring. For instance, if our balance
        weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
        would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
        repeat for model.layer.1 and so on
        """
        resolved_mappings: list[ResolvedMapping] = []
        for mapping_idx, mapping in enumerate(self.mappings):
            smooth_layers = get_layers(
                mapping.smooth_layer, model, exclude_internal_modules=True
            )
            smooth_names = [
                smooth_name
                for smooth_name in smooth_layers
                if not find_name_or_class_matches(smooth_name, model, self.ignore)
            ]

            num_skipped_mappings = 0
            pbar = tqdm(smooth_names)
            for smooth_name in pbar:
                pbar.set_description(
                    f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
                    f" ({num_skipped_mappings} skipped)"
                )
                smooth_layer = smooth_layers[smooth_name]

                smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
                smooth_parent = get_layer_by_name(smooth_parent_name, model)

                balance_layers, balance_names = [], []
                for balance_regex in mapping.balance_layers:
                    # find the submodules that match the activation layer
                    for balance_suffix, balance_layer in get_layers(
                        balance_regex,
                        smooth_parent,
                        exclude_internal_modules=True,
                    ).items():
                        balance_name = f"{smooth_parent_name}.{balance_suffix}"

                        # exclude v_proj->o_proj mappings whose shapes are incompatible
                        # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
                        if (
                            isinstance(smooth_layer, torch.nn.Linear)
                            and isinstance(balance_layer, torch.nn.Linear)
                            and balance_name.endswith(".o_proj")
                            and (
                                (
                                    smooth_name.endswith(".v_proj")
                                    and smooth_layer.out_features
                                    != balance_layer.in_features
                                )
                                or (
                                    smooth_name.endswith(".qkv_proj")
                                    and smooth_layer.out_features
                                    != 3 * balance_layer.in_features
                                )
                            )
                        ):
                            num_skipped_mappings += 1
                            continue

                        balance_layers.append(balance_layer)
                        balance_names.append(balance_name)

                if len(balance_layers) == 0:
                    continue

                elif len(balance_layers) == 1:
                    # for single balance layer, parent is the balance layer
                    parent_name, parent = balance_name, balance_layer
                else:
                    # for multiple balance layers, find lowest common parent
                    parent_name, parent = get_lowest_common_parent(balance_names, model)

                resolved_mappings.append(
                    ResolvedMapping(
                        smooth_name,
                        smooth_layer,
                        balance_layers,
                        balance_names=balance_names,
                        parent=parent,
                        parent_name=parent_name,
                    )
                )
        self._resolved_mappings = resolved_mappings
        return

    def _setup_activation_cache_hooks(self) -> None:
        """
        Attach a forward hook to each activation we want to smooth. This allows us to
        calculate the dynamic range during calibration
        """

        def cache_parent_kwargs_hook(
            module: torch.nn.Module,
            args: Tuple[torch.Tensor, ...],
            kwargs,
        ):
            values = inspect.signature(module.forward).bind(*args, **kwargs)
            self._parent_args_cache[module].append(values.arguments)

        def create_cache_smooth_activations_hook_fn(smooth_name):
            def cache_smooth_activations_hook(
                _module: torch.nn.Module,
                args: Tuple[torch.Tensor, ...],
                _output: torch.Tensor,
            ):
                self._smooth_activation_means[smooth_name] = _accumulate_mean(
                    # Assume that first argument is the input
                    args[0].cpu().detach().squeeze(),
                    self._smooth_activation_means.get(smooth_name, None),
                )

            return cache_smooth_activations_hook

        for mapping in self._resolved_mappings:
            # parent kwargs needed for future forward passes
            # same parent may appear multiple times in resolved mappings
            if mapping.parent not in self._parent_args_cache:
                self._parent_args_cache[mapping.parent] = IntermediatesCache(
                    None,
                    self.offload_device,
                )
                self.register_hook(
                    mapping.parent,
                    cache_parent_kwargs_hook,
                    "forward_pre",
                    with_kwargs=True,
                )

            # input activations to balance layers needed for loss function
            # storing inputs to first balance layer is sufficient
            # other balance layers get the same input
            self.register_hook(
                mapping.balance_layers[0],
                create_cache_smooth_activations_hook_fn(mapping.smooth_name),
                "forward",
            )

    @torch.no_grad()
    def _apply_smoothing(self, model: Module) -> None:
        """
        Calculate the best scaling factors for each layer to smooth activations and
        apply the scaling factors to the weights of the next layer to offset the
        smoothing

        :param model: model to apply smoothing to
        """
        # NOTE: When using SequentialPipeline, not all the mappings
        # will have cached activations in the segment being udpated
        mappings_to_smooth = [
            mapping
            for mapping in self._resolved_mappings
            if mapping.smooth_name in self._smooth_activation_means
        ]
        for mapping in tqdm(mappings_to_smooth, desc="Smoothing"):
            smooth_layer = mapping.smooth_layer
            balance_layers = mapping.balance_layers
            parent_module = mapping.parent

            # [STEP 1]: Compute per-channel mean of normalised weights
            # All layer weights are concatted together
            weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
            org_shape = weight.shape
            # The weights are reshaped to be organised by quantization group
            weight = weight.view(-1, self._group_size)
            # Calculates the relative magnitude of the weights within
            # each of the quantization groups, and rescales each group
            # individually so that each group has weights on a 0-1 scale.
            weight.abs_()
            weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
            # Resizes the rescaled weight matrix back up to its original dimensions
            weight = weight.view(org_shape)
            # Gets the average rescaled magnitude for each output channel
            w_mean = weight.mean(0)
            del weight

            with calibration_forward_context(model), HooksMixin.disable_hooks():
                # [STEP 3]: Compute output of module
                # could cache from hook, rather than recomputing here
                fp16_outputs = self._run_samples(parent_module)
                if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
                    logger.info(
                        f"Skipping smooth_layer {mapping.smooth_name}, no activations "
                        "found to scale. This can occasionally occur in MoE models "
                        "when certain experts are not activated by calibration samples."
                    )
                    del self._smooth_activation_means[mapping.smooth_name]
                    continue

                x_mean = self._smooth_activation_means[mapping.smooth_name][0]

                # [STEP 4]: Compute loss
                best_scales = self._compute_best_scale(
                    x_mean, w_mean, parent_module, balance_layers, fp16_outputs
                )

            @torch.no_grad()
            def smooth(module):
                with align_module_device(module):
                    scales = best_scales.to(module.weight.device)
                    if module in balance_layers:
                        update_offload_parameter(
                            module,
                            "weight",
                            module.weight.mul_(scales.view(1, -1)),
                        )
                    elif module == smooth_layer:
                        if module.weight.ndim == 1:
                            update_offload_parameter(
                                module,
                                "weight",
                                module.weight.div_(scales),
                            )
                        else:
                            # NOTE: edge case when smooth layer number of out_features
                            # is not equal to balance layer number of in_features
                            # e.g. when fused qkv_proj is used to smooth o_proj
                            # in this case, default to scaling the last output features
                            # because the desired smooth layer is v_proj
                            # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123
                            weight = module.weight
                            weight[-scales.size(0) :].div_(scales.view(-1, 1))
                            update_offload_parameter(module, "weight", weight)
                        if hasattr(module, "bias") and module.bias is not None:
                            update_offload_parameter(
                                module,
                                "bias",
                                module.bias.div_(scales),
                            )

            parent = get_fsdp_parent(mapping.smooth_name, model)
            if parent is not None:
                parent.apply(smooth)
            else:
                # if we're not running with FSDP we can apply smoothing directly
                for layer in balance_layers:
                    smooth(layer)
                smooth(smooth_layer)

            # remove caches needed to smooth this mapping
            del self._smooth_activation_means[mapping.smooth_name]

        for v in self._parent_args_cache.values():
            v.batch_intermediates.clear()
        self._assert_all_activations_consumed()

    def _run_samples(self, module: Module) -> List[torch.Tensor]:
        with align_module_device(module):
            outputs = [
                module(**batch_kwargs)
                for batch_kwargs in self._parent_args_cache[module]
            ]
            return [
                # If Tuple, assume that first argument is the input
                output[0] if isinstance(output, Tuple) else output
                for output in outputs
            ]

    def _compute_best_scale(
        self,
        x_mean: torch.Tensor,
        w_mean: torch.Tensor,
        parent_module: torch.nn.Module,
        linears2scale: List[torch.nn.Linear],
        fp16_outputs: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        Compute loss and select best scales

        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
        Q: weight quantization function | _pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
        best_error = float("inf")

        org_sd = {k: v.cpu() for k, v in parent_module.state_dict().items()}

        device = get_execution_device(parent_module)
        x_mean = x_mean.view(-1).to(device)
        w_mean = w_mean.view(-1).to(device)

        for ratio in range(n_grid):
            # create new scales
            ratio = ratio / n_grid

            # NOTE: s^-1 * x is fused here, according to paper
            if self.duo_scaling:
                scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
                    min=1e-4
                )
            else:
                scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            _scalesview = scales.view(1, -1).to(device)

            # avoid scaling values that overflow
            scales[torch.isinf(scales)] = 1
            scales[torch.isnan(scales)] = 1

            # Q(W * s)
            for linear in linears2scale:
                with align_module_device(linear):
                    linear.weight.mul_(_scalesview)
                    update_offload_parameter(
                        linear,
                        "weight",
                        _pseudo_quantize_tensor(
                            w=linear.weight.data,
                            symmetric=self._symmetric,
                            bit_width=self._num_bits,
                            group_size=self._group_size,
                        )[0]
                        / _scalesview,
                    )

            # W * X
            with HooksMixin.disable_hooks():
                int_w_outputs = self._run_samples(parent_module)

            # compute mean squared error (L2 norm)
            loss = self._compute_loss(fp16_outputs, int_w_outputs, device)

            history.append(loss)
            if loss < best_error:
                best_error = loss
                best_ratio = ratio
                best_scales = scales.clone()
            parent_module.load_state_dict(org_sd)

        if best_ratio == -1:
            logger.debug(history)
            raise Exception

        assert (
            torch.isnan(best_scales).sum() == 0
        ), f"Nan found in scales: {best_scales}"

        return best_scales.detach().cpu()

    @torch.no_grad()
    def _compute_loss(
        self,
        fp16_outputs: List[torch.Tensor],
        int_w_outputs: List[torch.Tensor],
        device: torch.device,
    ) -> torch.Tensor:
        loss = 0.0
        num_elements = 0

        # Compute the MSE loss for each batch
        for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
            batch_loss = (
                (fp16_batch.to(device) - int_w_batch.to(device))
                .view(-1)
                .float()
                .pow(2)
                .sum()
                .item()
            )
            loss += batch_loss
            num_elements += fp16_batch.numel()

        # Normalize the loss by the total number of elements
        loss /= num_elements

        return loss

    def _assert_all_activations_consumed(self):
        """
        Confirm all activations have been consumed
        If not, something has gone wrong
        """
        if len(self._smooth_activation_means) != 0:
            raise RuntimeError("Some cached activations were not used")

on_end(state, event, **kwargs)

Finish calibrating by setting scales and zero-points, removing observers and calibration hooks

Source code in llmcompressor/modifiers/awq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by setting scales and zero-points,
     removing observers and calibration hooks
    """
    self._assert_all_activations_consumed()

    self.ended_ = True

    modules = list(state.model.modules())
    for module in tqdm(modules, desc="Calibrating weights"):
        update_weight_zp_scale(module)

    QuantizationMixin.end_calibration(self, state.model)

    # remove activation hooks
    self.remove_hooks()

on_finalize(state, **kwargs)

Clean up by clearing the activations and mapping data

Parameters:

Name Type Description Default
state State

unused

required

Returns:

Type Description
bool

True

Source code in llmcompressor/modifiers/awq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    Clean up by clearing the activations and mapping data

    :param state: unused
    :return: True
    """
    if not self.ended_:
        self.on_end(state, None)

    self._parent_args_cache.clear()
    self._smooth_activation_means.clear()
    self._resolved_mappings.clear()

    return True

on_initialize(state, **kwargs)

Initialize AWQ on the given state Initialize quantization, resolve mappings, cache module kwargs

Parameters:

Name Type Description Default
state State

state to run AWQ on

required

Returns:

Type Description
bool

True on a successful run, False otherwise

Source code in llmcompressor/modifiers/awq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize AWQ on the given state
    Initialize quantization, resolve mappings, cache module kwargs

    :param state: state to run AWQ on
    :return: True on a successful run, False otherwise
    """

    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    if self.mappings is None:
        logger.info("No AWQModifier.mappings provided, inferring from model...")
        self.mappings = get_layer_mappings_from_architecture(
            architecture=state.model.__class__.__name__
        )

    self._set_resolved_mappings(state.model)

    return True

validate_model_after(model)

Confirm only one configuration for group_size, symmetric, and num_bits, as AWQ algorithm depends on it Confirm no activation quantization, as AWQ only works with WNA16

Source code in llmcompressor/modifiers/awq/base.py
@model_validator(mode="after")
def validate_model_after(model: "AWQModifier") -> "AWQModifier":
    """
    Confirm only one configuration for group_size, symmetric, and num_bits,
    as AWQ algorithm depends on it
    Confirm no activation quantization, as AWQ only works with WNA16
    """
    config = model.resolve_quantization_config()

    num_bits_set = set(
        group.weights.num_bits
        for group in config.config_groups.values()
        if group.weights is not None
    )
    assert (
        len(num_bits_set) == 1
    ), "In AWQ, all config groups must use the same configuration for num_bits"

    model._num_bits = next(iter(num_bits_set))

    symmetric_set = set(
        group.weights.symmetric
        for group in config.config_groups.values()
        if group.weights is not None
    )
    assert (
        len(symmetric_set) == 1
    ), "In AWQ, all config groups must use the same configuration for symmetric"

    model._symmetric = next(iter(symmetric_set))

    group_size_set = set(
        group.weights.group_size
        for group in config.config_groups.values()
        if group.weights is not None
    )
    assert (
        len(group_size_set) == 1
    ), "In AWQ, all config groups must use the same configuration for group_size"

    model._group_size = next(iter(group_size_set))

    in_num_bits_set = set(
        group.input_activations.num_bits
        for group in config.config_groups.values()
        if group.input_activations is not None
    )
    assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, (
        "AWQ activations must be 16-bit precision, "
        f"input activations {in_num_bits_set} not allowed"
    )

    out_num_bits_set = set(
        group.output_activations.num_bits
        for group in config.config_groups.values()
        if group.output_activations is not None
    )
    assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, (
        "AWQ activations must be 16-bit precision, "
        f"output activations {out_num_bits_set} not allowed"
    )

    return model

get_lowest_common_parent(names, module)

Given a list of names, returns the lowest-scope common parent.

NOTE: function excludes parents of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

Returns name of parent and pointer to parent module

Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix

Source code in llmcompressor/modifiers/awq/base.py
def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
    """
    Given a list of names, returns the lowest-scope common parent.

    NOTE: function excludes parents of type ModuleList, which don't play
    nicely with hooks because their forward method is never directly
    called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
    are selected based on router output and their forward method is called.
    https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

    Returns name of parent and pointer to parent module

    Implementation is a small alteration of os.path.commonprefix
    https://docs.python.org/3/library/os.path.html#os.path.commonprefix
    """
    s1 = min(names)
    s2 = max(names)
    parent_name = ""
    for i, c in enumerate(s1):
        if c != s2[i]:
            parent_name = s1[:i].rstrip(".")
            break

    while True:
        if parent_name == "":
            return "", module
        parent = get_layer_by_name(parent_name, module)
        if not isinstance(parent, torch.nn.ModuleList):
            return parent_name, parent
        parent_name = ".".join(parent_name.split(".")[:-1])