diff --git a/trino/trino.go b/trino/trino.go index 6f32316..5c9debd 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -1138,27 +1138,34 @@ func (qr *driverRows) Close() error { if qr.err == sql.ErrNoRows || qr.err == io.EOF { return nil } + qr.err = io.EOF - hs := make(http.Header) - if qr.stmt.user != "" { - hs.Add(trinoUserHeader, qr.stmt.user) - } - ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) - defer cancel() - req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) - if err != nil { - return err - } - resp, err := qr.stmt.conn.roundTrip(ctx, req) - if err != nil { - qferr, ok := err.(*ErrQueryFailed) - if ok && qferr.StatusCode == http.StatusNoContent { - qr.nextURI = "" - return nil + + if qr.nextURI != "" { + hs := make(http.Header) + if qr.stmt.user != "" { + hs.Add(trinoUserHeader, qr.stmt.user) } - return err + + ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) + defer cancel() + req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.nextURI, nil, hs) + if err != nil { + return err + } + resp, err := qr.stmt.conn.roundTrip(ctx, req) + if err != nil { + qferr, ok := err.(*ErrQueryFailed) + if ok && qferr.StatusCode == http.StatusNoContent { + qr.nextURI = "" + return nil + } + return err + } + resp.Body.Close() + } - resp.Body.Close() + return qr.err } @@ -1205,6 +1212,7 @@ func (qr *driverRows) Next(dest []driver.Value) error { if qr.err != nil { return qr.err } + if qr.columns == nil || qr.rowindex >= len(qr.data) { if qr.nextURI == "" { qr.err = io.EOF @@ -1215,22 +1223,34 @@ func (qr *driverRows) Next(dest []driver.Value) error { return err } } + if len(qr.coltype) == 0 { qr.err = sql.ErrNoRows return qr.err } - for i, v := range qr.coltype { - if i > len(dest)-1 { + + row := qr.data[qr.rowindex] + for i, colType := range qr.coltype { + if i >= len(dest) { break } - vv, err := v.ConvertValue(qr.data[qr.rowindex][i]) + val, err := colType.ConvertValue(row[i]) if err != nil { qr.err = err return err } - dest[i] = vv + dest[i] = val } + qr.rowindex++ + + // Prefetch next set of rows + if qr.rowindex == len(qr.data) && qr.nextURI != "" { + if err := qr.fetch(); err != nil { + qr.err = err + } + } + return nil } @@ -1330,6 +1350,7 @@ func (qr *driverRows) fetch() error { return err } qr.rowindex = 0 + qr.nextURI = qresp.NextURI qr.data = qresp.Data qr.rowsAffected = qresp.UpdateCount qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) diff --git a/trino/trino_test.go b/trino/trino_test.go index f1623a3..4013798 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -1973,3 +1973,109 @@ func TestForwardAuthorizationHeader(t *testing.T) { assert.NoError(t, db.Close()) } + +func TestPagination(t *testing.T) { + var buf, buf2, buf3 *bytes.Buffer + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/statement" { + if buf == nil { + buf = new(bytes.Buffer) + + json.NewEncoder(buf).Encode(&stmtResponse{ + ID: "fake-query", + NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", + Stats: stmtStats{ + State: "QUEUED", + }, + }) + } + w.WriteHeader(http.StatusOK) + w.Write(buf.Bytes()) + return + } + + if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { + if buf2 == nil { + buf2 = new(bytes.Buffer) + json.NewEncoder(buf2).Encode(&queryResponse{ + ID: "fake-query", + NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/2", + Columns: []queryColumn{ + { + Name: "_col0", + Type: "integer", + TypeSignature: typeSignature{ + RawType: "integer", + Arguments: []typeArgument{}, + }, + }, + }, + Data: []queryData{ + {1}, + }, + Stats: stmtStats{ + State: "FINISHED", + }, + }) + } + w.WriteHeader(http.StatusOK) + w.Write(buf2.Bytes()) + return + } + + if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/2" { + if buf3 == nil { + buf3 = new(bytes.Buffer) + json.NewEncoder(buf3).Encode(&queryResponse{ + ID: "fake-query", + Columns: []queryColumn{ + { + Name: "_col1", + Type: "integer", + TypeSignature: typeSignature{ + RawType: "integer", + Arguments: []typeArgument{}, + }, + }, + }, + Data: []queryData{ + {2}, + }, + Stats: stmtStats{ + State: "FINISHED", + }, + }) + } + w.WriteHeader(http.StatusOK) + w.Write(buf3.Bytes()) + return + } + + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) + })) + + defer ts.Close() + + db, err := sql.Open("trino", ts.URL) + require.NoError(t, err) + defer db.Close() + + // Run a query + rows, err := db.Query("SELECT 1") + + var results []int + for rows.Next() { + var value int + err := rows.Scan(&value) + require.NoError(t, err) + results = append(results, value) + } + + // Ensure no error in iteration + require.NoError(t, rows.Err()) + + // Assert expected results + assert.Equal(t, []int{1, 2}, results, "Expected query results to match") +}