Skip to content

Optimize fused_moe_kernel with Split-K #9486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yuan-luo
Copy link
Contributor

@yuan-luo yuan-luo commented Aug 22, 2025

Motivation

In TP MoE, fused_moe_kernel is one of the most critical and time consuming operators,
image

Per profiling, this kernel performs very high compute-throughput. That explains why it outperforms the Triton v3.4.0 fused_moe kernel according to the benchmark results in #9276.
image
image

The only defect of the baseline is a little warp stall sampling in LDS, which seems to be inevitable. Split-K also encounters this issue.
image

This PR's main idea is to introduce Split-K. The new kernel is not exceeding but close to baseline.
Per benchmark result, SPLIT_K=4 performs best on H20-96GB.
image

This PR currently sacrifices performance by introducing tl.atomic_add during parallel tl.store C. That means there is still room for improvement in the PR in the following parts, but it requires refactor, which makes the scope out of control. So I pop up this PR for open discussion:

  1. Avoid to use tl.atomic_add, instead use tl.store. But it requires to introduce a workspace in host which have to refactor the current load/store indexing.
  2. The workspace's stride are passed from the host side as Python int, but in the kernel they are not explicitly cast to tl.int64. In Triton 3.4.0, when performing pointer arithmetic, any SIMD lane that has an int32 scalar is treated as a valid address directly. When the total workspace size exceeds 2 GiB, the left‑shifted 32‑bit offset gets its high bits truncated.
    As a result, the pointer ends up referencing unallocated memory, causing the GPU to immediately throw an illegal memory access error. So all the strides need to be changed to int64, otherwise there's OOB. As the workspace layout is calculated with BLOCK_M * BLOCK_N, num_n * BLOCK_M * BLOCK_N, … which is python integer, Triton down-cast to 32-bit const. When the product of pid_* and stride_ws_* exceeds 2^32 − 1, an overflow occurs.
  3. TMA load/store would be worth trying.

In general, even after refactoring it to the workspace‑based Split‑K version, it can only achieve performance parity with the baseline. This also reflects the excellence of the baseline version, which was first introduced in vLLM 2024.1 by DeepSeek and several other open‑source developers. Much respect to them.

The following is this PR's profiling comparing to baseline. E2E test passed with correct result in different Split-K.

=====================================
hidden_state=4096
=====================================
baseline:
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.401824                 0.385744
1       256.0                      0.545376                 0.458112
2       512.0                      0.882352                 0.602528
3      1024.0                      0.969728                 0.900032
4      2048.0                      1.903408                 1.492832
5      4096.0                      3.089696                 2.634720
6      8192.0                      5.529952                 4.992800

SPLIT_K=2:
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.401680                 0.412768
1       256.0                      0.546368                 0.498432
2       512.0                      0.882320                 0.707152
3      1024.0                      0.969664                 1.263072
4      2048.0                      1.906144                 2.008480
5      4096.0                      3.097024                 4.283952
6      8192.0                      5.535840                 8.186672

SPLIT_K=4:
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.401536                 0.384896
1       256.0                      0.545216                 0.459136
2       512.0                      0.875728                 0.609344
3      1024.0                      0.975616                 0.922224
4      2048.0                      1.908400                 1.500448
5      4096.0                      3.075136                 2.739712
6      8192.0                      5.478160                 5.133056

SPLIT_K=8:
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.401472                 0.393584
1       256.0                      0.546336                 0.483776
2       512.0                      0.881824                 0.619376
3      1024.0                      0.972064                 1.055808
4      2048.0                      1.909248                 1.616832
5      4096.0                      3.096736                 3.392016
6      8192.0                      5.521024                 6.393248

=====================================
hidden_state=8192
=====================================
baseline
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.748768                 0.767296
1       256.0                      0.938736                 0.884320
2       512.0                      1.718560                 1.170704
3      1024.0                      1.874976                 1.752080
4      2048.0                      3.690912                 2.869088
5      4096.0                      6.121920                 5.267024
6      8192.0                     10.920896                 9.860880

SPLIT_K=2
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.772992                 0.793344
1       256.0                      0.938864                 0.964128
2       512.0                      1.715552                 1.375456
3      1024.0                      1.879328                 2.464768
4      2048.0                      3.694688                 3.904160
5      4096.0                      6.126976                 8.517600
6      8192.0                     10.905152                16.197520

SPLIT_K=4
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.758880                 0.738304
1       256.0                      0.940192                 0.877280
2       512.0                      1.722080                 1.171296
3      1024.0                      1.875840                 1.765936
4      2048.0                      3.681648                 2.866048
5      4096.0                      6.105984                 5.254512
6      8192.0                     10.923488                 9.877504

SPLIT_K=8
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.760432                 0.745264
1       256.0                      0.941344                 0.902560
2       512.0                      1.716768                 1.168304
3      1024.0                      1.874848                 1.793536
4      2048.0                      3.679568                 2.875120
5      4096.0                      6.108928                 5.448640
6      8192.0                     10.897920                10.135168

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an optimization for the fused_moe_kernel by implementing Split-K v1 support. The goal is to improve the performance of this critical operator in Mixture of Experts (MoE) models, particularly in the context of Tensor Parallelism (TP MoE). While the current implementation, which uses tl.atomic_add, achieves performance close to the baseline, it lays the groundwork for further optimizations and addresses a key area for performance improvement.

Highlights

  • Fused MoE Kernel Optimization: Implemented Split-K v1 support for the fused_moe_kernel to enhance performance in Mixture of Experts operations.
  • New Kernel Introduction: A new Triton kernel, fused_moe_kernel_splitk, has been added to handle the Split-K logic, enabling parallel computation across the K dimension.
  • Performance Characteristics: Benchmarking indicates that the Split-K implementation, particularly with SPLIT_K=4, achieves performance comparable to the existing baseline, with potential for further gains.
  • Identified Limitations: The current approach uses tl.atomic_add, which introduces a performance overhead, and there are identified issues with potential 32-bit integer overflow for strides in large workspaces, suggesting areas for future refactoring.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Split-K Triton kernel for the fused Mixture of Experts (MoE) operation, aiming to optimize performance. The changes include the new fused_moe_kernel_splitk kernel, logic to invoke it based on configuration, and updates to default configurations to support this new feature. My review focuses on improving code quality and maintainability. I've identified several unused parameters in the new kernel, a redundant variable assignment, and a hardcoded configuration value that could be made more flexible. Addressing these points will enhance the clarity and long-term health of the codebase.

@yuan-luo yuan-luo changed the title Support Split-K v1 for fused_moe_kernel Refactor fused_moe_kernel with Split-K Aug 22, 2025
@yuan-luo yuan-luo changed the title Refactor fused_moe_kernel with Split-K Optimize fused_moe_kernel with Split-K Aug 22, 2025
@BBuf
Copy link
Collaborator

BBuf commented Aug 22, 2025

It seems that increasing parallelism with splitk doesn't provide significant benefits for fused MoE. The analysis above is cool, but from a practical and code redundancy perspective, it's recommended not to introduce this kernel implementation in fused_moe_triton?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants