From cc1bb9fa2ee86568592e04090b1cbf9ccdf49ec2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 19 Aug 2025 00:14:13 -0700 Subject: [PATCH] Revert "[feature] Ascend NPU graph support (#8027)" This reverts commit 94371dbbd6359139cb0a90912838a47a27f61bcc. --- .../benchmark_torch_compile_fused_moe.py | 2 +- .../sglang/srt/distributed/parallel_state.py | 14 +- .../srt/layers/attention/ascend_backend.py | 157 +--- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../srt/model_executor/cuda_graph_runner.py | 823 ++++++++++++++++- .../sglang/srt/model_executor/graph_runner.py | 860 ------------------ .../sglang/srt/model_executor/model_runner.py | 26 +- .../srt/model_executor/npu_graph_runner.py | 94 -- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/glm4_moe.py | 2 +- python/sglang/srt/models/mllama.py | 2 +- python/sglang/srt/models/qwen3.py | 2 +- python/sglang/srt/models/qwen3_moe.py | 2 +- .../eagle_draft_cuda_graph_runner.py | 18 +- .../eagle_draft_extend_cuda_graph_runner.py | 18 +- test/srt/run_suite.py | 11 - test/srt/test_ascend_graph_tp1_bf16.py | 95 -- test/srt/test_ascend_graph_tp2_bf16.py | 97 -- 18 files changed, 878 insertions(+), 1349 deletions(-) delete mode 100644 python/sglang/srt/model_executor/graph_runner.py delete mode 100644 python/sglang/srt/model_executor/npu_graph_runner.py delete mode 100644 test/srt/test_ascend_graph_tp1_bf16.py delete mode 100644 test/srt/test_ascend_graph_tp2_bf16.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1fcea7cd49d..2b4faa24b1d 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe as fused_moe_triton, ) -from sglang.srt.model_executor.graph_runner import set_torch_compile_config +from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index a8a8d20f667..286618d6bcd 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -55,7 +55,7 @@ @dataclass class GraphCaptureContext: - stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream + stream: torch.cuda.Stream TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) @@ -252,13 +252,9 @@ def __init__( if is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") - elif _is_npu: - self.device = torch.device(f"npu:{local_rank}") else: self.device = torch.device("cpu") - self.device_module = torch.get_device_module(self.device) - self.use_pynccl = use_pynccl self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce @@ -406,7 +402,7 @@ def graph_capture( self, graph_capture_context: Optional[GraphCaptureContext] = None ): if graph_capture_context is None: - stream = self.device_module.Stream() + stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream @@ -417,11 +413,11 @@ def graph_capture( # ensure all initialization operations complete before attempting to # capture the graph on another stream - curr_stream = self.device_module.current_stream() + curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with self.device_module.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context: # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | @@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ) elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() - elif hasattr(torch, "npu") and torch.npu.is_available(): - torch.npu.empty_cache() def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 70ee79b25ae..020f04dcde0 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import torch import torch_npu @@ -27,7 +27,6 @@ class ForwardMetadata: # seq len inputs extend_seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None - seq_lens_cpu_list: Optional[List[int]] = None class AscendAttnBackend(AttentionBackend): @@ -52,7 +51,7 @@ def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): def __init__(self, model_runner: ModelRunner): super().__init__() - self.forward_metadata = None + self.forward_metadata = ForwardMetadata() self.device = model_runner.device self.gen_attention_mask(128, model_runner.dtype) self.page_size = model_runner.page_size @@ -61,15 +60,9 @@ def __init__(self, model_runner: ModelRunner): self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.native_attn = TorchNativeAttnBackend(model_runner) - self.graph_metadata = {} - self.max_context_len = model_runner.model_config.context_len - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.graph_mode = False def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" - self.forward_metadata = ForwardMetadata() - self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : forward_batch.seq_lens.max() @@ -82,63 +75,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() - self.graph_mode = False - - def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - self.graph_metadata = { - "block_tables": torch.empty( - (max_bs, self.max_context_len // self.page_size), - dtype=torch.int32, - device=self.device, - ), - } - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], - ): - metadata = ForwardMetadata() - - metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] - metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() - - self.graph_metadata[bs] = metadata - self.forward_metadata = metadata - - self.graph_mode = True - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], - seq_lens_cpu: Optional[torch.Tensor], - ): - metadata = self.graph_metadata[bs] - max_len = seq_lens_cpu[:bs].max().item() - max_seq_pages = (max_len + self.page_size - 1) // self.page_size - - metadata.block_tables[:bs, :max_seq_pages].copy_( - self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size] - // self.page_size - ) - metadata.block_tables[:bs, max_seq_pages:].fill_(0) - metadata.block_tables[bs:, :].fill_(0) - - self.forward_metadata = metadata - - self.graph_mode = True - def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -231,74 +167,28 @@ def forward_decode( layer, forward_batch.out_cache_loc, k, v ) if not self.use_mla: - if self.graph_mode: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - num_tokens = query.shape[0] - workspace = ( - torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - ) - ) - output = torch.empty( - (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), - dtype=q.dtype, - device=q.device, - ) - softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) - torch_npu.npu_fused_infer_attention_score.out( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - workspace=workspace, - out=[output, softmax_lse], - ) - else: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ) + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - num_tokens = query.shape[0] - output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + num_tokens = query.shape[0] + output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=self.forward_metadata.seq_lens_cpu_int, - out=output, - ) + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=output, + ) return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: query = q.view(-1, layer.tp_q_head_num, layer.head_dim) @@ -330,6 +220,3 @@ def forward_decode( out=attn_output, ) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) - - def get_cuda_graph_seq_len_fill_value(self): - return 0 diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 07d7f5234cd..1653d4535da 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -376,7 +376,7 @@ def set_kv_buffer( v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, ): - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if layer_id_override is not None: layer_id = layer_id_override diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index aeca8dcb7e2..cc87910ac10 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -15,22 +15,833 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import bisect +import gc +import inspect +import logging +import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable, Optional, Union import torch +import tqdm +from torch.profiler import ProfilerActivity, profile -from sglang.srt.model_executor.graph_runner import GraphRunner +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.dp_attention import ( + DpPaddingMode, + get_attention_tp_rank, + get_attention_tp_size, + set_dp_buffer_len, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.torchao_utils import save_gemlite_cache +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + PPProxyTensors, + enable_num_token_non_padded, +) +from sglang.srt.patch_torch import monkey_patch_torch_compile +from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin +from sglang.srt.utils import ( + empty_context, + get_available_gpu_memory, + get_device_memory_capacity, + rank0_log, + require_attn_tp_gather, + require_gathered_buffer, + require_mlp_sync, + require_mlp_tp_gather, +) + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +# Detect whether the current forward pass is in capture mode +is_capture_mode = False + + +def get_is_capture_mode(): + return is_capture_mode + + +@contextmanager +def model_capture_mode(): + global is_capture_mode + is_capture_mode = True + + yield + + is_capture_mode = False + + +@contextmanager +def freeze_gc(enable_cudagraph_gc: bool): + """ + Optimize garbage collection during CUDA graph capture. + Clean up, then freeze all remaining objects from being included + in future collections if GC is disabled during capture. + """ + gc.collect() + should_freeze = not enable_cudagraph_gc + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + + +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub.leave_torch_compile() + else: + sub.enter_torch_compile(num_tokens=num_tokens) + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse, num_tokens) + + +@contextmanager +def patch_model( + model: torch.nn.Module, + enable_compile: bool, + num_tokens: int, + tp_group: GroupCoordinator, +): + """Patch the model to make it compatible with with torch.compile""" + backup_ca_comm = None + + try: + if enable_compile: + _to_torch(model, reverse=False, num_tokens=num_tokens) + backup_ca_comm = tp_group.ca_comm + # Use custom-allreduce here. + # We found the custom allreduce is much faster than the built-in allreduce in torch, + # even with ENABLE_INTRA_NODE_COMM=1. + # tp_group.ca_comm = None + yield torch.compile( + torch.no_grad()(model.forward), + mode=os.environ.get( + "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" + ), + dynamic=False, + ) + else: + yield model.forward + finally: + if enable_compile: + _to_torch(model, reverse=True, num_tokens=num_tokens) + tp_group.ca_comm = backup_ca_comm + + +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + + monkey_patch_torch_compile() + + +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + + if capture_bs is None: + if server_args.speculative_algorithm is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + list(range(48, 161, 16)) + else: + capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) + else: + # Since speculative decoding requires more cuda graph memory, we + # capture less. + capture_bs = ( + list(range(1, 9)) + + list(range(10, 33, 2)) + + list(range(40, 64, 8)) + + list(range(80, 161, 16)) + ) + + gpu_mem = get_device_memory_capacity() + if gpu_mem is not None: + if gpu_mem > 90 * 1024: # H200, H20 + capture_bs += list(range(160, 257, 8)) + if gpu_mem > 160 * 1000: # B200, MI300 + capture_bs += list(range(256, 513, 16)) + + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very small. We add more values here to make sure we capture the maximum bs. + capture_bs += [model_runner.req_to_token_pool.size] + + mul_base = 1 -class CudaGraphRunner(GraphRunner): + if server_args.enable_two_batch_overlap: + mul_base *= 2 + + if require_gathered_buffer(server_args): + mul_base *= get_attention_tp_size() + + capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] + + if server_args.cuda_graph_max_bs: + capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] + if max(capture_bs) < server_args.cuda_graph_max_bs: + capture_bs += list( + range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) + ) + capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] + capture_bs = list(sorted(set(capture_bs))) + assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + return capture_bs, compile_bs + + +# Reuse this memory pool across all cuda graph runners. +global_graph_memory_pool = None + + +def get_global_graph_memory_pool(): + return global_graph_memory_pool + + +def set_global_graph_memory_pool(val): + global global_graph_memory_pool + global_graph_memory_pool = val + + +class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): # Parse args - super().__init__(model_runner) + self.model_runner = model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) + self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) + self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) + self.enable_two_batch_overlap = ( + model_runner.server_args.enable_two_batch_overlap + ) + self.speculative_algorithm = model_runner.server_args.speculative_algorithm + self.enable_profile_cuda_graph = ( + model_runner.server_args.enable_profile_cuda_graph + ) + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size + self.pp_size = model_runner.server_args.pp_size + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + rank0_log(f"Capture cuda graph bs {self.capture_bs}") + self.capture_forward_mode = ForwardMode.DECODE + self.capture_hidden_mode = CaptureHiddenMode.NULL + self.num_tokens_per_bs = 1 + if model_runner.spec_algorithm.is_eagle(): + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen") + else: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup + if model_runner.server_args.enable_return_hidden_states: + self.capture_hidden_mode = CaptureHiddenMode.FULL + + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.attn_backend.init_cuda_graph_state( + self.max_bs, self.max_num_token + ) + self.seq_len_fill_value = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.encoder_len_fill_value = 0 + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + + if self.enable_torch_compile: + set_torch_compile_config() + + if self.model_runner.server_args.enable_lora: + self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) + + # Graph inputs + with torch.device("cuda"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) + self.tbo_plugin = TboCudaGraphRunnerPlugin() + + # pipeline parallelism + if self.pp_size > 1: + self.pp_proxy_tensors = { + "hidden_states": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + "residual": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + } + + # Speculative_inference + if model_runner.spec_algorithm.is_eagle3(): + self.model_runner.model.set_eagle3_layers_to_capture() + + if self.is_encoder_decoder: + # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch + self.encoder_lens = torch.full( + (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 + ) + else: + self.encoder_lens = None + + if self.require_gathered_buffer: + if self.require_mlp_tp_gather: + self.global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + else: + assert self.require_attn_tp_gather + self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (1,), dtype=torch.int32 + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + + self.custom_mask = torch.ones( + ( + (self.seq_lens.sum().item() + self.max_num_token) + * self.num_tokens_per_bs + ), + dtype=torch.bool, + device="cuda", + ) + self.next_token_logits_buffer = torch.zeros( + (self.max_num_token, self.model_runner.model_config.vocab_size), + dtype=torch.float, + device="cuda", + ) + + # Capture + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" + ) + + def can_run(self, forward_batch: ForwardBatch): + if self.require_mlp_tp_gather: + cuda_graph_bs = ( + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else max(forward_batch.global_num_tokens_cpu) + ) + else: + cuda_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + + # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) + # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph + # because the full_text_row_masked_out_mask tensor will always be ones + is_encoder_lens_supported = ( + torch.all(forward_batch.encoder_lens > 0) + if self.is_encoder_decoder + else True + ) + + requested_capture_hidden_mode = max( + forward_batch.capture_hidden_mode, + ( + forward_batch.spec_info.capture_hidden_mode + if getattr(forward_batch.spec_info, "capture_hidden_mode", None) + is not None + else CaptureHiddenMode.NULL + ), + ) + capture_hidden_mode_matches = ( + requested_capture_hidden_mode == CaptureHiddenMode.NULL + or requested_capture_hidden_mode == self.capture_hidden_mode + ) + is_tbo_supported = ( + forward_batch.can_run_tbo if self.enable_two_batch_overlap else True + ) + + return ( + is_bs_supported + and is_encoder_lens_supported + and is_tbo_supported + and capture_hidden_mode_matches + ) + + def capture(self) -> None: + profile_context = empty_context() + if self.enable_profile_cuda_graph: + profile_context = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) + + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with freeze_gc( + self.model_runner.server_args.enable_cudagraph_gc + ), graph_capture() as graph_capture_context: + with profile_context as prof: + self.stream = graph_capture_context.stream + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + # Reverse the order to enable better memory sharing across cuda graphs. + capture_range = ( + tqdm.tqdm(list(reversed(self.capture_bs))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.capture_bs) + ) + for i, bs in enumerate(capture_range): + if get_tensor_model_parallel_rank() == 0: + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + capture_range.set_description( + f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" + ) + + with patch_model( + self.model_runner.model, + bs in self.compile_bs, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, + ) as forward: + ( + graph, + output_buffers, + ) = self.capture_one_batch_size(bs, forward) + self.graphs[bs] = graph + self.output_buffers[bs] = output_buffers + + # Save gemlite cache after each capture + save_gemlite_cache() + + if self.enable_profile_cuda_graph: + log_message = ( + "Sorted by CUDA Time:\n" + + prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total", row_limit=10 + ) + + "\n\nSorted by CPU Time:\n" + + prof.key_averages(group_by_input_shape=True).table( + sort_by="cpu_time_total", row_limit=10 + ) + ) + logger.info(log_message) + + def capture_one_batch_size(self, bs: int, forward: Callable): + graph = torch.cuda.CUDAGraph() + stream = self.stream + num_tokens = bs * self.num_tokens_per_bs + + # Graph inputs + input_ids = self.input_ids[:num_tokens] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] + if self.is_encoder_decoder: + encoder_lens = self.encoder_lens[:bs] + else: + encoder_lens = None + mrope_positions = self.mrope_positions[:, :bs] + next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] + self.num_token_non_padded[...] = num_tokens + + # pipeline parallelism + if self.pp_size > 1: + pp_proxy_tensors = PPProxyTensors( + {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} + ) + + if self.require_mlp_tp_gather: + self.global_num_tokens_gpu.copy_( + torch.tensor( + [num_tokens] * self.dp_size, + dtype=torch.int32, + device=input_ids.device, + ) + ) + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens] * self.dp_size, + dtype=torch.int32, + device=input_ids.device, + ) + ) + global_dp_buffer_len = num_tokens * self.dp_size + elif self.require_attn_tp_gather: + self.global_num_tokens_gpu.copy_( + torch.tensor( + [num_tokens], + dtype=torch.int32, + device=input_ids.device, + ) + ) + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens], + dtype=torch.int32, + device=input_ids.device, + ) + ) + global_dp_buffer_len = num_tokens + else: + global_dp_buffer_len = None + + spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) + + if self.model_runner.server_args.enable_lora: + # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever + # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). + lora_ids = [None] * bs + else: + lora_ids = None + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + next_token_logits_buffer=next_token_logits_buffer, + orig_seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum().item(), + encoder_lens=encoder_lens, + return_logprob=False, + positions=positions, + global_num_tokens_gpu=self.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=global_dp_buffer_len, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=self.capture_hidden_mode, + num_token_non_padded=self.num_token_non_padded, + global_forward_mode=self.capture_forward_mode, + lora_ids=lora_ids, + ) + self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) + + if lora_ids is not None: + self.model_runner.lora_manager.prepare_lora_batch(forward_batch) + + # Attention backend + self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_batch.forward_mode, + forward_batch.spec_info, + ) + + # Run and capture + def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len(global_dp_buffer_len, num_tokens) + + kwargs = {} + if ( + self.pp_size > 1 + and "pp_proxy_tensors" in inspect.signature(forward).parameters + ): + kwargs["pp_proxy_tensors"] = PPProxyTensors( + {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} + ) + + logits_output_or_pp_proxy_tensors = forward( + input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + return logits_output_or_pp_proxy_tensors + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + if get_global_graph_memory_pool() is None: + set_global_graph_memory_pool(torch.cuda.graph_pool_handle()) + # Set graph pool id globally to be able to use symmetric memory + set_graph_pool_id(get_global_graph_memory_pool()) + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + return graph, out + + def recapture_if_needed(self, forward_batch: ForwardBatch): + + # If the required capture_hidden_mode changes, we need to recapture the graph + + # These are the different factors that can influence the capture_hidden_mode + capture_hidden_mode_required_by_forward_batch = ( + forward_batch.capture_hidden_mode + ) + capture_hidden_mode_required_by_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + capture_hidden_mode_required_for_returning_hidden_states = ( + CaptureHiddenMode.FULL + if self.model_runner.server_args.enable_return_hidden_states + else CaptureHiddenMode.NULL + ) + + # Determine the highest capture_hidden_mode required + # (If we have FULL, we can emulate LAST or NULL) + # (If we have LAST, we can emulate NULL) + required_capture_hidden_mode = max( + capture_hidden_mode_required_by_forward_batch, + capture_hidden_mode_required_by_spec_info, + capture_hidden_mode_required_for_returning_hidden_states, + ) + + # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture + if self.capture_hidden_mode != required_capture_hidden_mode: + self.capture_hidden_mode = required_capture_hidden_mode + self.capture() + + def replay_prepare( + self, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): + self.recapture_if_needed(forward_batch) + + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + if self.require_mlp_tp_gather: + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens / self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else max_num_tokens + ) + index = bisect.bisect_left(self.capture_bs, max_batch_size) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + + # Common inputs + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + self.positions[:raw_num_token].copy_(forward_batch.positions) + + seq_lens_cpu = None + if forward_batch.seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(self.seq_len_fill_value) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + seq_lens_cpu = self.seq_lens_cpu[:bs] + + if pp_proxy_tensors: + for key in self.pp_proxy_tensors.keys(): + dim = pp_proxy_tensors[key].shape[0] + self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) + + if self.is_encoder_decoder: + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if self.require_gathered_buffer: + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) + if enable_num_token_non_padded(self.model_runner.server_args): + num_token_non_padded = forward_batch.num_token_non_padded + if self.require_gathered_buffer: + tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs + num_local_token_non_padded = torch.clamp( + num_token_non_padded - tokens_per_rank * self.attn_tp_rank, + min=0, + max=tokens_per_rank, + ) + self.num_token_non_padded.copy_(num_local_token_non_padded) + else: + self.num_token_non_padded.copy_(num_token_non_padded) + if self.enable_two_batch_overlap: + self.tbo_plugin.replay_prepare( + forward_mode=self.capture_forward_mode, + bs=bs, + num_token_non_padded=len(forward_batch.input_ids), + spec_info=forward_batch.spec_info, + ) + if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: + forward_batch.spec_info.custom_mask = self.custom_mask + # Attention backend + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( + bs, + self.req_pool_indices[:bs], + self.seq_lens[:bs], + forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, + self.encoder_lens[:bs] if self.is_encoder_decoder else None, + self.capture_forward_mode, + forward_batch.spec_info, + seq_lens_cpu=seq_lens_cpu, + ) + + # Store fields + self.raw_bs = raw_bs + self.raw_num_token = raw_num_token + self.bs = bs + + def replay( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + if not skip_attn_backend_init: + self.replay_prepare(forward_batch, pp_proxy_tensors) + else: + # In speculative decoding, these two fields are still needed. + self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) + self.positions[: self.raw_num_token].copy_(forward_batch.positions) + + # Replay + self.graphs[self.bs].replay() + + output = self.output_buffers[self.bs] + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_token], + hidden_states=( + output.hidden_states[: self.raw_num_token] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) + + def get_spec_info(self, num_tokens: int): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import EagleVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + else: + spec_info = EagleVerifyInput( + draft_token=None, + custom_mask=self.custom_mask, + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + retrive_cum_len=None, + spec_steps=self.model_runner.server_args.speculative_num_steps, + topk=self.model_runner.server_args.speculative_eagle_topk, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + return spec_info + - def _create_device_graph(self): - return torch.cuda.CUDAGraph() +CUDA_GRAPH_CAPTURE_FAILED_MSG = ( + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" +) diff --git a/python/sglang/srt/model_executor/graph_runner.py b/python/sglang/srt/model_executor/graph_runner.py deleted file mode 100644 index afcb00b4e76..00000000000 --- a/python/sglang/srt/model_executor/graph_runner.py +++ /dev/null @@ -1,860 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with device graph and torch.compile.""" - -from __future__ import annotations - -import bisect -import gc -import inspect -import logging -import os -from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Optional, Union - -import torch -import tqdm -from torch.profiler import ProfilerActivity, profile - -from sglang.srt.custom_op import CustomOp -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - set_graph_pool_id, -) -from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture -from sglang.srt.layers.dp_attention import ( - DpPaddingMode, - get_attention_tp_rank, - get_attention_tp_size, - set_dp_buffer_len, -) -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, - PPProxyTensors, - enable_num_token_non_padded, -) -from sglang.srt.patch_torch import monkey_patch_torch_compile -from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin -from sglang.srt.utils import ( - empty_context, - get_available_gpu_memory, - get_device_memory_capacity, - rank0_log, - require_attn_tp_gather, - require_gathered_buffer, - require_mlp_sync, - require_mlp_tp_gather, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - -# Detect whether the current forward pass is in capture mode -is_capture_mode = False - - -def get_is_capture_mode(): - return is_capture_mode - - -@contextmanager -def model_capture_mode(): - global is_capture_mode - is_capture_mode = True - - yield - - is_capture_mode = False - - -@contextmanager -def freeze_gc(enable_cudagraph_gc: bool): - """ - Optimize garbage collection during CUDA graph capture. - Clean up, then freeze all remaining objects from being included - in future collections if GC is disabled during capture. - """ - gc.collect() - should_freeze = not enable_cudagraph_gc - if should_freeze: - gc.freeze() - try: - yield - finally: - if should_freeze: - gc.unfreeze() - - -def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): - for sub in model._modules.values(): - if isinstance(sub, CustomOp): - if reverse: - sub.leave_torch_compile() - else: - sub.enter_torch_compile(num_tokens=num_tokens) - if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse, num_tokens) - - -@contextmanager -def patch_model( - model: torch.nn.Module, - enable_compile: bool, - num_tokens: int, - tp_group: GroupCoordinator, -): - """Patch the model to make it compatible with with torch.compile""" - backup_ca_comm = None - - try: - if enable_compile: - _to_torch(model, reverse=False, num_tokens=num_tokens) - backup_ca_comm = tp_group.ca_comm - # Use custom-allreduce here. - # We found the custom allreduce is much faster than the built-in allreduce in torch, - # even with ENABLE_INTRA_NODE_COMM=1. - # tp_group.ca_comm = None - yield torch.compile( - torch.no_grad()(model.forward), - mode=os.environ.get( - "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" - ), - dynamic=False, - ) - else: - yield model.forward - finally: - if enable_compile: - _to_torch(model, reverse=True, num_tokens=num_tokens) - tp_group.ca_comm = backup_ca_comm - - -def set_torch_compile_config(): - import torch._dynamo.config - import torch._inductor.config - - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - - # FIXME: tmp workaround - torch._dynamo.config.accumulated_cache_size_limit = 1024 - if hasattr(torch._dynamo.config, "cache_size_limit"): - torch._dynamo.config.cache_size_limit = 1024 - - monkey_patch_torch_compile() - - -def get_batch_sizes_to_capture(model_runner: ModelRunner): - server_args = model_runner.server_args - capture_bs = server_args.cuda_graph_bs - - if capture_bs is None: - if server_args.speculative_algorithm is None: - if server_args.disable_cuda_graph_padding: - capture_bs = list(range(1, 33)) + list(range(48, 161, 16)) - else: - capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) - else: - # Since speculative decoding requires more cuda graph memory, we - # capture less. - capture_bs = ( - list(range(1, 9)) - + list(range(10, 33, 2)) - + list(range(40, 64, 8)) - + list(range(80, 161, 16)) - ) - - gpu_mem = get_device_memory_capacity() - if gpu_mem is not None: - if gpu_mem > 90 * 1024: # H200, H20 - capture_bs += list(range(160, 257, 8)) - if gpu_mem > 160 * 1000: # B200, MI300 - capture_bs += list(range(256, 513, 16)) - - if max(capture_bs) > model_runner.req_to_token_pool.size: - # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very small. We add more values here to make sure we capture the maximum bs. - capture_bs += [model_runner.req_to_token_pool.size] - - mul_base = 1 - - if server_args.enable_two_batch_overlap: - mul_base *= 2 - - if require_gathered_buffer(server_args): - mul_base *= get_attention_tp_size() - - capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] - - if server_args.cuda_graph_max_bs: - capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] - if max(capture_bs) < server_args.cuda_graph_max_bs: - capture_bs += list( - range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) - ) - capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] - capture_bs = list(sorted(set(capture_bs))) - assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" - compile_bs = ( - [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] - if server_args.enable_torch_compile - else [] - ) - return capture_bs, compile_bs - - -# Reuse this memory pool across all device graph runners. -global_graph_memory_pool = None - - -def get_global_graph_memory_pool(): - return global_graph_memory_pool - - -def set_global_graph_memory_pool(val): - global global_graph_memory_pool - global_graph_memory_pool = val - - -class GraphRunner: - """A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile.""" - - def __init__(self, model_runner: ModelRunner): - # Parse args - self.model_runner = model_runner - self.device = model_runner.device - self.device_module = torch.get_device_module(self.device) - self.graphs = {} - self.output_buffers = {} - self.enable_torch_compile = model_runner.server_args.enable_torch_compile - self.disable_padding = model_runner.server_args.disable_cuda_graph_padding - self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder - self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) - self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) - self.require_mlp_sync = require_mlp_sync(model_runner.server_args) - self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) - self.enable_two_batch_overlap = ( - model_runner.server_args.enable_two_batch_overlap - ) - self.speculative_algorithm = model_runner.server_args.speculative_algorithm - self.enable_profile_cuda_graph = ( - model_runner.server_args.enable_profile_cuda_graph - ) - self.tp_size = model_runner.server_args.tp_size - self.dp_size = model_runner.server_args.dp_size - self.pp_size = model_runner.server_args.pp_size - - self.attn_tp_size = get_attention_tp_size() - self.attn_tp_rank = get_attention_tp_rank() - - # Batch sizes to capture - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - rank0_log(f"Capture graph bs {self.capture_bs}") - self.capture_forward_mode = ForwardMode.DECODE - self.capture_hidden_mode = CaptureHiddenMode.NULL - self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): - if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) - - # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup - if model_runner.server_args.enable_return_hidden_states: - self.capture_hidden_mode = CaptureHiddenMode.FULL - - # Attention backend - self.max_bs = max(self.capture_bs) - self.max_num_token = self.max_bs * self.num_tokens_per_bs - self.model_runner.attn_backend.init_cuda_graph_state( - self.max_bs, self.max_num_token - ) - self.seq_len_fill_value = ( - self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() - ) - - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary - self.encoder_len_fill_value = 0 - self.seq_lens_cpu = torch.full( - (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 - ) - - if self.enable_torch_compile: - set_torch_compile_config() - - if self.model_runner.server_args.enable_lora: - self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) - - # Graph inputs - with torch.device(self.device): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) - self.seq_lens = torch.full( - (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 - ) - self.out_cache_loc = torch.zeros( - (self.max_num_token,), dtype=self._cache_loc_dtype() - ) - self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) - self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) - self.tbo_plugin = TboCudaGraphRunnerPlugin() - - # pipeline parallelism - if self.pp_size > 1: - self.pp_proxy_tensors = { - "hidden_states": torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), - dtype=torch.bfloat16, - ), - "residual": torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), - dtype=torch.bfloat16, - ), - } - - # Speculative_inference - if model_runner.spec_algorithm.is_eagle3(): - self.model_runner.model.set_eagle3_layers_to_capture() - - if self.is_encoder_decoder: - # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch - self.encoder_lens = torch.full( - (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 - ) - else: - self.encoder_lens = None - - if self.require_gathered_buffer: - if self.require_mlp_tp_gather: - self.global_num_tokens_gpu = torch.zeros( - (self.dp_size,), dtype=torch.int32 - ) - self.global_num_tokens_for_logprob_gpu = torch.zeros( - (self.dp_size,), dtype=torch.int32 - ) - else: - assert self.require_attn_tp_gather - self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) - self.global_num_tokens_for_logprob_gpu = torch.zeros( - (1,), dtype=torch.int32 - ) - else: - self.global_num_tokens_gpu = None - self.global_num_tokens_for_logprob_gpu = None - - self.custom_mask = torch.ones( - ( - (self.seq_lens.sum().item() + self.max_num_token) - * self.num_tokens_per_bs - ), - dtype=torch.bool, - device=self.device, - ) - self.next_token_logits_buffer = torch.zeros( - (self.max_num_token, self.model_runner.model_config.vocab_size), - dtype=torch.float, - device=self.device, - ) - - # Capture - try: - with model_capture_mode(): - self.capture() - except RuntimeError as e: - raise Exception( - f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" - ) - - def _cache_loc_dtype(self): - return torch.int64 - - def can_run(self, forward_batch: ForwardBatch): - if self.require_mlp_tp_gather: - cuda_graph_bs = ( - max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs - if self.model_runner.spec_algorithm.is_eagle() - else max(forward_batch.global_num_tokens_cpu) - ) - else: - cuda_graph_bs = forward_batch.batch_size - - is_bs_supported = ( - cuda_graph_bs in self.graphs - if self.disable_padding - else cuda_graph_bs <= self.max_bs - ) - - if self.require_mlp_sync: - is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph - - # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) - # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph - # because the full_text_row_masked_out_mask tensor will always be ones - is_encoder_lens_supported = ( - torch.all(forward_batch.encoder_lens > 0) - if self.is_encoder_decoder - else True - ) - - requested_capture_hidden_mode = max( - forward_batch.capture_hidden_mode, - ( - forward_batch.spec_info.capture_hidden_mode - if getattr(forward_batch.spec_info, "capture_hidden_mode", None) - is not None - else CaptureHiddenMode.NULL - ), - ) - capture_hidden_mode_matches = ( - requested_capture_hidden_mode == CaptureHiddenMode.NULL - or requested_capture_hidden_mode == self.capture_hidden_mode - ) - is_tbo_supported = ( - forward_batch.can_run_tbo if self.enable_two_batch_overlap else True - ) - - return ( - is_bs_supported - and is_encoder_lens_supported - and is_tbo_supported - and capture_hidden_mode_matches - ) - - def capture(self) -> None: - profile_context = empty_context() - if self.enable_profile_cuda_graph: - profile_context = profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - ) - - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - with freeze_gc( - self.model_runner.server_args.enable_cudagraph_gc - ), graph_capture() as graph_capture_context: - with profile_context as prof: - self.stream = graph_capture_context.stream - avail_mem = get_available_gpu_memory( - self.model_runner.device, - self.model_runner.gpu_id, - empty_cache=False, - ) - # Reverse the order to enable better memory sharing across cuda graphs. - capture_range = ( - tqdm.tqdm(list(reversed(self.capture_bs))) - if get_tensor_model_parallel_rank() == 0 - else reversed(self.capture_bs) - ) - for i, bs in enumerate(capture_range): - if get_tensor_model_parallel_rank() == 0: - avail_mem = get_available_gpu_memory( - self.model_runner.device, - self.model_runner.gpu_id, - empty_cache=False, - ) - capture_range.set_description( - f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" - ) - - with patch_model( - self.model_runner.model, - bs in self.compile_bs, - num_tokens=bs * self.num_tokens_per_bs, - tp_group=self.model_runner.tp_group, - ) as forward: - ( - graph, - output_buffers, - ) = self.capture_one_batch_size(bs, forward) - self.graphs[bs] = graph - self.output_buffers[bs] = output_buffers - - # Save gemlite cache after each capture - save_gemlite_cache() - - if self.enable_profile_cuda_graph: - log_message = ( - "Sorted by CUDA Time:\n" - + prof.key_averages(group_by_input_shape=True).table( - sort_by="cuda_time_total", row_limit=10 - ) - + "\n\nSorted by CPU Time:\n" - + prof.key_averages(group_by_input_shape=True).table( - sort_by="cpu_time_total", row_limit=10 - ) - ) - logger.info(log_message) - - def _capture_graph(self, graph, pool, stream, run_once_fn): - with self.device_module.graph(graph, pool=pool, stream=stream): - out = run_once_fn() - return out - - def _create_device_graph(self): - pass - - def capture_one_batch_size(self, bs: int, forward: Callable): - graph = self._create_device_graph() - stream = self.stream - num_tokens = bs * self.num_tokens_per_bs - - # Graph inputs - input_ids = self.input_ids[:num_tokens] - req_pool_indices = self.req_pool_indices[:bs] - seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:num_tokens] - positions = self.positions[:num_tokens] - if self.is_encoder_decoder: - encoder_lens = self.encoder_lens[:bs] - else: - encoder_lens = None - mrope_positions = self.mrope_positions[:, :bs] - next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] - self.num_token_non_padded[...] = num_tokens - - # pipeline parallelism - if self.pp_size > 1: - pp_proxy_tensors = PPProxyTensors( - {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} - ) - - if self.require_mlp_tp_gather: - self.global_num_tokens_gpu.copy_( - torch.tensor( - [num_tokens] * self.dp_size, - dtype=torch.int32, - device=input_ids.device, - ) - ) - self.global_num_tokens_for_logprob_gpu.copy_( - torch.tensor( - [num_tokens] * self.dp_size, - dtype=torch.int32, - device=input_ids.device, - ) - ) - global_dp_buffer_len = num_tokens * self.dp_size - elif self.require_attn_tp_gather: - self.global_num_tokens_gpu.copy_( - torch.tensor( - [num_tokens], - dtype=torch.int32, - device=input_ids.device, - ) - ) - self.global_num_tokens_for_logprob_gpu.copy_( - torch.tensor( - [num_tokens], - dtype=torch.int32, - device=input_ids.device, - ) - ) - global_dp_buffer_len = num_tokens - else: - global_dp_buffer_len = None - - spec_info = self.get_spec_info(num_tokens) - if self.capture_hidden_mode != CaptureHiddenMode.FULL: - self.capture_hidden_mode = ( - spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL - ) - - if self.model_runner.server_args.enable_lora: - # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever - # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). - lora_ids = [None] * bs - else: - lora_ids = None - - forward_batch = ForwardBatch( - forward_mode=self.capture_forward_mode, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - next_token_logits_buffer=next_token_logits_buffer, - orig_seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum().item(), - encoder_lens=encoder_lens, - return_logprob=False, - positions=positions, - global_num_tokens_gpu=self.global_num_tokens_gpu, - global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, - dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), - global_dp_buffer_len=global_dp_buffer_len, - mrope_positions=mrope_positions, - spec_algorithm=self.model_runner.spec_algorithm, - spec_info=spec_info, - capture_hidden_mode=self.capture_hidden_mode, - num_token_non_padded=self.num_token_non_padded, - global_forward_mode=self.capture_forward_mode, - lora_ids=lora_ids, - ) - self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) - - if lora_ids is not None: - self.model_runner.lora_manager.prepare_lora_batch(forward_batch) - - # Attention backend - self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_batch.forward_mode, - forward_batch.spec_info, - ) - - # Run and capture - def run_once(): - # Clean intermediate result cache for DP attention - forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) - - kwargs = {} - if ( - self.pp_size > 1 - and "pp_proxy_tensors" in inspect.signature(forward).parameters - ): - kwargs["pp_proxy_tensors"] = PPProxyTensors( - {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} - ) - - logits_output_or_pp_proxy_tensors = forward( - input_ids, - forward_batch.positions, - forward_batch, - **kwargs, - ) - return logits_output_or_pp_proxy_tensors - - for _ in range(2): - self.device_module.synchronize() - self.model_runner.tp_group.barrier() - run_once() - - if get_global_graph_memory_pool() is None: - set_global_graph_memory_pool(self.device_module.graph_pool_handle()) - # Set graph pool id globally to be able to use symmetric memory - set_graph_pool_id(get_global_graph_memory_pool()) - out = self._capture_graph( - graph, get_global_graph_memory_pool(), stream, run_once - ) - - return graph, out - - def recapture_if_needed(self, forward_batch: ForwardBatch): - - # If the required capture_hidden_mode changes, we need to recapture the graph - - # These are the different factors that can influence the capture_hidden_mode - capture_hidden_mode_required_by_forward_batch = ( - forward_batch.capture_hidden_mode - ) - capture_hidden_mode_required_by_spec_info = getattr( - forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL - ) - capture_hidden_mode_required_for_returning_hidden_states = ( - CaptureHiddenMode.FULL - if self.model_runner.server_args.enable_return_hidden_states - else CaptureHiddenMode.NULL - ) - - # Determine the highest capture_hidden_mode required - # (If we have FULL, we can emulate LAST or NULL) - # (If we have LAST, we can emulate NULL) - required_capture_hidden_mode = max( - capture_hidden_mode_required_by_forward_batch, - capture_hidden_mode_required_by_spec_info, - capture_hidden_mode_required_for_returning_hidden_states, - ) - - # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture - if self.capture_hidden_mode != required_capture_hidden_mode: - self.capture_hidden_mode = required_capture_hidden_mode - self.capture() - - def replay_prepare( - self, - forward_batch: ForwardBatch, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ): - self.recapture_if_needed(forward_batch) - - raw_bs = forward_batch.batch_size - raw_num_token = raw_bs * self.num_tokens_per_bs - - # Pad - if self.require_mlp_tp_gather: - max_num_tokens = max(forward_batch.global_num_tokens_cpu) - max_batch_size = ( - max_num_tokens / self.num_tokens_per_bs - if self.model_runner.spec_algorithm.is_eagle() - else max_num_tokens - ) - index = bisect.bisect_left(self.capture_bs, max_batch_size) - else: - index = bisect.bisect_left(self.capture_bs, raw_bs) - bs = self.capture_bs[index] - if bs != raw_bs: - self.seq_lens.fill_(self.seq_len_fill_value) - self.out_cache_loc.zero_() - - # Common inputs - self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) - - seq_lens_cpu = None - if forward_batch.seq_lens_cpu is not None: - if bs != raw_bs: - self.seq_lens_cpu.fill_(self.seq_len_fill_value) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) - seq_lens_cpu = self.seq_lens_cpu[:bs] - - if pp_proxy_tensors: - for key in self.pp_proxy_tensors.keys(): - dim = pp_proxy_tensors[key].shape[0] - self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) - - if self.is_encoder_decoder: - self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) - if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - if self.require_gathered_buffer: - self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) - self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) - if enable_num_token_non_padded(self.model_runner.server_args): - num_token_non_padded = forward_batch.num_token_non_padded - if self.require_gathered_buffer: - tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs - num_local_token_non_padded = torch.clamp( - num_token_non_padded - tokens_per_rank * self.attn_tp_rank, - min=0, - max=tokens_per_rank, - ) - self.num_token_non_padded.copy_(num_local_token_non_padded) - else: - self.num_token_non_padded.copy_(num_token_non_padded) - if self.enable_two_batch_overlap: - self.tbo_plugin.replay_prepare( - forward_mode=self.capture_forward_mode, - bs=bs, - num_token_non_padded=len(forward_batch.input_ids), - spec_info=forward_batch.spec_info, - ) - if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: - forward_batch.spec_info.custom_mask = self.custom_mask - # Attention backend - self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, - self.req_pool_indices[:bs], - self.seq_lens[:bs], - forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, - self.encoder_lens[:bs] if self.is_encoder_decoder else None, - self.capture_forward_mode, - forward_batch.spec_info, - seq_lens_cpu=seq_lens_cpu, - ) - - # Store fields - self.raw_bs = raw_bs - self.raw_num_token = raw_num_token - self.bs = bs - - def replay( - self, - forward_batch: ForwardBatch, - skip_attn_backend_init: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - if not skip_attn_backend_init: - self.replay_prepare(forward_batch, pp_proxy_tensors) - else: - # In speculative decoding, these two fields are still needed. - self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) - self.positions[: self.raw_num_token].copy_(forward_batch.positions) - - # Replay - self.graphs[self.bs].replay() - - output = self.output_buffers[self.bs] - if isinstance(output, LogitsProcessorOutput): - return LogitsProcessorOutput( - next_token_logits=output.next_token_logits[: self.raw_num_token], - hidden_states=( - output.hidden_states[: self.raw_num_token] - if output.hidden_states is not None - else None - ), - ) - else: - assert isinstance(output, PPProxyTensors) - return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) - - def get_spec_info(self, num_tokens: int): - spec_info = None - if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import EagleVerifyInput - - if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen.") - else: - spec_info = EagleVerifyInput( - draft_token=None, - custom_mask=self.custom_mask, - positions=None, - retrive_index=None, - retrive_next_token=None, - retrive_next_sibling=None, - retrive_cum_len=None, - spec_steps=self.model_runner.server_args.speculative_num_steps, - topk=self.model_runner.server_args.speculative_eagle_topk, - draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, - capture_hidden_mode=CaptureHiddenMode.FULL, - seq_lens_sum=None, - seq_lens_cpu=None, - ) - - return spec_info - - -GRAPH_CAPTURE_FAILED_MSG = ( - "Possible solutions:\n" - "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" - "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" -) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b05973c812b..6665458b879 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -91,7 +91,6 @@ ) from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -342,12 +341,9 @@ def initialize(self, min_per_gpu_memory: float): if self.device == "cuda": self.init_cublas() self.init_attention_backend() - self.init_device_graphs() - elif self.device == "npu": - self.init_attention_backend() - self.init_device_graphs() + self.init_cuda_graphs() else: - self.graph_runner = None + self.cuda_graph_runner = None self.cuda_graph_mem_usage = 0 self.init_attention_backend() @@ -921,8 +917,7 @@ def update_weights_from_tensor( ) # We need to get device after patch otherwise the device would be wrong - self.device_module = torch.get_device_module(self.device) - infered_device = self.device_module.current_device() + infered_device = torch.cuda.current_device() named_tensors = [ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) @@ -1590,9 +1585,9 @@ def init_double_sparsity_channel_config(self, selected_channel): .cuda() ) - def init_device_graphs(self): + def init_cuda_graphs(self): """Capture cuda graphs.""" - self.graph_runner = None + self.cuda_graph_runner = None self.cuda_graph_mem_usage = 0 if not self.is_generation: @@ -1607,9 +1602,8 @@ def init_device_graphs(self): logger.info( f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - self.graph_runner = ( - CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) - ) + self.cuda_graph_runner = CudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.cuda_graph_mem_usage = before_mem - after_mem logger.info( @@ -1761,11 +1755,11 @@ def _forward_raw( ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() - and self.graph_runner - and self.graph_runner.can_run(forward_batch) + and self.cuda_graph_runner + and self.cuda_graph_runner.can_run(forward_batch) ) if can_run_cuda_graph: - ret = self.graph_runner.replay( + ret = self.cuda_graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py deleted file mode 100644 index 582b5b7c612..00000000000 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with npu graph and torch.compile.""" - -from __future__ import annotations - -import logging -import threading -from typing import TYPE_CHECKING - -import torch - -from sglang.srt.model_executor.graph_runner import GraphRunner - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors - - -class NPUGraphRunner(GraphRunner): - """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" - - def __init__(self, model_runner: ModelRunner): - super().__init__(model_runner) - - def _create_device_graph(self): - return torch.npu.NPUGraph() - - def _capture_graph(self, graph, pool, stream, run_once_fn): - with torch.npu.graph( - graph, - pool=pool, - stream=stream, - auto_dispatch_capture=True, - ): - out = run_once_fn() - return out - - def _update_inputs(self, seq_lens): - self.graphs[self.bs].update( - cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] - ) - - def _cache_loc_dtype(self): - return torch.int32 - - def replay( - self, - forward_batch: ForwardBatch, - skip_attn_backend_init: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - if not skip_attn_backend_init: - self.replay_prepare(forward_batch, pp_proxy_tensors) - else: - # In speculative decoding, these two fields are still needed. - self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) - self.positions[: self.raw_num_token].copy_(forward_batch.positions) - - # Replay - seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs) - thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) - thread.start() - self.graphs[self.bs].replay() - thread.join() - - output = self.output_buffers[self.bs] - if isinstance(output, LogitsProcessorOutput): - return LogitsProcessorOutput( - next_token_logits=output.next_token_logits[: self.raw_num_token], - hidden_states=( - output.hidden_states[: self.raw_num_token] - if output.hidden_states is not None - else None - ), - ) - else: - assert isinstance(output, PPProxyTensors) - return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37274e45b30..eeebe1863fb 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1200,7 +1200,7 @@ def forward_absorb_prepare( forward_batch: ForwardBatch, zero_allocator: BumpAllocator, ): - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if self.q_lora_rank is not None: if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index bf6ceaeb875..ab118ad9c5f 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -68,8 +68,8 @@ VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import ( DeepseekV2DecoderLayer, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 3ba736c7a94..fa294ddcd0c 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -966,7 +966,7 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, ) -> Union[Tuple, CausalLMOutputWithPast]: - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( self._batch_image_inputs(forward_batch) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index a73d8764acc..042159a5030 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -22,8 +22,8 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 26971c119c5..fcb45b94716 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -52,8 +52,8 @@ from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 3401e2738b2..e824fb1ae8e 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -6,20 +6,20 @@ import torch from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, -) -from sglang.srt.model_executor.graph_runner import ( - GRAPH_CAPTURE_FAILED_MSG, +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, get_batch_sizes_to_capture, get_global_graph_memory_pool, model_capture_mode, set_global_graph_memory_pool, set_torch_compile_config, ) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, @@ -121,7 +121,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.capture() except RuntimeError as e: raise Exception( - f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index b40db90dd98..4f4403fee50 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -6,14 +6,9 @@ import torch from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, -) -from sglang.srt.model_executor.graph_runner import ( - GRAPH_CAPTURE_FAILED_MSG, +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, LogitsProcessorOutput, get_batch_sizes_to_capture, get_global_graph_memory_pool, @@ -21,6 +16,11 @@ set_global_graph_memory_pool, set_torch_compile_config, ) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.utils import ( require_attn_tp_gather, @@ -149,7 +149,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.capture() except RuntimeError as e: raise Exception( - f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 26e99ae1029..b948bc82eb1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -229,17 +229,6 @@ class TestFile: TestFile("test_wave_attention_kernels.py", 2), TestFile("test_wave_attention_backend.py", 150), ], - "per-commit-1-ascend-npu": [ - TestFile("test_ascend_tp1_bf16.py", 400), - TestFile("test_ascend_graph_tp1_bf16.py", 400), - ], - "per-commit-2-ascend-npu": [ - TestFile("test_ascend_tp2_bf16.py", 400), - TestFile("test_ascend_graph_tp2_bf16.py", 400), - ], - "per-commit-4-ascend-npu": [ - TestFile("test_ascend_mla_w8a8int8.py", 400), - ], "per-commit-2-gpu-amd": [ TestFile("lora/test_lora_tp.py", 116), TestFile("rl/test_update_weights_from_distributed.py", 103), diff --git a/test/srt/test_ascend_graph_tp1_bf16.py b/test/srt/test_ascend_graph_tp1_bf16.py deleted file mode 100644 index 95c6b7bcf5b..00000000000 --- a/test/srt/test_ascend_graph_tp1_bf16.py +++ /dev/null @@ -1,95 +0,0 @@ -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, - run_bench_offline_throughput, -) - -TEST_MODEL_MATRIX = { - "Qwen/Qwen2.5-7B-Instruct": { - "accuracy": 0.85, - "latency": 150, - "output_throughput": 30, - }, -} - - -class TestAscendGraphTp1Bf16(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.models = TEST_MODEL_MATRIX.keys() - cls.base_url = DEFAULT_URL_FOR_TEST - cls.url = urlparse(DEFAULT_URL_FOR_TEST) - cls.common_args = [ - "--trust-remote-code", - "--mem-fraction-static", - 0.8, - "--attention-backend", - "ascend", - ] - - def test_a_gsm8k(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing accuracy: {model} ===##") - - process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *self.common_args, - ], - ) - - try: - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=128, - host=f"http://{self.url.hostname}", - port=int(self.url.port), - ) - - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreaterEqual( - metrics["accuracy"], - TEST_MODEL_MATRIX[model]["accuracy"], - ) - finally: - kill_process_tree(process.pid) - - def test_b_throughput(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing throughput: {model} ===##") - - output_throughput = run_bench_offline_throughput( - model, - [ - *self.common_args, - ], - ) - - print(f"##=== {model} throughput: {output_throughput} ===##") - - if is_in_ci(): - self.assertGreater( - output_throughput, - TEST_MODEL_MATRIX[model]["output_throughput"], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/test_ascend_graph_tp2_bf16.py b/test/srt/test_ascend_graph_tp2_bf16.py deleted file mode 100644 index f7c3c65377d..00000000000 --- a/test/srt/test_ascend_graph_tp2_bf16.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, - run_bench_offline_throughput, -) - -TEST_MODEL_MATRIX = { - "Qwen/Qwen2.5-7B-Instruct": { - "accuracy": 0.85, - "latency": 180, - "output_throughput": 20, - }, -} - - -class TestAscendGraphTp2Bf16(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.models = TEST_MODEL_MATRIX.keys() - cls.base_url = DEFAULT_URL_FOR_TEST - cls.url = urlparse(DEFAULT_URL_FOR_TEST) - cls.common_args = [ - "--trust-remote-code", - "--mem-fraction-static", - 0.8, - "--attention-backend", - "ascend", - "--tp-size", - 2, - ] - - def test_a_gsm8k(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing accuracy: {model} ===##") - - process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *self.common_args, - ], - ) - - try: - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=128, - host=f"http://{self.url.hostname}", - port=int(self.url.port), - ) - - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreaterEqual( - metrics["accuracy"], - TEST_MODEL_MATRIX[model]["accuracy"], - ) - finally: - kill_process_tree(process.pid) - - def test_b_throughput(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing throughput: {model} ===##") - - output_throughput = run_bench_offline_throughput( - model, - [ - *self.common_args, - ], - ) - - print(f"##=== {model} throughput: {output_throughput} ===##") - - if is_in_ci(): - self.assertGreater( - output_throughput, - TEST_MODEL_MATRIX[model]["output_throughput"], - ) - - -if __name__ == "__main__": - unittest.main()