Skip to content

Commit e7afac4

Browse files
authored
Merge pull request #9 from agoncear-mwb/main
implement parquet rows to read data from clickhouse in parquet format
2 parents ecf0a25 + dd038d2 commit e7afac4

File tree

8 files changed

+729
-223
lines changed

8 files changed

+729
-223
lines changed

chdb/driver/arrow.go

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package chdbdriver
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"reflect"
7+
"time"
8+
9+
"github.com/apache/arrow/go/v15/arrow"
10+
"github.com/apache/arrow/go/v15/arrow/array"
11+
"github.com/apache/arrow/go/v15/arrow/decimal128"
12+
"github.com/apache/arrow/go/v15/arrow/decimal256"
13+
"github.com/apache/arrow/go/v15/arrow/ipc"
14+
"github.com/chdb-io/chdb-go/chdbstable"
15+
)
16+
17+
type arrowRows struct {
18+
localResult *chdbstable.LocalResult
19+
reader *ipc.FileReader
20+
curRecord arrow.Record
21+
curRow int64
22+
}
23+
24+
func (r *arrowRows) Columns() (out []string) {
25+
sch := r.reader.Schema()
26+
for i := 0; i < sch.NumFields(); i++ {
27+
out = append(out, sch.Field(i).Name)
28+
}
29+
return
30+
}
31+
32+
func (r *arrowRows) Close() error {
33+
if r.curRecord != nil {
34+
r.curRecord = nil
35+
}
36+
// ignore reader close
37+
_ = r.reader.Close()
38+
r.reader = nil
39+
r.localResult = nil
40+
return nil
41+
}
42+
43+
func (r *arrowRows) Next(dest []driver.Value) error {
44+
if r.curRecord != nil && r.curRow == r.curRecord.NumRows() {
45+
r.curRecord = nil
46+
}
47+
for r.curRecord == nil {
48+
record, err := r.reader.Read()
49+
if err != nil {
50+
return err
51+
}
52+
if record.NumRows() == 0 {
53+
continue
54+
}
55+
r.curRecord = record
56+
r.curRow = 0
57+
}
58+
59+
for i, col := range r.curRecord.Columns() {
60+
if col.IsNull(int(r.curRow)) {
61+
dest[i] = nil
62+
continue
63+
}
64+
switch col := col.(type) {
65+
case *array.Boolean:
66+
dest[i] = col.Value(int(r.curRow))
67+
case *array.Int8:
68+
dest[i] = col.Value(int(r.curRow))
69+
case *array.Uint8:
70+
dest[i] = col.Value(int(r.curRow))
71+
case *array.Int16:
72+
dest[i] = col.Value(int(r.curRow))
73+
case *array.Uint16:
74+
dest[i] = col.Value(int(r.curRow))
75+
case *array.Int32:
76+
dest[i] = col.Value(int(r.curRow))
77+
case *array.Uint32:
78+
dest[i] = col.Value(int(r.curRow))
79+
case *array.Int64:
80+
dest[i] = col.Value(int(r.curRow))
81+
case *array.Uint64:
82+
dest[i] = col.Value(int(r.curRow))
83+
case *array.Float32:
84+
dest[i] = col.Value(int(r.curRow))
85+
case *array.Float64:
86+
dest[i] = col.Value(int(r.curRow))
87+
case *array.String:
88+
dest[i] = col.Value(int(r.curRow))
89+
case *array.LargeString:
90+
dest[i] = col.Value(int(r.curRow))
91+
case *array.Binary:
92+
dest[i] = col.Value(int(r.curRow))
93+
case *array.LargeBinary:
94+
dest[i] = col.Value(int(r.curRow))
95+
case *array.Date32:
96+
dest[i] = col.Value(int(r.curRow)).ToTime()
97+
case *array.Date64:
98+
dest[i] = col.Value(int(r.curRow)).ToTime()
99+
case *array.Time32:
100+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time32Type).Unit)
101+
case *array.Time64:
102+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
103+
case *array.Timestamp:
104+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
105+
case *array.Decimal128:
106+
dest[i] = col.Value(int(r.curRow))
107+
case *array.Decimal256:
108+
dest[i] = col.Value(int(r.curRow))
109+
default:
110+
return fmt.Errorf(
111+
"not yet implemented populating from columns of type " + col.DataType().String(),
112+
)
113+
}
114+
}
115+
116+
r.curRow++
117+
return nil
118+
}
119+
120+
func (r *arrowRows) ColumnTypeDatabaseTypeName(index int) string {
121+
return r.reader.Schema().Field(index).Type.String()
122+
}
123+
124+
func (r *arrowRows) ColumnTypeNullable(index int) (nullable, ok bool) {
125+
return r.reader.Schema().Field(index).Nullable, true
126+
}
127+
128+
func (r *arrowRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
129+
typ := r.reader.Schema().Field(index).Type
130+
switch dt := typ.(type) {
131+
case *arrow.Decimal128Type:
132+
return int64(dt.Precision), int64(dt.Scale), true
133+
case *arrow.Decimal256Type:
134+
return int64(dt.Precision), int64(dt.Scale), true
135+
}
136+
return 0, 0, false
137+
}
138+
139+
func (r *arrowRows) ColumnTypeScanType(index int) reflect.Type {
140+
switch r.reader.Schema().Field(index).Type.ID() {
141+
case arrow.BOOL:
142+
return reflect.TypeOf(false)
143+
case arrow.INT8:
144+
return reflect.TypeOf(int8(0))
145+
case arrow.UINT8:
146+
return reflect.TypeOf(uint8(0))
147+
case arrow.INT16:
148+
return reflect.TypeOf(int16(0))
149+
case arrow.UINT16:
150+
return reflect.TypeOf(uint16(0))
151+
case arrow.INT32:
152+
return reflect.TypeOf(int32(0))
153+
case arrow.UINT32:
154+
return reflect.TypeOf(uint32(0))
155+
case arrow.INT64:
156+
return reflect.TypeOf(int64(0))
157+
case arrow.UINT64:
158+
return reflect.TypeOf(uint64(0))
159+
case arrow.FLOAT32:
160+
return reflect.TypeOf(float32(0))
161+
case arrow.FLOAT64:
162+
return reflect.TypeOf(float64(0))
163+
case arrow.DECIMAL128:
164+
return reflect.TypeOf(decimal128.Num{})
165+
case arrow.DECIMAL256:
166+
return reflect.TypeOf(decimal256.Num{})
167+
case arrow.BINARY:
168+
return reflect.TypeOf([]byte{})
169+
case arrow.STRING:
170+
return reflect.TypeOf(string(""))
171+
case arrow.TIME32, arrow.TIME64, arrow.DATE32, arrow.DATE64, arrow.TIMESTAMP:
172+
return reflect.TypeOf(time.Time{})
173+
}
174+
return nil
175+
}

chdb/driver/arrow_test.go

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package chdbdriver
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"os"
7+
"testing"
8+
9+
"github.com/chdb-io/chdb-go/chdb"
10+
)
11+
12+
func TestDbWithArrow(t *testing.T) {
13+
14+
db, err := sql.Open("chdb", fmt.Sprintf("driverType=%s", "ARROW"))
15+
if err != nil {
16+
t.Errorf("open db fail, err:%s", err)
17+
}
18+
if db.Ping() != nil {
19+
t.Errorf("ping db fail")
20+
}
21+
rows, err := db.Query(`SELECT 1,'abc'`)
22+
if err != nil {
23+
t.Errorf("run Query fail, err:%s", err)
24+
}
25+
cols, err := rows.Columns()
26+
if err != nil {
27+
t.Errorf("get result columns fail, err: %s", err)
28+
}
29+
if len(cols) != 2 {
30+
t.Errorf("select result columns length should be 2")
31+
}
32+
var (
33+
bar int
34+
foo string
35+
)
36+
defer rows.Close()
37+
for rows.Next() {
38+
err := rows.Scan(&bar, &foo)
39+
if err != nil {
40+
t.Errorf("scan fail, err: %s", err)
41+
}
42+
if bar != 1 {
43+
t.Errorf("expected error")
44+
}
45+
if foo != "abc" {
46+
t.Errorf("expected error")
47+
}
48+
}
49+
}
50+
51+
func TestDBWithArrowSession(t *testing.T) {
52+
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
53+
if err != nil {
54+
t.Fatalf("create temp directory fail, err: %s", err)
55+
}
56+
defer os.RemoveAll(sessionDir)
57+
session, err := chdb.NewSession(sessionDir)
58+
if err != nil {
59+
t.Fatalf("new session fail, err: %s", err)
60+
}
61+
defer session.Cleanup()
62+
63+
session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
64+
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")
65+
66+
session.Query("INSERT INTO testdb.testtable VALUES (1), (2), (3);")
67+
68+
ret, err := session.Query("SELECT * FROM testdb.testtable;")
69+
if err != nil {
70+
t.Fatalf("Query fail, err: %s", err)
71+
}
72+
if string(ret.Buf()) != "1\n2\n3\n" {
73+
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
74+
}
75+
db, err := sql.Open("chdb", fmt.Sprintf("session=%s;driverType=%s", sessionDir, "ARROW"))
76+
if err != nil {
77+
t.Fatalf("open db fail, err: %s", err)
78+
}
79+
if db.Ping() != nil {
80+
t.Fatalf("ping db fail, err: %s", err)
81+
}
82+
rows, err := db.Query("select * from testdb.testtable;")
83+
if err != nil {
84+
t.Fatalf("exec create function fail, err: %s", err)
85+
}
86+
defer rows.Close()
87+
cols, err := rows.Columns()
88+
if err != nil {
89+
t.Fatalf("get result columns fail, err: %s", err)
90+
}
91+
if len(cols) != 1 {
92+
t.Fatalf("result columns length shoule be 3, actual: %d", len(cols))
93+
}
94+
var bar = 0
95+
var count = 1
96+
for rows.Next() {
97+
err = rows.Scan(&bar)
98+
if err != nil {
99+
t.Fatalf("scan fail, err: %s", err)
100+
}
101+
if bar != count {
102+
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
103+
}
104+
count++
105+
}
106+
}

0 commit comments

Comments
 (0)