-
Notifications
You must be signed in to change notification settings - Fork 15
added complete Qwen2.5-7B tensor parallel bounty implementation #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff. Thank you for your contribution. Verified the tests and the functionality. Some general comments inlined + changes requested. Otherwise looks good.
PASS
python3 generate_multi_chip.py --model_path qwen2_5_7b/ --prompt "what is the weather today"
PASS
python test_gsm8k.py --model_path qwen2_5_7b/ --num_samples 100
For forge engineers (to PR author, none of this is blocking, but general notes and comments for our internal engineers)
- the current model uses around 126gb of ram at its peak
- by default, jax uses GSPMD style partitioning (string based IR's) but you can use shardy partitioner instead. We support both today in tenstorrent compiler but jax is also moving to shardy as default in short term. We also have limited support for GSPMD (especially in case of automatic sharding)
jax.config.update("jax_use_shardy_partitioner", True)
Here are the shardy files dumped. They are generated for a 1d mesh (tensor parallel) but you can adapt for 2D mesh and change the mesh shape. All these operations are supported in tt-mlir today.
python3 generate_multi_chip.py --model_path qwen2_5_7b/ --prompt "what is the weather today
module @jit_matmul_fn attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <["mp"=8]>
func.func public @main(%arg0: tensor<1x34x3584xbf16>, %arg1: tensor<3584x3584xbf16>, %arg2: tensor<3584xbf16>) -> (tensor<1x34x3584xbf16> {jax.result_info = "result"}) {
%0 = sdy.manual_computation(%arg0, %arg1, %arg2) in_shardings=[<@mesh, [{}, {}, {}]>, <@mesh, [{}, {"mp"}]>, <@mesh, [{"mp"}]>] out_shardings=[<@mesh, [{}, {}, {}]>] manual_axes={"mp"} (%arg3: tensor<1x34x3584xbf16>, %arg4: tensor<3584x448xbf16>, %arg5: tensor<448xbf16>) {
%1 = stablehlo.dot_general %arg3, %arg4, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x34x3584xbf16>, tensor<3584x448xbf16>) -> tensor<1x34x448xbf16>
%2 = stablehlo.broadcast_in_dim %arg5, dims = [2] : (tensor<448xbf16>) -> tensor<1x1x448xbf16>
%3 = stablehlo.broadcast_in_dim %2, dims = [0, 1, 2] : (tensor<1x1x448xbf16>) -> tensor<1x34x448xbf16>
%4 = stablehlo.add %1, %3 : tensor<1x34x448xbf16>
%5 = stablehlo.broadcast_in_dim %4, dims = [1, 2, 3] : (tensor<1x34x448xbf16>) -> tensor<1x1x34x448xbf16>
%6 = "stablehlo.all_gather"(%5) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1x1x34x448xbf16>) -> tensor<8x1x34x448xbf16>
%7 = stablehlo.transpose %6, dims = [1, 2, 0, 3] : (tensor<8x1x34x448xbf16>) -> tensor<1x34x8x448xbf16>
%8 = stablehlo.reshape %7 : (tensor<1x34x8x448xbf16>) -> tensor<1x34x3584xbf16>
sdy.return %8 : tensor<1x34x3584xbf16>
} : (tensor<1x34x3584xbf16>, tensor<3584x3584xbf16>, tensor<3584xbf16>) -> tensor<1x34x3584xbf16>
return %0 : tensor<1x34x3584xbf16>
}
}
- the current implementation uses shard_map (which is fine), but we should wrap that under a jax.jit for better caching and more explicit parallelism. The way it currently works is it will generate a new token every loop and it creates a sharded map each loop. However, our compiler will attempt to bring that back to host instead of keeping it on device. Therefore, we should either add in support to keep the intermediates on device or just jit the entire call function. See https://github.com/tenstorrent/tt-xla/pull/747/files for an example. This also means we need to be more explicit with how the weights/inputs are sharded with PartitionSpec. This will also leverage the automatic parallelization pipeline we have in tt-mlir, just like the above bounty.
For ref, when I jit the model.apply function
jit_generate = jax.jit(
model.apply,
)
lowered = jit_generate.lower(params, input_ids=input_ids, attention_mask=attention_mask,
position_ids=position_ids, past_key_values=past_key_values)
I get the following stablehlo graph (ignore the manual_computation ops as that is from the current existing code but essentially we want a full graph like this with the inputs/weights sharded and our automatic pipeline will take care of the sharding propagation and collective insertion).
qwen.mlir.txt
- Datasets: For GSM8K in test_gsm8k.py. | ||
- Others: Standard. | ||
|
||
3. Download weights: From Hugging Face (https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) in safetensors format to a local `--model_path`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets add in a section of how to do this
pip install huggingface_hub
huggingface-cli login
huggingface-cli download Qwen/Qwen2.5-7B-Instruct *some_file* --local-dir qwen2_5_7b/
1. Clone the repo and navigate to this directory. | ||
2. Install dependencies (no internet access needed beyond initial pip): | ||
``` | ||
pip install jax flax transformers safetensors psutil numpy datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add something like
python3 -m venv qwen_env
source qwen_env/bin/activate
pip install -r requirements.txt
add a separate requirement.txt file
## Usage | ||
- **Inference Demo**: Run generation on multi-device. | ||
``` | ||
python generate_multi_chip.py --model_path /path/to/weights --prompt "Your custom prompt here" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing
pip install jinja2
in requirements.txt
- `--max_tokens`: Maximum number of tokens to generate (default: 500) | ||
- `--dtype`: Choose "bfloat16" (default) or "float32" | ||
- `--no_realtime`: Disable real-time token display | ||
- **Simulate meshes**: Set `os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=N'` (e.g., N=8 for 1x8). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clean up this code to add this
flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = flags
Ideally, the user would just run a script and all environment flags and variables should be set
**Before running the script, set the environment variable:** | ||
```bash | ||
# Simulate 8 devices (1x8 mesh) | ||
export XLA_FLAGS="--xla_force_host_platform_device_count=8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change these export flags to arg parse parameters
print("Generation complete.") | ||
return full_output, peak_memory, avg_time_per_token | ||
|
||
def main(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this error when attempting to run this
INFO:2025-08-15 20:04:13,920:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-15 20:04:13,920 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
It still runs fine on cpu, but can you triage why this error is happening?
Ticket
Link to Github Issue
Problem description
Includes a JAX tensor parallel implementation of Qwen2.5-7B for the mentioned bounty. With deterministic settings it provides token for token identical outputs to a pytorch implementation.
What's changed
Directly compared logits each step between a pytorch implementation and a single device JAX implementation. Once the single device JAX worked, I used shard_map to get it parallel. Tested for 1x8.
Checklist