Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/data/azcosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

### Other Changes
* Small performance optimizations to API's using query engine. See [PR 25669](https://github.com/Azure/azure-sdk-for-go/pull/25669)

## 1.5.0-beta.4 (2025-11-24)

Expand Down
90 changes: 47 additions & 43 deletions sdk/data/azcosmos/cosmos_container_query_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (c *ContainerClient) executeQueryWithEngine(queryEngine queryengine.QueryEn

// runEngineRequests concurrently executes per-partition QueryRequests for either query or readMany pipelines.
// prepareFn returns the query text, parameters, and a drain flag for each request.
// It serializes ProvideData calls through a single goroutine to preserve ordering guarantees required by the pipeline.
// Collects all results and calls ProvideData once with a single batch to reduce CGo overhead.
func runEngineRequests(
ctx context.Context,
c *ContainerClient,
Expand All @@ -192,32 +192,15 @@ func runEngineRequests(
requests []queryengine.QueryRequest,
concurrency int,
prepareFn func(req queryengine.QueryRequest) (query string, params []QueryParameter, drain bool),
) (totalCharge float32, err error) {
) (float32, error) {
if len(requests) == 0 {
return 0, nil
}

jobs := make(chan queryengine.QueryRequest, len(requests))
provideCh := make(chan []queryengine.QueryResult)
errCh := make(chan error, 1)
done := make(chan struct{})
providerDone := make(chan struct{})
var wg sync.WaitGroup
var chargeMu sync.Mutex

// Provider goroutine ensures only one ProvideData executes at a time.
go func() {
defer close(providerDone)
for batch := range provideCh {
if perr := pipeline.ProvideData(batch); perr != nil {
select {
case errCh <- perr:
default:
}
return
}
}
}()

// Adjust concurrency.
workerCount := concurrency
Expand All @@ -228,10 +211,15 @@ func runEngineRequests(
workerCount = 1
}

// Per-worker request charge slots and result slices (lock-free updates)
charges := make([]float32, workerCount)
resultsSlices := make([][]queryengine.QueryResult, workerCount)

for w := 0; w < workerCount; w++ {
wg.Add(1)
go func() {
go func(workerIndex int) {
defer wg.Done()
localResults := make([]queryengine.QueryResult, 0, 8)
for {
select {
case <-done:
Expand All @@ -240,12 +228,12 @@ func runEngineRequests(
return
case req, ok := <-jobs:
if !ok {
// jobs exhausted
resultsSlices[workerIndex] = localResults
return
}

log.Writef(azlog.EventRequest, "Engine pipeline requested data for PKRange: %s", req.PartitionKeyRangeID)
queryText, params, drain := prepareFn(req)
// Pagination loop
fetchMorePages := true
for fetchMorePages {
qr := queryRequest(req)
Expand All @@ -265,7 +253,6 @@ func runEngineRequests(
}
return
}

qResp, err := newQueryResponse(azResponse)
if err != nil {
select {
Expand All @@ -274,11 +261,7 @@ func runEngineRequests(
}
return
}
chargeMu.Lock()
totalCharge += qResp.RequestCharge
chargeMu.Unlock()

// Load the data into a buffer to send it to the pipeline
charges[workerIndex] += qResp.RequestCharge
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(azResponse.Body); err != nil {
select {
Expand All @@ -290,24 +273,17 @@ func runEngineRequests(
continuation := azResponse.Header.Get(cosmosHeaderContinuationToken)
data := buf.Bytes()
fetchMorePages = continuation != "" && drain

// Provide the data to the pipeline, make sure it's tagged with the partition key range ID so the pipeline can merge it into the correct partition.
result := queryengine.QueryResult{
localResults = append(localResults, queryengine.QueryResult{
PartitionKeyRangeID: req.PartitionKeyRangeID,
NextContinuation: continuation,
Data: data,
RequestId: req.Id,
}
})
log.Writef(EventQueryEngine, "Received response for PKRange: %s. Continuation present: %v", req.PartitionKeyRangeID, continuation != "")
select {
case <-done:
return
case provideCh <- []queryengine.QueryResult{result}:
}
}
}
}
}()
}(w)
}

// Feed jobs
Expand All @@ -323,8 +299,18 @@ func runEngineRequests(
close(jobs)
}()

// Close provider after workers finish
go func() { wg.Wait(); close(provideCh) }()
// Wait for workers to finish (or error/cancel)
workersDone := make(chan struct{})
go func() { wg.Wait(); close(workersDone) }()

// Helper to sum charges
sumCharges := func() float32 {
var total float32
for _, cval := range charges {
total += cval
}
return total
}

// Wait for completion / error / cancellation
select {
Expand All @@ -334,15 +320,33 @@ func runEngineRequests(
default:
close(done)
}
return totalCharge, e
return sumCharges(), e
case <-ctx.Done():
select {
case <-done:
default:
close(done)
}
return totalCharge, ctx.Err()
case <-providerDone:
return sumCharges(), ctx.Err()
case <-workersDone:
}

totalCharge := sumCharges()

// Merge per-worker result slices deterministically
// Pre-size combined slice for efficiency
var combinedCount int
for _, rs := range resultsSlices {
combinedCount += len(rs)
}
if combinedCount > 0 {
all := make([]queryengine.QueryResult, 0, combinedCount)
for _, rs := range resultsSlices {
all = append(all, rs...)
}
if err := pipeline.ProvideData(all); err != nil {
return totalCharge, err
}
}

return totalCharge, nil
Expand Down
Loading