About
It is often difficult to mentally visualize complex multiphase learning rate schedules.
To address this, d9d allows you to visualize the resulting learning rate curve interactively using plotly.
Usage Example
The visualize_lr_scheduler function takes a factory function that constructs your scheduler, simulates a training run, and plots the learning rate history.
import torch
from d9d.lr_scheduler.visualizer import visualize_lr_scheduler
from d9d.lr_scheduler.piecewise import piecewise_schedule, CurveLinear, CurveCosine
def create_scheduler(optimizer: torch.optim.Optimizer):
return (
piecewise_schedule(initial_multiplier=0.0, total_steps=100)
.for_steps(10, 1.0, CurveLinear())
.fill_rest(0.0, CurveCosine())
.build(optimizer)
)
# Opens an interactive plot in browser/notebook
visualize_lr_scheduler(
factory=create_scheduler,
num_steps=100, # Duration to simulate
init_lr=1e-3 # Base LR to visualize
)
API Reference
d9d.lr_scheduler.visualizer
visualize_lr_scheduler(factory, num_steps, init_lr=1.0)
Visualizes the learning rate schedule using Plotly.
This function simulates the training process for num_steps to record the LR changes
and generates an interactive plot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
factory
|
SchedulerFactory
|
A callable that accepts an Optimizer and returns an LRScheduler. |
required |
num_steps
|
int
|
The number of steps to simulate. |
required |
init_lr
|
float
|
The initial learning rate to set on the dummy optimizer. |
1.0
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If the |
Source code in d9d/lr_scheduler/visualizer.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | |