About

It is often difficult to mentally visualize complex multiphase learning rate schedules.

To address this, d9d allows you to visualize the resulting learning rate curve interactively using plotly.

Usage Example

The visualize_lr_scheduler function takes a factory function that constructs your scheduler, simulates a training run, and plots the learning rate history.

import torch

from d9d.lr_scheduler.visualizer import visualize_lr_scheduler
from d9d.lr_scheduler.piecewise import piecewise_schedule, CurveLinear, CurveCosine

def create_scheduler(optimizer: torch.optim.Optimizer):
    return (
        piecewise_schedule(initial_multiplier=0.0, total_steps=100)
        .for_steps(10, 1.0, CurveLinear())
        .fill_rest(0.0, CurveCosine())
        .build(optimizer)
    )

# Opens an interactive plot in browser/notebook
visualize_lr_scheduler(
    factory=create_scheduler,
    num_steps=100,      # Duration to simulate
    init_lr=1e-3        # Base LR to visualize
)

API Reference

d9d.lr_scheduler.visualizer

visualize_lr_scheduler(factory, num_steps, init_lr=1.0)

Visualizes the learning rate schedule using Plotly.

This function simulates the training process for num_steps to record the LR changes and generates an interactive plot.

Parameters:

Name Type Description Default
factory SchedulerFactory

A callable that accepts an Optimizer and returns an LRScheduler.

required
num_steps int

The number of steps to simulate.

required
init_lr float

The initial learning rate to set on the dummy optimizer.

1.0

Raises:

Type Description
ImportError

If the plotly library is not installed.

Source code in d9d/lr_scheduler/visualizer.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def visualize_lr_scheduler(factory: SchedulerFactory, num_steps: int, init_lr: float = 1.0):
    """
    Visualizes the learning rate schedule using Plotly.

    This function simulates the training process for `num_steps` to record the LR changes
    and generates an interactive plot.

    Args:
        factory: A callable that accepts an Optimizer and returns an LRScheduler.
        num_steps: The number of steps to simulate.
        init_lr: The initial learning rate to set on the dummy optimizer.

    Raises:
        ImportError: If the `plotly` library is not installed.
    """

    try:
        import plotly.graph_objects as go  # noqa: PLC0415
    except ImportError as e:
        raise ImportError("You have to install `plotly` dependency to use scheduler visualization") from e
    lrs = _get_history(factory, num_steps, init_lr)
    steps = list(range(num_steps))

    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=steps,
        y=lrs,
        mode="lines",
        name="Learning Rate",
        line={"color": "#636EFA", "width": 3},
        hovertemplate="<b>Step:</b> %{x}<br><b>LR:</b> %{y:.6f}<extra></extra>"
    ))

    fig.update_layout(
        title={
            "text": "Scheduler",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top"
        },
        xaxis_title="Steps",
        yaxis_title="Learning Rate",
        template="plotly_white",
        hovermode="x unified",
        height=500
    )

    fig.show()