Fused Triton kernels and custom autograd for DoRA composition, norm assembly, and forward+inner-product computation. Each public function dispatches to a Triton kernel when available, falling back to an equivalent PyTorch implementation.
Module Overview¶
peft.tuners.lora.dora_fused
¶
Fused Triton kernels and custom autograd functions for DoRA composition.
This module provides four optimizations:
1. Fused Element-wise Composition Kernel: replaces 4 sequential element-wise ops
with a single Triton kernel computing out = (mag - 1) * base + mag * (scale * lora).
2. Fused Norm Assembly Kernel: fuses the final norm computation
sqrt(clamp(w_norm_sq + 2*s*cross + s^2*ba_norm_sq, min=0)) into a single kernel.
3. Fused Forward-and-Inner Kernel: dual-output kernel that computes both the
composition output and inner = scale * lora + base in a single pass, used by
the custom autograd forward to eliminate intermediate VRAM allocations.
4. Custom Autograd Function: wraps the DoRA forward composition with hand-written
backward pass that fuses gradient computation into a single kernel.
All kernels gracefully fall back to plain PyTorch when Triton is unavailable or when tensors are not on CUDA.
Canonical evaluation order¶
The DoRA composition formula is::
out = (mag - 1) * base + mag * (scale * lora)
All PyTorch paths must parenthesise the LoRA term as mag * (scale * lora).
This means scale * lora is computed first, then multiplied by mag.
The in-place path achieves this via lora.mul_(scale).mul_(mag).
This contract matters because bf16/fp16 float multiplication is not associative:
(mag * scale) * lora and mag * (scale * lora) produce different rounding.
Enforcing a single order across all PyTorch paths (out-of-place, in-place, and
autograd forward-and-inner) guarantees bitwise parity between them when they
operate under the same dtype contract. Mixed-dtype eager AMP paths additionally
materialize the promoted stable-form result and copy it back into the activation
buffer so they remain bitwise-equal to the eager out-of-place reference.
DoraLinearLayer still casts mag before the fused-autograd dispatch, so
mixed-dtype eager-vs-fused parity is a separate, higher-level concern.
Triton kernels also use scale * lora first (explicit scaled_lora local),
but FMA hardware may still fuse or reorder ops, so Triton-vs-PyTorch agreement
is within O(epsilon) tolerance, not bitwise.
Dispatch Functions¶
Public API functions that select between Triton and PyTorch backends at runtime.
peft.tuners.lora.dora_fused.fused_dora_compose(lora, base, mag_norm_scale, scale, inplace=True)
¶
Fused DoRA composition: out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora).
Algebraically equivalent to mag * (scale * lora + base) - base but uses
the numerically stable form that avoids catastrophic cancellation in bf16/fp16.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lora
|
Tensor
|
LoRA result tensor [..., out_features] |
required |
base
|
Tensor
|
Base layer result tensor, same shape as lora |
required |
mag_norm_scale
|
Tensor
|
Magnitude/norm scale, broadcastable to lora shape. Expected shape [1, out_features] or [out_features]. |
required |
scale
|
float
|
Scalar LoRA scaling factor. |
required |
inplace
|
bool
|
If True, write result into lora tensor. |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The composed result tensor (may be lora if inplace=True). |
Source code in peft/tuners/lora/dora_fused.py
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 | |
peft.tuners.lora.dora_fused.fused_dora_forward_and_inner(lora, base, mag_norm_scale, scale)
¶
Compute both the DoRA composition output and the inner tensor in one pass.
Returns:
| Type | Description |
|---|---|
tuple
|
(out, inner) where: out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora) inner = scale * lora + base |
When Triton is available and tensors are on CUDA, a single fused kernel
computes both outputs simultaneously — the intermediate scaled_lora
never leaves SRAM registers. This eliminates the VRAM spike caused by
sequential PyTorch ops in FusedDoRAComposeFunction.forward.
Source code in peft/tuners/lora/dora_fused.py
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 | |
peft.tuners.lora.dora_fused.fused_norm_assembly(w_norm_sq, cross_term, ba_norm_sq, scale)
¶
Fused norm assembly: compute weight_norm from components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
w_norm_sq
|
Tensor
|
||W||^2 per row, shape [out_features] |
required |
cross_term
|
Tensor
|
|
required |
ba_norm_sq
|
Tensor
|
||BA||^2 per row, shape [out_features] |
required |
scale
|
float
|
LoRA scaling factor |
required |
Returns:
| Type | Description |
|---|---|
tuple
|
(weight_norm,) — magnitude division is always done in PyTorch by the caller. |
Source code in peft/tuners/lora/dora_fused.py
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 | |
peft.tuners.lora.dora_fused.fused_dora_compose_autograd(lora, base, mag_norm_scale, scale)
¶
DoRA composition with custom autograd for fused backward.
This is the main entry point for training-time composition that benefits from the fused backward pass.
When torch.compile is active and a torch.library.custom_op is
registered (PyTorch 2.4+), this routes through peft::fused_dora_compose
so that Dynamo sees a single opaque graph node. In eager mode, this
always uses FusedDoRAComposeFunction.apply with Triton kernels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lora
|
Tensor
|
LoRA output tensor [..., out_features] |
required |
base
|
Tensor
|
Base result tensor, same shape as lora |
required |
mag_norm_scale
|
Tensor
|
Magnitude/norm scale [1, out_features] |
required |
scale
|
float
|
Scalar LoRA scaling factor |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora) |
Source code in peft/tuners/lora/dora_fused.py
1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 | |
Custom Autograd¶
Custom torch.autograd.Function that fuses the DoRA compose forward and backward passes.
Computes the output and inner product in a single kernel pass so that intermediate scale * lora
never hits global memory; saves only the activation-sized inner tensor for backward.
peft.tuners.lora.dora_fused.FusedDoRAComposeFunction
¶
Bases: Function
Custom autograd function for DoRA composition with fused backward.
Forward (stable form): out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora)
Algebraically equivalent to mag_norm_scale * (scale * lora + base) - base,
but avoids catastrophic cancellation when mag_norm_scale ≈ 1.
Backward: fused gradient computation for d_lora, d_base, d_mag
VRAM tradeoff: This path saves inner = scale * lora + base
(one tensor, same size as lora) via ctx.save_for_backward only
when mag_norm_scale.requires_grad is True. When mag is frozen
(e.g. during warmup or partial fine-tuning), inner is never
allocated, reclaiming 100% of the fused backward VRAM overhead.
When inner is needed, the forward pass uses
fused_dora_forward_and_inner to compute both out and inner
in a single fused kernel (when Triton is available), so the intermediate
scaled_lora = scale * lora never leaves SRAM registers and is never
allocated as a global-memory tensor. This eliminates the VRAM spike
from sequential PyTorch ops. When mag is frozen, a forward-only fused
compose is used instead (no inner allocation at all). Versus the
in-place unfused baseline, overhead is at most 1x lora-sized
activation per layer (the saved inner).
Source code in peft/tuners/lora/dora_fused.py
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 | |
forward(ctx, lora, base, mag_norm_scale, scale)
staticmethod
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lora
|
Tensor
|
LoRA output tensor [..., out_features] |
required |
base
|
Tensor
|
Base result tensor, same shape as lora |
required |
mag_norm_scale
|
Tensor
|
Magnitude/norm scale [1, out_features] or [out_features] |
required |
scale
|
float
|
Scalar LoRA scaling factor (float) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
out = (mag_norm_scale - 1) * base + mag_norm_scale * (scale * lora) |
Source code in peft/tuners/lora/dora_fused.py
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 | |
backward(ctx, d_out)
staticmethod
¶
Fused backward pass.
Gradients (derived from out = mag * inner - base where
inner = scale * lora + base):
d_lora = mag * scale * d_out
d_base = (mag - 1) * d_out
d_mag = sum_over_broadcast_dims(inner * d_out)
Numerical note — bf16/fp16 precision gap:
The forward pass uses the numerically stable form
out = (mag - 1) * base + mag * (scale * lora) to avoid
catastrophic cancellation when mag ≈ 1. The backward,
however, is derived from the algebraically equivalent form
out = mag * inner - base (where inner = scale * lora + base).
In exact arithmetic the two are identical, but in bf16/fp16 they
differ by O(eps_bf16) per element per layer due to rounding in
intermediate accumulations. This gap is expected and benign for
typical training workloads.
Source code in peft/tuners/lora/dora_fused.py
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 | |
Triton Kernels¶
Triton kernel functions implementing the fused operations. Docstrings contain the mathematical formulations; source is hidden as Triton DSL is not useful as API reference.
Note
These kernels are not called directly. Use the dispatch functions above, which handle backend selection, input validation, and autotuning.
peft.tuners.lora.dora_fused._fused_dora_compose_triton(lora, base, mag_norm_scale, scale, inplace=True)
¶
Triton implementation of fused DoRA composition.
peft.tuners.lora.dora_fused._fused_dora_forward_and_inner_triton(lora, base, mag_norm_scale, scale)
¶
Triton implementation of fused forward-and-inner computation.
peft.tuners.lora.dora_fused._fused_norm_assembly_triton(w_norm_sq, cross_term, ba_norm_sq, scale)
¶
Triton implementation of fused norm assembly (norm-only).
peft.tuners.lora.dora_fused._fused_backward_triton(d_out, inner, mag_norm_scale, scale, needs_lora_grad, needs_base_grad, needs_mag_grad)
¶
Triton implementation of fused backward.
Uses a Triton kernel for the element-wise d_lora / d_base computation
and a plain PyTorch .sum() for the d_mag reduction. This avoids the
tl.atomic_add contention that scaled poorly with large num_rows in
the previous single-kernel approach.
Only allocates output tensors and launches the kernel when at least one of d_lora / d_base is needed. When only one is needed, we fall back to a single PyTorch elementwise op (cheaper than a kernel launch for one output). When neither is needed we skip the kernel entirely.
PyTorch Fallbacks¶
Pure PyTorch implementations used when Triton is unavailable or during torch.compile tracing.
peft.tuners.lora.dora_fused._fused_dora_compose_torch(lora, base, mag_norm_scale, scale, inplace=True)
¶
Pure PyTorch fallback for fused DoRA composition.
Uses the numerically stable form (mag - 1) * base + mag * (scale * lora)
to avoid catastrophic cancellation in bf16/fp16 when mag ≈ 1.
Note
The caller (dora.py) is responsible for ensuring lora is a
freshly-computed tensor (e.g. lora_B(lora_A(x))) and does not
share storage with other live tensors when inplace=True.
Source code in peft/tuners/lora/dora_fused.py
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 | |
peft.tuners.lora.dora_fused._fused_dora_forward_and_inner_torch(lora, base, mag_norm_scale, scale)
¶
Pure PyTorch fallback for fused forward-and-inner computation.
Computes both outputs without Triton. scaled_lora is a temporary
that is freed after inner and out are computed.
Source code in peft/tuners/lora/dora_fused.py
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 | |
peft.tuners.lora.dora_fused._fused_norm_assembly_torch(w_norm_sq, cross_term, ba_norm_sq, scale)
¶
Pure PyTorch fallback for fused norm assembly.
Source code in peft/tuners/lora/dora_fused.py
917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 | |
peft.tuners.lora.dora_fused._fused_backward_torch(d_out, inner, mag_norm_scale, scale, needs_lora_grad, needs_base_grad, needs_mag_grad)
¶
Pure PyTorch fallback for fused backward.
inner is scale * lora + base (precomputed in forward).
Source code in peft/tuners/lora/dora_fused.py
1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 | |
Autotune Configuration¶
Triton autotune configuration generators. These produce lists of triton.Config objects
controlling block sizes, warps, and pipeline stages for each kernel.
peft.tuners.lora.dora_fused._compose_configs()
¶
Autotune configs for compose and forward_and_inner kernels (RPP=1).
6-GPU analysis: BS=4096/8192 dominate forward (195/216 wins of 626), BS=2048/1024 are significant secondary winners. BS=16384 and BS=32768 scored 0 forward wins and are dropped. RPP=1 won 96% of entries.
Warp counts are pinned per block size from the dominant winners
BS=512 → W=1 (unanimous B200/B300/H100/H200/RTX6000) BS=1024 → W=2 (H200/L40S dominant, covers B200/RTX6000 adequately) BS=2048 → W=8 (B300/H100/H200 dominant) BS=4096 → W=4,8 (W=8 dominant on B200/B300/H100/H200; W=4 secondary) BS=8192 → W=4,8 (W=4 dominant forward; W=8 dominant compose on H200/L40S)
Stages left at Triton default (S=2) — bandwidth-bound kernels show <5% sensitivity to stage count across most shapes.
peft.tuners.lora.dora_fused._backward_configs()
¶
Autotune configs for the backward kernel (RPP=1).
6-GPU analysis: BS=16384 dominates on H100 (43 wins), BS=8192/4096 are strong secondary winners. BS=32768 scored 0 wins and is dropped. BS=512 is marginal (6 total wins) and dropped. RPP=1 won 93% of entries.
Warp counts pinned per block size
BS=1024 → W=1 (B300 7/9, H100 2/3) BS=2048 → W=2 (B300/H100/H200 dominant) BS=4096 → W=2,4 (W=2 dominant H200/RTX6000; W=4 dominant B300) BS=8192 → W=4 (dominant most GPUs, W=8 secondary) BS=16384 → W=8,16 (W=16 unanimous H200, dominant B300; W=8 dominant H100)
Stages left at Triton default — see compose docstring.
peft.tuners.lora.dora_fused._norm_configs()
¶
Autotune configs for norm assembly kernels.
Returns None in default mode (norm kernels use @triton.jit with fixed BS=256 — they are launch-latency bound on modern GPUs and autotune cannot differentiate configs).