Skip to content

add concurrent progress bar #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions oss/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ type DownloaderOptions struct {
ClientOptions []func(*Options)
}

type downloaderProgress struct {
pr ProgressFunc
written int64
total int64
mu sync.Mutex
}

func (cpt *downloaderProgress) Write(b []byte) (n int, err error) {
n = len(b)
increment := int64(n)
cpt.mu.Lock()
defer cpt.mu.Unlock()
cpt.written += increment
cpt.pr(increment, cpt.written, cpt.total)
return
}

type Downloader struct {
options DownloaderOptions
client DownloadAPIClient
Expand Down Expand Up @@ -378,7 +395,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {
}

// writeChunkFn runs in worker goroutines to pull chunks off of the ch channel
writeChunkFn := func(ch chan downloaderChunk) {
writeChunkFn := func(ch chan downloaderChunk, progress *downloaderProgress) {
defer wg.Done()
var hash hash.Hash64
if d.calcCRC {
Expand All @@ -395,7 +412,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {
continue
}

dchunk, derr := d.downloadChunk(chunk, hash)
dchunk, derr := d.downloadChunk(chunk, hash, progress)

if derr != nil && derr != io.EOF {
saveErrFn(derr)
Expand Down Expand Up @@ -455,9 +472,16 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {

// Start the download workers
ch := make(chan downloaderChunk, d.options.ParallelNum)
var progress *downloaderProgress
if d.request.ProgressFn != nil {
progress = &downloaderProgress{
pr: d.request.ProgressFn,
total: d.sizeInBytes,
}
}
for i := 0; i < d.options.ParallelNum; i++ {
wg.Add(1)
go writeChunkFn(ch)
go writeChunkFn(ch, progress)
}

// Start tracker worker if need track downloaded chunk
Expand Down Expand Up @@ -511,7 +535,7 @@ func (d *downloaderDelegate) incrWritten(n int64) {
d.written += n
}

func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64) (downloadedChunk, error) {
func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64, progress *downloaderProgress) (downloadedChunk, error) {
// Get the next byte range of data
var request GetObjectRequest
copyRequest(&request, d.request)
Expand Down Expand Up @@ -546,6 +570,15 @@ func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash
r io.Reader = reader
crc64 uint64 = 0
)
writer := io.MultiWriter()
if progress != nil {
writer = io.MultiWriter(writer, progress)
}
if hash != nil {
hash.Reset()
writer = io.MultiWriter(writer, hash)
}
r = io.TeeReader(reader, writer)
if hash != nil {
hash.Reset()
r = io.TeeReader(reader, hash)
Expand Down
87 changes: 87 additions & 0 deletions oss/downloader_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1273,3 +1273,90 @@ func TestMockDownloaderDownloadFilePayer(t *testing.T) {
})
assert.Nil(t, err)
}

func TestMockDownloaderWithProgress(t *testing.T) {
length := 3*1024*1024 + 1234
data := []byte(randStr(length))
gmtTime := getNowGMT()
tracker := &downloaderMockTracker{
lastModified: gmtTime,
data: data,
}
server := testSetupDownloaderMockServer(t, tracker)
defer server.Close()
assert.NotNil(t, server)
cfg := LoadDefaultConfig().
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
WithRegion("cn-hangzhou").
WithEndpoint(server.URL).
WithReadWriteTimeout(300 * time.Second)
client := NewClient(cfg)
var n int64
d := client.NewDownloader(func(do *DownloaderOptions) {
do.ParallelNum = 1
do.PartSize = 1 * 1024 * 1024
})
assert.NotNil(t, d)
assert.NotNil(t, d.client)
assert.Equal(t, int64(1*1024*1024), d.options.PartSize)
assert.Equal(t, 1, d.options.ParallelNum)
// filePath is invalid
_, err := d.DownloadFile(
context.TODO(),
&GetObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
},
}, "")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "invalid field, filePath")
localFile := randStr(8) + "-no-surfix"
defer func() {
os.Remove(localFile)
}()
_, err = d.DownloadFile(
context.TODO(),
&GetObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
},
}, localFile)
assert.Nil(t, err)
assert.Equal(t, n, int64(length))
n = int64(0)
d = client.NewDownloader(func(do *DownloaderOptions) {
do.ParallelNum = 3
do.PartSize = 3 * 1024 * 1024
})
assert.NotNil(t, d)
assert.NotNil(t, d.client)
assert.Equal(t, int64(3*1024*1024), d.options.PartSize)
assert.Equal(t, 3, d.options.ParallelNum)
// filePath is invalid
_, err = d.DownloadFile(
context.TODO(),
&GetObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
},
}, "")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "invalid field, filePath")
_, err = d.DownloadFile(
context.TODO(),
&GetObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
},
}, localFile)
assert.Nil(t, err)
assert.Equal(t, n, int64(length))
}
29 changes: 27 additions & 2 deletions oss/uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ type UploaderOptions struct {
ClientOptions []func(*Options)
}

type uploaderProgress struct {
pr ProgressFunc
written int64
total int64
mu sync.Mutex
}

func (cpt *uploaderProgress) incrWritten(n int64) {
increment := n
cpt.mu.Lock()
defer cpt.mu.Unlock()
cpt.written += increment
cpt.pr(increment, cpt.written, cpt.total)
}

type Uploader struct {
options UploaderOptions
client UploadAPIClient
Expand Down Expand Up @@ -539,7 +554,7 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
}

// readChunk runs in worker goroutines to pull chunks off of the ch channel
readChunkFn := func(ch chan uploaderChunk) {
readChunkFn := func(ch chan uploaderChunk, progress *uploaderProgress) {
defer wg.Done()
for {
data, ok := <-ch
Expand All @@ -563,6 +578,9 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
//fmt.Printf("UploadPart result: %#v, %#v\n", upResult, err)

if err == nil {
if progress != nil {
progress.incrWritten(int64(data.size))
}
mu.Lock()
parts = append(parts, UploadPart{ETag: upResult.ETag, PartNumber: data.partNum})
if enableCRC {
Expand All @@ -579,9 +597,16 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
}

ch := make(chan uploaderChunk, u.options.ParallelNum)
var progress *uploaderProgress
if u.request.ProgressFn != nil {
progress = &uploaderProgress{
pr: u.request.ProgressFn,
total: u.totalSize,
}
}
for i := 0; i < u.options.ParallelNum; i++ {
wg.Add(1)
go readChunkFn(ch)
go readChunkFn(ch, progress)
}

// Read and queue the parts
Expand Down
122 changes: 122 additions & 0 deletions oss/uploader_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,125 @@ func TestMockUploadWithPayer(t *testing.T) {
bytes.NewReader(data))
assert.Nil(t, err)
}

func TestMockUploadSinglePartFromFileWithProgress(t *testing.T) {
partSize := DefaultUploadPartSize
length := 5*100*1024 + 123
partsNum := length/int(partSize) + 1
tracker := &uploaderMockTracker{
partNum: partsNum,
saveDate: make([][]byte, partsNum),
checkTime: make([]time.Time, partsNum),
timeout: make([]time.Duration, partsNum),
uploadPartErr: make([]bool, partsNum),
}

data := []byte(randStr(length))
hash := NewCRC64(0)
hash.Write(data)
dataCrc64ecma := fmt.Sprint(hash.Sum64())

localFile := randStr(8) + ".txt"
createFileFromByte(t, localFile, data)
defer func() {
os.Remove(localFile)
}()

server := testSetupUploaderMockServer(t, tracker)
defer server.Close()
assert.NotNil(t, server)

cfg := LoadDefaultConfig().
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
WithRegion("cn-hangzhou").
WithEndpoint(server.URL).
WithReadWriteTimeout(300 * time.Second)

client := NewClient(cfg)
u := NewUploader(client)

assert.NotNil(t, u.client)
assert.Equal(t, DefaultUploadParallel, u.options.ParallelNum)
assert.Equal(t, DefaultUploadPartSize, u.options.PartSize)

n := int64(0)
result, err := u.UploadFile(context.TODO(), &PutObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total)
},
}, localFile)
assert.Nil(t, err)
assert.NotNil(t, result)
assert.Nil(t, result.UploadId)
assert.Equal(t, dataCrc64ecma, *result.HashCRC64)
assert.Equal(t, n, int64(length))
}

func TestMockUploadParallelFromFileWithProgress(t *testing.T) {
partSize := int64(100 * 1024)
length := 5*100*1024 + 123
partsNum := length/int(partSize) + 1
tracker := &uploaderMockTracker{
partNum: partsNum,
saveDate: make([][]byte, partsNum),
checkTime: make([]time.Time, partsNum),
timeout: make([]time.Duration, partsNum),
uploadPartErr: make([]bool, partsNum),
}

data := []byte(randStr(length))
hash := NewCRC64(0)
hash.Write(data)
dataCrc64ecma := fmt.Sprint(hash.Sum64())

localFile := randStr(8) + "-no-surfix"
createFileFromByte(t, localFile, data)
defer func() {
os.Remove(localFile)
}()

server := testSetupUploaderMockServer(t, tracker)
defer server.Close()
assert.NotNil(t, server)

cfg := LoadDefaultConfig().
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
WithRegion("cn-hangzhou").
WithEndpoint(server.URL).
WithReadWriteTimeout(300 * time.Second)

client := NewClient(cfg)

u := NewUploader(client,
func(uo *UploaderOptions) {
uo.ParallelNum = 4
uo.PartSize = partSize
},
)
assert.Equal(t, 4, u.options.ParallelNum)
assert.Equal(t, partSize, u.options.PartSize)

tracker.timeout[0] = 1 * time.Second
tracker.timeout[2] = 500 * time.Millisecond

n := int64(0)
result, err := u.UploadFile(context.TODO(), &PutObjectRequest{
Bucket: Ptr("bucket"),
Key: Ptr("key"),
ProgressFn: func(increment, transferred, total int64) {
n = transferred
fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total)
},
}, localFile)
assert.Nil(t, err)
assert.NotNil(t, result)

assert.Nil(t, err)
assert.NotNil(t, result)
assert.Equal(t, "uploadId-1234", *result.UploadId)
assert.Equal(t, dataCrc64ecma, *result.HashCRC64)
assert.Equal(t, n, int64(length))
}