Introduction

Welcome to Chapter 15! As you dive deeper into the exciting world of post-training Large Language Models with Tunix and JAX, you’ll inevitably encounter moments where things don’t quite go as planned. Code doesn’t always run perfectly on the first try, especially with complex distributed systems and JIT compilation. This is where the crucial skill of debugging and troubleshooting comes into play.

In this chapter, we’ll equip you with the essential tools and techniques to effectively diagnose and resolve issues in your Tunix workflows. We’ll demystify common JAX error messages, explore Tunix’s built-in logging, and guide you through a systematic approach to pinpointing problems. By the end, you’ll feel confident tackling even the trickiest bugs, transforming frustration into a satisfying problem-solving experience.

Before we begin, make sure you’re comfortable with the core Tunix concepts, JAX fundamentals, and have a basic understanding of model training loops as covered in previous chapters. Let’s turn those head-scratching moments into “aha!” moments!

Core Concepts for Effective Debugging

Debugging in the JAX and Tunix ecosystem can sometimes feel like solving a puzzle with very few clues. JAX’s powerful Just-In-Time (JIT) compilation and automatic vectorization (vmap) or parallelization (pmap) can abstract away the exact execution flow, making errors appear far from their root cause. Tunix, built on top of JAX, inherits these characteristics while adding its own layers of abstraction. Let’s break down the core concepts you’ll need.

Understanding JAX’s “Just-In-Time” (JIT) Compilation

JAX’s jit transform is a cornerstone of its high performance. It compiles Python functions into optimized XLA (Accelerated Linear Algebra) computations. While fantastic for speed, it has a significant impact on debugging:

  • Deferred Execution: JIT-compiled code doesn’t execute line-by-line immediately. Instead, JAX traces your function to build a computation graph. Errors often only surface when this compiled graph is executed, sometimes much later than where the logical error occurred in your Python code.
  • Abstracted Shapes and Dtypes: During tracing, JAX works with abstract shapes and data types. Concrete values are only known at runtime, meaning issues like shape mismatches might not be caught until the compiled function is called.
  • Limited Python Debugger Access: Standard Python debuggers (pdb, ipdb) can step through the Python code that defines the JIT-compiled function, but they cannot step inside the compiled XLA code.

Why it matters: When a JAX-related error pops up in a Tunix training loop, your first thought might be, “Where did this come from?” Understanding JIT helps you realize that the error might be an input issue to the compiled function, not necessarily a bug in the Python lines that define it.

Tunix’s Logging System

Tunix, like most robust Python libraries, uses the standard logging module. This is your window into what Tunix is doing internally. By configuring the logging level, you can get more verbose output about:

  • Model loading and initialization steps.
  • Data preprocessing stages.
  • Loss calculation and gradient accumulation.
  • Device allocation and synchronization.

Why it matters: When Tunix behaves unexpectedly, increasing the logging verbosity can often reveal intermediate states or warnings that point to the problem. For example, if a model isn’t loading correctly, INFO or DEBUG level logs might show specific file paths being accessed or configuration parameters being parsed.

Decoding JAX Error Messages

JAX error messages can sometimes look intimidating, filled with stack traces and XLA specifics. However, they often contain crucial information if you know what to look for. Common JAX errors include:

  • Shape Mismatch Errors (e.g., IncompatibleShapError, ValueError: operands could not be broadcast together with shapes...): These are perhaps the most common. JAX is very strict about array shapes. If you try to perform an operation (like addition or matrix multiplication) between arrays that don’t have compatible shapes, JAX will complain.
    • Clue: The error message will usually explicitly state the expected and actual shapes.
    • Strategy: Trace back the operation, print shapes of intermediate arrays (x.shape or jax.ShapeDtypeStruct(x.shape, x.dtype) if JIT is active) to find where the mismatch originates.
  • Device Memory Errors (e.g., OOM - Out Of Memory): This means your model or batch size is too large for the available GPU/TPU memory.
    • Clue: Look for “ResourceExhaustedError” or “OOM” in the traceback.
    • Strategy: Reduce batch size, use gradient accumulation, consider smaller model variants, or switch to a device with more memory.
  • NaN/Inf Errors (Not a Number / Infinity): These indicate numerical instability, often due to exploding gradients, division by zero, or log(0).
    • Clue: The error might not be a JAX-specific error but rather RuntimeWarning: invalid value encountered in... followed by a NaN propagating through your computations.
    • Strategy: Check your loss function for potential divisions by zero or log(0). Implement gradient clipping. Monitor loss values and intermediate activations during training.

Why it matters: Learning to interpret these messages is like learning a new language. Each error type points you towards a specific category of problem, significantly narrowing down your search.

A General Debugging Workflow

A systematic approach can save hours of frustration. Here’s a general workflow you can adapt:

flowchart TD A[Issue Detected] --->|JAX Tunix Error| B; B --->|Yes| C[Check Tunix Logs]; C --> D{Clues in Logs}; D --->|Yes| E[Hypothesize and Inspect Code]; D --->|No| F[Disable JIT]; F --> G[Run Code Again]; G --> H{Clearer Error Message}; H --->|Yes| E; H --->|No| I[Add Print Statements or Debugger]; I --> E; E --> J[Formulate Fix and Test]; J --> K{Issue Resolved}; K --->|Yes| L[Re enable JIT]; K --->|No| C; B --->|No| M[Standard Python Debugging]; M --> E;

Why it matters: Following a structured approach prevents aimless poking and ensures you gather information systematically.

Step-by-Step Implementation: Debugging a Tunix Workflow

Let’s walk through a practical example of how to debug a common issue in a Tunix-like scenario. We’ll simulate a simple Tunix Trainer setup and intentionally introduce a shape mismatch error.

First, ensure you have JAX and Tunix installed. For Tunix, you’d typically install it from its GitHub repository as of 2026-01-30, as it’s an actively developed library. We’ll assume you have a JAX version like jax==0.4.23 or newer, and jaxlib matching your accelerator setup. Tunix itself doesn’t have frequent major releases that would drastically change its API for basic usage, but always check the official GitHub for the latest stable version (e.g., 0.3.0 or newer).

pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # For CUDA 12, adjust for your setup
pip install flax transformers # Common dependencies for LLM work
# Install Tunix from source or latest stable release if available via pip
# As of 2026-01-30, the most reliable way might still be from source for bleeding edge:
# git clone https://github.com/google/tunix.git
# cd tunix
# pip install -e .

(Note: Always refer to the official JAX and Tunix documentation for the absolute latest installation instructions and version compatibility.)

Now, let’s create a dummy script (debug_tunix.py) that mimics a Tunix training step, including a deliberate error.

1. Setting Up a Basic Tunix-like Scenario (with a Bug)

We’ll define a simple dummy model and a training step that will eventually lead to a shape mismatch.

# debug_tunix.py
import jax
import jax.numpy as jnp
import flax.linen as nn
import logging
from typing import Sequence

# Configure basic logging for Tunix. This helps see internal messages.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("TunixDebug")

# Dummy Tunix-like Model (a simple MLP)
class DummyLLM(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        logger.info(f"DummyLLM input shape: {x.shape}")
        for i, dim in enumerate(self.features):
            x = nn.Dense(dim, name=f'dense_{i}')(x)
            if i < len(self.features) - 1:
                x = nn.relu(x)
        return x

# Dummy Tunix-like Loss Function
def dummy_loss_fn(params, model_apply_fn, batch):
    inputs, targets = batch
    logits = model_apply_fn({'params': params}, inputs)
    
    # INTENTIONAL BUG: We expect logits to be 2D, but we'll feed a 1D target.
    # This will cause a shape mismatch during comparison.
    logger.info(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
    loss = jnp.mean((logits - targets)**2) # Simple MSE loss
    return loss

# Dummy Tunix-like Optimizer
@jax.jit
def update_step(params, model_apply_fn, batch, opt_state):
    # In a real Tunix setup, this would use optax or a similar optimizer.
    # For debugging, we just need a function that uses the loss.
    grad_fn = jax.value_and_grad(dummy_loss_fn)
    loss, grads = grad_fn(params, model_apply_fn, batch)
    
    # Apply gradients (simplified, not a real optimizer step)
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params, opt_state, loss

# Main execution block
if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    model = DummyLLM(features=[128, 64, 1]) # Output feature is 1 for regression

    # Initialize model parameters
    dummy_input = jnp.ones((1, 10)) # Batch size 1, input dim 10
    params = model.init(key, dummy_input)['params']
    
    # Dummy optimizer state (not used in this simplified update_step, but common)
    opt_state = {} 

    # Prepare batch data
    # Input batch size 4, input dim 10
    input_data = jax.random.normal(key, (4, 10)) 
    # Target data (INTENTIONAL BUG: 1D array, but model outputs 2D (4,1))
    target_data = jax.random.normal(key, (4,)) 
    
    batch = (input_data, target_data)

    print("\n--- Attempting training step with bug ---")
    try:
        new_params, _, current_loss = update_step(params, model.apply, batch, opt_state)
        print(f"Training step successful! Loss: {current_loss}")
    except Exception as e:
        print(f"\n--- Error caught during training step ---")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")

Explanation of the Code:

  1. logging.basicConfig: We set up basic logging to INFO level. This means any logger.info() calls will print to the console.
  2. DummyLLM: A simple Flax nn.Module that acts as our “model”. It takes an input and passes it through a few dense layers. We added a logger.info to see the input shape.
  3. dummy_loss_fn: This is where our intentional bug lies. The model outputs logits with shape (batch_size, 1) because the last Dense layer has features=1. However, we’ve created targets with shape (batch_size,). When we try to subtract them (logits - targets), JAX’s broadcasting rules will fail, leading to a shape mismatch. We also added logging for logits and targets shapes.
  4. update_step: A simplified JAX-jitted function that calculates gradients and updates parameters. This mimics a single Tunix training step. The jax.jit decorator is crucial here for demonstrating debugging challenges.
  5. Main Block (if __name__ == "__main__":):
    • Initializes the DummyLLM and its parameters.
    • Creates input_data with shape (4, 10) (batch size 4, input features 10).
    • Crucially, it creates target_data with shape (4,), which is a 1D array.
    • It then attempts to run update_step, wrapped in a try-except block to catch the error gracefully.

Now, run this script: python debug_tunix.py

You should see output similar to this (exact line numbers may vary):

... (logging messages) ...
2026-01-30 10:00:00,000 - TunixDebug - INFO - DummyLLM input shape: (1, 10)
... (more logging from model init) ...

--- Attempting training step with bug ---
2026-01-30 10:00:00,000 - TunixDebug - INFO - DummyLLM input shape: (4, 10)
2026-01-30 10:00:00,000 - TunixDebug - INFO - Logits shape: (4, 1), Targets shape: (4,)

--- Error caught during training step ---
Error type: ValueError
Error message: Incompatible shapes for broadcasting: ((4, 1), (4,))

2. Analyzing the Initial Error and Tunix Logs

Notice the error message: ValueError: Incompatible shapes for broadcasting: ((4, 1), (4,)). This is a classic JAX shape mismatch!

Looking at our custom TunixDebug logs, we see: Logits shape: (4, 1), Targets shape: (4,)

This is fantastic! The logs immediately pinpoint the problem. The model is producing logits with shape (4, 1) (4 examples, 1 output feature), but our targets are (4,) (4 examples, but 1D). JAX cannot broadcast a (4, 1) array with a (4,) array for element-wise subtraction.

What we’ve learned:

  • Even with JIT, Tunix’s internal logging and our custom logger.info statements can be incredibly helpful.
  • The JAX error message clearly states the problematic shapes.

3. Disabling JIT for Deeper Inspection (If Logs Aren’t Enough)

What if the logs weren’t so clear, or the error was more subtle? This is where jax.disable_jit() becomes your best friend. It forces JAX to execute operations in Python, often yielding more detailed Python tracebacks.

Modify debug_tunix.py by adding jax.disable_jit() at the beginning of the if __name__ == "__main__": block:

# ... (previous code) ...

# Main execution block
if __name__ == "__main__":
    # Add this line to disable JIT
    jax.disable_jit(True) # Disable JIT for easier debugging

    key = jax.random.PRNGKey(0)
    model = DummyLLM(features=[128, 64, 1])

    # ... (rest of the code is the same) ...

Run the script again: python debug_tunix.py

The output will still show the ValueError: Incompatible shapes for broadcasting: ((4, 1), (4,)), but the traceback might be slightly different or contain more Python-centric details, potentially pointing to the exact line within dummy_loss_fn where the subtraction occurs, rather than just the compiled function. For this specific error, the message is already quite clear, but for more complex JIT-related issues, disabling it is a crucial step.

4. Fixing the Bug

Based on the error message and logs, the fix is straightforward: ensure targets has a compatible shape with logits. Since logits is (4, 1), targets should also be (4, 1).

Modify the target_data creation in debug_tunix.py:

# ... (previous code) ...

    # Prepare batch data
    input_data = jax.random.normal(key, (4, 10))
    # FIX: Reshape target data to match logits (4, 1)
    target_data = jax.random.normal(key, (4, 1)) # Now (4, 1) to match logits
    
    batch = (input_data, target_data)

# ... (rest of the code is the same) ...

Now, re-enable JIT (remove or comment out jax.disable_jit(True)) and run the script again: python debug_tunix.py

You should now see:

... (logging messages) ...

--- Attempting training step with bug ---
2026-01-30 10:00:00,000 - TunixDebug - INFO - DummyLLM input shape: (4, 10)
2026-01-30 10:00:00,000 - TunixDebug - INFO - Logits shape: (4, 1), Targets shape: (4, 1)
Training step successful! Loss: 1.9987654

Success! The training step executed without error, and the loss value is printed.

Mini-Challenge

Now it’s your turn to apply what you’ve learned.

Challenge: Introduce a new bug into the dummy_loss_fn that causes NaN (Not a Number) values to appear in the loss. Your task is to modify the dummy_loss_fn to create this NaN issue and then use logging and conceptual understanding to identify and fix it.

Hint: Think about mathematical operations that easily produce NaNs, such as division by zero or log(0). You don’t need to change target_data back to its buggy state. Focus on the loss = ... line.

What to observe/learn: How NaNs propagate, and how to use logging to catch them early.

(Pause here, try to solve the challenge yourself before looking at the solution idea!)

Solution Idea for Mini-Challenge

One way to introduce NaN is to divide by a value that can become zero. Let’s make the logits potentially zero and then divide by them.

1. Introduce the bug (e.g., in dummy_loss_fn):

# ... (inside dummy_loss_fn) ...
    # INTENTIONAL BUG: Add a division by a potentially zero value
    epsilon = 1e-7 # Small constant to avoid true division by zero initially
    loss = jnp.mean((logits - targets)**2 / (jnp.abs(logits) + epsilon)) 
    # If logits contain very small numbers, this division can lead to large values
    # or if epsilon is 0 and logits are 0, it would be NaN.
    # For a more direct NaN, you could do: loss = jnp.log(logits - targets) 
    # if (logits - targets) could be negative.

For a more direct NaN, let’s use jnp.log with a potential negative argument.

# ... (inside dummy_loss_fn) ...
    # INTENTIONAL BUG: Causes NaN if (logits - targets) is ever non-positive
    difference = logits - targets
    # To force a NaN, let's ensure 'difference' can be negative
    # For example, if targets are much larger than logits
    # Let's make targets a large positive number to ensure difference is negative
    # (In a real scenario, this would happen dynamically)
    # For now, let's just use log(abs(difference)) to avoid immediate error,
    # but still show how NaNs propagate if inputs are bad.
    # A cleaner way to force NaN for demonstration:
    # `loss = jnp.log(jnp.array([-1.0]))` inside the loss function.
    # Or, if we stick to the original MSE, we could divide by zero.
    
    # Let's use a simpler, more direct NaN producer for clarity:
    # Imagine a scenario where a metric requires a positive input, but gets negative.
    # For example, if we were calculating `jnp.log(something_that_can_be_negative)`
    # To make it simple, let's force a `log(0)` or `log(negative)`
    if jnp.any(logits - targets <= 0): # Check if any difference is non-positive
        # This is a contrived example to illustrate NaN,
        # but represents a scenario where a mathematical function gets invalid input.
        loss = jnp.log(jnp.array([0.0])) # This will produce NaN
    else:
        loss = jnp.mean((logits - targets)**2) # Simple MSE loss

This is a bit too contrived. Let’s stick to a simpler, more common way NaNs appear: numerical instability from very small/large numbers or division.

Let’s modify the loss to create NaNs if logits become too small, leading to a division by zero.

# debug_tunix.py (Bugged version for Mini-Challenge)
# ... (imports and initial setup) ...

# Dummy Tunix-like Loss Function
def dummy_loss_fn(params, model_apply_fn, batch):
    inputs, targets = batch
    logits = model_apply_fn({'params': params}, inputs)
    
    logger.info(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")

    # INTENTIONAL BUG FOR MINI-CHALLENGE: Create NaN if logits are too small
    # This simulates a scenario where a denominator becomes zero.
    # Let's make a denominator that can be zero
    denominator = jnp.mean(jnp.abs(logits)) # Example: average absolute logit
    
    # If denominator gets too close to zero, this division will cause NaNs/Infs
    # Forcing it for demonstration:
    if denominator < 0.1: # Threshold to trigger the bug
        # Simulate a problematic division
        loss = jnp.mean((logits - targets)**2 / (denominator * 0.001)) # Make denominator very small
    else:
        loss = jnp.mean((logits - targets)**2) # Original MSE loss
    
    # Also log the loss *before* returning, to see if it's NaN
    logger.info(f"Intermediate Loss value: {loss}")
    return loss

# ... (rest of the script, with target_data fixed to (4,1)) ...

Run this. You might get NaN directly or an Inf initially. The logger.info(f"Intermediate Loss value: {loss}") will be key. If it prints NaN, you know the problem is within the dummy_loss_fn itself.

2. Identify the NaN using logs: The logger.info(f"Intermediate Loss value: {loss}") line will show NaN.

3. Fix the bug: The fix would be to ensure denominator never becomes zero or too small, often by adding a small epsilon or using numerical stable functions.

# debug_tunix.py (Fixed version for Mini-Challenge)
# ... (inside dummy_loss_fn) ...

    # FIX: Ensure denominator is never too small, add a stable epsilon
    denominator = jnp.mean(jnp.abs(logits)) + 1e-8 # Add a small epsilon
    
    # Revert to stable MSE loss (or fix the problematic division)
    loss = jnp.mean((logits - targets)**2) 
    # If the intent was to divide, it should be done carefully:
    # loss = jnp.mean((logits - targets)**2 / denominator) # Now denominator is safe

Common Pitfalls & Troubleshooting

Beyond the shape mismatches and NaNs, here are a few other common issues you might encounter with Tunix and JAX, along with troubleshooting tips:

  1. Device Memory Errors (Out of Memory - OOM)

    • Symptom: Your program crashes with ResourceExhaustedError or mentions “OOM”. This means your GPU/TPU ran out of memory.
    • Cause: Too large a batch size, a very large model, or accumulating too many intermediate activations during gradient computation.
    • Troubleshooting:
      • Reduce Batch Size: This is the quickest fix.
      • Gradient Accumulation: If you need a large effective batch size, accumulate gradients over several smaller batches. Tunix often supports this.
      • Model Parallelism/Sharding: For extremely large models, you might need to split the model across multiple devices. Tunix and JAX’s pmap can help manage this, but it adds complexity.
      • Check for Leaky JAX Objects: Sometimes, JAX arrays can be inadvertently held onto, preventing memory from being freed. Ensure you’re not storing large intermediate arrays unnecessarily.
  2. Slow Performance or Unexpected CPU Fallback

    • Symptom: Your training is significantly slower than expected, or JAX reports “No GPU/TPU found, falling back to CPU”.
    • Cause:
      • Incorrect JAX/JAXlib Installation: JAXlib needs to be installed with the correct CUDA/ROCm/TPU support.
      • Host-Device Transfers: Frequent transfers of data between Python (host) and accelerator (device) can bottleneck performance. JAX operations are fast on the device, but moving data is slow.
      • Not JIT-compiling enough: Small operations outside of JIT can add up.
      • Inefficient XLA compilation: Sometimes, an operation might not be efficiently compiled by XLA.
    • Troubleshooting:
      • Verify Installation: Double-check pip list | grep jax and ensure jaxlib specifies your accelerator (cuda, tpu, etc.). Check jax.devices() to confirm your device is detected.
      • Minimize Host-Device Transfers: Keep data (inputs, parameters) on the device as much as possible. Use jax.device_put() for initial placement.
      • Profile Your Code: Use JAX’s profiling tools (e.g., jax.profiler) to identify bottlenecks. This can show you where time is spent.
      • Refactor for JIT: Ensure your core training and evaluation loops are JIT-compiled.
  3. Stuck Training / No Loss Improvement

    • Symptom: Loss is flat, or model performance isn’t improving.
    • Cause:
      • Learning Rate Issues: Too high (exploding gradients) or too low (no progress).
      • Gradient Vanishing/Exploding: Common in deep networks.
      • Incorrect Loss Function: Loss function doesn’t correctly reflect the task.
      • Data Issues: Bad data, incorrect preprocessing, or label errors.
      • Model Architecture Flaws: Model might be too simple, too complex, or have architectural problems.
    • Troubleshooting:
      • Learning Rate Schedule: Experiment with different learning rates and schedules. Use gradient clipping.
      • Monitor Gradients: Log gradient norms. If they are NaN, Inf, or extremely small/large, that’s a clue.
      • Sanity Checks: Train a tiny model on a tiny subset of data (e.g., 2-3 batches). It should overfit quickly. If it doesn’t, there’s a fundamental issue.
      • Inspect Data: Visualize inputs and outputs. Ensure preprocessing is correct.
      • Initialization: Check model parameter initialization.

Summary

Phew! You’ve navigated the tricky waters of debugging Tunix and JAX workflows. Here’s a quick recap of the most important takeaways:

  • JAX’s JIT is powerful but hides errors: Remember that JIT defers execution, making errors appear later.
  • Leverage Tunix and JAX logging: Configure logging to INFO or DEBUG level to get insights into internal operations and variable shapes.
  • jax.disable_jit(True) is your primary weapon: When logs aren’t enough, temporarily disable JIT to get clearer Python tracebacks.
  • Understand JAX error messages: Learn to interpret shape mismatches, OOM errors, and NaN/Inf warnings. They are your clues!
  • Follow a systematic debugging workflow: Don’t just randomly change code. Hypothesize, inspect, test, and iterate.
  • Common pitfalls: Be aware of device memory limits, performance bottlenecks from host-device transfers, and issues like flatlining loss.

Debugging is a skill that improves with practice. The more you encounter and solve problems, the better you’ll become at anticipating them and quickly finding solutions.

What’s Next?

In the next chapter, we’ll shift our focus from finding problems to preventing them and ensuring our Tunix models are ready for prime time. We’ll explore best practices for deploying your fine-tuned LLMs, covering topics like model serialization, serving, and monitoring in production environments. Get ready to share your amazing work with the world!

References

  1. Tunix Official GitHub Repository: The primary source for Tunix code, issues, and the latest releases.
  2. Tunix Official Documentation: Comprehensive guides and API references for Tunix.
  3. JAX Documentation - Debugging JAX Programs: Official guide on debugging JAX-specific issues.
  4. JAX Documentation - JIT Compilation: Deep dive into how JIT works in JAX.
  5. Python Logging HOWTO: Official Python documentation on the logging module.

This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.