Datasets
About
The d9d.dataset package provides specialized PyTorch Dataset wrappers designed for distributed training scenarios.
Core Concepts
Why Not Auto-Wrap Datasets Automatically?
d9d provides explicit composable wrappers rather than relying on implicit "magic" or automatic Sampler injection often found in other frameworks.
-
Flexible Composition and Order-of-Operations: The behavior of a data pipeline changes significantly depending on the order of composition. By stacking wrappers manually, you control the data flow logic:
-
Granular Configuration: Different datasets have different physical constraints that require specific configurations. A dataset loaded from network storage might require contiguous reads to be performant (
ShardIndexingMode.chunked), while an in-memory dataset might prefer round-robin access (ShardIndexingMode.sequential). Explicit wrappers ensure that these configuration options are exposed to the user rather than buried in global trainer arguments.
Features
Smart Bucketing
In NLP and Sequence processing, batches often contain items of varying lengths. Standard random sampling forces the batch to be padded to the length of the longest sequence, wasting computational resources on padding tokens.
BufferSortedDataset implements a "Smart Bucketing" strategy to balance efficiency and statistical variance. It ensures that items within a specific micro-batch have similar lengths (minimizing padding), while preventing the data stream from becoming strictly deterministic or sorted.
Usage Example
To use BufferSortedDataset, your underlying dataset must implement the DatasetImplementingSortKeyProtocol (i.e., it must have a sort_key(index) method).
Sharding
When using Data Parallelism, each GPU processes a subset of the data. ShardedDataset provides a deterministic view of a specific shard of the data based on the rank (shard ID).
It supports:
- Sequential Sharding: Round-robin distribution (
0, 4, 8...for rank 0). - Chunked Sharding: Contiguous blocks (
0, 1, 2...for rank 0). - Optional Padding: Ensuring all shards have exactly the same length. This is critical for distributed training loops where uneven dataset sizes can cause process hangs.
Usage Example (for Data Parallel)
Usage Example (Manual)
Padding Utilities
When creating batches from variable-length sequences, tensors must be padded to the same length to form a valid tensor stack.
pad_stack_1d provides a robust utility for this, specifically designed to help writing collate_fn.
Usage Example
d9d.dataset
This package provides utilities and torch.utils.data.Dataset implementations.
BufferSortedDataset
Bases: Dataset[_T_co], Stateful
A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.
This prevents extreme padding in variable-length training (by grouping similar lengths) while maintaining enough randomness to ensure statistical variance in updates.
Algorithm:
- Select a range of indices (size
buffer_size). - Generate sort keys: (base_dataset.sort_key(), random_tie_breaker).
- Sort indices by this tuple.
- Group sorted list into packs of size
pack_size. - Shuffle the order of these packs (inter-pack shuffle).
- Shuffle the items within these packs (intra-pack shuffle).
- Flatten and serve.
__init__(base_dataset, buffer_size, pack_size, init_seed=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_dataset
|
DatasetImplementingSortKeyProtocol[_T_co]
|
The underlying dataset. |
required |
buffer_size
|
int
|
The number of items to load into the buffer for sorting. |
required |
pack_size
|
int
|
The size of local groups (batches/micro-batches). |
required |
init_seed
|
int | None
|
Seed for the random number generator. |
None
|
DatasetImplementingSortKeyProtocol
Bases: Protocol[_T_co]
Protocol for datasets that support retrieval of a specific key for sorting purposes.
This is typically used for length-based bucketing/sorting where the dataset needs to expose the length of an item without loading the full item.
__getitem__(item)
Retrieves the item at the specific index.
__len__()
Returns the total number of items in the dataset.
PaddingSide1D
Bases: StrEnum
Enum specifying the side for padding 1D sequences.
Attributes:
| Name | Type | Description |
|---|---|---|
left |
Pad on the left side. |
|
right |
Pad on the right side. |
ShardIndexingMode
Bases: StrEnum
Defines how the dataset is split across shards.
Modes
sequential: Round-robin distribution.
1 2 3 4 | |
chunked: Contiguous blocks.
1 2 3 4 | |
ShardedDataset
Bases: Dataset[_T_co], Stateful
A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.
This is useful for Data Parallel training where each process should only see a subset of the data. It supports different indexing modes and optional padding to ensure all shards have equal length (preventing hangs in distributed collectives).
__getitem__(index)
Retrieves an item from the underlying dataset mapping logic shard index to physical index.
If padding is enabled and the index exceeds the valid data for this shard, the last item in the dataset is returned.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
The index relative to this shard. |
required |
Returns:
| Type | Description |
|---|---|
_T_co
|
The data item. |
__init__(dataset, total_shards, current_shard, indexing_mode, pad_to_equal_size_across_shards)
Constructs a ShardedDataset object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
Dataset[_T_co]
|
The underlying dataset to shard. |
required |
total_shards
|
int
|
The total number of shards (e.g., number of DP ranks). |
required |
current_shard
|
int
|
The index of the current shard (e.g., current DP rank). |
required |
indexing_mode
|
ShardIndexingMode
|
How indices are assigned to shards (sequential/round-robin or chunked). |
required |
pad_to_equal_size_across_shards
|
bool
|
If True, the length of the dataset will be padded so that all shards report the same length. The last standard element is repeated. |
required |
__len__()
Returns the number of items in this specific shard.
If pad_to_equal_size_across_shards is True, this returns the ceiling
length (max length across all shards).
TokenPoolingType
Bases: StrEnum
Enumeration of supported token pooling strategies.
Attributes:
| Name | Type | Description |
|---|---|---|
first |
Selects the first token of the sequence (e.g., [CLS] token). |
|
last |
Selects the last non-padding token of the sequence (e.g., for Transformer Decoder). |
|
all |
Selects all non-padding tokens (e.g., for mean pooling). |
pad_stack_1d(items, pad_value, padding_side=PaddingSide1D.right, pad_to_multiple_of=None)
Stacks 1D tensors into a batch, applying padding.
Calculates the maximum length among the input tensors (optionally aligning to a multiple), pads elements to match this length on the specified side, and stacks them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
Sequence[Tensor]
|
A sequence of 1D tensors to be stacked. |
required |
pad_value
|
int
|
The value used for padding. |
required |
padding_side
|
PaddingSide1D
|
The side on which to apply padding (left or right). |
right
|
pad_to_multiple_of
|
int | None
|
Optional integer. If provided, ensures the target length is a multiple of this value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
A single stacked tensor of shape (batch, max_length). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no items are provided or if |
shard_dataset_data_parallel(dataset, dist_context, indexing_mode=ShardIndexingMode.sequential, pad_to_equal_size_across_shards=True)
Wraps a dataset into a ShardedDataset based on the Data Parallel dimension of the distributed context.
This is a helper function to automatically determine the correct rank and world size from the 'dp' (Data Parallel) mesh dimension within the batch domain DeviceMesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
Dataset[_T_co]
|
The source dataset to shard. |
required |
dist_context
|
DistributedContext
|
The distributed context. |
required |
indexing_mode
|
ShardIndexingMode
|
The strategy for splitting data indices (sequential/round-robin or chunked). |
sequential
|
pad_to_equal_size_across_shards
|
bool
|
If True, ensures all shards have the same length by padding. |
True
|
Returns:
| Type | Description |
|---|---|
Dataset[_T_co]
|
A dataset instance representing the local shard. |
token_pooling_mask_from_attention_mask(attention_mask, pooling_type)
Generates a binary mask for token pooling based on the specified strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attention_mask
|
Tensor
|
A binary mask indicating valid tokens (1) and padding (0). Expected shape is (batch_size, sequence_length). |
required |
pooling_type
|
TokenPoolingType
|
The strategy to use for selecting tokens. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A LongTensor of the same shape as input containing 1s at positions |
Tensor
|
to be included in pooling and 0s elsewhere. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the provided pooling type is not supported. |