diff --git a/oss/downloader.go b/oss/downloader.go index 296ed2b..23c7877 100644 --- a/oss/downloader.go +++ b/oss/downloader.go @@ -29,6 +29,23 @@ type DownloaderOptions struct { ClientOptions []func(*Options) } +type downloaderProgress struct { + pr ProgressFunc + written int64 + total int64 + mu sync.Mutex +} + +func (cpt *downloaderProgress) Write(b []byte) (n int, err error) { + n = len(b) + increment := int64(n) + cpt.mu.Lock() + defer cpt.mu.Unlock() + cpt.written += increment + cpt.pr(increment, cpt.written, cpt.total) + return +} + type Downloader struct { options DownloaderOptions client DownloadAPIClient @@ -378,7 +395,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) { } // writeChunkFn runs in worker goroutines to pull chunks off of the ch channel - writeChunkFn := func(ch chan downloaderChunk) { + writeChunkFn := func(ch chan downloaderChunk, progress *downloaderProgress) { defer wg.Done() var hash hash.Hash64 if d.calcCRC { @@ -395,7 +412,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) { continue } - dchunk, derr := d.downloadChunk(chunk, hash) + dchunk, derr := d.downloadChunk(chunk, hash, progress) if derr != nil && derr != io.EOF { saveErrFn(derr) @@ -455,9 +472,16 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) { // Start the download workers ch := make(chan downloaderChunk, d.options.ParallelNum) + var progress *downloaderProgress + if d.request.ProgressFn != nil { + progress = &downloaderProgress{ + pr: d.request.ProgressFn, + total: d.sizeInBytes, + } + } for i := 0; i < d.options.ParallelNum; i++ { wg.Add(1) - go writeChunkFn(ch) + go writeChunkFn(ch, progress) } // Start tracker worker if need track downloaded chunk @@ -511,7 +535,7 @@ func (d *downloaderDelegate) incrWritten(n int64) { d.written += n } -func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64) (downloadedChunk, error) { +func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64, progress *downloaderProgress) (downloadedChunk, error) { // Get the next byte range of data var request GetObjectRequest copyRequest(&request, d.request) @@ -546,6 +570,15 @@ func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash r io.Reader = reader crc64 uint64 = 0 ) + writer := io.MultiWriter() + if progress != nil { + writer = io.MultiWriter(writer, progress) + } + if hash != nil { + hash.Reset() + writer = io.MultiWriter(writer, hash) + } + r = io.TeeReader(reader, writer) if hash != nil { hash.Reset() r = io.TeeReader(reader, hash) diff --git a/oss/downloader_mock_test.go b/oss/downloader_mock_test.go index de29a1a..dc95307 100644 --- a/oss/downloader_mock_test.go +++ b/oss/downloader_mock_test.go @@ -1273,3 +1273,90 @@ func TestMockDownloaderDownloadFilePayer(t *testing.T) { }) assert.Nil(t, err) } + +func TestMockDownloaderWithProgress(t *testing.T) { + length := 3*1024*1024 + 1234 + data := []byte(randStr(length)) + gmtTime := getNowGMT() + tracker := &downloaderMockTracker{ + lastModified: gmtTime, + data: data, + } + server := testSetupDownloaderMockServer(t, tracker) + defer server.Close() + assert.NotNil(t, server) + cfg := LoadDefaultConfig(). + WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()). + WithRegion("cn-hangzhou"). + WithEndpoint(server.URL). + WithReadWriteTimeout(300 * time.Second) + client := NewClient(cfg) + var n int64 + d := client.NewDownloader(func(do *DownloaderOptions) { + do.ParallelNum = 1 + do.PartSize = 1 * 1024 * 1024 + }) + assert.NotNil(t, d) + assert.NotNil(t, d.client) + assert.Equal(t, int64(1*1024*1024), d.options.PartSize) + assert.Equal(t, 1, d.options.ParallelNum) + // filePath is invalid + _, err := d.DownloadFile( + context.TODO(), + &GetObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + }, + }, "") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "invalid field, filePath") + localFile := randStr(8) + "-no-surfix" + defer func() { + os.Remove(localFile) + }() + _, err = d.DownloadFile( + context.TODO(), + &GetObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + }, + }, localFile) + assert.Nil(t, err) + assert.Equal(t, n, int64(length)) + n = int64(0) + d = client.NewDownloader(func(do *DownloaderOptions) { + do.ParallelNum = 3 + do.PartSize = 3 * 1024 * 1024 + }) + assert.NotNil(t, d) + assert.NotNil(t, d.client) + assert.Equal(t, int64(3*1024*1024), d.options.PartSize) + assert.Equal(t, 3, d.options.ParallelNum) + // filePath is invalid + _, err = d.DownloadFile( + context.TODO(), + &GetObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + }, + }, "") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "invalid field, filePath") + _, err = d.DownloadFile( + context.TODO(), + &GetObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + }, + }, localFile) + assert.Nil(t, err) + assert.Equal(t, n, int64(length)) +} diff --git a/oss/uploader.go b/oss/uploader.go index 4daf578..ae34912 100644 --- a/oss/uploader.go +++ b/oss/uploader.go @@ -26,6 +26,21 @@ type UploaderOptions struct { ClientOptions []func(*Options) } +type uploaderProgress struct { + pr ProgressFunc + written int64 + total int64 + mu sync.Mutex +} + +func (cpt *uploaderProgress) incrWritten(n int64) { + increment := n + cpt.mu.Lock() + defer cpt.mu.Unlock() + cpt.written += increment + cpt.pr(increment, cpt.written, cpt.total) +} + type Uploader struct { options UploaderOptions client UploadAPIClient @@ -539,7 +554,7 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) { } // readChunk runs in worker goroutines to pull chunks off of the ch channel - readChunkFn := func(ch chan uploaderChunk) { + readChunkFn := func(ch chan uploaderChunk, progress *uploaderProgress) { defer wg.Done() for { data, ok := <-ch @@ -563,6 +578,9 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) { //fmt.Printf("UploadPart result: %#v, %#v\n", upResult, err) if err == nil { + if progress != nil { + progress.incrWritten(int64(data.size)) + } mu.Lock() parts = append(parts, UploadPart{ETag: upResult.ETag, PartNumber: data.partNum}) if enableCRC { @@ -579,9 +597,16 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) { } ch := make(chan uploaderChunk, u.options.ParallelNum) + var progress *uploaderProgress + if u.request.ProgressFn != nil { + progress = &uploaderProgress{ + pr: u.request.ProgressFn, + total: u.totalSize, + } + } for i := 0; i < u.options.ParallelNum; i++ { wg.Add(1) - go readChunkFn(ch) + go readChunkFn(ch, progress) } // Read and queue the parts diff --git a/oss/uploader_mock_test.go b/oss/uploader_mock_test.go index e41a8bf..60b109e 100644 --- a/oss/uploader_mock_test.go +++ b/oss/uploader_mock_test.go @@ -1858,3 +1858,125 @@ func TestMockUploadWithPayer(t *testing.T) { bytes.NewReader(data)) assert.Nil(t, err) } + +func TestMockUploadSinglePartFromFileWithProgress(t *testing.T) { + partSize := DefaultUploadPartSize + length := 5*100*1024 + 123 + partsNum := length/int(partSize) + 1 + tracker := &uploaderMockTracker{ + partNum: partsNum, + saveDate: make([][]byte, partsNum), + checkTime: make([]time.Time, partsNum), + timeout: make([]time.Duration, partsNum), + uploadPartErr: make([]bool, partsNum), + } + + data := []byte(randStr(length)) + hash := NewCRC64(0) + hash.Write(data) + dataCrc64ecma := fmt.Sprint(hash.Sum64()) + + localFile := randStr(8) + ".txt" + createFileFromByte(t, localFile, data) + defer func() { + os.Remove(localFile) + }() + + server := testSetupUploaderMockServer(t, tracker) + defer server.Close() + assert.NotNil(t, server) + + cfg := LoadDefaultConfig(). + WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()). + WithRegion("cn-hangzhou"). + WithEndpoint(server.URL). + WithReadWriteTimeout(300 * time.Second) + + client := NewClient(cfg) + u := NewUploader(client) + + assert.NotNil(t, u.client) + assert.Equal(t, DefaultUploadParallel, u.options.ParallelNum) + assert.Equal(t, DefaultUploadPartSize, u.options.PartSize) + + n := int64(0) + result, err := u.UploadFile(context.TODO(), &PutObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total) + }, + }, localFile) + assert.Nil(t, err) + assert.NotNil(t, result) + assert.Nil(t, result.UploadId) + assert.Equal(t, dataCrc64ecma, *result.HashCRC64) + assert.Equal(t, n, int64(length)) +} + +func TestMockUploadParallelFromFileWithProgress(t *testing.T) { + partSize := int64(100 * 1024) + length := 5*100*1024 + 123 + partsNum := length/int(partSize) + 1 + tracker := &uploaderMockTracker{ + partNum: partsNum, + saveDate: make([][]byte, partsNum), + checkTime: make([]time.Time, partsNum), + timeout: make([]time.Duration, partsNum), + uploadPartErr: make([]bool, partsNum), + } + + data := []byte(randStr(length)) + hash := NewCRC64(0) + hash.Write(data) + dataCrc64ecma := fmt.Sprint(hash.Sum64()) + + localFile := randStr(8) + "-no-surfix" + createFileFromByte(t, localFile, data) + defer func() { + os.Remove(localFile) + }() + + server := testSetupUploaderMockServer(t, tracker) + defer server.Close() + assert.NotNil(t, server) + + cfg := LoadDefaultConfig(). + WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()). + WithRegion("cn-hangzhou"). + WithEndpoint(server.URL). + WithReadWriteTimeout(300 * time.Second) + + client := NewClient(cfg) + + u := NewUploader(client, + func(uo *UploaderOptions) { + uo.ParallelNum = 4 + uo.PartSize = partSize + }, + ) + assert.Equal(t, 4, u.options.ParallelNum) + assert.Equal(t, partSize, u.options.PartSize) + + tracker.timeout[0] = 1 * time.Second + tracker.timeout[2] = 500 * time.Millisecond + + n := int64(0) + result, err := u.UploadFile(context.TODO(), &PutObjectRequest{ + Bucket: Ptr("bucket"), + Key: Ptr("key"), + ProgressFn: func(increment, transferred, total int64) { + n = transferred + fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total) + }, + }, localFile) + assert.Nil(t, err) + assert.NotNil(t, result) + + assert.Nil(t, err) + assert.NotNil(t, result) + assert.Equal(t, "uploadId-1234", *result.UploadId) + assert.Equal(t, dataCrc64ecma, *result.HashCRC64) + assert.Equal(t, n, int64(length)) +} \ No newline at end of file