Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 49 additions & 76 deletions docs/references/production_request_trace.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,57 @@ This section explains how to configure the request tracing and export the trace
pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc
```

2. launch opentelemetry collector and jaeger
2. Launch OpenTelemetry collector and Jaeger
```bash
docker compose -f examples/monitoring/tracing_compose.yaml up -d
```

3. start your SGLang server with tracing enabled
3. Start your SGLang server with tracing enabled
```bash
# set env variables
export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500
export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64
# start the prefill and decode server
python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
# start the mini lb
# start the model-gate-way
python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
```

Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
Replace `0.0.0.0:4317` with the actual endpoint of the OpenTelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.

To use the HTTP/protobuf span exporter, set the following environment variable and point to an HTTP endpoint, for example, `http://0.0.0.0:4318/v1/traces`.
```bash
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
```


4. raise some requests
4. Raise some requests
5. Observe whether trace data is being exported
* Access port 16686 of Jaeger using a web browser to visualize the request traces.
* The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI.

## How to add Tracing for slices you're interested in?
6. Dynamically adjust trace level
The trace level accepts configurable values from `0` to `3`. The meanings of different trace level values are as follows:
```
0: disable tracing
1: Trace important slices
2: Trace all slices except nested ones
3: Trace all slices
```
The trace level can be dynamically set via HTTP API, for example:
```bash
curl http://0.0.0.0:30000/set_trace_level?level=2
```
Replace `0.0.0.0:30000` with your actual server address, and replace `level=2` with the level you want to set.

**Note**: You must set the parameter `--enable-trace`; otherwise, the trace capability will not be enabled regardless of any dynamic adjustments to the trace level.

## How to add Tracing for slices you're interested in?(API introduction)
We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below.

1. initialization
**All of the following implementations are done in python/sglang/srt/observability/req_time_stats.py. If you want to add another slice, please do it here.**

1. Initialization

Every process involved in tracing during the initialization phase should execute:
```python
Expand All @@ -63,98 +81,53 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
```
The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view.

2. Mark the beginning and end of a request
2. Create a trace context for a request
Each request needs to call `TraceReqContext()` to initialize a request context, which is used to generate slice spans and record request stage info. You can either store it within the request object or maintain it as a global variable.

3. Mark the beginning and end of a request
```
trace_req_start(rid, bootstrap_room)
trace_req_finish(rid)
trace_ctx.trace_req_start().
trace_ctx.trace_req_finish()
```
These two APIs must be called within the same process, for example, in the tokenizer.
trace_req_start() and trace_req_finish() must be called within the same process, for example, in the tokenizer.

3. Add tracing for slice
4. Add tracing for a slice

* Add slice tracing normally:
```python
trace_slice_start("slice A", rid)
trace_slice_end("slice A", rid)
```
trace_ctx.trace_slice_start(RequestStage.TOKENIZER.stage_name)
trace_ctx.trace_slice_end(RequestStage.TOKENIZER.stage_name)

- Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end.
<br>Note: Anonymous slices must not be nested.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid)
or
trace_ctx.trace_slice(slice: TraceSliceContext)
```

- In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed.
- The end of the last slice in a thread must be marked with thread_finish_flag=True, or explicitly call trace_ctx.abort(); otherwise, the thread's span will not be properly generated.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid, auto_next_anon = True)
trace_slice_end("slice B", rid, auto_next_anon = True)
trace_slice_end("slice C", rid, auto_next_anon = True)
trace_slice_end("slice D", rid)
```
- The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated.
```python
trace_slice_end("slice D", rid, thread_finish_flag = True)
trace_ctx.slice_end(RequestStage.D.stage_name, thread_finish_flag = True)
trace_ctx.abort()
```

4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to another thread via ZMQ
```python
trace_context = trace_get_proc_propagate_context(rid)
req.trace_context = trace_context
```
5. When the request execution flow transfers to another thread, the thread context needs to be explicitly rebuilt.
- receiver: Execute the following code after receiving the request via ZMQ
```python
trace_set_proc_propagate_context(rid, req.trace_context)
```

5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to node thread via http
```python
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
session.post(url, headers=headers)
```
- receiver: Execute the following code after receiving the request via http
```python
trace_set_remote_propagate_context(request.headers['trace_context'])
trace_ctx.rebuild_thread_context()
```

## How to Extend the Tracing Framework to Support Complex Tracing Scenarios

The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.

The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows:
The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure or span structure: `TraceReqContext`, `TraceThreadContext` and `TraceSliceContext`. Their relationship is as follows:
```
SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
TraceReqContext (req_id="req-123")
├── TraceThreadContext(thread_label="scheduler", tp_rank=0)
| └── TraceSliceContext(slice_name="prefill")
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── TraceThreadContext(thread_label="scheduler", tp_rank=1)
└── TraceSliceContext(slice_name="prefill")
```

Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list.
Each traced request maintains a global `TraceReqContext` and creates a corresponding request span. For every thread that processes the request, a `TraceThreadContext` is recorded and a thread span is created. The `TraceThreadContext` is nested within the `TraceReqContext`, and each currently traced code slice—potentially nestedis stored in its associated `TraceThreadContext`.

In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.

When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.


We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design.

```
bootstrap room span
├── router req root span
| └── router thread span
| └── slice span
├── prefill req root span
| ├── tokenizer thread span
| | └── slice span
| └── scheduler thread span
| └── slice span
└── decode req root span
├── tokenizer thread span
| └── slice span
└── scheduler thread span
└── slice span
```
32 changes: 9 additions & 23 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from __future__ import annotations

import logging
import time
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
Expand All @@ -47,7 +46,7 @@
prepare_abort,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
Expand All @@ -60,7 +59,10 @@
ReqToTokenPool,
)
from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool
from sglang.srt.tracing.trace import trace_event_batch, trace_slice_end
from sglang.srt.observability.req_time_stats import (
set_schedule_time_batch,
set_time_batch,
)
from sglang.srt.utils import get_int_env_var
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter

Expand Down Expand Up @@ -361,8 +363,6 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
prefill_dp_rank=req.data_parallel_rank,
)

req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
Expand Down Expand Up @@ -592,13 +592,7 @@ def pop_preallocated(
)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
time.perf_counter()
)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
trace_slice_end(
RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
)
decode_req.req.time_stats.set_decode_transfer_queue_entry_time()

self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
Expand Down Expand Up @@ -811,12 +805,7 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool:

decode_req.kv_receiver.clear()
decode_req.kv_receiver = None
trace_slice_end(
RequestStage.DECODE_TRANSFERRED,
decode_req.req.rid,
auto_next_anon=True,
)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
decode_req.req.time_stats.set_wait_queue_entry_time()
return True

def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]:
Expand Down Expand Up @@ -880,7 +869,6 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
self.req_to_metadata_buffer_idx_allocator.free(idx)

self.queue = [
Expand Down Expand Up @@ -1003,7 +991,7 @@ def get_next_disagg_decode_batch_to_run(
ret = self.maybe_prepare_mlp_sync_batch_and_log_stats(ret)

if ret:
trace_event_batch("schedule", ret.reqs)
set_schedule_time_batch(ret)
return ret

def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
Expand Down Expand Up @@ -1031,7 +1019,6 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.add_latency(RequestStage.DECODE_WAITING)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
Expand All @@ -1040,8 +1027,7 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
if len(can_run_list) == 0:
return None

for req in can_run_list:
req.time_stats.forward_entry_time = time.perf_counter()
set_time_batch(can_run_list, "set_forward_entry_time")

# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
Expand Down
Loading
Loading