Skip to content

Commit

Permalink
Added context.Context to public methods (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
thrawn01 authored Feb 7, 2025
1 parent caab2d1 commit 0d5ab1e
Show file tree
Hide file tree
Showing 19 changed files with 438 additions and 301 deletions.
6 changes: 3 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ func main() {
key := []byte("key1")
value := []byte("value1")

db.Put(key, value)
_ = db.Put(ctx, key, value)
fmt.Println("Put:", string(key), string(value))

data, _ := db.Get(ctx, key)
fmt.Println("Get:", string(key), string(data))

db.Delete(key)
_ = db.Delete(ctx, key)
_, err := db.Get(ctx, key)
if err != nil && err.Error() == "key not found" {
fmt.Println("Delete:", string(key))
} else {
slog.Error("Unable to delete", "error", err)
}

if err := db.Close(); err != nil {
if err := db.Close(ctx); err != nil {
slog.Error("Error closing db", "error", err)
}
}
7 changes: 4 additions & 3 deletions internal/sstable/blob.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sstable

import (
"context"
"errors"

"github.com/slatedb/slatedb-go/slatedb/common"
Expand All @@ -15,17 +16,17 @@ type bytesBlob struct {
data []byte
}

func (b *bytesBlob) Len() (int, error) {
func (b *bytesBlob) Len(_ context.Context) (int, error) {
return len(b.data), nil
}

func (b *bytesBlob) ReadRange(r common.Range) ([]byte, error) {
func (b *bytesBlob) ReadRange(_ context.Context, r common.Range) ([]byte, error) {
if r.Start > uint64(len(b.data)) || r.End > uint64(len(b.data)) || r.Start > r.End {
return nil, errors.New("invalid range")
}
return b.data[r.Start:r.End], nil
}

func (b *bytesBlob) Read() ([]byte, error) {
func (b *bytesBlob) Read(_ context.Context) ([]byte, error) {
return b.data, nil
}
11 changes: 6 additions & 5 deletions internal/sstable/builder_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sstable_test

import (
"context"
"fmt"
"testing"

Expand Down Expand Up @@ -95,7 +96,7 @@ func TestBuilder(t *testing.T) {
}

func TestEncodeDecode(t *testing.T) {

ctx := context.Background()
input := [][]types.KeyValue{
{types.KeyValue{Key: []byte("key1"), Value: []byte("value1")}},
{types.KeyValue{Key: []byte("key2"), Value: []byte("value2")}},
Expand Down Expand Up @@ -125,20 +126,20 @@ func TestEncodeDecode(t *testing.T) {
blob := sstable.NewBytesBlob(encoded)

// Decode the Info from the table
info, err := sstable.ReadInfo(blob)
info, err := sstable.ReadInfo(ctx, blob)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, table.Info.FirstKey, info.FirstKey)
assert.Equal(t, table.Info.IndexOffset, info.IndexOffset)
assert.Equal(t, table.Info.IndexLen, info.IndexLen)

// Decode the index from the table
index, err := sstable.ReadIndex(info, blob)
index, err := sstable.ReadIndex(ctx, info, blob)
assert.NoError(t, err)
assert.NotNil(t, index)

// Read the first block from the table
blocks, err := sstable.ReadBlocks(info, index, common.Range{Start: 0, End: 3}, blob)
blocks, err := sstable.ReadBlocks(ctx, info, index, common.Range{Start: 0, End: 3}, blob)
assert.NoError(t, err)
assert.Equal(t, 3, len(input))

Expand All @@ -153,7 +154,7 @@ func TestEncodeDecode(t *testing.T) {
assert2.NextEntry(t, it, []byte("key3"), []byte("value3"))

// Test bloom filter
filter, err := sstable.ReadFilter(info, blob)
filter, err := sstable.ReadFilter(ctx, info, blob)
assert.NoError(t, err)
assert.True(t, filter.IsPresent())
f, ok := filter.Get()
Expand Down
21 changes: 11 additions & 10 deletions internal/sstable/decode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sstable

import (
"context"
"encoding/binary"
"fmt"

Expand All @@ -21,8 +22,8 @@ func DefaultConfig() Config {
}
}

func ReadInfo(obj common.ReadOnlyBlob) (*Info, error) {
size, err := obj.Len()
func ReadInfo(ctx context.Context, obj common.ReadOnlyBlob) (*Info, error) {
size, err := obj.Len(ctx)
if err != nil {
return nil, err
}
Expand All @@ -32,21 +33,21 @@ func ReadInfo(obj common.ReadOnlyBlob) (*Info, error) {

// Get the metadata. Last 4 bytes are the metadata offset of SsTableInfo
offsetIndex := uint64(size - 4)
offsetBytes, err := obj.ReadRange(common.Range{Start: offsetIndex, End: uint64(size)})
offsetBytes, err := obj.ReadRange(ctx, common.Range{Start: offsetIndex, End: uint64(size)})
if err != nil {
return nil, err
}

metadataOffset := binary.BigEndian.Uint32(offsetBytes)
metadataBytes, err := obj.ReadRange(common.Range{Start: uint64(metadataOffset), End: offsetIndex})
metadataBytes, err := obj.ReadRange(ctx, common.Range{Start: uint64(metadataOffset), End: offsetIndex})
if err != nil {
return nil, err
}

return DecodeInfo(metadataBytes)
}

func ReadFilter(sstInfo *Info, obj common.ReadOnlyBlob) (mo.Option[bloom.Filter], error) {
func ReadFilter(ctx context.Context, sstInfo *Info, obj common.ReadOnlyBlob) (mo.Option[bloom.Filter], error) {
if sstInfo.FilterLen < 1 {
return mo.None[bloom.Filter](), nil
}
Expand All @@ -56,7 +57,7 @@ func ReadFilter(sstInfo *Info, obj common.ReadOnlyBlob) (mo.Option[bloom.Filter]
End: sstInfo.FilterOffset + sstInfo.FilterLen,
}

filterBytes, err := obj.ReadRange(filterOffsetRange)
filterBytes, err := obj.ReadRange(ctx, filterOffsetRange)
if err != nil {
return mo.None[bloom.Filter](), fmt.Errorf("while reading filter offset: %w", err)
}
Expand All @@ -69,8 +70,8 @@ func ReadFilter(sstInfo *Info, obj common.ReadOnlyBlob) (mo.Option[bloom.Filter]
return mo.Some(filterData), nil
}

func ReadIndex(info *Info, obj common.ReadOnlyBlob) (*Index, error) {
indexBytes, err := obj.ReadRange(common.Range{
func ReadIndex(ctx context.Context, info *Info, obj common.ReadOnlyBlob) (*Index, error) {
indexBytes, err := obj.ReadRange(ctx, common.Range{
Start: info.IndexOffset,
End: info.IndexOffset + info.IndexLen,
})
Expand Down Expand Up @@ -103,7 +104,7 @@ func getBlockRange(rng common.Range, sstInfo *Info, index *Index) common.Range {

// ReadBlocks reads the complete data required into a byte slice (dataBytes)
// and then breaks the data up into slice of Blocks (decodedBlocks) which is returned
func ReadBlocks(info *Info, index *Index, r common.Range, obj common.ReadOnlyBlob) ([]block.Block, error) {
func ReadBlocks(ctx context.Context, info *Info, index *Index, r common.Range, obj common.ReadOnlyBlob) ([]block.Block, error) {
if r.Start >= r.End {
return nil, fmt.Errorf("block start '%d' range cannot be greater than end range '%d'", r.Start, r.End)
}
Expand All @@ -117,7 +118,7 @@ func ReadBlocks(info *Info, index *Index, r common.Range, obj common.ReadOnlyBlo
}

rng := getBlockRange(r, info, index)
dataBytes, err := obj.ReadRange(rng)
dataBytes, err := obj.ReadRange(ctx, rng)
if err != nil {
return nil, fmt.Errorf("while reading block range [%d:%d]: %w", rng.Start, rng.End, err)
}
Expand Down
18 changes: 9 additions & 9 deletions internal/sstable/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

type TableStore interface {
ReadIndex(*Handle) (*Index, error)
ReadBlocksUsingIndex(*Handle, common.Range, *Index) ([]block.Block, error)
ReadIndex(context.Context, *Handle) (*Index, error)
ReadBlocksUsingIndex(context.Context, *Handle, common.Range, *Index) ([]block.Block, error)
}

// Iterator iterates through KeyValue pairs present in the SSTable.
Expand All @@ -26,8 +26,8 @@ type Iterator struct {
nextBlock uint64
}

func NewIterator(handle *Handle, store TableStore) (*Iterator, error) {
index, err := store.ReadIndex(handle)
func NewIterator(ctx context.Context, handle *Handle, store TableStore) (*Iterator, error) {
index, err := store.ReadIndex(ctx, handle)
if err != nil {
return nil, err
}
Expand All @@ -40,8 +40,8 @@ func NewIterator(handle *Handle, store TableStore) (*Iterator, error) {
}, nil
}

func NewIteratorAtKey(handle *Handle, key []byte, store TableStore) (*Iterator, error) {
index, err := store.ReadIndex(handle)
func NewIteratorAtKey(ctx context.Context, handle *Handle, key []byte, store TableStore) (*Iterator, error) {
index, err := store.ReadIndex(ctx, handle)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -77,7 +77,7 @@ func (iter *Iterator) Next(ctx context.Context) (types.KeyValue, bool) {
func (iter *Iterator) NextEntry(ctx context.Context) (types.RowEntry, bool) {
for {
if iter.blockIter == nil {
it, err := iter.nextBlockIter()
it, err := iter.nextBlockIter(ctx)
if err != nil {
// TODO(thrawn01): This could be a transient error, or a corruption error
// we need to handle each differently.
Expand Down Expand Up @@ -107,14 +107,14 @@ func (iter *Iterator) NextEntry(ctx context.Context) (types.RowEntry, bool) {
}

// nextBlockIter fetches the next block and returns an iterator for that block
func (iter *Iterator) nextBlockIter() (*block.Iterator, error) {
func (iter *Iterator) nextBlockIter(ctx context.Context) (*block.Iterator, error) {
if iter.nextBlock >= uint64(iter.index.BlockMetaLength()) {
return nil, nil // No more blocks to read
}

// Fetch the next block
rng := common.Range{Start: iter.nextBlock, End: iter.nextBlock + 1}
blocks, err := iter.store.ReadBlocksUsingIndex(iter.handle, rng, iter.index)
blocks, err := iter.store.ReadBlocksUsingIndex(ctx, iter.handle, rng, iter.index)
if err != nil {
return nil, fmt.Errorf("while reading block range [%d:%d]: %w", rng.Start, rng.End, err)
}
Expand Down
8 changes: 5 additions & 3 deletions slatedb/common/blob.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package common

import "context"

type ReadOnlyBlob interface {
Len() (int, error)
ReadRange(r Range) ([]byte, error)
Read() ([]byte, error)
Len(ctx context.Context) (int, error)
ReadRange(ctx context.Context, r Range) ([]byte, error)
Read(ctx context.Context) ([]byte, error)
}
19 changes: 11 additions & 8 deletions slatedb/compacted/sortedrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,24 @@ type SortedRunIterator struct {
warn types.ErrWarn
}

func NewSortedRunIterator(sr SortedRun, store sstable.TableStore) (*SortedRunIterator, error) {
return newSortedRunIter(sr.SSTList, store, mo.None[[]byte]())
func NewSortedRunIterator(ctx context.Context, sr SortedRun, store sstable.TableStore) (*SortedRunIterator, error) {
return newSortedRunIter(ctx, sr.SSTList, store, mo.None[[]byte]())
}

func NewSortedRunIteratorFromKey(sr SortedRun, key []byte, store sstable.TableStore) (*SortedRunIterator, error) {
func NewSortedRunIteratorFromKey(ctx context.Context, sr SortedRun, key []byte, store sstable.TableStore) (*SortedRunIterator, error) {
sstList := sr.SSTList
idx, ok := sr.indexOfSSTWithKey(key).Get()
if ok {
sstList = sr.SSTList[idx:]
}

return newSortedRunIter(sstList, store, mo.Some(key))
return newSortedRunIter(ctx, sstList, store, mo.Some(key))
}

func newSortedRunIter(sstList []sstable.Handle, store sstable.TableStore, fromKey mo.Option[[]byte]) (*SortedRunIterator, error) {
func newSortedRunIter(ctx context.Context,
sstList []sstable.Handle,
store sstable.TableStore,
fromKey mo.Option[[]byte]) (*SortedRunIterator, error) {

sstListIter := newSSTListIterator(sstList)
currentKVIter := mo.None[*sstable.Iterator]()
Expand All @@ -86,12 +89,12 @@ func newSortedRunIter(sstList []sstable.Handle, store sstable.TableStore, fromKe
var err error
if fromKey.IsPresent() {
key, _ := fromKey.Get()
iter, err = sstable.NewIteratorAtKey(&sst, key, store)
iter, err = sstable.NewIteratorAtKey(ctx, &sst, key, store)
if err != nil {
return nil, err
}
} else {
iter, err = sstable.NewIterator(&sst, store)
iter, err = sstable.NewIterator(ctx, &sst, store)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -148,7 +151,7 @@ func (iter *SortedRunIterator) NextEntry(ctx context.Context) (types.RowEntry, b
return types.RowEntry{}, false
}

newKVIter, err := sstable.NewIterator(&sst, iter.tableStore)
newKVIter, err := sstable.NewIterator(ctx, &sst, iter.tableStore)
if err != nil {
iter.warn.Add("while creating SSTable iterator: %s", err.Error())
return types.RowEntry{}, false
Expand Down
5 changes: 3 additions & 2 deletions slatedb/compaction/compactor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package compaction

import (
"context"
"strconv"

"github.com/oklog/ulid/v2"
Expand Down Expand Up @@ -114,8 +115,8 @@ func NewCompactor(manifestStore *store.ManifestStore, tableStore *store.TableSto
}, nil
}

func (c *Compactor) Close() {
c.orchestrator.shutdown()
func (c *Compactor) Close(ctx context.Context) error {
return c.orchestrator.shutdown(ctx)
}

func spawnAndRunCompactionOrchestrator(
Expand Down
23 changes: 17 additions & 6 deletions slatedb/compaction/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func (e *Executor) loadIterators(compaction Job) (iter.KVIterator, error) {

l0Iters := make([]iter.KVIterator, 0)
for _, sst := range compaction.sstList {
sstIter, err := sstable.NewIterator(&sst, e.tableStore.Clone())
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
sstIter, err := sstable.NewIterator(ctx, &sst, e.tableStore.Clone())
cancel()
if err != nil {
return nil, err
}
Expand All @@ -63,15 +65,18 @@ func (e *Executor) loadIterators(compaction Job) (iter.KVIterator, error) {

srIters := make([]iter.KVIterator, 0)
for _, sr := range compaction.sortedRuns {
srIter, err := compacted.NewSortedRunIterator(sr, e.tableStore.Clone())
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
srIter, err := compacted.NewSortedRunIterator(ctx, sr, e.tableStore.Clone())
cancel()
if err != nil {
return nil, err
}
srIters = append(srIters, srIter)
}

ctx := context.TODO()
var l0MergeIter, srMergeIter iter.KVIterator
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
defer cancel()
if len(compaction.sortedRuns) == 0 {
l0MergeIter = iter.NewMergeSort(ctx, l0Iters...)
return l0MergeIter, nil
Expand All @@ -95,7 +100,9 @@ func (e *Executor) executeCompaction(compaction Job) (*compacted.SortedRun, erro
currentWriter := e.tableStore.TableWriter(sstable.NewIDCompacted(ulid.Make()))
currentSize := 0
for {
kv, ok := allIter.NextEntry(context.TODO())
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
kv, ok := allIter.NextEntry(ctx)
cancel()
if !ok {
if w := allIter.Warnings(); w != nil {
warn.Merge(w)
Expand All @@ -119,15 +126,19 @@ func (e *Executor) executeCompaction(compaction Job) (*compacted.SortedRun, erro
currentSize = 0
finishedWriter := currentWriter
currentWriter = e.tableStore.TableWriter(sstable.NewIDCompacted(ulid.Make()))
sst, err := finishedWriter.Close()
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
sst, err := finishedWriter.Close(ctx)
cancel()
if err != nil {
return nil, err
}
outputSSTs = append(outputSSTs, *sst)
}
}
if currentSize > 0 {
sst, err := currentWriter.Close()
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
sst, err := currentWriter.Close(ctx)
cancel()
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 0d5ab1e

Please sign in to comment.