About
The d9d.dataset package provides specialized PyTorch Dataset wrappers designed for distributed training scenarios.
Core Concepts
Why Not Auto-Wrap Datasets Automatically?
d9d provides explicit composable wrappers rather than relying on implicit "magic" or automatic Sampler injection often found in other frameworks.
-
Flexible Composition and Order-of-Operations: The behavior of a data pipeline changes significantly depending on the order of composition. By stacking wrappers manually, you control the data flow logic:
-
Granular Configuration: Different datasets have different physical constraints that require specific configurations. A dataset loaded from network storage might require contiguous reads to be performant (
ShardIndexingMode.chunked), while an in-memory dataset might prefer round-robin access (ShardIndexingMode.sequential). Explicit wrappers ensure that these configuration options are exposed to the user rather than buried in global trainer arguments.
Features
Smart Bucketing
In NLP and Sequence processing, batches often contain items of varying lengths. Standard random sampling forces the batch to be padded to the length of the longest sequence, wasting computational resources on padding tokens.
BufferSortedDataset implements a "Smart Bucketing" strategy that ensures that items within a specific micro-batch have similar lengths (minimizing padding), while the order of micro-batches remains random enough to preserve training stability.
Usage Example
To use BufferSortedDataset, your underlying dataset must implement the DatasetImplementingSortKeyProtocol (i.e., it must have a sort_key(index) method).
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from d9d.dataset import BufferSortedDataset
class MyTextDataset(Dataset):
def __init__(self, data: list[str]):
self.data = data # list of strings
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# ! You need to implement this one \/ !
def sort_key(self, index):
return len(self.data[index])
# Create Base Dataset
raw_data = ["short", "very very very long phrase", "tiny", "medium size"] * 100
base_ds = MyTextDataset(raw_data)
# Wrap with Smart Bucketing
# - buffer_size=100: Look at 100 items at a time to find similar lengths
# - pack_size=4: Group them into batches of 4
sorted_ds = BufferSortedDataset(
base_dataset=base_ds,
buffer_size=100,
pack_size=4,
init_seed=42
)
Sharding
When using Data Parallelism, each GPU processes a subset of the data. ShardedDataset provides a deterministic view of a specific shard of the data based on the rank (shard ID).
It supports:
- Sequential Sharding: Round-robin distribution (
0, 4, 8...for rank 0). - Chunked Sharding: Contiguous blocks (
0, 1, 2...for rank 0). - Optional Padding: Ensuring all shards have exactly the same length. This is critical for distributed training loops where uneven dataset sizes can cause process hangs.
Usage Example
import torch
from torch.utils.data import TensorDataset
from d9d.core.dist_context import DistributedContext, BATCH_DOMAIN
from d9d.dataset import ShardedDataset, ShardIndexingMode
# You can infer your Data Parallel size and rank from DistributedContext object
context: DistributedContext
batch_mesh = context.mesh_for(BATCH_DOMAIN)
dp_size = batch_mesh.size('dp')
dp_rank = batch_mesh.get_local_rank('dp')
# Create Full Dataset
base_ds = TensorDataset(torch.randn(100, 10))
# Shard it
sharded_ds = ShardedDataset(
dataset=base_ds,
total_shards=dp_size,
current_shard=dp_rank,
indexing_mode=ShardIndexingMode.chunked,
# Crucial for preventing distributed hangs
pad_to_equal_size_across_shards=True
)
print(f"I am rank {dp_rank}, I see {len(sharded_ds)} items.")
d9d.dataset
This package provides utilities and torch.utils.data.Dataset implementations.
BufferSortedDataset
Bases: Dataset[_T_co], Stateful
A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.
This prevents extreme padding in variable-length training (by grouping similar lengths) while maintaining enough randomness to ensure statistical variance in updates.
Algorithm:
- Select a range of indices (size
buffer_size). - Sort these indices based on
base_dataset.sort_key(). - Break the sorted list into packs of size
pack_size. - Shuffle the order of these packs.
- Flatten the list and serve items.
Source code in d9d/dataset/buffer_sorted.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | |
__getitem__(index)
Retrieves an item from the locally sorted/shuffled buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
The global index. |
required |
Returns:
| Type | Description |
|---|---|
_T_co
|
The dataset item. |
Source code in d9d/dataset/buffer_sorted.py
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
__init__(base_dataset, buffer_size, pack_size, init_seed=None)
Constructs a BufferSortedDataset object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_dataset
|
DatasetImplementingSortKeyProtocol[_T_co]
|
The underlying dataset implementing the |
required |
buffer_size
|
int
|
The number of items to load into the buffer for sorting. |
required |
pack_size
|
int
|
The size of local groups (batches/micro-batches) that remain contiguous after sorting, but are shuffled relative to other packs. |
required |
init_seed
|
int | None
|
Seed for the random number generator used for shuffling packs. |
None
|
Source code in d9d/dataset/buffer_sorted.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | |
__len__()
Returns the length of the base dataset.
Source code in d9d/dataset/buffer_sorted.py
123 124 125 126 | |
DatasetImplementingSortKeyProtocol
Bases: Protocol[_T_co]
Protocol for datasets that support retrieval of a specific key for sorting purposes.
This is typically used for length-based bucketing/sorting where the dataset needs to expose the length of an item without loading the full item.
Source code in d9d/dataset/buffer_sorted.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | |
__getitem__(item)
Retrieves the item at the specific index.
Source code in d9d/dataset/buffer_sorted.py
35 36 37 | |
__len__()
Returns the total number of items in the dataset.
Source code in d9d/dataset/buffer_sorted.py
19 20 21 | |
sort_key(index)
Returns a value used for sorting the dataset at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
The index of the item. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
A comparable value (e.g., int length) used for sorting. |
Source code in d9d/dataset/buffer_sorted.py
23 24 25 26 27 28 29 30 31 32 33 | |
ShardIndexingMode
Bases: StrEnum
Defines how the dataset is split across shards.
Modes
sequential: Round-robin distribution.
shard0: 0, 4, 8, 12
shard1: 1, 5, 9, 13
shard2: 2, 6, 10
shard3: 3, 7, 11
chunked: Contiguous blocks.
shard0: 0, 1, 2, 3
shard1: 4, 5, 6, 7
shard2: 8, 9, 10, 11
shard3: 12, 13
Source code in d9d/dataset/sharded.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |
ShardedDataset
Bases: Dataset[_T_co], Stateful
A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.
This is useful for Data Parallel training where each process should only see a subset of the data. It supports different indexing modes and optional padding to ensure all shards have equal length (preventing hangs in distributed collectives).
Source code in d9d/dataset/sharded.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | |
__getitem__(index)
Retrieves an item from the underlying dataset mapping logic shard index to physical index.
If padding is enabled and the index exceeds the valid data for this shard, the last item in the dataset is returned.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
The index relative to this shard. |
required |
Returns:
| Type | Description |
|---|---|
_T_co
|
The data item. |
Source code in d9d/dataset/sharded.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | |
__init__(dataset, total_shards, current_shard, indexing_mode, pad_to_equal_size_across_shards)
Constructs a ShardedDataset object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
Dataset[_T_co]
|
The underlying dataset to shard. |
required |
total_shards
|
int
|
The total number of shards (e.g., number of DP ranks). |
required |
current_shard
|
int
|
The index of the current shard (e.g., current DP rank). |
required |
indexing_mode
|
ShardIndexingMode
|
How indices are assigned to shards (sequential/round-robin or chunked). |
required |
pad_to_equal_size_across_shards
|
bool
|
If True, the length of the dataset will be padded so that all shards report the same length. The last standard element is repeated. |
required |
Source code in d9d/dataset/sharded.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | |
__len__()
Returns the number of items in this specific shard.
If pad_to_equal_size_across_shards is True, this returns the ceiling
length (max length across all shards).
Source code in d9d/dataset/sharded.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | |