Skip to content

Datasets

About

The d9d.dataset package provides specialized PyTorch Dataset wrappers designed for distributed training scenarios.

Core Concepts

Why Not Auto-Wrap Datasets Automatically?

d9d provides explicit composable wrappers rather than relying on implicit "magic" or automatic Sampler injection often found in other frameworks.

  • Flexible Composition and Order-of-Operations: The behavior of a data pipeline changes significantly depending on the order of composition. By stacking wrappers manually, you control the data flow logic:

  • Granular Configuration: Different datasets have different physical constraints that require specific configurations. A dataset loaded from network storage might require contiguous reads to be performant (ShardIndexingMode.chunked), while an in-memory dataset might prefer round-robin access (ShardIndexingMode.sequential). Explicit wrappers ensure that these configuration options are exposed to the user rather than buried in global trainer arguments.

Features

Smart Bucketing

In NLP and Sequence processing, batches often contain items of varying lengths. Standard random sampling forces the batch to be padded to the length of the longest sequence, wasting computational resources on padding tokens.

BufferSortedDataset implements a "Smart Bucketing" strategy to balance efficiency and statistical variance. It ensures that items within a specific micro-batch have similar lengths (minimizing padding), while preventing the data stream from becoming strictly deterministic or sorted.

Usage Example

To use BufferSortedDataset, your underlying dataset must implement the DatasetImplementingSortKeyProtocol (i.e., it must have a sort_key(index) method).

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from d9d.dataset import BufferSortedDataset

class MyTextDataset(Dataset):
    def __init__(self, data: list[str]):
        self.data = data # list of strings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    # ! You need to implement this one \/ !
    def sort_key(self, index):
        return len(self.data[index])

# Create Base Dataset (Ideally globally shuffled beforehand)
raw_data = ["short", "very very very long phrase", "tiny", "medium size"] * 100
base_ds = MyTextDataset(raw_data)

# Wrap with Smart Bucketing
# - buffer_size=100: Look at 100 items at a time to find similar lengths
# - pack_size=4: Group them into batches of 4
sorted_ds = BufferSortedDataset(
    base_dataset=base_ds,
    buffer_size=100,
    pack_size=4,
    init_seed=42
)

Sharding

When using Data Parallelism, each GPU processes a subset of the data. ShardedDataset provides a deterministic view of a specific shard of the data based on the rank (shard ID).

It supports:

  • Sequential Sharding: Round-robin distribution (0, 4, 8... for rank 0).
  • Chunked Sharding: Contiguous blocks (0, 1, 2... for rank 0).
  • Optional Padding: Ensuring all shards have exactly the same length. This is critical for distributed training loops where uneven dataset sizes can cause process hangs.

Usage Example (for Data Parallel)

import torch
from torch.utils.data import TensorDataset
from d9d.core.dist_context import DistributedContext, BATCH_DOMAIN
from d9d.dataset import shard_dataset_data_parallel, ShardIndexingMode

# You can infer your Data Parallel size and rank from DistributedContext object 
context: DistributedContext

# Create Full Dataset
base_ds = TensorDataset(torch.randn(100, 10))

# Shard it
sharded_ds = shard_dataset_data_parallel(
    dataset=base_ds,
    dist_context=context,
    # Optional Parameters:
    indexing_mode=ShardIndexingMode.chunked,
    pad_to_equal_size_across_shards=True 
)

print(f"I am rank {dp_rank}, I see {len(sharded_ds)} items.")

Usage Example (Manual)

import torch
from torch.utils.data import TensorDataset
from d9d.core.dist_context import DistributedContext, BATCH_DOMAIN
from d9d.dataset import ShardedDataset, ShardIndexingMode

# You can infer your Data Parallel size and rank from DistributedContext object 
context: DistributedContext
batch_mesh = context.mesh_for(BATCH_DOMAIN)
dp_size = batch_mesh.size('dp')
dp_rank = batch_mesh.get_local_rank('dp')

# Create Full Dataset
base_ds = TensorDataset(torch.randn(100, 10))

# Shard it
sharded_ds = ShardedDataset(
    dataset=base_ds,
    total_shards=dp_size,
    current_shard=dp_rank,
    indexing_mode=ShardIndexingMode.chunked,
    # Crucial for preventing distributed hangs
    pad_to_equal_size_across_shards=True 
)

print(f"I am rank {dp_rank}, I see {len(sharded_ds)} items.")

Padding Utilities

When creating batches from variable-length sequences, tensors must be padded to the same length to form a valid tensor stack.

pad_stack_1d provides a robust utility for this, specifically designed to help writing collate_fn.

Usage Example

import torch
from d9d.dataset import pad_stack_1d, PaddingSide1D

# Variable length sequences
items = [
    torch.tensor([1, 2, 3]),
    torch.tensor([4]),
    torch.tensor([5, 6])
]

# 1. Standard Right Padding
batch = pad_stack_1d(items, pad_value=0, padding_side=PaddingSide1D.right)

# 2. Left Padding 
batch_gen = pad_stack_1d(items, pad_value=0, padding_side=PaddingSide1D.left)

# 3. Aligned Padding
# Ensures the dimensions are friendly to GPU kernels or for Context Parallel sharding
batch_aligned = pad_stack_1d(
    items, 
    pad_value=0, 
    pad_to_multiple_of=8
)

d9d.dataset

This package provides utilities and torch.utils.data.Dataset implementations.

BufferSortedDataset

Bases: Dataset[_T_co], Stateful

A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.

This prevents extreme padding in variable-length training (by grouping similar lengths) while maintaining enough randomness to ensure statistical variance in updates.

Algorithm:

  1. Select a range of indices (size buffer_size).
  2. Generate sort keys: (base_dataset.sort_key(), random_tie_breaker).
  3. Sort indices by this tuple.
  4. Group sorted list into packs of size pack_size.
  5. Shuffle the order of these packs (inter-pack shuffle).
  6. Shuffle the items within these packs (intra-pack shuffle).
  7. Flatten and serve.

__init__(base_dataset, buffer_size, pack_size, init_seed=None)

Parameters:

Name Type Description Default
base_dataset DatasetImplementingSortKeyProtocol[_T_co]

The underlying dataset.

required
buffer_size int

The number of items to load into the buffer for sorting.

required
pack_size int

The size of local groups (batches/micro-batches).

required
init_seed int | None

Seed for the random number generator.

None

DatasetImplementingSortKeyProtocol

Bases: Protocol[_T_co]

Protocol for datasets that support retrieval of a specific key for sorting purposes.

This is typically used for length-based bucketing/sorting where the dataset needs to expose the length of an item without loading the full item.

__getitem__(item)

Retrieves the item at the specific index.

__len__()

Returns the total number of items in the dataset.

sort_key(index)

Returns a value used for sorting the dataset at the given index.

Parameters:

Name Type Description Default
index int

The index of the item.

required

Returns:

Type Description
Any

A comparable value (e.g., int length) used for sorting.

PaddingSide1D

Bases: StrEnum

Enum specifying the side for padding 1D sequences.

Attributes:

Name Type Description
left

Pad on the left side.

right

Pad on the right side.

ShardIndexingMode

Bases: StrEnum

Defines how the dataset is split across shards.

Modes

sequential: Round-robin distribution.

1
2
3
4
shard0: 0, 4, 8, 12
shard1: 1, 5, 9, 13
shard2: 2, 6, 10
shard3: 3, 7, 11

chunked: Contiguous blocks.

1
2
3
4
shard0: 0, 1, 2, 3
shard1: 4, 5, 6, 7
shard2: 8, 9, 10, 11
shard3: 12, 13

ShardedDataset

Bases: Dataset[_T_co], Stateful

A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.

This is useful for Data Parallel training where each process should only see a subset of the data. It supports different indexing modes and optional padding to ensure all shards have equal length (preventing hangs in distributed collectives).

__getitem__(index)

Retrieves an item from the underlying dataset mapping logic shard index to physical index.

If padding is enabled and the index exceeds the valid data for this shard, the last item in the dataset is returned.

Parameters:

Name Type Description Default
index int

The index relative to this shard.

required

Returns:

Type Description
_T_co

The data item.

__init__(dataset, total_shards, current_shard, indexing_mode, pad_to_equal_size_across_shards)

Constructs a ShardedDataset object.

Parameters:

Name Type Description Default
dataset Dataset[_T_co]

The underlying dataset to shard.

required
total_shards int

The total number of shards (e.g., number of DP ranks).

required
current_shard int

The index of the current shard (e.g., current DP rank).

required
indexing_mode ShardIndexingMode

How indices are assigned to shards (sequential/round-robin or chunked).

required
pad_to_equal_size_across_shards bool

If True, the length of the dataset will be padded so that all shards report the same length. The last standard element is repeated.

required

__len__()

Returns the number of items in this specific shard.

If pad_to_equal_size_across_shards is True, this returns the ceiling length (max length across all shards).

TokenPoolingType

Bases: StrEnum

Enumeration of supported token pooling strategies.

Attributes:

Name Type Description
first

Selects the first token of the sequence (e.g., [CLS] token).

last

Selects the last non-padding token of the sequence (e.g., for Transformer Decoder).

all

Selects all non-padding tokens (e.g., for mean pooling).

pad_stack_1d(items, pad_value, padding_side=PaddingSide1D.right, pad_to_multiple_of=None)

Stacks 1D tensors into a batch, applying padding.

Calculates the maximum length among the input tensors (optionally aligning to a multiple), pads elements to match this length on the specified side, and stacks them.

Parameters:

Name Type Description Default
items Sequence[Tensor]

A sequence of 1D tensors to be stacked.

required
pad_value int

The value used for padding.

required
padding_side PaddingSide1D

The side on which to apply padding (left or right).

right
pad_to_multiple_of int | None

Optional integer. If provided, ensures the target length is a multiple of this value.

None

Returns:

Type Description
Tensor

A single stacked tensor of shape (batch, max_length).

Raises:

Type Description
ValueError

If no items are provided or if pad_to_multiple_of is <= 0.

shard_dataset_data_parallel(dataset, dist_context, indexing_mode=ShardIndexingMode.sequential, pad_to_equal_size_across_shards=True)

Wraps a dataset into a ShardedDataset based on the Data Parallel dimension of the distributed context.

This is a helper function to automatically determine the correct rank and world size from the 'dp' (Data Parallel) mesh dimension within the batch domain DeviceMesh.

Parameters:

Name Type Description Default
dataset Dataset[_T_co]

The source dataset to shard.

required
dist_context DistributedContext

The distributed context.

required
indexing_mode ShardIndexingMode

The strategy for splitting data indices (sequential/round-robin or chunked).

sequential
pad_to_equal_size_across_shards bool

If True, ensures all shards have the same length by padding.

True

Returns:

Type Description
Dataset[_T_co]

A dataset instance representing the local shard.

token_pooling_mask_from_attention_mask(attention_mask, pooling_type)

Generates a binary mask for token pooling based on the specified strategy.

Parameters:

Name Type Description Default
attention_mask Tensor

A binary mask indicating valid tokens (1) and padding (0). Expected shape is (batch_size, sequence_length).

required
pooling_type TokenPoolingType

The strategy to use for selecting tokens.

required

Returns:

Type Description
Tensor

A LongTensor of the same shape as input containing 1s at positions

Tensor

to be included in pooling and 0s elsewhere.

Raises:

Type Description
ValueError

If the provided pooling type is not supported.