generated from kubernetes/kubernetes-template-project
-
Notifications
You must be signed in to change notification settings - Fork 201
Add non-streaming response to approximate prefix cache. #1719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zetxqx
wants to merge
4
commits into
kubernetes-sigs:main
Choose a base branch
from
zetxqx:prefix-cache-resp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+882
−44
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
63d9a71
Fix function comment and pass existing logger into HandleResponseBody…
BenjaminBraunDev 1a7793a
Revert logging parameter addition, keeping consistent with existing f…
BenjaminBraunDev 028974c
Add reponse to prefix cache in nonStreaming mode.
zetxqx 39ae663
make ResponseComplete to accept LLMResponse and update the encoding m…
zetxqx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ limitations under the License. | |||||
| package prefix | ||||||
|
|
||||||
| import ( | ||||||
| "bytes" | ||||||
| "context" | ||||||
| "encoding/binary" | ||||||
| "encoding/json" | ||||||
|
|
@@ -28,6 +29,7 @@ import ( | |||||
| k8stypes "k8s.io/apimachinery/pkg/types" | ||||||
| "sigs.k8s.io/controller-runtime/pkg/log" | ||||||
|
|
||||||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" | ||||||
| backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" | ||||||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" | ||||||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" | ||||||
|
|
@@ -117,6 +119,12 @@ var _ plugins.StateData = &SchedulingContextState{} | |||||
| type SchedulingContextState struct { | ||||||
| // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. | ||||||
| PrefixHashes []BlockHash | ||||||
| // RestBytes is the trailing bytes that not able to fill in a full block and left over. | ||||||
| // If not empty, this will be used as the starting block for the following response that will | ||||||
| // be added to the response as well. This happens especially at the multi-turn scenario. | ||||||
| RestBytes []byte | ||||||
| // BlockSize is the block size used to caculate the hash of the request/response. | ||||||
| BlockSize int | ||||||
| // A map of server to its longest prefix cache match length. | ||||||
| PrefixCacheServers map[ServerID]int | ||||||
| } | ||||||
|
|
@@ -192,10 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin { | |||||
|
|
||||||
| // Score returns the scoring result for the given list of pods based on context. | ||||||
| func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { | ||||||
| blockSize := getBlockSize(pods, p.config.DefaultBlockSize) | ||||||
| // pre score step, hashing prompt and find longest prefix match. | ||||||
| hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) | ||||||
| hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch) | ||||||
| state := &SchedulingContextState{ | ||||||
| PrefixHashes: hashes, | ||||||
| RestBytes: restBytes, | ||||||
| BlockSize: blockSize, | ||||||
| PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -226,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche | |||||
| targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile | ||||||
|
|
||||||
| state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) | ||||||
| p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it | ||||||
| if err != nil { | ||||||
| log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) | ||||||
| return | ||||||
|
|
@@ -244,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche | |||||
|
|
||||||
| total := len(state.PrefixHashes) | ||||||
| matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] | ||||||
|
|
||||||
| blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) | ||||||
| metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) | ||||||
| metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize) | ||||||
| } | ||||||
|
|
||||||
| // matchLongestPrefix returns a map of servers and length of prefix that each server caches. | ||||||
|
|
@@ -301,47 +309,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) | |||||
| // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. | ||||||
| // hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache. | ||||||
| // For block i, hash(i) = hash(block i content, hash(i-1)). | ||||||
| func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { | ||||||
| // Also return the extra string. | ||||||
| func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { | ||||||
| loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) | ||||||
| if request == nil || request.Body == nil { | ||||||
| loggerDebug.Info("Request or request data is nil, skipping hashing") | ||||||
| return nil | ||||||
| return nil, nil | ||||||
| } | ||||||
|
|
||||||
| userInput, err := getUserInputBytes(request) | ||||||
| if err != nil { | ||||||
| loggerDebug.Error(err, "Failed to get user input bytes") | ||||||
| return nil | ||||||
| return nil, nil | ||||||
| } | ||||||
| prevBlockHash := defaultPrevBlock(request) | ||||||
| return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks) | ||||||
| } | ||||||
|
|
||||||
| if len(userInput) < cacheBlockSize { | ||||||
| loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) | ||||||
| return nil | ||||||
| } | ||||||
| if len(userInput) > cacheBlockSize*maxPrefixBlocks { | ||||||
| loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) | ||||||
| userInput = userInput[:maxPrefixBlocks*cacheBlockSize] | ||||||
| } | ||||||
| // Split the body into blocks of size cacheBlockSize. | ||||||
| // If the last block is smaller than cacheBlockSize, it will be ignored. | ||||||
| res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) | ||||||
| // Add the model to the first block hash so that different models have different hashes even with the same body. | ||||||
| func defaultPrevBlock(request *types.LLMRequest) BlockHash { | ||||||
| h := xxhash.New() | ||||||
| // Add the model to the first block hash so that different models have different hashes even with the same body. | ||||||
| _, _ = h.Write([]byte(request.TargetModel)) | ||||||
| if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" { | ||||||
| _, _ = h.Write([]byte(cacheSalt)) | ||||||
| } | ||||||
|
|
||||||
| prevBlockHash := BlockHash(h.Sum64()) | ||||||
| for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { | ||||||
| return BlockHash(h.Sum64()) | ||||||
| } | ||||||
|
|
||||||
| func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { | ||||||
| loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) | ||||||
| if len(input)+prevBlockLength < cacheBlockSize { | ||||||
| loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize) | ||||||
| return nil, input | ||||||
| } | ||||||
| if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks { | ||||||
| loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) | ||||||
| input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)] | ||||||
| } | ||||||
| // Split the body into blocks of size cacheBlockSize. | ||||||
| // If the last block is smaller than cacheBlockSize, it will be ignored. | ||||||
| res := make([]BlockHash, 0, len(input)/cacheBlockSize) | ||||||
| lastOffSet := 0 | ||||||
| h := xxhash.New() | ||||||
| for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize { | ||||||
| h.Reset() | ||||||
| _, _ = h.Write(userInput[i : i+cacheBlockSize]) | ||||||
| _, _ = h.Write(input[i : i+cacheBlockSize]) | ||||||
| _, _ = h.Write(toBytes(prevBlockHash)) | ||||||
| res = append(res, BlockHash(h.Sum64())) | ||||||
|
|
||||||
| prevBlockHash = res[len(res)-1] | ||||||
| lastOffSet = i + cacheBlockSize | ||||||
| } | ||||||
| return res | ||||||
| return res, input[lastOffSet:] | ||||||
| } | ||||||
|
|
||||||
| func toBytes(i BlockHash) []byte { | ||||||
|
|
@@ -356,7 +376,39 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { | |||||
| } | ||||||
|
|
||||||
| // must be chat-completions request at this point, return bytes of entire messages | ||||||
| return json.Marshal(request.Body.ChatCompletions.Messages) | ||||||
| return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...) | ||||||
| } | ||||||
|
|
||||||
| func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) { | ||||||
| state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) | ||||||
| if err != nil { | ||||||
| log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) | ||||||
| return | ||||||
| } | ||||||
| p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. | ||||||
|
|
||||||
| reponseForKVCache, err := response.FirstChoiceContent() | ||||||
| if err != nil { | ||||||
| log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId) | ||||||
| return | ||||||
| } | ||||||
| var input bytes.Buffer | ||||||
| input.Write(state.RestBytes) | ||||||
| input.Write(reponseForKVCache) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| server := ServerID(targetPod.NamespacedName) | ||||||
| prevBlockHash := defaultPrevBlock(request) | ||||||
| prevBlockHashLength := 0 | ||||||
| if len(state.PrefixHashes) > 0 { | ||||||
| prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] | ||||||
| prevBlockHashLength = len(state.PrefixHashes) | ||||||
| } | ||||||
| hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch) | ||||||
| p.wg.Add(1) | ||||||
| go func() { | ||||||
| p.indexer.Add(hashBlocks, server) | ||||||
| p.wg.Done() | ||||||
| }() | ||||||
| } | ||||||
|
|
||||||
| func getBlockSize(pods []types.Pod, defaultBlockSize int) int { | ||||||
|
|
||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also. Why only the first choice response and not all of them?