Welcome back, future LLM master! In the previous chapters, we laid the groundwork by understanding Tunix’s architecture and setting up our development environment. Now, it’s time to talk about the fuel that powers any Large Language Model: data!

This chapter is all about getting your data ready for Tunix. We’ll dive deep into the crucial steps of preparing your text-based datasets, understanding how to tokenize them, and setting up efficient data loading pipelines that play nicely with JAX and Tunix. Think of this as preparing a delicious meal – you need to carefully select, clean, and chop your ingredients before you can even think about cooking!

Why does this matter so much? Because the quality and format of your input data directly impact the performance and success of your LLM post-training. Poorly prepared data can lead to models that don’t learn effectively, exhibit biases, or simply fail to train. By the end of this chapter, you’ll have a solid grasp of how to transform raw text into the structured, optimized format Tunix expects, ensuring your models get the best possible start.

Ready to dig into the data? Let’s go!

Core Concepts: The Pillars of Data Preparation

Before we start writing any code, let’s understand the fundamental concepts that underpin data preparation for LLMs in the JAX ecosystem.

What is “Data Preparation” for LLMs?

At its heart, data preparation for LLMs involves transforming raw, human-readable text into a numerical format that a machine learning model can understand and process. This isn’t just about simple conversion; it’s a multi-step process that includes:

  1. Cleaning: Removing irrelevant characters, formatting issues, or noise.
  2. Structuring: Organizing data into examples suitable for training (e.g., input-output pairs for instruction tuning).
  3. Tokenization: Breaking down text into discrete units (tokens) and mapping them to numerical IDs.
  4. Padding & Truncation: Ensuring all input sequences have a consistent length.
  5. Batching: Grouping multiple processed examples into batches for efficient parallel processing on accelerators like GPUs/TPUs.

Tokenization: Speaking the Model’s Language

Imagine you’re trying to teach a friend a new language, but instead of words, you teach them individual letters or parts of words. That’s somewhat analogous to tokenization. LLMs don’t understand human words directly; they understand numbers. Tokenization is the process of converting text into sequences of numerical IDs, where each ID represents a “token.”

A tokenizer is a special component that performs this conversion. It typically does the following:

  • Splits text: It breaks down sentences into subword units (e.g., “unbelievable” might become “un”, “believe”, “able”). This helps handle rare words and reduces vocabulary size.
  • Maps to IDs: Each unique token is assigned a unique integer ID from the tokenizer’s vocabulary.
  • Adds special tokens: Tokens like [CLS] (classification token), [SEP] (separator token), [PAD] (padding token), and [UNK] (unknown token) are often added for specific model architectures or tasks.

For example, the sentence “Hello world!” might become: [101, 7592, 2088, 106, 102] (where 101 is [CLS], 7592 is Hello, 2088 is world, 106 is !, and 102 is [SEP]).

The transformers library from Hugging Face provides a rich ecosystem of pre-trained tokenizers that are compatible with various LLMs. We’ll be leveraging these.

Dataset Abstraction: Efficiently Managing Your Data

When dealing with potentially massive datasets, you can’t just load everything into memory at once. You need a system that can efficiently stream data, shuffle it, batch it, and apply transformations without overwhelming your system. This is where dataset abstraction libraries come in.

  • Hugging Face datasets library: This library provides a unified way to load and process many publicly available datasets, and also to easily create custom datasets from various file formats (CSV, JSON, text). It’s highly optimized and integrates well with transformers.
  • tf.data (from TensorFlow): While Tunix is JAX-native, tf.data is an incredibly powerful and widely used library for building high-performance input pipelines. It can be used to prepare data that is then consumed by JAX models. It excels at handling large datasets, prefetching, and parallel processing. Tunix often expects JAX arrays, and tf.data can seamlessly deliver data in that format.

A typical data pipeline might look like this:

flowchart LR A[Raw Text Data] -->|Load| B[Hugging Face Dataset] B -->|Tokenize| C[Tokenized Dataset] C -->|Format for JAX/TF.Data| D[tf.data.Dataset] D -->|Shuffle & Batch| E[Batched Data] E -->|Feed to Tunix| F[Tunix Model]

This diagram illustrates the journey of your raw text data, from its initial loading into a Hugging Face Dataset, through tokenization, formatting for tf.data, and finally, into batched JAX arrays ready for your Tunix model.

Tunix’s Data Expectations

Tunix, being JAX-native, expects data primarily as JAX arrays. When you provide data to a Tunix training loop, it will typically be in the form of a dictionary where keys map to input names (e.g., input_ids, attention_mask, labels) and values are JAX arrays of the appropriate shape and dtype.

For example, a single batch might look like:

{
    'input_ids': jax.Array([[...], [...]]),  # Shape: (batch_size, sequence_length)
    'attention_mask': jax.Array([[...], [...]]), # Shape: (batch_size, sequence_length)
    'labels': jax.Array([...]) # Shape: (batch_size,) or (batch_size, sequence_length)
}

The exact keys and shapes depend on your model and the specific post-training task.

Step-by-Step Implementation: Building a Data Pipeline

Let’s put these concepts into practice. We’ll create a simple, synthetic dataset and build a data pipeline to prepare it for Tunix.

Step 1: Install Necessary Libraries

First, ensure you have the required libraries. We’ll need transformers for tokenization, datasets for easy data handling, and tensorflow for tf.data to build our efficient pipeline.

Open your terminal or notebook and run:

pip install transformers==4.36.0 datasets==2.16.1 tensorflow==2.15.0 jax==0.4.23 jaxlib==0.4.23 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Why these versions? As of 2026-01-30, these are stable and widely used versions known to work well together. jaxlib needs to be installed with a specific URL to match your CUDA version if you’re using a GPU. The -f flag helps pip find the correct jaxlib wheel. If you have an NVIDIA GPU, you’d typically install jaxlib with CUDA support (e.g., jaxlib[cuda12_pip]). For simplicity, we’re using a generic JAX installation here.

Step 2: Create a Synthetic Dataset

Instead of loading a massive dataset, let’s create a small, in-memory dataset to demonstrate the process. This will represent some raw text we want to use for post-training.

# data_prep_tunix.py
from datasets import Dataset
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
import tensorflow as tf

# Suppress TensorFlow logging to keep output clean
tf.get_logger().setLevel('ERROR')

print(f"JAX version: {jax.__version__}")
print(f"TensorFlow version: {tf.__version__}")
print(f"Transformers version: {AutoTokenizer.__version__}")
print(f"Datasets version: {Dataset.__version__}")

# Our synthetic raw text data
raw_data = {
    "text": [
        "Tunix is a JAX-native library for LLM post-training.",
        "It focuses on efficient and scalable model alignment.",
        "Data preparation is a critical first step for successful training.",
        "Let's learn how to create a robust data pipeline!",
        "JAX provides high-performance numerical computing."
    ]
}

# Convert to a Hugging Face Dataset object
hf_dataset = Dataset.from_dict(raw_data)

print("\n--- Raw Hugging Face Dataset ---")
print(hf_dataset)
print(hf_dataset[0])

Explanation:

  1. We import Dataset from datasets, AutoTokenizer from transformers, and tensorflow for tf.data. We also import jax and jax.numpy as these will be our target data types.
  2. tf.get_logger().setLevel('ERROR') helps quiet down TensorFlow’s verbose output, which is useful when you’re primarily using JAX but leveraging tf.data.
  3. We define raw_data as a dictionary, which is a common format for creating a Dataset from scratch. Each entry in the “text” list is a separate example.
  4. Dataset.from_dict(raw_data) converts our dictionary into a Dataset object, which offers convenient methods for processing.
  5. We print the dataset and its first example to show its structure.

Step 3: Load a Tokenizer

Now, let’s load a pre-trained tokenizer. For general LLM tasks, a tokenizer like bert-base-uncased or google/flan-t5-small is a good starting point. Here, we’ll use a commonly available BERT tokenizer.

Add this code to data_prep_tunix.py:

# Choose a pre-trained tokenizer
# 'bert-base-uncased' is a good general-purpose tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

print(f"\n--- Tokenizer Loaded: {tokenizer.name_or_path} ---")
print(f"Tokenizer vocabulary size: {len(tokenizer)}")

Explanation:

  1. AutoTokenizer.from_pretrained("bert-base-uncased") downloads and loads the tokenizer associated with the bert-base-uncased model. AutoTokenizer is a smart class that figures out the correct tokenizer class based on the model name.
  2. We print the tokenizer’s name and vocabulary size to confirm it loaded correctly.

Step 4: Tokenize the Dataset

Next, we’ll apply our tokenizer to the hf_dataset. This is where the magic of converting text to numerical IDs happens.

Add this code to data_prep_tunix.py:

# Define a tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,      # Truncate sequences longer than max_length
        padding="max_length", # Pad shorter sequences to max_length
        max_length=128        # Set a fixed maximum sequence length
    )

# Apply the tokenization function to the dataset
tokenized_hf_dataset = hf_dataset.map(
    tokenize_function,
    batched=True,        # Process multiple examples at once for efficiency
    remove_columns=["text"] # Remove the original 'text' column
)

print("\n--- Tokenized Hugging Face Dataset ---")
print(tokenized_hf_dataset)
print(tokenized_hf_dataset[0])

Explanation:

  1. tokenize_function(examples): This function takes a batch of examples (a dictionary where examples["text"] is a list of strings) and applies the tokenizer.
    • truncation=True: If a sentence is longer than max_length, it will be cut off.
    • padding="max_length": If a sentence is shorter than max_length, it will be padded with special padding tokens (typically ID 0) to reach max_length. This is crucial for creating uniform batches.
    • max_length=128: We set a maximum sequence length. This should be chosen based on your model’s capabilities and the typical length of your input texts.
  2. hf_dataset.map(...): This method applies our tokenize_function to every example in the dataset.
    • batched=True: This tells the map function to pass multiple examples to tokenize_function at once, which is significantly faster.
    • remove_columns=["text"]: After tokenization, we no longer need the raw text, so we remove this column to save memory.
  3. We print the tokenized dataset and its first example. Notice the new input_ids, token_type_ids, and attention_mask columns.
    • input_ids: The numerical IDs of the tokens.
    • token_type_ids: Used by some models (like BERT) to distinguish between different segments in a sequence (e.g., question and answer).
    • attention_mask: A binary mask (1 for real tokens, 0 for padding tokens) that tells the model which tokens to pay attention to.

Step 5: Prepare for tf.data and JAX

While tokenized_hf_dataset is great, we need to convert it into a tf.data.Dataset for efficient batching and then ensure the elements are JAX-compatible.

Add this code to data_prep_tunix.py:

# Set the format of the dataset to TensorFlow tensors
tokenized_hf_dataset.set_format("tf")

# Create a tf.data.Dataset from the tokenized Hugging Face Dataset
# We use .to_tf_dataset() which is a convenient method provided by the datasets library
# The columns are already in the format Tunix expects: input_ids, attention_mask, token_type_ids
tf_dataset = tokenized_hf_dataset.to_tf_dataset(
    columns=["input_ids", "attention_mask", "token_type_ids"], # Features for the model
    # No 'label_cols' needed for this basic example, as we're not doing a specific task yet
    shuffle=True,
    batch_size=2, # Small batch size for demonstration
    # collate_fn=tf.data.default_collate_fn # Default collate function is usually fine
)

print("\n--- First Batch from tf.data.Dataset (TensorFlow Tensors) ---")
for batch in tf_dataset.take(1):
    # tf.data.Dataset yields TensorFlow tensors by default
    print("Input IDs shape:", batch["input_ids"].shape)
    print("Attention Mask shape:", batch["attention_mask"].shape)
    print("Token Type IDs shape:", batch["token_type_ids"].shape)
    # Convert TensorFlow tensors to JAX arrays for Tunix
    jax_batch = jax.tree_map(lambda x: jnp.asarray(x), batch)
    print("\n--- First Batch from tf.data.Dataset (Converted to JAX Arrays) ---")
    print("Input IDs (JAX array type):", type(jax_batch["input_ids"]))
    print("Attention Mask (JAX array type):", type(jax_batch["attention_mask"]))
    print("Token Type IDs (JAX array type):", type(jax_batch["token_type_ids"]))
    print("Example JAX input_ids:\n", jax_batch["input_ids"][0][:10]) # Print first 10 tokens of first example
    break # Only take one batch for printing

Explanation:

  1. tokenized_hf_dataset.set_format("tf"): This tells the Hugging Face dataset to return TensorFlow tensors when accessed, which is ideal for tf.data.
  2. tokenized_hf_dataset.to_tf_dataset(...): This is a very convenient method that directly converts the Hugging Face dataset into a tf.data.Dataset.
    • columns: Specifies which columns from the tokenized dataset should be considered as features for the model. Tunix models typically expect input_ids, attention_mask, and sometimes token_type_ids.
    • shuffle=True: Shuffles the data each epoch, which is important for robust training.
    • batch_size=2: We set a small batch size for demonstration purposes. In a real scenario, this would be larger (e.g., 8, 16, 32, 64) depending on your hardware.
  3. We then iterate through one batch of the tf_dataset.
    • Initially, tf_dataset yields TensorFlow tensors.
    • jax.tree_map(lambda x: jnp.asarray(x), batch): This is a powerful JAX utility that applies a function (jnp.asarray in this case) to every leaf (tensor) in a nested structure (our batch dictionary). This converts all TensorFlow tensors in the batch into JAX arrays, which is what Tunix expects.
    • We print the shapes and types to confirm the conversion.

Now, you have a tf_dataset that, when iterated, will provide batches of JAX arrays, perfectly formatted for consumption by Tunix!

Full data_prep_tunix.py

Here’s the complete script:

import jax
import jax.numpy as jnp
import tensorflow as tf
from datasets import Dataset
from transformers import AutoTokenizer

# Suppress TensorFlow logging to keep output clean
tf.get_logger().setLevel('ERROR')

print(f"JAX version: {jax.__version__}")
print(f"TensorFlow version: {tf.__version__}")
print(f"Transformers version: {AutoTokenizer.__version__}")
print(f"Datasets version: {Dataset.__version__}")

# --- Step 2: Create a Synthetic Dataset ---
raw_data = {
    "text": [
        "Tunix is a JAX-native library for LLM post-training.",
        "It focuses on efficient and scalable model alignment.",
        "Data preparation is a critical first step for successful training.",
        "Let's learn how to create a robust data pipeline!",
        "JAX provides high-performance numerical computing."
    ]
}
hf_dataset = Dataset.from_dict(raw_data)
print("\n--- Raw Hugging Face Dataset ---")
print(hf_dataset)
print(hf_dataset[0])

# --- Step 3: Load a Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(f"\n--- Tokenizer Loaded: {tokenizer.name_or_path} ---")
print(f"Tokenizer vocabulary size: {len(tokenizer)}")

# --- Step 4: Tokenize the Dataset ---
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

tokenized_hf_dataset = hf_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)
print("\n--- Tokenized Hugging Face Dataset ---")
print(tokenized_hf_dataset)
print(tokenized_hf_dataset[0])

# --- Step 5: Prepare for tf.data and JAX ---
tokenized_hf_dataset.set_format("tf")

tf_dataset = tokenized_hf_dataset.to_tf_dataset(
    columns=["input_ids", "attention_mask", "token_type_ids"],
    shuffle=True,
    batch_size=2,
)

print("\n--- First Batch from tf.data.Dataset (TensorFlow Tensors) ---")
for batch in tf_dataset.take(1):
    print("Input IDs shape:", batch["input_ids"].shape)
    print("Attention Mask shape:", batch["attention_mask"].shape)
    print("Token Type IDs shape:", batch["token_type_ids"].shape)
    
    jax_batch = jax.tree_map(lambda x: jnp.asarray(x), batch)
    
    print("\n--- First Batch from tf.data.Dataset (Converted to JAX Arrays) ---")
    print("Input IDs (JAX array type):", type(jax_batch["input_ids"]))
    print("Attention Mask (JAX array type):", type(jax_batch["attention_mask"]))
    print("Token Type IDs (JAX array type):", type(jax_batch["token_type_ids"]))
    print("Example JAX input_ids (first 10 tokens of first example):\n", jax_batch["input_ids"][0][:10])
    break

Run this script and observe the output. You’ll see how the raw text gets transformed into numerical IDs, padded, and then batched into JAX arrays. This is the foundation for feeding data into your Tunix post-training loops!

Mini-Challenge: Customize Your Tokenization

You’ve seen how to tokenize with a fixed max_length and padding="max_length". Now, let’s try a variation.

Challenge: Modify the tokenize_function to use padding="longest" instead of padding="max_length". What effect does this have on the max_length of the input_ids and attention_mask in your batches? What are the pros and cons of padding="longest" versus padding="max_length"?

Hint: padding="longest" will pad each batch to the length of the longest sequence within that specific batch, rather than a global fixed max_length. You might want to remove max_length from the tokenizer arguments when using padding="longest" to truly see its effect, or set a very large max_length to ensure it doesn’t truncate.

What to observe/learn:

  • How padding="longest" dynamically adjusts the sequence length per batch.
  • The trade-offs between dynamic padding (less wasted computation on padding) and static padding (simpler JAX compilation for fixed shapes).

Common Pitfalls & Troubleshooting

Data preparation can be tricky. Here are some common issues you might encounter:

  1. Tokenizer/Model Mismatch: Using a tokenizer from one model architecture (e.g., BERT) with a model from another (e.g., T5) can lead to unexpected token IDs, incorrect special tokens, or errors.
    • Fix: Always use the tokenizer specifically designed for the pre-trained model you intend to use with Tunix. If you’re using AutoTokenizer.from_pretrained(), ensure the model name matches your actual model.
  2. Incorrect Padding/Truncation:
    • If you don’t pad, you’ll get errors about inconsistent input shapes when batching.
    • If you truncate too aggressively (max_length is too small), you might lose critical information from your input texts.
    • If you pad to max_length but your model’s context window is smaller, you’re doing unnecessary computation.
    • Fix: Carefully choose max_length based on your data distribution and model’s capabilities. Ensure padding is set correctly for your batching strategy.
  3. Data Type Mismatches: JAX often prefers int32 for indices (input_ids, attention_mask) and float32 for activations. If your data pipeline produces int64 or float64 for these, Tunix might complain or convert them, potentially impacting performance.
    • Fix: When converting to JAX arrays, explicitly cast to the desired dtype, e.g., jnp.asarray(x, dtype=jnp.int32). Hugging Face datasets and tf.data generally handle this well, but it’s good to be aware.
  4. Memory Issues with Large Datasets: Loading an entire large dataset into memory can crash your system.
    • Fix: Leverage the streaming capabilities of datasets (e.g., load_dataset("some_dataset", streaming=True)) and tf.data’s efficient disk-to-device pipeline. Only load what’s necessary for the current batch.

Summary

Phew! We’ve covered a lot of ground in this chapter. Here are the key takeaways:

  • Data preparation is paramount: It’s the foundation for successful LLM post-training with Tunix.
  • Tokenization is key: It converts human text into numerical IDs that models understand, using tools like Hugging Face’s transformers library.
  • datasets and tf.data are your friends: These libraries provide robust and efficient ways to load, process, shuffle, and batch large datasets for JAX-native models.
  • Tunix expects JAX arrays: Your final data pipeline should yield dictionaries of JAX arrays (input_ids, attention_mask, etc.) to feed into your Tunix training loops.
  • Padding and Truncation: These are critical for creating uniform input sequences for batching, with choices like padding="max_length" and padding="longest" impacting performance and memory.

You’ve now mastered the art of preparing your data for Tunix! This is a huge step, as a well-prepared dataset can make all the difference. In the next chapter, we’ll shift our focus from data to models. We’ll learn how to load pre-trained LLMs, configure them for post-training tasks, and integrate them with Tunix. Get ready to bring your data and models together!

References

This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.