About

The d9d.lr_scheduler.piecewise module provides a flexible, builder-based system for constructing piecewise learning rate schedules.

Instead of writing custom LRScheduler subclasses, manual functions for LambdaLR for every variation of piecewise schedule (i.e. "Warmup + Hold + Decay"), you can construct any such a schedule declaratively by chaining phases together.

Usage Example

Here is how to create a standard "Linear Warmup + Hold + Cosine Decay" schedule:

import torch
from d9d.lr_scheduler.piecewise import *

optimizer: torch.optim.Optimizer = ...
total_steps: int = 1000

# Define Schedule
# 1. Start at 0.0
# 2. Linear warmup to 1.0*LR over 100 steps
# 3. Stay at 1.0 * LR until 50% of training steps
# 3. Cosine decay to 0.1 (10% of LR) for the rest of training
scheduler = (
    piecewise_schedule(initial_multiplier=0.0, total_steps=total_steps)
    .for_steps(100, target_multiplier=1.0, curve=CurveLinear())
    .until_percentage(0.5, target_multiplier=1.0, curve=CurveLinear())
    .fill_rest(target_multiplier=0.1, curve=CurveCosine())
    .build(optimizer)
)

Available Curves

The following curve classes are available to interpolate values between phases:

Curve Class Description
CurveLinear Standard straight-line interpolation.
CurveCosine Half-period cosine interpolation (Cosine Annealing).
CurvePoly(power) Polynomial interpolation. power=1 is linear, power=2 is quadratic, etc.
CurveExponential Exponential (log-linear) interpolation.

API Reference

d9d.lr_scheduler.piecewise

Implements flexible piecewise learning rate schedules via a builder pattern.

CurveBase

Bases: ABC

Abstract base class for interpolation curves used in scheduling.

Source code in d9d/lr_scheduler/piecewise/curves.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class CurveBase(abc.ABC):
    """
    Abstract base class for interpolation curves used in scheduling.
    """

    @abc.abstractmethod
    def compute(self, start: float, end: float, step_p: float) -> float:
        """
        Calculates the interpolated value.

        Args:
            start: The value at the beginning of the phase.
            end: The value at the end of the phase.
            step_p: Progress fraction through the phase (0.0 to 1.0).

        Returns:
            The interpolated value.
        """

compute(start, end, step_p) abstractmethod

Calculates the interpolated value.

Parameters:

Name Type Description Default
start float

The value at the beginning of the phase.

required
end float

The value at the end of the phase.

required
step_p float

Progress fraction through the phase (0.0 to 1.0).

required

Returns:

Type Description
float

The interpolated value.

Source code in d9d/lr_scheduler/piecewise/curves.py
10
11
12
13
14
15
16
17
18
19
20
21
22
@abc.abstractmethod
def compute(self, start: float, end: float, step_p: float) -> float:
    """
    Calculates the interpolated value.

    Args:
        start: The value at the beginning of the phase.
        end: The value at the end of the phase.
        step_p: Progress fraction through the phase (0.0 to 1.0).

    Returns:
        The interpolated value.
    """

CurveCosine

Bases: CurveBase

Interpolates using a cosine annealing schedule (half-period cosine).

Source code in d9d/lr_scheduler/piecewise/curves.py
34
35
36
37
38
39
40
41
class CurveCosine(CurveBase):
    """
    Interpolates using a cosine annealing schedule (half-period cosine).
    """

    def compute(self, start: float, end: float, step_p: float) -> float:
        cos_out = (1 + math.cos(math.pi * step_p)) / 2
        return end + (start - end) * cos_out

CurveExponential

Bases: CurveBase

Interpolates exponentially between start and end values (log-space linear).

Source code in d9d/lr_scheduler/piecewise/curves.py
64
65
66
67
68
69
70
71
72
73
74
75
class CurveExponential(CurveBase):
    """
    Interpolates exponentially between start and end values (log-space linear).
    """

    def compute(self, start: float, end: float, step_p: float) -> float:
        eps = 1e-8
        safe_start = max(start, eps)
        safe_end = max(end, eps)

        out_log = math.log(safe_start) + (math.log(safe_end) - math.log(safe_start)) * step_p
        return math.exp(out_log)

CurveLinear

Bases: CurveBase

Linearly interpolates between start and end values.

Source code in d9d/lr_scheduler/piecewise/curves.py
25
26
27
28
29
30
31
class CurveLinear(CurveBase):
    """
    Linearly interpolates between start and end values.
    """

    def compute(self, start: float, end: float, step_p: float) -> float:
        return start + (end - start) * step_p

CurvePoly

Bases: CurveBase

Interpolates using a polynomial function.

Source code in d9d/lr_scheduler/piecewise/curves.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class CurvePoly(CurveBase):
    """
    Interpolates using a polynomial function.
    """

    def __init__(self, power: float):
        """
        Constructs a polynomial curve.

        Args:
            power: The exponent of the polynomial. 1.0 is linear, 2.0 is quadratic, etc.
        """

        self._power = power

    def compute(self, start: float, end: float, step_p: float) -> float:
        p_transformed = step_p ** self._power
        return start + (end - start) * p_transformed

__init__(power)

Constructs a polynomial curve.

Parameters:

Name Type Description Default
power float

The exponent of the polynomial. 1.0 is linear, 2.0 is quadratic, etc.

required
Source code in d9d/lr_scheduler/piecewise/curves.py
49
50
51
52
53
54
55
56
57
def __init__(self, power: float):
    """
    Constructs a polynomial curve.

    Args:
        power: The exponent of the polynomial. 1.0 is linear, 2.0 is quadratic, etc.
    """

    self._power = power

piecewise_schedule(initial_multiplier, total_steps=None)

Entry point for creating a piecewise learning rate schedule.

Parameters:

Name Type Description Default
initial_multiplier float

The initial learning rate multiplier.

required
total_steps int | None

Total training steps. Required for percentage-based scheduling.

None

Returns:

Type Description
PiecewiseScheduleBuilder

A builder instance to configure phases.

Source code in d9d/lr_scheduler/piecewise/builder.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def piecewise_schedule(
        initial_multiplier: float,
        total_steps: int | None = None
) -> PiecewiseScheduleBuilder:
    """
    Entry point for creating a piecewise learning rate schedule.

    Args:
        initial_multiplier: The initial learning rate multiplier.
        total_steps: Total training steps. Required for percentage-based scheduling.

    Returns:
        A builder instance to configure phases.
    """

    return PiecewiseScheduleBuilder(
        initial_multiplier=initial_multiplier,
        total_steps=total_steps
    )