DoRA layer classes that wrap standard PEFT LoRA layers with weight-decomposed adaptation. Each class computes factored column norms of the composed weight and applies magnitude/direction decomposition during the forward pass.

Inheritance

DoraLinearLayer (nn.Module)
├── DoraEmbeddingLayer
└── _DoraConvNdLayer
    ├── DoraConv1dLayer
    ├── DoraConv2dLayer
    └── DoraConv3dLayer

All classes inherit from DoraLinearLayer, sharing composition logic (dispatch, chunking, fused/eager selection). _DoraConvNdLayer adds conv-specific weight reshaping before delegating to the linear factored-norm path.


The primary DoRA layer for nn.Linear modules. Handles weight norm computation via factored column norms, dispatches to fused Triton kernels or eager PyTorch, and supports FSDP2/ZeRO-3 parameter gathering.

peft.tuners.lora.dora.DoraLinearLayer

Bases: Module

Source code in peft/tuners/lora/dora.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
class DoraLinearLayer(nn.Module):
    def __init__(self, fan_in_fan_out):
        super().__init__()
        self.fan_in_fan_out = fan_in_fan_out
        self._last_chunk_size: Optional[int] = None
        self._last_forward_chunk_size: Optional[int] = None

    def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
        weight = transpose(weight, self.fan_in_fan_out)
        compute_dtype = torch.float32 if weight.dtype in (torch.float16, torch.bfloat16) else weight.dtype
        weight_comp = weight.to(dtype=compute_dtype)
        lora_weight_comp = lora_weight.to(device=weight.device, dtype=compute_dtype)

        total = weight_comp + scaling * lora_weight_comp
        weight_norm = torch.linalg.vector_norm(total, dim=1)
        if weight_norm.dtype != weight.dtype:
            weight_norm = weight_norm.to(dtype=weight.dtype)
        return weight_norm

    @torch.no_grad()
    def _get_weight_norm_linear(
        self,
        *,
        base_weight: torch.Tensor,
        lora_A_w: torch.Tensor,
        lora_B_w: torch.Tensor,
        scaling: float,
        chunk_size: Optional[int] = None,
    ):
        """
        Compute ||W + s·(B A)||_row-wise without materializing B A.

        We use:
          ||W + s·(B A)||^2 = ||W||^2 + 2 s ⟨W, B A⟩ + s^2 ||B A||^2
        Let U := W A^T  (shape: [out, r])  and  G := A A^T  (shape: [r, r]).
        Then:
          ⟨W, B A⟩ per-row equals (B ⊙ U).sum(dim=1), since U_jk = ⟨W_j, A_k⟩ and (B A)_j = Σ_k B_jk A_k.
          ||B A||^2 per-row equals (B G ⊙ B).sum(dim=1).

        This avoids constructing the dense [out, in] product B A, reduces memory,
        and allows chunking along 'in' to cap working set size.

        Returns:
            weight_norm: tensor of shape [out_features]. Magnitude division is
            always done by the caller in PyTorch, ensuring identical precision
            regardless of whether Triton kernels are used.
        """

        W_t = transpose(base_weight, self.fan_in_fan_out)
        device = W_t.device
        dtype = W_t.dtype

        """
        Compute all norms in float32 for numerical stability when weights are bf16/fp16.
        We disable autocast locally to prevent the backend from downcasting matmuls.
        """
        compute_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype
        if compute_dtype not in _DTYPE_TO_ELEMENT_SIZE:
            raise ValueError(
                f"Unsupported compute_dtype {compute_dtype} for DoRA norm computation. "
                f"Expected one of {set(_DTYPE_TO_ELEMENT_SIZE.keys())}."
            )
        element_size = _DTYPE_TO_ELEMENT_SIZE[compute_dtype]

        out_features, in_features = W_t.shape
        if chunk_size is None:
            memory_threshold = _get_norm_memory_threshold_bytes()
            # Include all live buffers: W_chunk (out*chunk), A_chunk (r*chunk),
            # plus once-per-call U (out*r) and Gram (r*r), all in compute_dtype.
            rank = lora_A_w.shape[0]
            full_bytes = ((out_features + rank) * in_features + (out_features * rank) + (rank * rank)) * element_size
            if full_bytes <= memory_threshold:
                chunk_size = in_features
            else:
                const_bytes = (out_features * rank + rank * rank) * element_size
                denom = (out_features + rank) * element_size
                if memory_threshold > const_bytes:
                    raw = (memory_threshold - const_bytes) // max(denom, 1)
                    chunk_size = int(max(1, min(in_features, raw)))
                else:
                    chunk_size = 1
                # Align for accelerator tensor core kernels when possible
                if W_t.device.type in ("cuda", "xpu") and chunk_size > 64:
                    chunk_size = (chunk_size // 64) * 64
        else:
            chunk_size = max(1, min(in_features, chunk_size))

        self._last_chunk_size = chunk_size
        logger.debug(
            "DoRA: chunk_size=%d (out=%d, in=%d, rank=%d, thresholdMB=%.1f)",
            chunk_size,
            out_features,
            in_features,
            lora_A_w.shape[0],
            _get_norm_memory_threshold_bytes() / (1024 * 1024),
        )

        scale_value = float(scaling)
        scale_is_zero = scale_value == 0.0

        w_norm_sq = torch.zeros(out_features, device=device, dtype=compute_dtype)
        rank = lora_A_w.shape[0]
        U = None if scale_is_zero else torch.zeros(out_features, rank, device=device, dtype=compute_dtype)
        gram = None if scale_is_zero else torch.zeros(rank, rank, device=device, dtype=compute_dtype)

        for start in range(0, in_features, chunk_size):
            end = min(start + chunk_size, in_features)
            W_chunk = W_t[:, start:end]
            W_chunk = W_chunk.to(dtype=compute_dtype)

            w_norm_sq += (W_chunk * W_chunk).sum(dim=1)

            if scale_is_zero:
                continue

            A_chunk = lora_A_w[:, start:end]
            A_chunk = A_chunk.to(device=device, dtype=compute_dtype)

            U.addmm_(W_chunk, A_chunk.transpose(0, 1))
            gram.addmm_(A_chunk, A_chunk.transpose(0, 1))

        if scale_is_zero:
            norm_sq = w_norm_sq
            norm_sq = norm_sq.clamp_min_(0)  # in-place safe: function runs under @torch.no_grad()
            weight_norm = torch.sqrt(norm_sq)
        else:
            B_comp = lora_B_w.to(device=device, dtype=compute_dtype)
            cross_term = (B_comp * U).sum(dim=1)
            BA = B_comp @ gram
            ba_norm_sq = (BA * B_comp).sum(dim=1)

            if _use_fused_kernels() and w_norm_sq.is_cuda:
                (weight_norm,) = fused_norm_assembly(
                    w_norm_sq,
                    cross_term,
                    ba_norm_sq,
                    scale_value,
                )
            else:
                # Use Python float directly — PyTorch handles scalar-tensor
                # promotion natively, avoiding a tiny CUDA alloc per call.
                norm_sq = w_norm_sq + (2.0 * scale_value) * cross_term + (scale_value * scale_value) * ba_norm_sq
                norm_sq = norm_sq.clamp_min_(0)  # in-place safe: function runs under @torch.no_grad()
                weight_norm = torch.sqrt(norm_sq)

        if weight_norm.dtype != dtype:
            weight_norm = weight_norm.to(dtype=dtype)

        return weight_norm

    @dynamo_disable
    def _compose_with_base_chunks(
        self,
        *,
        x: torch.Tensor,
        lora_result: torch.Tensor,
        base_weight_t: torch.Tensor,
        mag_norm_scale: torch.Tensor,
        scale: float,
    ) -> None:
        """Compose DoRA output chunk-wise to cap peak memory.

        Recomputes ``base_result`` from ``x @ base_weight_t`` in chunks,
        applying the stable composition form per chunk so that only one
        chunk's worth of temporaries is live at a time.
        """

        if dynamo_graph_break is not None and dynamo_is_compiling is not None and dynamo_is_compiling():
            # The loop below depends on Python control flow and small-integer guards that
            # frequently change across layers; let Dynamo drop to eager to avoid runaway
            # recompilations when torch.compile is enabled.
            dynamo_graph_break()

        out_features = base_weight_t.shape[0]
        if out_features == 0:
            self._last_forward_chunk_size = 0
            return

        # Number of rows in the linear output (product of non-feature dims)
        prefix_rows = lora_result.numel() // out_features
        needs_grad = lora_result.requires_grad or mag_norm_scale.requires_grad or x.requires_grad
        use_fused = _use_fused_kernels() and lora_result.is_cuda and not needs_grad
        threshold = _get_forward_chunk_threshold_bytes()

        # The eager mixed-dtype path materializes the stable compose result in
        # the promoted dtype before copying it back into ``lora_result``. Budget
        # chunking against that wider temporary so AMP eager chunking does not
        # assume peak memory still scales only with the activation dtype.
        if use_fused:
            working_element_size = lora_result.element_size()
        else:
            compose_dtype = _promoted_compose_dtype(lora_result.dtype, lora_result.dtype, mag_norm_scale.dtype)
            working_element_size = _dtype_element_size(compose_dtype)

        if prefix_rows == 0:
            chunk_size = out_features
        else:
            denom = prefix_rows * max(working_element_size, 1)
            if denom == 0:
                chunk_size = out_features
            else:
                capacity = threshold // denom
                if capacity <= 0:
                    chunk_size = 1
                else:
                    chunk_size = int(min(out_features, capacity))

        if chunk_size <= 0:
            chunk_size = 1

        device_type = lora_result.device.type
        if device_type in ("cuda", "xpu") and chunk_size > 64:
            aligned = (chunk_size // 64) * 64
            chunk_size = max(64, aligned)
            chunk_size = min(chunk_size, out_features)

        self._last_forward_chunk_size = chunk_size
        logger.debug(
            "DoRA: forward chunk_size=%d (rows=%d, out=%d, thresholdMB=%.1f)",
            chunk_size,
            prefix_rows,
            out_features,
            threshold / (1024 * 1024),
        )

        # NOTE: chunked composition mutates lora_result slices in-place, so it
        # cannot use fused_dora_compose_autograd (which saves ``inner`` for
        # backward). Fall back to eager PyTorch compose when grads are needed.
        if _use_fused_backward() and needs_grad:
            logger.debug(
                "DoRA: chunked compose falling back to eager path because "
                "fused backward is incompatible with in-place chunk mutation."
            )

        # Cast mag only for the Triton inference path (use_fused requires
        # not needs_grad).  The eager chunk path keeps fp32 mag so mixed-dtype
        # eager chunks follow the same promoted reference as eager training.
        if use_fused and mag_norm_scale.dtype != lora_result.dtype:
            mag_norm_scale = mag_norm_scale.to(lora_result.dtype)

        for start in range(0, out_features, chunk_size):
            end = min(start + chunk_size, out_features)
            base_slice = F.linear(x, base_weight_t[start:end, :])
            chunk = lora_result[..., start:end]
            mag_chunk = mag_norm_scale[..., start:end]
            if use_fused:
                fused_dora_compose(chunk, base_slice, mag_chunk, scale, inplace=True)
            else:
                # INVARIANT: In-place mutation of ``chunk`` (a view of ``lora_result``)
                # is safe here because ``lora_result`` is a non-leaf intermediate
                # (produced by ``lora_B(lora_A(x))``).  PyTorch allows in-place ops
                # on non-leaf intermediates whose own data is not needed by any other
                # autograd node.  The only grad-requiring tensor whose value propagates
                # through this in-place op is ``mag_norm_scale`` (via multiplication),
                # and its gradient only needs the *result* of the in-place op, not the
                # original ``lora_result`` value.  If this invariant changes (e.g. if
                # ``lora_result`` becomes a leaf or is reused elsewhere), this in-place
                # path will raise an autograd error at backward time.
                _compose_eager_inplace(chunk, base_slice, mag_chunk, scale)

    def _compose_with_dispatch(
        self,
        *,
        lora_out: torch.Tensor,
        base_result: torch.Tensor,
        mag_norm_scale: torch.Tensor,
        scale: float,
    ) -> torch.Tensor:
        """Compose DoRA output via fused-backward, fused-forward, or eager fallback.

        Forward-only fused compose is inference-only (no autograd graph nodes).
        Training-time fused composition uses the custom autograd path by default;
        disable with ``PEFT_DORA_FUSED_BACKWARD=0``.

        This method is compile-friendly: all branches depend on tensor metadata
        and cached env-var booleans, not on tensor data.  Dynamo can guard on
        these without graph breaks.
        """
        # base_result may require gradients even when LoRA/magnitude are frozen.
        needs_grad = lora_out.requires_grad or base_result.requires_grad or mag_norm_scale.requires_grad

        # Under AMP, mag_norm_scale is fp32 (computed under _disable_autocast)
        # while activations are bf16/fp16.  The fused Triton paths still need
        # homogeneous dtypes; the eager paths keep fp32 mag so the stable form
        # computes (g-1) in fp32, preserving small corrections that bf16 would
        # round to zero.  Mixed-dtype eager-vs-fused parity therefore remains a
        # separate open issue until the fused-autograd path adopts the same
        # dtype contract.
        if mag_norm_scale.dtype != lora_out.dtype:
            mag_norm_scale_cast = mag_norm_scale.to(lora_out.dtype)
        else:
            mag_norm_scale_cast = mag_norm_scale

        if needs_grad and _should_use_fused_backward_for_tensor(lora_out, mag_norm_scale_cast):
            # Fused autograd path — keeps the historical homogeneous-dtype
            # contract for now, so mixed-dtype AMP can still differ from eager.
            return fused_dora_compose_autograd(
                lora_out,
                base_result,
                mag_norm_scale_cast,
                scale,
            )

        if _use_fused_kernels() and lora_out.is_cuda and not needs_grad:
            # Forward-only fused path — no autograd nodes, inference only.
            return fused_dora_compose(
                lora_out,
                base_result,
                mag_norm_scale_cast,
                scale,
                inplace=True,
            )

        if needs_grad:
            # Eager training: keep fp32 mag for (g-1) precision, cast result
            # to activation dtype to prevent fp32 activation memory bloat.
            #
            # VRAM note: when mag is fp32 and activations are bf16 (AMP),
            # PyTorch type-promotion creates transient fp32 intermediates
            # of size [batch*seq, out_features].  Peak is ~2× the activation
            # size at fp32 (4× vs bf16).  This is the non-chunked path
            # (base_result precomputed), so the chunk budget does not bound
            # it.  The fused autograd path (default) avoids this by casting
            # mag to activation dtype first; this path trades higher
            # transient VRAM for fp32 (g-1) precision.  Set
            # PEFT_DORA_FUSED_BACKWARD=1 (default) to avoid this path.
            result = mag_norm_scale * (scale * lora_out) + (mag_norm_scale - 1) * base_result
            if result.dtype != lora_out.dtype:
                result = result.to(lora_out.dtype)
            return result

        # Eager inference in-place: fp32 mag is fine — in-place ops (mul_,
        # add_) truncate to lora_out's dtype at each step, and (g-1) is
        # computed in fp32 before the final truncation.
        return _compose_eager_inplace(lora_out, base_result, mag_norm_scale, scale)

    def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
        # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
        dtype_is_fp16 = lora_A.dtype == torch.float16
        if dtype_is_fp16:
            lora_A = lora_A.float()
            lora_B = lora_B.float()

        # Include lora_A/lora_B in the gather scope — under ZeRO-3 the adapter
        # parameters can also be sharded (e.g. mid-training adapter swaps).
        with _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B):
            if base_layer.__class__.__name__ == "Linear4bit":
                # We have to create a copy of the base layer, otherwise, FSDP will throw an error. 8bit does not work
                # yet because Int8Params cannot be correctly deep-copied (attributes vanish)
                base_layer = deepcopy(base_layer)

            weight = dequantize_module_weight(base_layer)
            weight = weight.to(lora_A.device)
            if weight.data.ndim >= 3:  # For handling LoRAs applied to Conv layers.
                weight_norm = self._get_weight_norm_conv_factored(
                    base_weight=weight,
                    lora_A_w=lora_A,
                    lora_B_w=lora_B,
                    scaling=scaling,
                )
            else:
                weight_norm = self._get_weight_norm_linear(
                    base_weight=weight,
                    lora_A_w=lora_A,
                    lora_B_w=lora_B,
                    scaling=scaling,
                )

            if dtype_is_fp16:
                weight_norm = weight_norm.half()

        if place_on_cpu:
            weight_norm = weight_norm.to("cpu")
        self.weight = nn.Parameter(weight_norm, requires_grad=True)

    def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
        """
        For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
        output.
        Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).

        Compile-friendly: the ``base_result is not None`` path (common case)
        is fully traceable by Dynamo.  The ``base_result is None`` path
        routes through ``_compose_with_base_chunks`` which has
        ``@dynamo_disable`` due to its data-dependent chunk loop.
        """
        # Compute weight_norm in a memory-efficient way without materializing full lora_weight.
        magnitude = self.weight
        device_type = base_layer.weight.device.type
        base_weight = None
        # _fsdp_full_param_ctx receives lora_A/lora_B alongside base_layer
        # because they are nn.Module instances (Linear sub-layers of the LoRA
        # adapter) that FSDP1 can individually wrap — summon_full_params needs
        # the module handle to unshard their parameters.  In contrast, the
        # embedding path passes only base_layer to _fsdp_full_param_ctx because
        # its lora_A/lora_B are raw tensors (transposed nn.Parameters from a
        # ParameterDict), not nn.Module instances, so summon_full_params cannot
        # act on them.  The embedding path gathers those raw tensors via
        # _maybe_gather_base_params_ctx instead.
        with (
            torch.no_grad(),
            _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B),
            _fsdp_full_param_ctx(base_layer, lora_A, lora_B),
            _disable_autocast(device_type),
        ):
            base_weight = dequantize_module_weight(base_layer)
            # Norm reads weight in-place — no clone needed yet.
            weight_norm = self._get_weight_norm_linear(
                base_weight=base_weight,
                lora_A_w=lora_A.weight,
                lora_B_w=lora_B.weight,
                scaling=scaling,
            )
            # Snapshot AFTER norm computation: the base_result=None path needs
            # the weight to survive the gather scope, but deferring the clone
            # avoids a peak-VRAM spike during the norm computation (which already
            # allocates chunked intermediates).  For large MoE layers (32K×128K)
            # this halves the transient allocation inside the gather scope.
            if base_result is None:
                base_weight = _snapshot_dequantized_weight(base_layer, base_weight)
            # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
            # Computed under ``no_grad`` so the norm stays constant during backpropagation.
        # Division always in PyTorch — identical precision regardless of Triton availability.
        # Ensure weight_norm is on the same device as magnitude (CPU-offloaded
        # gathers can leave weight_norm on CPU while magnitude lives on GPU).
        if weight_norm.device != magnitude.device:
            weight_norm = weight_norm.to(device=magnitude.device)
        if weight_norm.is_floating_point():
            # eps depends on weight_norm's dtype: bf16/fp16 can't represent
            # 1e-12, so use 1e-6 to prevent overflow after the m/||W|| cast.
            eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
            weight_norm = weight_norm.clamp_min(eps)
        mag_norm_scale = (magnitude / weight_norm).view(1, -1)

        # Compute LoRA output and compose result with minimal temporaries
        lora_result = lora_B(lora_A(x))
        scale = scaling

        if base_result is not None:
            bias = base_layer.bias
            if bias is not None:
                # Move bias to base_result's device (CPU-offloaded base_layer
                # keeps bias on CPU while base_result lives on GPU), and cast
                # to base_result dtype to avoid fp32 type promotion under AMP
                # (where base_result is bf16 but bias is fp32).  Without the
                # dtype cast, base_result would be promoted to fp32, defeating
                # Triton dispatch in _compose_with_dispatch.
                #
                # Assumption: base_result is a standard floating dtype (bf16,
                # fp16, fp32) from F.linear under autocast.  If a quantized
                # base layer produces int8/fp8 base_result, this cast would
                # silently convert the bias to a quantized type.
                if bias.device != base_result.device or bias.dtype != base_result.dtype:
                    bias = bias.to(device=base_result.device, dtype=base_result.dtype)
                base_result = base_result - bias
            # Release dequantized base weight early on the common base_result path.
            base_weight = None
            self._last_forward_chunk_size = None
            return self._compose_with_dispatch(
                lora_out=lora_result,
                base_result=base_result,
                mag_norm_scale=mag_norm_scale,
                scale=scale,
            )

        # Note: this creates a full copy of the base weight on GPU. For large
        # layers this is a non-trivial allocation, but inherently unavoidable
        # when base_result is not precomputed. The common base_result path
        # (above) avoids this allocation entirely.
        if base_weight.device != x.device or base_weight.dtype != x.dtype:
            base_weight = base_weight.to(device=x.device, dtype=x.dtype)
        base_weight_t = transpose(base_weight, self.fan_in_fan_out)
        self._compose_with_base_chunks(
            x=x,
            lora_result=lora_result,
            base_weight_t=base_weight_t,
            mag_norm_scale=mag_norm_scale,
            scale=scale,
        )
        return lora_result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora.dora." + rep

__init__(fan_in_fan_out)

Source code in peft/tuners/lora/dora.py
1006
1007
1008
1009
1010
def __init__(self, fan_in_fan_out):
    super().__init__()
    self.fan_in_fan_out = fan_in_fan_out
    self._last_chunk_size: Optional[int] = None
    self._last_forward_chunk_size: Optional[int] = None

get_weight_norm(weight, lora_weight, scaling)

Source code in peft/tuners/lora/dora.py
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
    weight = transpose(weight, self.fan_in_fan_out)
    compute_dtype = torch.float32 if weight.dtype in (torch.float16, torch.bfloat16) else weight.dtype
    weight_comp = weight.to(dtype=compute_dtype)
    lora_weight_comp = lora_weight.to(device=weight.device, dtype=compute_dtype)

    total = weight_comp + scaling * lora_weight_comp
    weight_norm = torch.linalg.vector_norm(total, dim=1)
    if weight_norm.dtype != weight.dtype:
        weight_norm = weight_norm.to(dtype=weight.dtype)
    return weight_norm

update_layer(*, base_layer, lora_A, lora_B, scaling, place_on_cpu=False)

Source code in peft/tuners/lora/dora.py
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
    # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
    dtype_is_fp16 = lora_A.dtype == torch.float16
    if dtype_is_fp16:
        lora_A = lora_A.float()
        lora_B = lora_B.float()

    # Include lora_A/lora_B in the gather scope — under ZeRO-3 the adapter
    # parameters can also be sharded (e.g. mid-training adapter swaps).
    with _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B):
        if base_layer.__class__.__name__ == "Linear4bit":
            # We have to create a copy of the base layer, otherwise, FSDP will throw an error. 8bit does not work
            # yet because Int8Params cannot be correctly deep-copied (attributes vanish)
            base_layer = deepcopy(base_layer)

        weight = dequantize_module_weight(base_layer)
        weight = weight.to(lora_A.device)
        if weight.data.ndim >= 3:  # For handling LoRAs applied to Conv layers.
            weight_norm = self._get_weight_norm_conv_factored(
                base_weight=weight,
                lora_A_w=lora_A,
                lora_B_w=lora_B,
                scaling=scaling,
            )
        else:
            weight_norm = self._get_weight_norm_linear(
                base_weight=weight,
                lora_A_w=lora_A,
                lora_B_w=lora_B,
                scaling=scaling,
            )

        if dtype_is_fp16:
            weight_norm = weight_norm.half()

    if place_on_cpu:
        weight_norm = weight_norm.to("cpu")
    self.weight = nn.Parameter(weight_norm, requires_grad=True)

forward(x, *, lora_A, lora_B, scaling, base_layer, base_result=None)

For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).

Compile-friendly: the base_result is not None path (common case) is fully traceable by Dynamo. The base_result is None path routes through _compose_with_base_chunks which has @dynamo_disable due to its data-dependent chunk loop.

Source code in peft/tuners/lora/dora.py
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
    """
    For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
    output.
    Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).

    Compile-friendly: the ``base_result is not None`` path (common case)
    is fully traceable by Dynamo.  The ``base_result is None`` path
    routes through ``_compose_with_base_chunks`` which has
    ``@dynamo_disable`` due to its data-dependent chunk loop.
    """
    # Compute weight_norm in a memory-efficient way without materializing full lora_weight.
    magnitude = self.weight
    device_type = base_layer.weight.device.type
    base_weight = None
    # _fsdp_full_param_ctx receives lora_A/lora_B alongside base_layer
    # because they are nn.Module instances (Linear sub-layers of the LoRA
    # adapter) that FSDP1 can individually wrap — summon_full_params needs
    # the module handle to unshard their parameters.  In contrast, the
    # embedding path passes only base_layer to _fsdp_full_param_ctx because
    # its lora_A/lora_B are raw tensors (transposed nn.Parameters from a
    # ParameterDict), not nn.Module instances, so summon_full_params cannot
    # act on them.  The embedding path gathers those raw tensors via
    # _maybe_gather_base_params_ctx instead.
    with (
        torch.no_grad(),
        _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B),
        _fsdp_full_param_ctx(base_layer, lora_A, lora_B),
        _disable_autocast(device_type),
    ):
        base_weight = dequantize_module_weight(base_layer)
        # Norm reads weight in-place — no clone needed yet.
        weight_norm = self._get_weight_norm_linear(
            base_weight=base_weight,
            lora_A_w=lora_A.weight,
            lora_B_w=lora_B.weight,
            scaling=scaling,
        )
        # Snapshot AFTER norm computation: the base_result=None path needs
        # the weight to survive the gather scope, but deferring the clone
        # avoids a peak-VRAM spike during the norm computation (which already
        # allocates chunked intermediates).  For large MoE layers (32K×128K)
        # this halves the transient allocation inside the gather scope.
        if base_result is None:
            base_weight = _snapshot_dequantized_weight(base_layer, base_weight)
        # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
        # Computed under ``no_grad`` so the norm stays constant during backpropagation.
    # Division always in PyTorch — identical precision regardless of Triton availability.
    # Ensure weight_norm is on the same device as magnitude (CPU-offloaded
    # gathers can leave weight_norm on CPU while magnitude lives on GPU).
    if weight_norm.device != magnitude.device:
        weight_norm = weight_norm.to(device=magnitude.device)
    if weight_norm.is_floating_point():
        # eps depends on weight_norm's dtype: bf16/fp16 can't represent
        # 1e-12, so use 1e-6 to prevent overflow after the m/||W|| cast.
        eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
        weight_norm = weight_norm.clamp_min(eps)
    mag_norm_scale = (magnitude / weight_norm).view(1, -1)

    # Compute LoRA output and compose result with minimal temporaries
    lora_result = lora_B(lora_A(x))
    scale = scaling

    if base_result is not None:
        bias = base_layer.bias
        if bias is not None:
            # Move bias to base_result's device (CPU-offloaded base_layer
            # keeps bias on CPU while base_result lives on GPU), and cast
            # to base_result dtype to avoid fp32 type promotion under AMP
            # (where base_result is bf16 but bias is fp32).  Without the
            # dtype cast, base_result would be promoted to fp32, defeating
            # Triton dispatch in _compose_with_dispatch.
            #
            # Assumption: base_result is a standard floating dtype (bf16,
            # fp16, fp32) from F.linear under autocast.  If a quantized
            # base layer produces int8/fp8 base_result, this cast would
            # silently convert the bias to a quantized type.
            if bias.device != base_result.device or bias.dtype != base_result.dtype:
                bias = bias.to(device=base_result.device, dtype=base_result.dtype)
            base_result = base_result - bias
        # Release dequantized base weight early on the common base_result path.
        base_weight = None
        self._last_forward_chunk_size = None
        return self._compose_with_dispatch(
            lora_out=lora_result,
            base_result=base_result,
            mag_norm_scale=mag_norm_scale,
            scale=scale,
        )

    # Note: this creates a full copy of the base weight on GPU. For large
    # layers this is a non-trivial allocation, but inherently unavoidable
    # when base_result is not precomputed. The common base_result path
    # (above) avoids this allocation entirely.
    if base_weight.device != x.device or base_weight.dtype != x.dtype:
        base_weight = base_weight.to(device=x.device, dtype=x.dtype)
    base_weight_t = transpose(base_weight, self.fan_in_fan_out)
    self._compose_with_base_chunks(
        x=x,
        lora_result=lora_result,
        base_weight_t=base_weight_t,
        mag_norm_scale=mag_norm_scale,
        scale=scale,
    )
    return lora_result

DoRA layer for nn.Embedding modules. Inherits all composition logic from DoraLinearLayer and overrides only the forward pass to handle embedding lookups.

peft.tuners.lora.dora.DoraEmbeddingLayer

Bases: DoraLinearLayer

Source code in peft/tuners/lora/dora.py
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
class DoraEmbeddingLayer(DoraLinearLayer):
    def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn, base_result=None):
        """
        For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
        output.
        """
        magnitude = self.weight
        device_type = base_layer.weight.device.type
        # lora_A/lora_B are raw tensors (transposed nn.Parameters from the
        # parent Embedding module's ParameterDict), not nn.Module instances.
        # _maybe_gather_base_params_ctx handles raw tensors directly.
        #
        # _fsdp_full_param_ctx only receives base_layer (not lora_A/lora_B)
        # because FSDP1 wraps nn.Module instances — raw tensors stored in a
        # ParameterDict are not individually FSDP-wrapped, so summon_full_params
        # wouldn't apply to them.  The FSDP2 detection check runs on
        # base_layer only via _get_module_state.  If FSDP2 wraps only a
        # parent module (not base_layer directly), detection may miss it —
        # but in that case FSDP2's pre-forward hooks unshard parameters
        # before this forward runs, so norms see full weights.
        with _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B), _fsdp_full_param_ctx(base_layer):
            gathered_lora_A = _refresh_embedding_lora_view(lora_A)
            gathered_lora_B = _refresh_embedding_lora_view(lora_B)
            if _is_zero3_active():
                # Clone while the full tensors are materialized, then use those
                # stable clones after the gather scope exits.  CloneBackward
                # still routes gradients back to the original Parameters.
                # Note: the clones double the transient LoRA allocation inside
                # the gather scope (e.g. ~64 MB per rank-64 256K-vocab adapter
                # in bf16).  This is inherent — the originals must stay alive
                # for GatheredParameters cleanup.
                lora_A_forward = gathered_lora_A.clone()
                lora_B_forward = gathered_lora_B.clone()
            else:
                lora_A_forward = gathered_lora_A
                lora_B_forward = gathered_lora_B
            with torch.no_grad(), _disable_autocast(device_type):
                # Build the dense embedding delta under no_grad so the norm path
                # matches linear/conv detach semantics and does not allocate an
                # unnecessary autograd graph for the temporary product.
                # Note: this materializes the full [num_embeddings, embedding_dim]
                # LoRA delta — O(V×d) allocation.  Unlike linear/conv, there is no
                # factored norm path for embeddings yet.  Fine for typical vocab
                # sizes, but worth being aware of at 256k+ tokens.
                #
                # VRAM budget: peak inside this scope is the sum of:
                #   1. lora_A/B clones (ZeRO-3 path): 2 × rank × dim × elem_size
                #   2. lora_weight below: num_embeddings × embedding_dim × elem_size
                #   3. gathered base weight (dequantize_module_weight): num_embeddings × embedding_dim × elem_size
                # For a 256K-vocab, 4096-dim, rank-64 adapter in bf16 this is
                # ~2 GB transient.  The clones are unavoidable under ZeRO-3
                # (originals must stay alive for GatheredParameters cleanup).
                lora_weight = (lora_A_forward @ lora_B_forward).T
                weight = dequantize_module_weight(base_layer)
                weight_dtype = weight.dtype
                weight_norm = self.get_weight_norm(weight, lora_weight, scaling)
                # Defer snapshot to after norm — see linear forward for rationale.
                if base_result is None:
                    weight = _snapshot_dequantized_weight(base_layer, weight)
        # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
        # "[...] we suggest treating ||V +∆V ||_c in
        # Eq. (5) as a constant, thereby detaching it from the gradient
        # graph. This means that while ||V + ∆V ||_c dynamically
        # reflects the updates of ∆V , it won’t receive any gradient
        # during backpropagation"
        # weight_norm is already detached: torch.no_grad() above prevents graph
        # construction for the entire norm computation block, matching the linear
        # path which relies on no_grad() alone without a separate .detach().
        # Ensure weight_norm is on the same device as magnitude (CPU-offloaded
        # gathers can leave weight_norm on CPU while magnitude lives on GPU).
        if weight_norm.device != magnitude.device:
            weight_norm = weight_norm.to(device=magnitude.device)
        if weight_norm.is_floating_point():
            eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
            weight_norm = weight_norm.clamp_min(eps)
        mag_norm_scale = magnitude / weight_norm
        if base_result is None:
            # Ensure weight is on the execution device (CPU-offloaded gathers
            # may return weights still on CPU while x lives on GPU).
            # Note: allocates a full copy of the embedding matrix on GPU;
            # the common base_result path avoids this entirely.
            # Only transfer device — x.dtype is Long (token indices), not a
            # floating-point type suitable for the weight matrix.
            if weight.device != x.device:
                weight = weight.to(device=x.device)
            base_result = embed_fn(x, weight)
        # Route through _compose_with_dispatch so embedding layers benefit
        # from fused Triton kernels when available (same path as linear layers).
        lora_out = embed_fn(x, lora_A_forward) @ lora_B_forward
        result_dora = self._compose_with_dispatch(
            lora_out=lora_out,
            base_result=base_result,
            mag_norm_scale=mag_norm_scale,
            scale=scaling,
        )
        return mag_norm_scale, result_dora

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora.dora." + rep

forward(x, *, lora_A, lora_B, scaling, base_layer, embed_fn, base_result=None)

For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output.

Source code in peft/tuners/lora/dora.py
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn, base_result=None):
    """
    For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
    output.
    """
    magnitude = self.weight
    device_type = base_layer.weight.device.type
    # lora_A/lora_B are raw tensors (transposed nn.Parameters from the
    # parent Embedding module's ParameterDict), not nn.Module instances.
    # _maybe_gather_base_params_ctx handles raw tensors directly.
    #
    # _fsdp_full_param_ctx only receives base_layer (not lora_A/lora_B)
    # because FSDP1 wraps nn.Module instances — raw tensors stored in a
    # ParameterDict are not individually FSDP-wrapped, so summon_full_params
    # wouldn't apply to them.  The FSDP2 detection check runs on
    # base_layer only via _get_module_state.  If FSDP2 wraps only a
    # parent module (not base_layer directly), detection may miss it —
    # but in that case FSDP2's pre-forward hooks unshard parameters
    # before this forward runs, so norms see full weights.
    with _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B), _fsdp_full_param_ctx(base_layer):
        gathered_lora_A = _refresh_embedding_lora_view(lora_A)
        gathered_lora_B = _refresh_embedding_lora_view(lora_B)
        if _is_zero3_active():
            # Clone while the full tensors are materialized, then use those
            # stable clones after the gather scope exits.  CloneBackward
            # still routes gradients back to the original Parameters.
            # Note: the clones double the transient LoRA allocation inside
            # the gather scope (e.g. ~64 MB per rank-64 256K-vocab adapter
            # in bf16).  This is inherent — the originals must stay alive
            # for GatheredParameters cleanup.
            lora_A_forward = gathered_lora_A.clone()
            lora_B_forward = gathered_lora_B.clone()
        else:
            lora_A_forward = gathered_lora_A
            lora_B_forward = gathered_lora_B
        with torch.no_grad(), _disable_autocast(device_type):
            # Build the dense embedding delta under no_grad so the norm path
            # matches linear/conv detach semantics and does not allocate an
            # unnecessary autograd graph for the temporary product.
            # Note: this materializes the full [num_embeddings, embedding_dim]
            # LoRA delta — O(V×d) allocation.  Unlike linear/conv, there is no
            # factored norm path for embeddings yet.  Fine for typical vocab
            # sizes, but worth being aware of at 256k+ tokens.
            #
            # VRAM budget: peak inside this scope is the sum of:
            #   1. lora_A/B clones (ZeRO-3 path): 2 × rank × dim × elem_size
            #   2. lora_weight below: num_embeddings × embedding_dim × elem_size
            #   3. gathered base weight (dequantize_module_weight): num_embeddings × embedding_dim × elem_size
            # For a 256K-vocab, 4096-dim, rank-64 adapter in bf16 this is
            # ~2 GB transient.  The clones are unavoidable under ZeRO-3
            # (originals must stay alive for GatheredParameters cleanup).
            lora_weight = (lora_A_forward @ lora_B_forward).T
            weight = dequantize_module_weight(base_layer)
            weight_dtype = weight.dtype
            weight_norm = self.get_weight_norm(weight, lora_weight, scaling)
            # Defer snapshot to after norm — see linear forward for rationale.
            if base_result is None:
                weight = _snapshot_dequantized_weight(base_layer, weight)
    # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
    # "[...] we suggest treating ||V +∆V ||_c in
    # Eq. (5) as a constant, thereby detaching it from the gradient
    # graph. This means that while ||V + ∆V ||_c dynamically
    # reflects the updates of ∆V , it won’t receive any gradient
    # during backpropagation"
    # weight_norm is already detached: torch.no_grad() above prevents graph
    # construction for the entire norm computation block, matching the linear
    # path which relies on no_grad() alone without a separate .detach().
    # Ensure weight_norm is on the same device as magnitude (CPU-offloaded
    # gathers can leave weight_norm on CPU while magnitude lives on GPU).
    if weight_norm.device != magnitude.device:
        weight_norm = weight_norm.to(device=magnitude.device)
    if weight_norm.is_floating_point():
        eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
        weight_norm = weight_norm.clamp_min(eps)
    mag_norm_scale = magnitude / weight_norm
    if base_result is None:
        # Ensure weight is on the execution device (CPU-offloaded gathers
        # may return weights still on CPU while x lives on GPU).
        # Note: allocates a full copy of the embedding matrix on GPU;
        # the common base_result path avoids this entirely.
        # Only transfer device — x.dtype is Long (token indices), not a
        # floating-point type suitable for the weight matrix.
        if weight.device != x.device:
            weight = weight.to(device=x.device)
        base_result = embed_fn(x, weight)
    # Route through _compose_with_dispatch so embedding layers benefit
    # from fused Triton kernels when available (same path as linear layers).
    lora_out = embed_fn(x, lora_A_forward) @ lora_B_forward
    result_dora = self._compose_with_dispatch(
        lora_out=lora_out,
        base_result=base_result,
        mag_norm_scale=mag_norm_scale,
        scale=scaling,
    )
    return mag_norm_scale, result_dora

Base class for all convolution DoRA layers. Reshapes N-dimensional convolution weights to 2D, then delegates to the linear factored-norm path inherited from DoraLinearLayer.

peft.tuners.lora.dora._DoraConvNdLayer

Bases: DoraLinearLayer

Source code in peft/tuners/lora/dora.py
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
class _DoraConvNdLayer(DoraLinearLayer):
    def _get_weight_norm_conv_factored(
        self,
        *,
        base_weight: torch.Tensor,
        lora_A_w: torch.Tensor,
        lora_B_w: torch.Tensor,
        scaling: float,
        chunk_size: Optional[int] = None,
    ) -> torch.Tensor:
        out_channels = base_weight.shape[0]
        flat_weight = base_weight.reshape(out_channels, -1)
        rank = lora_A_w.shape[0]
        flat_A = lora_A_w.reshape(rank, -1)
        flat_B = lora_B_w.reshape(out_channels, -1)

        # Handle grouped convolutions by expanding B across group-specific rank bands
        # lora_A_w has shape [rank, in_per_group, kH, kW]
        # lora_B_w has shape [out_channels, rank_per_group, 1, 1] when groups>1
        rank = lora_A_w.shape[0]
        rank_per_group = flat_B.shape[1]
        if rank_per_group > 0 and rank % rank_per_group == 0:
            groups = rank // rank_per_group
            if groups > 1 and (out_channels % groups == 0):
                out_per_group = out_channels // groups
                B_expanded = flat_B.new_zeros((out_channels, rank))
                for g in range(groups):
                    rows = slice(g * out_per_group, (g + 1) * out_per_group)
                    cols = slice(g * rank_per_group, (g + 1) * rank_per_group)
                    B_expanded[rows, cols] = flat_B[rows, :]
                flat_B = B_expanded

        norms = self._get_weight_norm_linear(
            base_weight=flat_weight,
            lora_A_w=flat_A,
            lora_B_w=flat_B,
            scaling=scaling,
            chunk_size=chunk_size,
        )

        view_shape = (1, out_channels) + (1,) * max(0, base_weight.dim() - 2)
        return norms.view(view_shape)

    def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
        # calculate L2 norm of weight matrix, column-wise
        compute_dtype = torch.float32 if weight.dtype in (torch.float16, torch.bfloat16) else weight.dtype
        weight_comp = weight.to(dtype=compute_dtype)
        lora_weight_comp = lora_weight.to(device=weight.device, dtype=compute_dtype)

        total = weight_comp + scaling * lora_weight_comp
        # the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D
        dim = tuple(range(1, weight.dim()))
        weight_norm = total.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
        if weight_norm.dtype != weight.dtype:
            weight_norm = weight_norm.to(dtype=weight.dtype)
        return weight_norm

    def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
        """
        For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
        output.
        Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).
        """
        magnitude = self.weight
        device_type = base_layer.weight.device.type
        # See linear forward for FSDP1 no-op note on passing lora_A/lora_B.
        with (
            torch.no_grad(),
            _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B),
            _fsdp_full_param_ctx(base_layer, lora_A, lora_B),
            _disable_autocast(device_type),
        ):
            weight = dequantize_module_weight(base_layer)
            weight_norm = self._get_weight_norm_conv_factored(
                base_weight=weight,
                lora_A_w=lora_A.weight,
                lora_B_w=lora_B.weight,
                scaling=scaling,
            )
            # Defer snapshot to after norm — see linear forward for rationale.
            if base_result is None:
                weight = _snapshot_dequantized_weight(base_layer, weight)
        # weight_norm is a derived tensor (from norm computation), not a view
        # of the (re-shardable) parameter — safe to use outside the gather scope.
        # Only transfer device when needed (CPU-offloaded gathers can leave
        # weight_norm on CPU while magnitude lives on GPU).
        if weight_norm.device != magnitude.device:
            weight_norm = weight_norm.to(device=magnitude.device)
        if weight_norm.is_floating_point():
            eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
            weight_norm = weight_norm.clamp_min(eps)
        mag_norm_scale = magnitude / weight_norm

        if base_result is None:
            # Ensure weight is on the execution device (CPU-offloaded gathers
            # may return weights still on CPU while x lives on GPU).
            # Note: allocates a full copy of the conv weight on GPU;
            # the common base_result path avoids this entirely.
            if weight.device != x.device or weight.dtype != x.dtype:
                weight = weight.to(device=x.device, dtype=x.dtype)
            base_result = self.conv_fn(
                x,
                weight,
                bias=None,
                stride=base_layer.stride,
                padding=base_layer.padding,
                dilation=base_layer.dilation,
                groups=base_layer.groups,
            )
        else:
            bias = base_layer.bias
            if bias is not None:
                # Move bias to base_result's device and dtype (CPU-offloaded
                # base_layer keeps bias on CPU; AMP keeps bias in fp32 while
                # base_result is bf16).
                if bias.device != base_result.device or bias.dtype != base_result.dtype:
                    bias = bias.to(device=base_result.device, dtype=base_result.dtype)
                # reshape bias to (1, -1, 1, ...)
                bias_shape = (1, -1) + (1,) * (base_result.dim() - 2)
                base_result = base_result - bias.view(*bias_shape)

        lora_out = lora_B(lora_A(x))
        return self._compose_with_dispatch(
            lora_out=lora_out,
            base_result=base_result,
            mag_norm_scale=mag_norm_scale,
            scale=scaling,
        )

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora.dora." + rep

get_weight_norm(weight, lora_weight, scaling)

Source code in peft/tuners/lora/dora.py
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
    # calculate L2 norm of weight matrix, column-wise
    compute_dtype = torch.float32 if weight.dtype in (torch.float16, torch.bfloat16) else weight.dtype
    weight_comp = weight.to(dtype=compute_dtype)
    lora_weight_comp = lora_weight.to(device=weight.device, dtype=compute_dtype)

    total = weight_comp + scaling * lora_weight_comp
    # the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D
    dim = tuple(range(1, weight.dim()))
    weight_norm = total.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
    if weight_norm.dtype != weight.dtype:
        weight_norm = weight_norm.to(dtype=weight.dtype)
    return weight_norm

forward(x, *, lora_A, lora_B, scaling, base_layer, base_result=None)

For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).

Source code in peft/tuners/lora/dora.py
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
    """
    For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
    output.
    Norm path runs under no_grad in fp32 and is detached (DoRA §4.3).
    """
    magnitude = self.weight
    device_type = base_layer.weight.device.type
    # See linear forward for FSDP1 no-op note on passing lora_A/lora_B.
    with (
        torch.no_grad(),
        _maybe_gather_base_params_ctx(base_layer, lora_A, lora_B),
        _fsdp_full_param_ctx(base_layer, lora_A, lora_B),
        _disable_autocast(device_type),
    ):
        weight = dequantize_module_weight(base_layer)
        weight_norm = self._get_weight_norm_conv_factored(
            base_weight=weight,
            lora_A_w=lora_A.weight,
            lora_B_w=lora_B.weight,
            scaling=scaling,
        )
        # Defer snapshot to after norm — see linear forward for rationale.
        if base_result is None:
            weight = _snapshot_dequantized_weight(base_layer, weight)
    # weight_norm is a derived tensor (from norm computation), not a view
    # of the (re-shardable) parameter — safe to use outside the gather scope.
    # Only transfer device when needed (CPU-offloaded gathers can leave
    # weight_norm on CPU while magnitude lives on GPU).
    if weight_norm.device != magnitude.device:
        weight_norm = weight_norm.to(device=magnitude.device)
    if weight_norm.is_floating_point():
        eps = 1e-12 if weight_norm.dtype in (torch.float32, torch.float64) else 1e-6
        weight_norm = weight_norm.clamp_min(eps)
    mag_norm_scale = magnitude / weight_norm

    if base_result is None:
        # Ensure weight is on the execution device (CPU-offloaded gathers
        # may return weights still on CPU while x lives on GPU).
        # Note: allocates a full copy of the conv weight on GPU;
        # the common base_result path avoids this entirely.
        if weight.device != x.device or weight.dtype != x.dtype:
            weight = weight.to(device=x.device, dtype=x.dtype)
        base_result = self.conv_fn(
            x,
            weight,
            bias=None,
            stride=base_layer.stride,
            padding=base_layer.padding,
            dilation=base_layer.dilation,
            groups=base_layer.groups,
        )
    else:
        bias = base_layer.bias
        if bias is not None:
            # Move bias to base_result's device and dtype (CPU-offloaded
            # base_layer keeps bias on CPU; AMP keeps bias in fp32 while
            # base_result is bf16).
            if bias.device != base_result.device or bias.dtype != base_result.dtype:
                bias = bias.to(device=base_result.device, dtype=base_result.dtype)
            # reshape bias to (1, -1, 1, ...)
            bias_shape = (1, -1) + (1,) * (base_result.dim() - 2)
            base_result = base_result - bias.view(*bias_shape)

    lora_out = lora_B(lora_A(x))
    return self._compose_with_dispatch(
        lora_out=lora_out,
        base_result=base_result,
        mag_norm_scale=mag_norm_scale,
        scale=scaling,
    )

Convolution Variants

Concrete convolution DoRA layers. Each sets the appropriate conv_fn (e.g., F.conv1d) in __init__; all other logic is inherited from _DoraConvNdLayer.

peft.tuners.lora.dora.DoraConv1dLayer

Bases: _DoraConvNdLayer

Source code in peft/tuners/lora/dora.py
1727
1728
1729
1730
class DoraConv1dLayer(_DoraConvNdLayer):
    def __init__(self, fan_in_fan_out):
        super().__init__(fan_in_fan_out)
        self.conv_fn = F.conv1d

peft.tuners.lora.dora.DoraConv2dLayer

Bases: _DoraConvNdLayer

Source code in peft/tuners/lora/dora.py
1733
1734
1735
1736
class DoraConv2dLayer(_DoraConvNdLayer):
    def __init__(self, fan_in_fan_out):
        super().__init__(fan_in_fan_out)
        self.conv_fn = F.conv2d

peft.tuners.lora.dora.DoraConv3dLayer

Bases: _DoraConvNdLayer

Source code in peft/tuners/lora/dora.py
1739
1740
1741
1742
class DoraConv3dLayer(_DoraConvNdLayer):
    def __init__(self, fan_in_fan_out):
        super().__init__(fan_in_fan_out)
        self.conv_fn = F.conv3d