Fused Triton kernels and custom autograd for DoRA composition, norm assembly, and forward+inner-product computation. Each public function dispatches to a Triton kernel when available, falling back to an equivalent PyTorch implementation.

Module Overview

peft.tuners.lora.dora_fused

Fused Triton kernels and custom autograd functions for DoRA composition.

This module provides four optimizations: 1. Fused Element-wise Composition Kernel: replaces 4 sequential element-wise ops with a single Triton kernel computing out = (mag - 1) * base + mag * (scale * lora). 2. Fused Norm Assembly Kernel: fuses the final norm computation sqrt(clamp(w_norm_sq + 2*s*cross + s^2*ba_norm_sq, min=0)) into a single kernel. 3. Fused Forward-and-Inner Kernel: dual-output kernel that computes both the composition output and inner = scale * lora + base in a single pass, used by the custom autograd forward to eliminate intermediate VRAM allocations. 4. Custom Autograd Function: wraps the DoRA forward composition with hand-written backward pass that fuses gradient computation into a single kernel.

All kernels gracefully fall back to plain PyTorch when Triton is unavailable or when tensors are not on CUDA.

Canonical evaluation order

The DoRA composition formula is::

out = (mag - 1) * base + mag * (scale * lora)

All PyTorch paths must parenthesise the LoRA term as mag * (scale * lora). This means scale * lora is computed first, then multiplied by mag. The in-place path achieves this via lora.mul_(scale).mul_(mag).

This contract matters because bf16/fp16 float multiplication is not associative: (mag * scale) * lora and mag * (scale * lora) produce different rounding. Enforcing a single order across all PyTorch paths (out-of-place, in-place, and autograd forward-and-inner) guarantees bitwise parity between them when they operate under the same dtype contract. Mixed-dtype eager AMP paths additionally materialize the promoted stable-form result and copy it back into the activation buffer so they remain bitwise-equal to the eager out-of-place reference. DoraLinearLayer still casts mag before the fused-autograd dispatch, so mixed-dtype eager-vs-fused parity is a separate, higher-level concern.

Triton kernels also use scale * lora first (explicit scaled_lora local), but FMA hardware may still fuse or reorder ops, so Triton-vs-PyTorch agreement is within O(epsilon) tolerance, not bitwise.


Dispatch Functions

Public API functions that select between Triton and PyTorch backends at runtime.

peft.tuners.lora.dora_fused.fused_dora_compose(lora, base, mag_norm_scale, scale, inplace=True)

Fused DoRA composition: out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora).

Algebraically equivalent to mag * (scale * lora + base) - base but uses the numerically stable form that avoids catastrophic cancellation in bf16/fp16.

Parameters:

Name Type Description Default
lora Tensor

LoRA result tensor [..., out_features]

required
base Tensor

Base layer result tensor, same shape as lora

required
mag_norm_scale Tensor

Magnitude/norm scale, broadcastable to lora shape. Expected shape [1, out_features] or [out_features].

required
scale float

Scalar LoRA scaling factor.

required
inplace bool

If True, write result into lora tensor.

True

Returns:

Type Description
Tensor

The composed result tensor (may be lora if inplace=True).

Source code in peft/tuners/lora/dora_fused.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
def fused_dora_compose(
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
    inplace: bool = True,
) -> torch.Tensor:
    """
    Fused DoRA composition: out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora).

    Algebraically equivalent to ``mag * (scale * lora + base) - base`` but uses
    the numerically stable form that avoids catastrophic cancellation in bf16/fp16.

    Args:
        lora: LoRA result tensor [..., out_features]
        base: Base layer result tensor, same shape as lora
        mag_norm_scale: Magnitude/norm scale, broadcastable to lora shape.
                        Expected shape [1, out_features] or [out_features].
        scale: Scalar LoRA scaling factor.
        inplace: If True, write result into lora tensor.

    Returns:
        The composed result tensor (may be lora if inplace=True).
    """
    if (
        not _is_dynamo_compiling()
        and _TRITON_AVAILABLE
        and lora.is_cuda
        and lora.is_contiguous()
        and base.is_contiguous()
        and mag_norm_scale.is_contiguous()
        and _mag_broadcasts_last_dim(mag_norm_scale, lora)
        and lora.dtype == base.dtype == mag_norm_scale.dtype
    ):
        return _fused_dora_compose_triton(lora, base, mag_norm_scale, scale, inplace)
    return _fused_dora_compose_torch(lora, base, mag_norm_scale, scale, inplace)

peft.tuners.lora.dora_fused.fused_dora_forward_and_inner(lora, base, mag_norm_scale, scale)

Compute both the DoRA composition output and the inner tensor in one pass.

Returns:

Type Description
tuple

(out, inner) where: out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora) inner = scale * lora + base

When Triton is available and tensors are on CUDA, a single fused kernel computes both outputs simultaneously — the intermediate scaled_lora never leaves SRAM registers. This eliminates the VRAM spike caused by sequential PyTorch ops in FusedDoRAComposeFunction.forward.

Source code in peft/tuners/lora/dora_fused.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def fused_dora_forward_and_inner(
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
) -> tuple:
    """
    Compute both the DoRA composition output and the ``inner`` tensor in one pass.

    Returns:
        (out, inner) where:
            out   = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
            inner = scale * lora + base

    When Triton is available and tensors are on CUDA, a single fused kernel
    computes both outputs simultaneously — the intermediate ``scaled_lora``
    never leaves SRAM registers.  This eliminates the VRAM spike caused by
    sequential PyTorch ops in ``FusedDoRAComposeFunction.forward``.
    """
    if (
        not _is_dynamo_compiling()
        and _TRITON_AVAILABLE
        and lora.is_cuda
        and lora.is_contiguous()
        and base.is_contiguous()
        and mag_norm_scale.is_contiguous()
        and _mag_broadcasts_last_dim(mag_norm_scale, lora)
        and lora.dtype == base.dtype == mag_norm_scale.dtype
    ):
        return _fused_dora_forward_and_inner_triton(lora, base, mag_norm_scale, scale)
    return _fused_dora_forward_and_inner_torch(lora, base, mag_norm_scale, scale)

peft.tuners.lora.dora_fused.fused_norm_assembly(w_norm_sq, cross_term, ba_norm_sq, scale)

Fused norm assembly: compute weight_norm from components.

Parameters:

Name Type Description Default
w_norm_sq Tensor

||W||^2 per row, shape [out_features]

required
cross_term Tensor

per row, shape [out_features]

required
ba_norm_sq Tensor

||BA||^2 per row, shape [out_features]

required
scale float

LoRA scaling factor

required

Returns:

Type Description
tuple

(weight_norm,) — magnitude division is always done in PyTorch by the caller.

Source code in peft/tuners/lora/dora_fused.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
def fused_norm_assembly(
    w_norm_sq: torch.Tensor,
    cross_term: torch.Tensor,
    ba_norm_sq: torch.Tensor,
    scale: float,
) -> tuple:
    """
    Fused norm assembly: compute weight_norm from components.

    Args:
        w_norm_sq: ||W||^2 per row, shape [out_features]
        cross_term: <W, BA> per row, shape [out_features]
        ba_norm_sq: ||BA||^2 per row, shape [out_features]
        scale: LoRA scaling factor

    Returns:
        (weight_norm,) — magnitude division is always done in PyTorch by the caller.
    """
    if (
        not _is_dynamo_compiling()
        and _TRITON_AVAILABLE
        and w_norm_sq.is_cuda
        and w_norm_sq.is_contiguous()
        and cross_term.is_contiguous()
        and ba_norm_sq.is_contiguous()
        # The Triton norm kernel uses inline PTX (sqrt.rn.f32) which is
        # NVIDIA-specific.  On ROCm/HIP (torch.version.hip is set), fall
        # back to PyTorch to avoid a crash from NVIDIA-only assembly.
        and torch.version.hip is None
    ):
        return _fused_norm_assembly_triton(
            w_norm_sq,
            cross_term,
            ba_norm_sq,
            scale,
        )
    return _fused_norm_assembly_torch(
        w_norm_sq,
        cross_term,
        ba_norm_sq,
        scale,
    )

peft.tuners.lora.dora_fused.fused_dora_compose_autograd(lora, base, mag_norm_scale, scale)

DoRA composition with custom autograd for fused backward.

This is the main entry point for training-time composition that benefits from the fused backward pass.

When torch.compile is active and a torch.library.custom_op is registered (PyTorch 2.4+), this routes through peft::fused_dora_compose so that Dynamo sees a single opaque graph node. In eager mode, this always uses FusedDoRAComposeFunction.apply with Triton kernels.

Parameters:

Name Type Description Default
lora Tensor

LoRA output tensor [..., out_features]

required
base Tensor

Base result tensor, same shape as lora

required
mag_norm_scale Tensor

Magnitude/norm scale [1, out_features]

required
scale float

Scalar LoRA scaling factor

required

Returns:

Type Description
Tensor

out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)

Source code in peft/tuners/lora/dora_fused.py
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
def fused_dora_compose_autograd(
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    DoRA composition with custom autograd for fused backward.

    This is the main entry point for training-time composition that benefits
    from the fused backward pass.

    When ``torch.compile`` is active and a ``torch.library.custom_op`` is
    registered (PyTorch 2.4+), this routes through ``peft::fused_dora_compose``
    so that Dynamo sees a single opaque graph node.  In eager mode, this
    always uses ``FusedDoRAComposeFunction.apply`` with Triton kernels.

    Args:
        lora: LoRA output tensor [..., out_features]
        base: Base result tensor, same shape as lora
        mag_norm_scale: Magnitude/norm scale [1, out_features]
        scale: Scalar LoRA scaling factor

    Returns:
        out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
    """
    if _is_dynamo_compiling():
        if _HAS_CUSTOM_OP:
            return _fused_dora_compose_custom_op(lora, base, mag_norm_scale, scale)
        # Fallback for older PyTorch without custom_op support.
        return (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
    return FusedDoRAComposeFunction.apply(lora, base, mag_norm_scale, scale)

Custom Autograd

Custom torch.autograd.Function that fuses the DoRA compose forward and backward passes. Computes the output and inner product in a single kernel pass so that intermediate scale * lora never hits global memory; saves only the activation-sized inner tensor for backward.

peft.tuners.lora.dora_fused.FusedDoRAComposeFunction

Bases: Function

Custom autograd function for DoRA composition with fused backward.

Forward (stable form): out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)

Algebraically equivalent to mag_norm_scale * (scale * lora + base) - base, but avoids catastrophic cancellation when mag_norm_scale ≈ 1.

Backward: fused gradient computation for d_lora, d_base, d_mag

VRAM tradeoff: This path saves inner = scale * lora + base (one tensor, same size as lora) via ctx.save_for_backward only when mag_norm_scale.requires_grad is True. When mag is frozen (e.g. during warmup or partial fine-tuning), inner is never allocated, reclaiming 100% of the fused backward VRAM overhead.

When inner is needed, the forward pass uses fused_dora_forward_and_inner to compute both out and inner in a single fused kernel (when Triton is available), so the intermediate scaled_lora = scale * lora never leaves SRAM registers and is never allocated as a global-memory tensor. This eliminates the VRAM spike from sequential PyTorch ops. When mag is frozen, a forward-only fused compose is used instead (no inner allocation at all). Versus the in-place unfused baseline, overhead is at most 1x lora-sized activation per layer (the saved inner).

Source code in peft/tuners/lora/dora_fused.py
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
class FusedDoRAComposeFunction(torch.autograd.Function):
    """
    Custom autograd function for DoRA composition with fused backward.

    Forward (stable form): out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)

    Algebraically equivalent to ``mag_norm_scale * (scale * lora + base) - base``,
    but avoids catastrophic cancellation when ``mag_norm_scale ≈ 1``.

    Backward: fused gradient computation for d_lora, d_base, d_mag

    **VRAM tradeoff**: This path saves ``inner = scale * lora + base``
    (one tensor, same size as ``lora``) via ``ctx.save_for_backward`` only
    when ``mag_norm_scale.requires_grad`` is True.  When mag is frozen
    (e.g. during warmup or partial fine-tuning), ``inner`` is never
    allocated, reclaiming 100% of the fused backward VRAM overhead.

    When ``inner`` is needed, the forward pass uses
    ``fused_dora_forward_and_inner`` to compute both ``out`` and ``inner``
    in a single fused kernel (when Triton is available), so the intermediate
    ``scaled_lora = scale * lora`` never leaves SRAM registers and is never
    allocated as a global-memory tensor.  This eliminates the VRAM spike
    from sequential PyTorch ops.  When mag is frozen, a forward-only fused
    compose is used instead (no ``inner`` allocation at all).  Versus the
    in-place unfused baseline, overhead is at most 1x ``lora``-sized
    activation per layer (the saved ``inner``).
    """

    @staticmethod
    def forward(
        ctx,
        lora: torch.Tensor,
        base: torch.Tensor,
        mag_norm_scale: torch.Tensor,
        scale: float,
    ) -> torch.Tensor:
        """
        Args:
            lora: LoRA output tensor [..., out_features]
            base: Base result tensor, same shape as lora
            mag_norm_scale: Magnitude/norm scale [1, out_features] or [out_features]
            scale: Scalar LoRA scaling factor (float)

        Returns:
            out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
        """
        # The entire forward body runs under no_grad because gradients are
        # hand-computed in backward().
        with torch.no_grad():
            # Only materialise ``inner`` when we need d_mag in backward.
            # When mag is frozen (requires_grad=False) — e.g. during warmup
            # or partial fine-tuning — this skips a full activation-sized
            # allocation (up to 34 GB on 70B models).
            if mag_norm_scale.requires_grad:
                # fused_dora_forward_and_inner computes both ``out`` and
                # ``inner`` in a single Triton kernel (when available),
                # keeping ``scaled_lora`` in SRAM registers only.
                out, inner = fused_dora_forward_and_inner(lora, base, mag_norm_scale, scale)
                ctx.save_for_backward(inner, mag_norm_scale)
                ctx.needs_mag = True
            else:
                # mag frozen — no inner needed, forward-only compose.
                out = fused_dora_compose(lora, base, mag_norm_scale, scale, inplace=False)
                ctx.save_for_backward(
                    mag_norm_scale,
                )
                ctx.needs_mag = False

        ctx.scale = scale

        return out

    @staticmethod
    def backward(ctx, d_out):
        """
        Fused backward pass.

        Gradients (derived from ``out = mag * inner - base`` where
        ``inner = scale * lora + base``):
            d_lora = mag * scale * d_out
            d_base = (mag - 1) * d_out
            d_mag  = sum_over_broadcast_dims(inner * d_out)

        Numerical note — bf16/fp16 precision gap:
            The forward pass uses the numerically stable form
            ``out = (mag - 1) * base + mag * (scale * lora)`` to avoid
            catastrophic cancellation when ``mag ≈ 1``.  The backward,
            however, is derived from the algebraically equivalent form
            ``out = mag * inner - base`` (where ``inner = scale * lora + base``).
            In exact arithmetic the two are identical, but in bf16/fp16 they
            differ by O(eps_bf16) per element per layer due to rounding in
            intermediate accumulations.  This gap is expected and benign for
            typical training workloads.
        """
        scale = ctx.scale
        needs_mag = ctx.needs_mag

        if needs_mag:
            inner, mag_norm_scale = ctx.saved_tensors
        else:
            (mag_norm_scale,) = ctx.saved_tensors
            inner = None  # not saved — mag was frozen

        d_lora = d_base = d_mag = None

        needs_lora_grad = ctx.needs_input_grad[0]
        needs_base_grad = ctx.needs_input_grad[1]
        needs_mag_grad = ctx.needs_input_grad[2]

        if (
            not _is_dynamo_compiling()
            and _TRITON_AVAILABLE
            and d_out.is_cuda
            and d_out.is_contiguous()
            and (inner is None or inner.is_contiguous())
            and _mag_broadcasts_last_dim(mag_norm_scale, d_out)
            and d_out.dtype == mag_norm_scale.dtype
            # inner dtype intentionally not checked: Triton backward only
            # reads d_out and mag; inner is used in d_mag reduction which
            # has its own .to() cast.
        ):
            d_lora, d_base, d_mag = _fused_backward_triton(
                d_out,
                inner,
                mag_norm_scale,
                scale,
                needs_lora_grad,
                needs_base_grad,
                needs_mag_grad,
            )
        else:
            d_lora, d_base, d_mag = _fused_backward_torch(
                d_out,
                inner,
                mag_norm_scale,
                scale,
                needs_lora_grad,
                needs_base_grad,
                needs_mag_grad,
            )

        return d_lora, d_base, d_mag, None  # None for scale (not a Tensor)

forward(ctx, lora, base, mag_norm_scale, scale) staticmethod

Parameters:

Name Type Description Default
lora Tensor

LoRA output tensor [..., out_features]

required
base Tensor

Base result tensor, same shape as lora

required
mag_norm_scale Tensor

Magnitude/norm scale [1, out_features] or [out_features]

required
scale float

Scalar LoRA scaling factor (float)

required

Returns:

Type Description
Tensor

out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)

Source code in peft/tuners/lora/dora_fused.py
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
@staticmethod
def forward(
    ctx,
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    Args:
        lora: LoRA output tensor [..., out_features]
        base: Base result tensor, same shape as lora
        mag_norm_scale: Magnitude/norm scale [1, out_features] or [out_features]
        scale: Scalar LoRA scaling factor (float)

    Returns:
        out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
    """
    # The entire forward body runs under no_grad because gradients are
    # hand-computed in backward().
    with torch.no_grad():
        # Only materialise ``inner`` when we need d_mag in backward.
        # When mag is frozen (requires_grad=False) — e.g. during warmup
        # or partial fine-tuning — this skips a full activation-sized
        # allocation (up to 34 GB on 70B models).
        if mag_norm_scale.requires_grad:
            # fused_dora_forward_and_inner computes both ``out`` and
            # ``inner`` in a single Triton kernel (when available),
            # keeping ``scaled_lora`` in SRAM registers only.
            out, inner = fused_dora_forward_and_inner(lora, base, mag_norm_scale, scale)
            ctx.save_for_backward(inner, mag_norm_scale)
            ctx.needs_mag = True
        else:
            # mag frozen — no inner needed, forward-only compose.
            out = fused_dora_compose(lora, base, mag_norm_scale, scale, inplace=False)
            ctx.save_for_backward(
                mag_norm_scale,
            )
            ctx.needs_mag = False

    ctx.scale = scale

    return out

backward(ctx, d_out) staticmethod

Fused backward pass.

Gradients (derived from out = mag * inner - base where inner = scale * lora + base): d_lora = mag * scale * d_out d_base = (mag - 1) * d_out d_mag = sum_over_broadcast_dims(inner * d_out)

Numerical note — bf16/fp16 precision gap: The forward pass uses the numerically stable form out = (mag - 1) * base + mag * (scale * lora) to avoid catastrophic cancellation when mag ≈ 1. The backward, however, is derived from the algebraically equivalent form out = mag * inner - base (where inner = scale * lora + base). In exact arithmetic the two are identical, but in bf16/fp16 they differ by O(eps_bf16) per element per layer due to rounding in intermediate accumulations. This gap is expected and benign for typical training workloads.

Source code in peft/tuners/lora/dora_fused.py
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
@staticmethod
def backward(ctx, d_out):
    """
    Fused backward pass.

    Gradients (derived from ``out = mag * inner - base`` where
    ``inner = scale * lora + base``):
        d_lora = mag * scale * d_out
        d_base = (mag - 1) * d_out
        d_mag  = sum_over_broadcast_dims(inner * d_out)

    Numerical note — bf16/fp16 precision gap:
        The forward pass uses the numerically stable form
        ``out = (mag - 1) * base + mag * (scale * lora)`` to avoid
        catastrophic cancellation when ``mag ≈ 1``.  The backward,
        however, is derived from the algebraically equivalent form
        ``out = mag * inner - base`` (where ``inner = scale * lora + base``).
        In exact arithmetic the two are identical, but in bf16/fp16 they
        differ by O(eps_bf16) per element per layer due to rounding in
        intermediate accumulations.  This gap is expected and benign for
        typical training workloads.
    """
    scale = ctx.scale
    needs_mag = ctx.needs_mag

    if needs_mag:
        inner, mag_norm_scale = ctx.saved_tensors
    else:
        (mag_norm_scale,) = ctx.saved_tensors
        inner = None  # not saved — mag was frozen

    d_lora = d_base = d_mag = None

    needs_lora_grad = ctx.needs_input_grad[0]
    needs_base_grad = ctx.needs_input_grad[1]
    needs_mag_grad = ctx.needs_input_grad[2]

    if (
        not _is_dynamo_compiling()
        and _TRITON_AVAILABLE
        and d_out.is_cuda
        and d_out.is_contiguous()
        and (inner is None or inner.is_contiguous())
        and _mag_broadcasts_last_dim(mag_norm_scale, d_out)
        and d_out.dtype == mag_norm_scale.dtype
        # inner dtype intentionally not checked: Triton backward only
        # reads d_out and mag; inner is used in d_mag reduction which
        # has its own .to() cast.
    ):
        d_lora, d_base, d_mag = _fused_backward_triton(
            d_out,
            inner,
            mag_norm_scale,
            scale,
            needs_lora_grad,
            needs_base_grad,
            needs_mag_grad,
        )
    else:
        d_lora, d_base, d_mag = _fused_backward_torch(
            d_out,
            inner,
            mag_norm_scale,
            scale,
            needs_lora_grad,
            needs_base_grad,
            needs_mag_grad,
        )

    return d_lora, d_base, d_mag, None  # None for scale (not a Tensor)

Triton Kernels

Triton kernel functions implementing the fused operations. Docstrings contain the mathematical formulations; source is hidden as Triton DSL is not useful as API reference.

Note

These kernels are not called directly. Use the dispatch functions above, which handle backend selection, input validation, and autotuning.

peft.tuners.lora.dora_fused._fused_dora_compose_triton(lora, base, mag_norm_scale, scale, inplace=True)

Triton implementation of fused DoRA composition.

peft.tuners.lora.dora_fused._fused_dora_forward_and_inner_triton(lora, base, mag_norm_scale, scale)

Triton implementation of fused forward-and-inner computation.

peft.tuners.lora.dora_fused._fused_norm_assembly_triton(w_norm_sq, cross_term, ba_norm_sq, scale)

Triton implementation of fused norm assembly (norm-only).

peft.tuners.lora.dora_fused._fused_backward_triton(d_out, inner, mag_norm_scale, scale, needs_lora_grad, needs_base_grad, needs_mag_grad)

Triton implementation of fused backward.

Uses a Triton kernel for the element-wise d_lora / d_base computation and a plain PyTorch .sum() for the d_mag reduction. This avoids the tl.atomic_add contention that scaled poorly with large num_rows in the previous single-kernel approach.

Only allocates output tensors and launches the kernel when at least one of d_lora / d_base is needed. When only one is needed, we fall back to a single PyTorch elementwise op (cheaper than a kernel launch for one output). When neither is needed we skip the kernel entirely.


PyTorch Fallbacks

Pure PyTorch implementations used when Triton is unavailable or during torch.compile tracing.

peft.tuners.lora.dora_fused._fused_dora_compose_torch(lora, base, mag_norm_scale, scale, inplace=True)

Pure PyTorch fallback for fused DoRA composition.

Uses the numerically stable form (mag - 1) * base + mag * (scale * lora) to avoid catastrophic cancellation in bf16/fp16 when mag ≈ 1.

Note

The caller (dora.py) is responsible for ensuring lora is a freshly-computed tensor (e.g. lora_B(lora_A(x))) and does not share storage with other live tensors when inplace=True.

Source code in peft/tuners/lora/dora_fused.py
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
def _fused_dora_compose_torch(
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
    inplace: bool = True,
) -> torch.Tensor:
    """Pure PyTorch fallback for fused DoRA composition.

    Uses the numerically stable form ``(mag - 1) * base + mag * (scale * lora)``
    to avoid catastrophic cancellation in bf16/fp16 when mag ≈ 1.

    Note:
        The caller (``dora.py``) is responsible for ensuring ``lora`` is a
        freshly-computed tensor (e.g. ``lora_B(lora_A(x))``) and does not
        share storage with other live tensors when ``inplace=True``.
    """
    if inplace:
        if torch.promote_types(torch.promote_types(lora.dtype, base.dtype), mag_norm_scale.dtype) != lora.dtype:
            result = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
            lora.copy_(result)
            return lora

        # Numerically stable in-place: lora = (mag-1)*base + mag*(scale*lora)
        # Two mul_ calls preserve canonical evaluation order mag*(scale*lora),
        # matching the out-of-place path and the Triton kernel for same-dtype
        # bitwise parity.
        lora.mul_(scale)
        lora.mul_(mag_norm_scale)
        lora.add_(base * (mag_norm_scale - 1))
        return lora
    else:
        result = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
        # Cast to lora.dtype for consistent output contract (matches inplace branch).
        if result.dtype != lora.dtype:
            result = result.to(lora.dtype)
        return result

peft.tuners.lora.dora_fused._fused_dora_forward_and_inner_torch(lora, base, mag_norm_scale, scale)

Pure PyTorch fallback for fused forward-and-inner computation.

Computes both outputs without Triton. scaled_lora is a temporary that is freed after inner and out are computed.

Source code in peft/tuners/lora/dora_fused.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def _fused_dora_forward_and_inner_torch(
    lora: torch.Tensor,
    base: torch.Tensor,
    mag_norm_scale: torch.Tensor,
    scale: float,
) -> tuple:
    """Pure PyTorch fallback for fused forward-and-inner computation.

    Computes both outputs without Triton.  ``scaled_lora`` is a temporary
    that is freed after ``inner`` and ``out`` are computed.
    """
    result_dtype = lora.dtype
    scaled_lora = scale * lora
    inner = scaled_lora + base
    out = (mag_norm_scale - 1) * base + mag_norm_scale * scaled_lora
    if out.dtype != result_dtype:
        out = out.to(result_dtype)
    # Cast inner to match the Triton path's output contract (inner in
    # lora.dtype).  Without this, mixed-dtype inputs produce an fp32 inner
    # that doubles the activation VRAM saved for backward (d_mag reduction).
    if inner.dtype != result_dtype:
        inner = inner.to(result_dtype)
    return out, inner

peft.tuners.lora.dora_fused._fused_norm_assembly_torch(w_norm_sq, cross_term, ba_norm_sq, scale)

Pure PyTorch fallback for fused norm assembly.

Source code in peft/tuners/lora/dora_fused.py
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
def _fused_norm_assembly_torch(
    w_norm_sq: torch.Tensor,
    cross_term: torch.Tensor,
    ba_norm_sq: torch.Tensor,
    scale: float,
) -> tuple:
    """Pure PyTorch fallback for fused norm assembly."""
    # Use Python float directly instead of torch.as_tensor to avoid allocating
    # a new scalar tensor on every call.  PyTorch handles scalar-tensor
    # promotion natively in the arithmetic below.
    s = float(scale)
    norm_sq = w_norm_sq + (2.0 * s) * cross_term + (s * s) * ba_norm_sq
    # Keep this out-of-place to avoid in-place autograd hazards if this helper
    # is called with grad tracking enabled.
    norm_sq = norm_sq.clamp_min(0)
    weight_norm = torch.sqrt(norm_sq)

    return (weight_norm,)

peft.tuners.lora.dora_fused._fused_backward_torch(d_out, inner, mag_norm_scale, scale, needs_lora_grad, needs_base_grad, needs_mag_grad)

Pure PyTorch fallback for fused backward.

inner is scale * lora + base (precomputed in forward).

Source code in peft/tuners/lora/dora_fused.py
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
def _fused_backward_torch(d_out, inner, mag_norm_scale, scale, needs_lora_grad, needs_base_grad, needs_mag_grad):
    """Pure PyTorch fallback for fused backward.

    ``inner`` is ``scale * lora + base`` (precomputed in forward).
    """
    d_lora = d_base = d_mag = None

    if needs_lora_grad:
        d_lora = mag_norm_scale * scale * d_out

    if needs_base_grad:
        d_base = (mag_norm_scale - 1) * d_out

    if needs_mag_grad:
        assert inner is not None, "inner must be saved when mag requires grad"
        # inner was computed under no_grad in forward and may have a different
        # dtype than d_out when AMP autocast is active (e.g. inner in fp32
        # while d_out in fp16).  Align dtypes to avoid silent precision loss.
        if inner.dtype != d_out.dtype:
            inner = inner.to(d_out.dtype)
        d_mag_full = inner * d_out
        # Reduce over dimensions that were broadcast from mag_norm_scale.
        # For Linear [B, F] with mag [1, F] → reduce dim 0.
        # For Conv2d [N, C, H, W] with mag [1, C, 1, 1] → reduce dims 0,2,3.
        sum_dims = _broadcast_reduce_dims(d_out.shape, mag_norm_scale.shape)
        if sum_dims:
            d_mag = d_mag_full.sum(dim=sum_dims, keepdim=True)
        else:
            d_mag = d_mag_full
        # Reshape to match mag_norm_scale shape
        d_mag = d_mag.reshape(mag_norm_scale.shape)

    return d_lora, d_base, d_mag

Autotune Configuration

Triton autotune configuration generators. These produce lists of triton.Config objects controlling block sizes, warps, and pipeline stages for each kernel.

peft.tuners.lora.dora_fused._compose_configs()

Autotune configs for compose and forward_and_inner kernels (RPP=1).

6-GPU analysis: BS=4096/8192 dominate forward (195/216 wins of 626), BS=2048/1024 are significant secondary winners. BS=16384 and BS=32768 scored 0 forward wins and are dropped. RPP=1 won 96% of entries.

Warp counts are pinned per block size from the dominant winners

BS=512 → W=1 (unanimous B200/B300/H100/H200/RTX6000) BS=1024 → W=2 (H200/L40S dominant, covers B200/RTX6000 adequately) BS=2048 → W=8 (B300/H100/H200 dominant) BS=4096 → W=4,8 (W=8 dominant on B200/B300/H100/H200; W=4 secondary) BS=8192 → W=4,8 (W=4 dominant forward; W=8 dominant compose on H200/L40S)

Stages left at Triton default (S=2) — bandwidth-bound kernels show <5% sensitivity to stage count across most shapes.

peft.tuners.lora.dora_fused._backward_configs()

Autotune configs for the backward kernel (RPP=1).

6-GPU analysis: BS=16384 dominates on H100 (43 wins), BS=8192/4096 are strong secondary winners. BS=32768 scored 0 wins and is dropped. BS=512 is marginal (6 total wins) and dropped. RPP=1 won 93% of entries.

Warp counts pinned per block size

BS=1024 → W=1 (B300 7/9, H100 2/3) BS=2048 → W=2 (B300/H100/H200 dominant) BS=4096 → W=2,4 (W=2 dominant H200/RTX6000; W=4 dominant B300) BS=8192 → W=4 (dominant most GPUs, W=8 secondary) BS=16384 → W=8,16 (W=16 unanimous H200, dominant B300; W=8 dominant H100)

Stages left at Triton default — see compose docstring.

peft.tuners.lora.dora_fused._norm_configs()

Autotune configs for norm assembly kernels.

Returns None in default mode (norm kernels use @triton.jit with fixed BS=256 — they are launch-latency bound on modern GPUs and autotune cannot differentiate configs).