Introduction

Welcome back, future LLM expert! In our previous chapters, we laid the groundwork by setting up Tunix and understanding its core philosophy. Now, it’s time to peek under the hood and explore how Tunix, built on the powerful JAX ecosystem, handles the intricate dance of model architectures and their ever-evolving state.

Understanding how your Large Language Model (LLM) is represented and how its parameters (the “knowledge” it holds) are managed is absolutely crucial for effective post-training. Unlike traditional imperative frameworks where model state might be implicitly updated, JAX operates on a functional paradigm. This means state management is explicit, predictable, and incredibly powerful when you know how to wield it. Tunix leverages this power, often integrating with libraries like Flax NNX, to give you granular control over your LLM’s internal workings.

By the end of this chapter, you’ll have a solid grasp of JAX’s functional approach to models, how Tunix utilizes Flax NNX for defining and managing LLM components, and the critical concept of explicit state management. This knowledge is fundamental for building sophisticated post-training routines, debugging effectively, and ultimately, achieving peak performance with your models. Let’s dive in!

Core Concepts: The Functional Heart of Tunix

Tunix’s strength comes from its foundation in JAX, a high-performance numerical computation library. JAX’s design principles, particularly its functional programming paradigm, significantly influence how models are built and how their state is handled.

JAX’s Functional Approach

At its core, JAX treats computations as pure functions. This means a function, given the same inputs, will always produce the same outputs and have no side effects. This might seem restrictive at first, especially if you’re used to frameworks where model parameters are internal attributes that get modified in place.

Think of it like a meticulous chef:

  • Imperative Frameworks: The chef might grab a pot, add ingredients, stir, and modify the pot’s contents directly. The “pot” (model) itself changes.
  • JAX’s Functional Approach: The chef always takes fresh ingredients (input data and current model parameters), performs a cooking step (the model’s forward pass), and produces a new, updated dish (output and new parameters). The original ingredients and pot remain untouched.

This immutability has profound benefits:

  • Predictability: Easier to reason about and debug.
  • Parallelism: JAX can safely and efficiently parallelize operations (vmap) and compile them for various hardware (jit) because there are no hidden side effects.
  • Explicit State: You always know exactly what state is being used and produced.

Enter Flax NNX: Building Blocks for JAX Models

While JAX provides the numerical backbone, building complex neural networks efficiently often requires a higher-level API. This is where Flax comes in, and specifically, Flax NNX (Neural Network eXperiments), which Tunix often integrates with. NNX is designed to give you maximum flexibility and control over your model’s state within JAX’s functional paradigm.

NNX introduces a few key concepts:

  • nnx.Module: The base class for all neural network layers and models. It’s a container for parameters and other states.
  • nnx.Param: This is how you declare trainable parameters (like weights and biases) within your nnx.Module. These are the values that your optimizer will update during training.
  • nnx.State: When you instantiate an nnx.Module, it creates an nnx.State object. This object holds all the internal variables of your module – not just parameters, but also things like batch normalization statistics or even optimizer states. This State is what you explicitly pass around and update.
  • nnx.Rngs: JAX uses a deterministic pseudo-random number generator (PRNG). For operations that require randomness (like parameter initialization or dropout), you need to provide explicit random keys (rngs). NNX provides a convenient way to manage these.

Tunix’s “White-Box” Design

The combination of JAX’s functional purity and Flax NNX’s explicit state management underpins Tunix’s “white-box” design philosophy. What does “white-box” mean here?

It means that Tunix allows you to:

  1. See everything: Every parameter, every layer, every internal state variable is accessible and explicit.
  2. Modify anything: Because state is passed explicitly, you can inspect, modify, or even replace parts of your model’s state during the post-training process. This is incredibly powerful for advanced techniques like parameter-efficient fine-tuning (PEFT), model editing, or targeted knowledge injection.

Contrast this with a “black-box” approach where you might only interact with an LLM via its inputs and outputs, without direct access to its internal components. Tunix empowers you to dive deep.

Let’s visualize this flow:

flowchart TD A[Model Definition nnx Module] --> B{Initialization with nnx Rngs} B --> C[Initial nnx State] C --> D{JAX Functional Call Forward Pass} D --> E[Input Data] E --> D D --> F[Output] D --> G[Updated nnx State] G --> H{Optimizer Step} H --> I[New nnx State Next Iteration]
  • Explanation: The nnx.Module defines the architecture. It’s initialized using nnx.Rngs to create an nnx.State. This State is then explicitly passed along with input data to JAX functional calls (like a forward pass). The call produces an output and a new, updated nnx.State. This updated state is then used for subsequent steps, like an optimizer applying gradient updates, which again produces a new state.

Step-by-Step Implementation: Building a Simple NNX Module

Let’s put these concepts into practice by defining a simple multi-layer perceptron (MLP) using Flax NNX. We’ll see how to define parameters, initialize the model, and manage its state.

First, make sure you have Tunix and its dependencies installed. If you skipped Chapter 1, you can install it via pip:

pip install "tunix[full]>=0.1.0" jax flax

(Note: As of 2026-01-30, tunix version 0.1.0 is a stable reference. Always check the official Tunix GitHub releases for the absolute latest version.)

Now, let’s create our Python file, say mlp_model.py.

1. Setting up a Basic Flax NNX Module

We’ll start by importing nnx and defining a simple MLP class.

# mlp_model.py

import jax
import jax.numpy as jnp
from flax.experimental import nnx

# Let's define a simple Multi-Layer Perceptron (MLP)
# using Flax NNX.
class SimpleMLP(nnx.Module):
    # The __init__ method is where we define the layers
    # and any parameters our module will have.
    def __init__(self, features_out: int, *, rngs: nnx.Rngs):
        # We'll create a linear layer.
        # nnx.Linear automatically creates nnx.Param objects
        # for its weights and biases.
        # It also needs an RNG key for initialization.
        self.layer1 = nnx.Linear(2, 4, rngs=rngs)
        self.layer2 = nnx.Linear(4, features_out, rngs=rngs)

    # The __call__ method defines the forward pass of our module.
    # It takes the input data and applies the layers.
    def __call__(self, x: jax.Array) -> jax.Array:
        # Apply the first layer, then a ReLU activation.
        x = nnx.relu(self.layer1(x))
        # Apply the second layer.
        x = self.layer2(x)
        return x

print("SimpleMLP module defined!")

Explanation:

  • import jax, import jax.numpy as jnp: Standard JAX imports.
  • from flax.experimental import nnx: Imports the nnx module. Flax NNX is still considered experimental but is the recommended way for explicit state management.
  • class SimpleMLP(nnx.Module): Our MLP inherits from nnx.Module, making it an NNX-compatible component.
  • def __init__(self, features_out: int, *, rngs: nnx.Rngs):
    • features_out: The number of output features for our MLP.
    • rngs: nnx.Rngs: This is crucial! We’re explicitly requiring an nnx.Rngs object. This object holds the random number generator keys needed for operations like parameter initialization.
    • self.layer1 = nnx.Linear(2, 4, rngs=rngs): We define our first linear layer. nnx.Linear is a pre-built NNX module that creates nnx.Param objects for its weights and biases internally. We pass rngs so it can initialize these parameters randomly. Our input will have 2 features, and this layer will output 4.
    • self.layer2 = nnx.Linear(4, features_out, rngs=rngs): Our second linear layer, taking 4 features from the first layer and outputting features_out.
  • def __call__(self, x: jax.Array) -> jax.Array: This method defines how data flows through our model. It’s like the forward method in other frameworks.
    • x = nnx.relu(self.layer1(x)): We apply the first linear layer and then a ReLU activation function.
    • x = self.layer2(x): We apply the second linear layer.

2. Initializing Model State

Now that we have our SimpleMLP definition, let’s create an instance of it and see how its state is initialized.

Add the following code to mlp_model.py:

# ... (previous code for SimpleMLP) ...

# To initialize our model, we need to provide a random number generator (RNG) key.
# JAX's PRNG is deterministic; we need to explicitly manage keys.
# nnx.Rngs helps us manage multiple RNGs for different purposes (e.g., params, dropout).
# We'll create an initial key for 'params' group.
key = jax.random.PRNGKey(0) # A fixed seed for reproducibility
rngs = nnx.Rngs(params=key)

# Instantiate our SimpleMLP module.
# When we instantiate an nnx.Module, it creates an nnx.State object
# that holds all its parameters and other internal variables.
features_out = 1
model = SimpleMLP(features_out, rngs=rngs)

print("\nModel initialized!")
print(f"Model type: {type(model)}")
print(f"Model state type: {type(model.state)}")

# We can inspect the parameters stored within the model's state.
# model.state.parameters() gives us a view of the trainable parameters.
print("\nModel parameters:")
# nnx.State stores parameters in a tree structure.
# We can use .pretty_tree() for a nice formatted output.
print(model.state.parameters().pretty_tree())

# Let's inspect a specific parameter, e.g., the weights of layer1
print("\nWeights of layer1:")
print(model.state.layer1.kernel.value)

Explanation:

  • key = jax.random.PRNGKey(0): We create a JAX PRNG key with a seed of 0. Using a fixed seed ensures that our parameter initialization is reproducible.
  • rngs = nnx.Rngs(params=key): We wrap our PRNG key in an nnx.Rngs object. We assign it to the params group. If we had dropout, we might also have dropout=jax.random.PRNGKey(1).
  • model = SimpleMLP(features_out, rngs=rngs): We instantiate our SimpleMLP. During this step, the __init__ method runs, and nnx.Linear uses the rngs object to initialize its weights and biases, which are stored as nnx.Param objects within the model instance’s state.
  • model.state: This is the nnx.State object that holds all the variables of our SimpleMLP instance.
  • model.state.parameters().pretty_tree(): This shows us a nicely formatted tree structure of all the nnx.Param objects (trainable parameters) within our model. You’ll see kernel (weights) and bias for both layer1 and layer2.
  • model.state.layer1.kernel.value: We can directly access the NumPy array value of a specific parameter.

3. Performing a Forward Pass and Understanding State Updates

Now, let’s perform a forward pass with some dummy data. This is where the functional nature of JAX and NNX becomes very clear: the __call__ method will implicitly take the current state and return a new state (though in this simple forward pass, the parameters themselves aren’t changed, but if we had, say, batch norm, its statistics would update).

Add the following to mlp_model.py:

# ... (previous code for initialization) ...

# Let's create some dummy input data.
# Our first layer expects 2 input features.
dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # A batch of 2 samples

print(f"\nInput data shape: {dummy_input.shape}")

# To perform a forward pass, we call the model instance like a function.
# The model's __call__ method implicitly operates on and returns its state.
# For simple forward passes without stateful layers (like BatchNorm),
# the state returned will be identical to the input state.
# But it's good practice to always capture the returned state.
output, new_state = model(dummy_input)

print(f"Output from model (shape {output.shape}):")
print(output)

# In a pure forward pass, the model's parameters themselves don't change.
# However, if we had layers like BatchNorm, their statistics would be updated
# in the returned `new_state`.
# Let's confirm that the state objects are the same (no batch norm here).
print(f"\nIs the original state object the same as the new_state object? {model.state is new_state}")

# We can also access the parameters from the new_state
print("\nWeights of layer1 from new_state (should be identical to original):")
print(new_state.layer1.kernel.value)

# Let's verify if the parameter values are indeed identical
print(f"Are layer1 kernel values identical? {jnp.array_equal(model.state.layer1.kernel.value, new_state.layer1.kernel.value)}")

Explanation:

  • dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]]): We create a small batch of input data. Each sample has 2 features, matching our nnx.Linear(2, 4, ...) layer.
  • output, new_state = model(dummy_input): This is the core of the functional call. When we “call” an nnx.Module instance, JAX implicitly passes its current nnx.State along with the input data to the __call__ method. The __call__ method then computes the output and returns the potentially updated nnx.State along with the output.
  • model.state is new_state: In this specific SimpleMLP without stateful layers like Batch Normalization, the new_state returned will be the exact same object as model.state. This indicates no internal variables (beyond the parameters, which are currently static) were updated. If we had a Batch Norm layer, new_state would be a new object containing updated moving averages. This explicit state passing is key to JAX’s functional paradigm.

This step-by-step example shows you how to define a model, initialize its state, and perform a forward pass using Flax NNX, which is the foundation for how Tunix manages LLMs.

Mini-Challenge: Adding Dropout to Your MLP

Now it’s your turn to get hands-on!

Challenge: Modify the SimpleMLP to include a dropout layer after the first relu activation. Observe how nnx.Rngs are used for this.

Task Description:

  1. Add self.dropout = nnx.Dropout(rate=0.5) in the __init__ method of SimpleMLP.
  2. In the __call__ method, apply self.dropout after nnx.relu and before self.layer2. Remember that nnx.Dropout requires an rngs object and a use_running_mode flag (for train/eval).
  3. When you call the model for the forward pass, you’ll need to provide an additional rngs group for dropout and set use_running_mode=False for training-like behavior.
  4. Print the new_state’s parameters again and verify the dropout layer itself doesn’t add trainable parameters (it’s a process, not a parameter holder).

Hint:

  • You’ll need to define a separate RNG key for dropout in your nnx.Rngs object, e.g., rngs = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(1)).
  • The dropout layer in __call__ will look something like x = self.dropout(x, rngs=rngs, use_running_mode=False).

What to Observe/Learn:

  • How nnx.Rngs are compartmentalized for different random operations.
  • The difference in model behavior (output values will change due to dropout).
  • That nnx.Dropout itself doesn’t add nnx.Param objects to the state.

Feel free to experiment and try to solve this before looking up solutions!

Click for Solution Hint!
# mlp_model_solution.py

import jax
import jax.numpy as jnp
from flax.experimental import nnx

class SimpleMLP(nnx.Module):
    def __init__(self, features_out: int, *, rngs: nnx.Rngs):
        self.layer1 = nnx.Linear(2, 4, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.5) # Add dropout layer
        self.layer2 = nnx.Linear(4, features_out, rngs=rngs)

    def __call__(self, x: jax.Array, *, rngs: nnx.Rngs, use_running_mode: bool) -> jax.Array:
        x = nnx.relu(self.layer1(x))
        # Apply dropout, passing the 'dropout' RNG and the running mode flag
        x = self.dropout(x, rngs=rngs, use_running_mode=use_running_mode)
        x = self.layer2(x)
        return x

key = jax.random.PRNGKey(0)
# Create a separate RNG key for dropout
rngs_init = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(1))

features_out = 1
model = SimpleMLP(features_out, rngs=rngs_init)

print("Model initialized with Dropout!")
print("\nModel parameters (dropout doesn't add params):")
print(model.state.parameters().pretty_tree())


dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]])

# When calling, provide the rngs for dropout and specify use_running_mode
output, new_state = model(dummy_input, rngs=rngs_init, use_running_mode=False) # False for training mode

print(f"\nOutput from model with dropout (shape {output.shape}):")
print(output)

# Try calling again with a new dropout key to see different results if desired
rngs_second_call = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(2))
output_2, _ = model(dummy_input, rngs=rngs_second_call, use_running_mode=False)
print(f"\nOutput from model with different dropout key (shape {output_2.shape}):")
print(output_2)
print(f"Are outputs from different dropout keys identical? {jnp.array_equal(output, output_2)}")

# Now, let's call it in evaluation mode (no dropout applied)
output_eval, _ = model(dummy_input, rngs=rngs_init, use_running_mode=True)
print(f"\nOutput from model in evaluation mode (shape {output_eval.shape}):")
print(output_eval)

Common Pitfalls & Troubleshooting

Working with JAX and Flax NNX, especially when coming from other frameworks, can present a few unique challenges.

  1. Forgetting to Capture Updated State: This is perhaps the most common pitfall. Because JAX functions are pure and state is immutable, any function that “updates” state (like an optimizer step or a batch normalization layer during training) will return a new state object. If you don’t capture this new state and pass it to the next operation, you’ll be using stale parameters or statistics, leading to incorrect or non-converging training.

    • Fix: Always assign the returned state: new_state, output = my_model(state, input_data) or new_opt_state, new_params = optimizer.update(grads, opt_state, params).
  2. Incorrect RNG Management: JAX’s explicit PRNG system requires careful handling of random keys. Reusing the same key for multiple independent random operations can lead to correlated randomness, while forgetting to split a key for sequential random operations will result in the same “random” numbers being generated.

    • Fix:
      • For independent random calls (e.g., initializing multiple layers), use jax.random.split(key, num_splits).
      • For sequential calls (e.g., dropout in each training step), jax.random.fold_in(key, step_id) or jax.random.split(key, 2)[0] for the next step.
      • nnx.Rngs helps by managing different key streams for different purposes (e.g., params, dropout). Ensure you pass the correct rngs object to your module’s __call__ method when needed.
  3. Confusion between Flax NNX and Flax Linen: Flax has two main API styles: Linen (the more established, object-oriented API) and NNX (the newer, experimental API designed for explicit state management). While they share some concepts, their usage patterns for state management are different. Tunix often leans into the explicit state management of NNX.

    • Fix: Be mindful of which API you’re using. If you see nnx.Module, nnx.Param, nnx.State, you’re in NNX land. If you see flax.linen.Module, self.sow, self.param, that’s Linen. Stick to NNX when working with Tunix’s explicit state patterns.

Summary

Phew! You’ve navigated the functional depths of JAX and discovered how Tunix leverages Flax NNX to manage LLM architectures and their state. Here are the key takeaways from this chapter:

  • JAX’s Functional Core: JAX emphasizes pure functions and immutable data, meaning operations produce new states rather than modifying existing ones in place.
  • Flax NNX for Explicit Models: Tunix integrates with Flax NNX, which provides nnx.Module, nnx.Param, nnx.State, and nnx.Rngs for defining model architectures and managing their variables explicitly.
  • “White-Box” Control: This explicit state management enables Tunix’s “white-box” design, giving you fine-grained access and control over every part of your LLM during post-training.
  • RNGs are Crucial: JAX’s deterministic PRNG requires careful management of random keys, often handled conveniently by nnx.Rngs.
  • State is Immutable: Always remember to capture the new state returned by JAX/Flax NNX functions that perform updates.

Understanding these concepts is not just theoretical; it’s the bedrock for building robust, scalable, and highly customizable post-training workflows with Tunix. In the next chapter, we’ll start putting this model architecture knowledge to use as we explore how Tunix orchestrates the training loop for LLMs. Get ready to train some models!

References


This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.