Skip to content

added complete Qwen2.5-7B tensor parallel bounty implementation #958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ekulkisnek
Copy link

Ticket

Link to Github Issue

Problem description

Includes a JAX tensor parallel implementation of Qwen2.5-7B for the mentioned bounty. With deterministic settings it provides token for token identical outputs to a pytorch implementation.

What's changed

Directly compared logits each step between a pytorch implementation and a single device JAX implementation. Once the single device JAX worked, I used shard_map to get it parallel. Tested for 1x8.

Checklist

  • Working tensor parallel JAX Qwen2.5-7B model with same GSM8K performance as a pytorch implementation.

Copy link

@tapspatel tapspatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff. Thank you for your contribution. Verified the tests and the functionality. Some general comments inlined + changes requested. Otherwise looks good.

PASS

python3 generate_multi_chip.py --model_path qwen2_5_7b/ --prompt "what is the weather today"

PASS

python test_gsm8k.py --model_path qwen2_5_7b/ --num_samples 100

For forge engineers (to PR author, none of this is blocking, but general notes and comments for our internal engineers)

  • the current model uses around 126gb of ram at its peak
  • by default, jax uses GSPMD style partitioning (string based IR's) but you can use shardy partitioner instead. We support both today in tenstorrent compiler but jax is also moving to shardy as default in short term. We also have limited support for GSPMD (especially in case of automatic sharding)
jax.config.update("jax_use_shardy_partitioner", True)

Here are the shardy files dumped. They are generated for a 1d mesh (tensor parallel) but you can adapt for 2D mesh and change the mesh shape. All these operations are supported in tt-mlir today.

python3 generate_multi_chip.py --model_path qwen2_5_7b/ --prompt "what is the weather today

module @jit_matmul_fn attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["mp"=8]>
  func.func public @main(%arg0: tensor<1x34x3584xbf16>, %arg1: tensor<3584x3584xbf16>, %arg2: tensor<3584xbf16>) -> (tensor<1x34x3584xbf16> {jax.result_info = "result"}) {
    %0 = sdy.manual_computation(%arg0, %arg1, %arg2) in_shardings=[<@mesh, [{}, {}, {}]>, <@mesh, [{}, {"mp"}]>, <@mesh, [{"mp"}]>] out_shardings=[<@mesh, [{}, {}, {}]>] manual_axes={"mp"} (%arg3: tensor<1x34x3584xbf16>, %arg4: tensor<3584x448xbf16>, %arg5: tensor<448xbf16>) {
      %1 = stablehlo.dot_general %arg3, %arg4, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x34x3584xbf16>, tensor<3584x448xbf16>) -> tensor<1x34x448xbf16>
      %2 = stablehlo.broadcast_in_dim %arg5, dims = [2] : (tensor<448xbf16>) -> tensor<1x1x448xbf16>
      %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1, 2] : (tensor<1x1x448xbf16>) -> tensor<1x34x448xbf16>
      %4 = stablehlo.add %1, %3 : tensor<1x34x448xbf16>
      %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2, 3] : (tensor<1x34x448xbf16>) -> tensor<1x1x34x448xbf16>
      %6 = "stablehlo.all_gather"(%5) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1x1x34x448xbf16>) -> tensor<8x1x34x448xbf16>
      %7 = stablehlo.transpose %6, dims = [1, 2, 0, 3] : (tensor<8x1x34x448xbf16>) -> tensor<1x34x8x448xbf16>
      %8 = stablehlo.reshape %7 : (tensor<1x34x8x448xbf16>) -> tensor<1x34x3584xbf16>
      sdy.return %8 : tensor<1x34x3584xbf16>
    } : (tensor<1x34x3584xbf16>, tensor<3584x3584xbf16>, tensor<3584xbf16>) -> tensor<1x34x3584xbf16>
    return %0 : tensor<1x34x3584xbf16>
  }
}
  • the current implementation uses shard_map (which is fine), but we should wrap that under a jax.jit for better caching and more explicit parallelism. The way it currently works is it will generate a new token every loop and it creates a sharded map each loop. However, our compiler will attempt to bring that back to host instead of keeping it on device. Therefore, we should either add in support to keep the intermediates on device or just jit the entire call function. See https://github.com/tenstorrent/tt-xla/pull/747/files for an example. This also means we need to be more explicit with how the weights/inputs are sharded with PartitionSpec. This will also leverage the automatic parallelization pipeline we have in tt-mlir, just like the above bounty.

For ref, when I jit the model.apply function

jit_generate = jax.jit(
        model.apply,
    )
lowered = jit_generate.lower(params, input_ids=input_ids, attention_mask=attention_mask, 
                             position_ids=position_ids, past_key_values=past_key_values)

I get the following stablehlo graph (ignore the manual_computation ops as that is from the current existing code but essentially we want a full graph like this with the inputs/weights sharded and our automatic pipeline will take care of the sharding propagation and collective insertion).
qwen.mlir.txt

- Datasets: For GSM8K in test_gsm8k.py.
- Others: Standard.

3. Download weights: From Hugging Face (https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) in safetensors format to a local `--model_path`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add in a section of how to do this

pip install huggingface_hub
huggingface-cli login
huggingface-cli download Qwen/Qwen2.5-7B-Instruct *some_file* --local-dir qwen2_5_7b/

1. Clone the repo and navigate to this directory.
2. Install dependencies (no internet access needed beyond initial pip):
```
pip install jax flax transformers safetensors psutil numpy datasets

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add something like

python3 -m venv qwen_env
source qwen_env/bin/activate
pip install -r requirements.txt

add a separate requirement.txt file

## Usage
- **Inference Demo**: Run generation on multi-device.
```
python generate_multi_chip.py --model_path /path/to/weights --prompt "Your custom prompt here"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing

pip install jinja2

in requirements.txt

- `--max_tokens`: Maximum number of tokens to generate (default: 500)
- `--dtype`: Choose "bfloat16" (default) or "float32"
- `--no_realtime`: Disable real-time token display
- **Simulate meshes**: Set `os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=N'` (e.g., N=8 for 1x8).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clean up this code to add this

flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = flags

Ideally, the user would just run a script and all environment flags and variables should be set

**Before running the script, set the environment variable:**
```bash
# Simulate 8 devices (1x8 mesh)
export XLA_FLAGS="--xla_force_host_platform_device_count=8"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change these export flags to arg parse parameters

print("Generation complete.")
return full_output, peak_memory, avg_time_per_token

def main():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this error when attempting to run this

INFO:2025-08-15 20:04:13,920:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-15 20:04:13,920 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory

It still runs fine on cpu, but can you triage why this error is happening?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
community issue was filed by a community member (not TT)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[$3000 BOUNTY] Add a Tensor parallel JAX Qwen2.5-7B Model to TT-xla Model Demos
2 participants