About

The d9d.dataset package provides specialized PyTorch Dataset wrappers designed for distributed training scenarios.

Core Concepts

Why Not Auto-Wrap Datasets Automatically?

d9d provides explicit composable wrappers rather than relying on implicit "magic" or automatic Sampler injection often found in other frameworks.

  • Flexible Composition and Order-of-Operations: The behavior of a data pipeline changes significantly depending on the order of composition. By stacking wrappers manually, you control the data flow logic:

  • Granular Configuration: Different datasets have different physical constraints that require specific configurations. A dataset loaded from network storage might require contiguous reads to be performant (ShardIndexingMode.chunked), while an in-memory dataset might prefer round-robin access (ShardIndexingMode.sequential). Explicit wrappers ensure that these configuration options are exposed to the user rather than buried in global trainer arguments.

Features

Smart Bucketing

In NLP and Sequence processing, batches often contain items of varying lengths. Standard random sampling forces the batch to be padded to the length of the longest sequence, wasting computational resources on padding tokens.

BufferSortedDataset implements a "Smart Bucketing" strategy that ensures that items within a specific micro-batch have similar lengths (minimizing padding), while the order of micro-batches remains random enough to preserve training stability.

Usage Example

To use BufferSortedDataset, your underlying dataset must implement the DatasetImplementingSortKeyProtocol (i.e., it must have a sort_key(index) method).

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from d9d.dataset import BufferSortedDataset

class MyTextDataset(Dataset):
    def __init__(self, data: list[str]):
        self.data = data # list of strings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    # ! You need to implement this one \/ !
    def sort_key(self, index):
        return len(self.data[index])

# Create Base Dataset
raw_data = ["short", "very very very long phrase", "tiny", "medium size"] * 100
base_ds = MyTextDataset(raw_data)

# Wrap with Smart Bucketing
# - buffer_size=100: Look at 100 items at a time to find similar lengths
# - pack_size=4: Group them into batches of 4
sorted_ds = BufferSortedDataset(
    base_dataset=base_ds,
    buffer_size=100,
    pack_size=4,
    init_seed=42
)

Sharding

When using Data Parallelism, each GPU processes a subset of the data. ShardedDataset provides a deterministic view of a specific shard of the data based on the rank (shard ID).

It supports:

  • Sequential Sharding: Round-robin distribution (0, 4, 8... for rank 0).
  • Chunked Sharding: Contiguous blocks (0, 1, 2... for rank 0).
  • Optional Padding: Ensuring all shards have exactly the same length. This is critical for distributed training loops where uneven dataset sizes can cause process hangs.

Usage Example

import torch
from torch.utils.data import TensorDataset
from d9d.core.dist_context import DistributedContext, BATCH_DOMAIN
from d9d.dataset import ShardedDataset, ShardIndexingMode

# You can infer your Data Parallel size and rank from DistributedContext object 
context: DistributedContext
batch_mesh = context.mesh_for(BATCH_DOMAIN)
dp_size = batch_mesh.size('dp')
dp_rank = batch_mesh.get_local_rank('dp')

# Create Full Dataset
base_ds = TensorDataset(torch.randn(100, 10))

# Shard it
sharded_ds = ShardedDataset(
    dataset=base_ds,
    total_shards=dp_size,
    current_shard=dp_rank,
    indexing_mode=ShardIndexingMode.chunked,
    # Crucial for preventing distributed hangs
    pad_to_equal_size_across_shards=True 
)

print(f"I am rank {dp_rank}, I see {len(sharded_ds)} items.")

d9d.dataset

This package provides utilities and torch.utils.data.Dataset implementations.

BufferSortedDataset

Bases: Dataset[_T_co], Stateful

A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.

This prevents extreme padding in variable-length training (by grouping similar lengths) while maintaining enough randomness to ensure statistical variance in updates.

Algorithm:

  1. Select a range of indices (size buffer_size).
  2. Sort these indices based on base_dataset.sort_key().
  3. Break the sorted list into packs of size pack_size.
  4. Shuffle the order of these packs.
  5. Flatten the list and serve items.
Source code in d9d/dataset/buffer_sorted.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class BufferSortedDataset(Dataset[_T_co], Stateful):
    """
    A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.

    This prevents extreme padding in variable-length training (by grouping similar lengths)
    while maintaining enough randomness to ensure statistical variance in updates.

    Algorithm:

    1. Select a range of indices (size `buffer_size`).
    2. Sort these indices based on `base_dataset.sort_key()`.
    3. Break the sorted list into packs of size `pack_size`.
    4. Shuffle the order of these packs.
    5. Flatten the list and serve items.
    """

    def __init__(
            self,
            base_dataset: DatasetImplementingSortKeyProtocol[_T_co],
            buffer_size: int,
            pack_size: int,
            init_seed: int | None = None
    ):
        """
        Constructs a BufferSortedDataset object.

        Args:
            base_dataset: The underlying dataset implementing the `DatasetImplementingSortKeyProtocol` protocol.
            buffer_size: The number of items to load into the buffer for sorting.
            pack_size: The size of local groups (batches/micro-batches) that remain
                contiguous after sorting, but are shuffled relative to other packs.
            init_seed: Seed for the random number generator used for shuffling packs.
        """

        self._base_dataset = base_dataset
        self._buffer_size = buffer_size
        self._pack_size = pack_size

        self._rng = random.Random(init_seed ^ 0x105E7 if init_seed is not None else None)
        self._buffer_indices: list[int] = []
        self._buffer_idx: int = -1

    def _update_buffer_idx(self, buffer_idx: int):
        select_start = buffer_idx * self._buffer_size
        select_end = (buffer_idx + 1) * self._buffer_size
        select_end = min(select_end, len(self._base_dataset))

        base_idx = list(range(select_start, select_end))
        base_sort_keys = [self._base_dataset.sort_key(idx) for idx in range(select_start, select_end)]

        local_idx = list(range(len(base_idx)))
        local_idx = sorted(local_idx, key=lambda local_id: base_sort_keys[local_id])

        local_idx_batch = [
            local_idx[i: i + self._pack_size]
            for i in range(0, len(local_idx), self._pack_size)
        ]
        self._rng.shuffle(local_idx_batch)
        local_idx = [y for x in local_idx_batch for y in x]

        self._buffer_indices = [base_idx[local_id] for local_id in local_idx]

        self._buffer_idx = buffer_idx

    def __getitem__(self, index: int) -> _T_co:
        """
        Retrieves an item from the locally sorted/shuffled buffer.

        Args:
            index: The global index.

        Returns:
            The dataset item.
        """

        needs_buffer_idx = index // self._buffer_size
        if self._buffer_idx != needs_buffer_idx:
            self._update_buffer_idx(needs_buffer_idx)

        take_id = self._buffer_indices[index % self._buffer_size]

        return self._base_dataset[take_id]

    def __len__(self) -> int:
        """Returns the length of the base dataset."""

        return len(self._base_dataset)

    def state_dict(self) -> dict[str, Any]:
        ret = {
            "seed": pickle.dumps(self._rng.getstate()),
            "buffer_idx": self._buffer_idx,
            "buffer_indices": self._buffer_indices,
        }
        if isinstance(self._base_dataset, Stateful):
            ret["base_dataset"] = self._base_dataset.state_dict()
        return ret

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._rng.setstate(pickle.loads(state_dict["seed"]))  # noqa: S301
        self._buffer_idx = state_dict["buffer_idx"]
        self._buffer_indices = state_dict["buffer_indices"]
        if isinstance(self._base_dataset, Stateful):
            self._base_dataset.load_state_dict(state_dict["base_dataset"])

__getitem__(index)

Retrieves an item from the locally sorted/shuffled buffer.

Parameters:

Name Type Description Default
index int

The global index.

required

Returns:

Type Description
_T_co

The dataset item.

Source code in d9d/dataset/buffer_sorted.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def __getitem__(self, index: int) -> _T_co:
    """
    Retrieves an item from the locally sorted/shuffled buffer.

    Args:
        index: The global index.

    Returns:
        The dataset item.
    """

    needs_buffer_idx = index // self._buffer_size
    if self._buffer_idx != needs_buffer_idx:
        self._update_buffer_idx(needs_buffer_idx)

    take_id = self._buffer_indices[index % self._buffer_size]

    return self._base_dataset[take_id]

__init__(base_dataset, buffer_size, pack_size, init_seed=None)

Constructs a BufferSortedDataset object.

Parameters:

Name Type Description Default
base_dataset DatasetImplementingSortKeyProtocol[_T_co]

The underlying dataset implementing the DatasetImplementingSortKeyProtocol protocol.

required
buffer_size int

The number of items to load into the buffer for sorting.

required
pack_size int

The size of local groups (batches/micro-batches) that remain contiguous after sorting, but are shuffled relative to other packs.

required
init_seed int | None

Seed for the random number generator used for shuffling packs.

None
Source code in d9d/dataset/buffer_sorted.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
        self,
        base_dataset: DatasetImplementingSortKeyProtocol[_T_co],
        buffer_size: int,
        pack_size: int,
        init_seed: int | None = None
):
    """
    Constructs a BufferSortedDataset object.

    Args:
        base_dataset: The underlying dataset implementing the `DatasetImplementingSortKeyProtocol` protocol.
        buffer_size: The number of items to load into the buffer for sorting.
        pack_size: The size of local groups (batches/micro-batches) that remain
            contiguous after sorting, but are shuffled relative to other packs.
        init_seed: Seed for the random number generator used for shuffling packs.
    """

    self._base_dataset = base_dataset
    self._buffer_size = buffer_size
    self._pack_size = pack_size

    self._rng = random.Random(init_seed ^ 0x105E7 if init_seed is not None else None)
    self._buffer_indices: list[int] = []
    self._buffer_idx: int = -1

__len__()

Returns the length of the base dataset.

Source code in d9d/dataset/buffer_sorted.py
123
124
125
126
def __len__(self) -> int:
    """Returns the length of the base dataset."""

    return len(self._base_dataset)

DatasetImplementingSortKeyProtocol

Bases: Protocol[_T_co]

Protocol for datasets that support retrieval of a specific key for sorting purposes.

This is typically used for length-based bucketing/sorting where the dataset needs to expose the length of an item without loading the full item.

Source code in d9d/dataset/buffer_sorted.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class DatasetImplementingSortKeyProtocol(Protocol[_T_co]):
    """
    Protocol for datasets that support retrieval of a specific key for sorting purposes.

    This is typically used for length-based bucketing/sorting where the dataset
    needs to expose the length of an item without loading the full item.
    """

    def __len__(self) -> int:
        """Returns the total number of items in the dataset."""
        ...

    def sort_key(self, index: int) -> Any:
        """
        Returns a value used for sorting the dataset at the given index.

        Args:
            index: The index of the item.

        Returns:
            A comparable value (e.g., int length) used for sorting.
        """
        ...

    def __getitem__(self, item: int) -> _T_co:
        """Retrieves the item at the specific index."""
        ...

__getitem__(item)

Retrieves the item at the specific index.

Source code in d9d/dataset/buffer_sorted.py
35
36
37
def __getitem__(self, item: int) -> _T_co:
    """Retrieves the item at the specific index."""
    ...

__len__()

Returns the total number of items in the dataset.

Source code in d9d/dataset/buffer_sorted.py
19
20
21
def __len__(self) -> int:
    """Returns the total number of items in the dataset."""
    ...

sort_key(index)

Returns a value used for sorting the dataset at the given index.

Parameters:

Name Type Description Default
index int

The index of the item.

required

Returns:

Type Description
Any

A comparable value (e.g., int length) used for sorting.

Source code in d9d/dataset/buffer_sorted.py
23
24
25
26
27
28
29
30
31
32
33
def sort_key(self, index: int) -> Any:
    """
    Returns a value used for sorting the dataset at the given index.

    Args:
        index: The index of the item.

    Returns:
        A comparable value (e.g., int length) used for sorting.
    """
    ...

ShardIndexingMode

Bases: StrEnum

Defines how the dataset is split across shards.

Modes

sequential: Round-robin distribution.

shard0: 0, 4, 8, 12
shard1: 1, 5, 9, 13
shard2: 2, 6, 10
shard3: 3, 7, 11

chunked: Contiguous blocks.

shard0: 0, 1, 2, 3
shard1: 4, 5, 6, 7
shard2: 8, 9, 10, 11
shard3: 12, 13
Source code in d9d/dataset/sharded.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class ShardIndexingMode(StrEnum):
    """
    Defines how the dataset is split across shards.

    Modes:
        sequential: Round-robin distribution.

            shard0: 0, 4, 8, 12
            shard1: 1, 5, 9, 13
            shard2: 2, 6, 10
            shard3: 3, 7, 11

        chunked: Contiguous blocks.

            shard0: 0, 1, 2, 3
            shard1: 4, 5, 6, 7
            shard2: 8, 9, 10, 11
            shard3: 12, 13
    """

    sequential = "sequential"
    chunked = "chunked"

ShardedDataset

Bases: Dataset[_T_co], Stateful

A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.

This is useful for Data Parallel training where each process should only see a subset of the data. It supports different indexing modes and optional padding to ensure all shards have equal length (preventing hangs in distributed collectives).

Source code in d9d/dataset/sharded.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class ShardedDataset(Dataset[_T_co], Stateful):
    """
    A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.

    This is useful for Data Parallel training where each process should only see
    a subset of the data. It supports different indexing modes and optional padding
    to ensure all shards have equal length (preventing hangs in distributed collectives).
    """

    def __init__(
            self,
            dataset: Dataset[_T_co],
            total_shards: int,
            current_shard: int,
            indexing_mode: ShardIndexingMode,
            pad_to_equal_size_across_shards: bool
    ):
        """
        Constructs a ShardedDataset object.

        Args:
            dataset: The underlying dataset to shard.
            total_shards: The total number of shards (e.g., number of DP ranks).
            current_shard: The index of the current shard (e.g., current DP rank).
            indexing_mode: How indices are assigned to shards (sequential/round-robin or chunked).
            pad_to_equal_size_across_shards: If True, the length of the dataset will be padded
                so that all shards report the same length. The last standard element is repeated.
        """

        if not isinstance(dataset, Sized):
            raise ValueError("Dataset should implement __len__ method")

        self._dataset = dataset

        self._total_shards = total_shards
        self._current_shard = current_shard

        self._indexing_mode = indexing_mode
        self._pad_to_equal_size_across_shards = pad_to_equal_size_across_shards

    def _compute_real_index_sequential(self, index: int) -> int:
        return index * self._total_shards + self._current_shard

    def _get_base_index_unsafe(self, index: int) -> int:
        """
        Calculates the underlying dataset index for a given shard index,
        without boundary checking.
        """

        match self._indexing_mode:
            case ShardIndexingMode.sequential:
                base_index = index * self._total_shards + self._current_shard

                return base_index
            case ShardIndexingMode.chunked:
                ceil_len = math.ceil(len(self._dataset) / self._total_shards)
                shard_start_offset = ceil_len * self._current_shard

                return shard_start_offset + index
            case _:
                raise ValueError(f"Unknown shard indexing mode: {self._indexing_mode}")

    def __getitem__(self, index: int) -> _T_co:
        """
        Retrieves an item from the underlying dataset mapping logic shard index to physical index.

        If padding is enabled and the index exceeds the valid data for this shard,
        the last item in the dataset is returned.

        Args:
            index: The index relative to this shard.

        Returns:
            The data item.
        """

        base_index = self._get_base_index_unsafe(index)
        if base_index >= len(self._dataset):
            base_index = len(self._dataset) - 1
        return self._dataset[base_index]

    def __len__(self) -> int:
        """
        Returns the number of items in this specific shard.

        If `pad_to_equal_size_across_shards` is True, this returns the ceiling
        length (max length across all shards).
        """

        ceil_len = math.ceil(len(self._dataset) / self._total_shards)

        if self._pad_to_equal_size_across_shards:
            return ceil_len

        shards_remainder = len(self._dataset) % self._total_shards
        match self._indexing_mode:
            case ShardIndexingMode.sequential:
                shards_full = len(self._dataset) // self._total_shards
                return shards_full + 1 if self._current_shard < shards_remainder else shards_full
            case ShardIndexingMode.chunked:
                is_shard_last = self._current_shard == self._total_shards - 1
                if not is_shard_last or shards_remainder == 0:
                    return ceil_len
                else:
                    return ceil_len - (self._total_shards - shards_remainder)

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        if isinstance(self._dataset, Stateful):
            self._dataset.load_state_dict(state_dict["dataset"])

        # check whether env mismatched
        if state_dict["total_shards"] != self._total_shards:
            raise ValueError("Shard count mismatch")
        self._total_shards = state_dict["total_shards"]

        self._current_shard = state_dict["current_shard"]

    def state_dict(self) -> dict[str, Any]:
        dct: dict[str, Any] = {
            "total_shards": self._total_shards,
            "current_shard": self._current_shard
        }
        if isinstance(self._dataset, Stateful):
            dct["dataset"] = self._dataset.state_dict()
        return dct

__getitem__(index)

Retrieves an item from the underlying dataset mapping logic shard index to physical index.

If padding is enabled and the index exceeds the valid data for this shard, the last item in the dataset is returned.

Parameters:

Name Type Description Default
index int

The index relative to this shard.

required

Returns:

Type Description
_T_co

The data item.

Source code in d9d/dataset/sharded.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __getitem__(self, index: int) -> _T_co:
    """
    Retrieves an item from the underlying dataset mapping logic shard index to physical index.

    If padding is enabled and the index exceeds the valid data for this shard,
    the last item in the dataset is returned.

    Args:
        index: The index relative to this shard.

    Returns:
        The data item.
    """

    base_index = self._get_base_index_unsafe(index)
    if base_index >= len(self._dataset):
        base_index = len(self._dataset) - 1
    return self._dataset[base_index]

__init__(dataset, total_shards, current_shard, indexing_mode, pad_to_equal_size_across_shards)

Constructs a ShardedDataset object.

Parameters:

Name Type Description Default
dataset Dataset[_T_co]

The underlying dataset to shard.

required
total_shards int

The total number of shards (e.g., number of DP ranks).

required
current_shard int

The index of the current shard (e.g., current DP rank).

required
indexing_mode ShardIndexingMode

How indices are assigned to shards (sequential/round-robin or chunked).

required
pad_to_equal_size_across_shards bool

If True, the length of the dataset will be padded so that all shards report the same length. The last standard element is repeated.

required
Source code in d9d/dataset/sharded.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
        self,
        dataset: Dataset[_T_co],
        total_shards: int,
        current_shard: int,
        indexing_mode: ShardIndexingMode,
        pad_to_equal_size_across_shards: bool
):
    """
    Constructs a ShardedDataset object.

    Args:
        dataset: The underlying dataset to shard.
        total_shards: The total number of shards (e.g., number of DP ranks).
        current_shard: The index of the current shard (e.g., current DP rank).
        indexing_mode: How indices are assigned to shards (sequential/round-robin or chunked).
        pad_to_equal_size_across_shards: If True, the length of the dataset will be padded
            so that all shards report the same length. The last standard element is repeated.
    """

    if not isinstance(dataset, Sized):
        raise ValueError("Dataset should implement __len__ method")

    self._dataset = dataset

    self._total_shards = total_shards
    self._current_shard = current_shard

    self._indexing_mode = indexing_mode
    self._pad_to_equal_size_across_shards = pad_to_equal_size_across_shards

__len__()

Returns the number of items in this specific shard.

If pad_to_equal_size_across_shards is True, this returns the ceiling length (max length across all shards).

Source code in d9d/dataset/sharded.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __len__(self) -> int:
    """
    Returns the number of items in this specific shard.

    If `pad_to_equal_size_across_shards` is True, this returns the ceiling
    length (max length across all shards).
    """

    ceil_len = math.ceil(len(self._dataset) / self._total_shards)

    if self._pad_to_equal_size_across_shards:
        return ceil_len

    shards_remainder = len(self._dataset) % self._total_shards
    match self._indexing_mode:
        case ShardIndexingMode.sequential:
            shards_full = len(self._dataset) // self._total_shards
            return shards_full + 1 if self._current_shard < shards_remainder else shards_full
        case ShardIndexingMode.chunked:
            is_shard_last = self._current_shard == self._total_shards - 1
            if not is_shard_last or shards_remainder == 0:
                return ceil_len
            else:
                return ceil_len - (self._total_shards - shards_remainder)