diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index af55e2688..1693ed2e9 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -171,19 +171,22 @@ func (acc *Account) averageLoop() { } } -// Check the read is valid -func (acc *Account) checkRead() (err error) { +// Check the read is valid returning the number of bytes it is over +func (acc *Account) checkRead() (over int64, err error) { acc.statmu.Lock() - if acc.max >= 0 && acc.stats.GetBytes() >= acc.max { - acc.statmu.Unlock() - return ErrorMaxTransferLimitReachedFatal + if acc.max >= 0 { + over = acc.stats.GetBytes() - acc.max + if over >= 0 { + acc.statmu.Unlock() + return over, ErrorMaxTransferLimitReachedFatal + } } // Set start time. if acc.start.IsZero() { acc.start = time.Now() } acc.statmu.Unlock() - return nil + return over, nil } // ServerSideCopyStart should be called at the start of a server side copy @@ -223,10 +226,18 @@ func (acc *Account) accountRead(n int) { // read bytes from the io.Reader passed in and account them func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { - err = acc.checkRead() + _, err = acc.checkRead() if err == nil { n, err = in.Read(p) acc.accountRead(n) + if over, checkErr := acc.checkRead(); checkErr == ErrorMaxTransferLimitReachedFatal { + // chop the overage off + n -= int(over) + if n < 0 { + n = 0 + } + err = checkErr + } } return n, err } @@ -242,7 +253,7 @@ func (acc *Account) Read(p []byte) (n int, err error) { func (acc *Account) AccountRead(n int) (err error) { acc.mu.Lock() defer acc.mu.Unlock() - err = acc.checkRead() + _, err = acc.checkRead() if err == nil { acc.accountRead(n) } diff --git a/fs/accounting/accounting_test.go b/fs/accounting/accounting_test.go index 083072e59..4b876e1cd 100644 --- a/fs/accounting/accounting_test.go +++ b/fs/accounting/accounting_test.go @@ -215,8 +215,8 @@ func TestAccountMaxTransfer(t *testing.T) { assert.Equal(t, 10, n) assert.NoError(t, err) n, err = acc.Read(b) - assert.Equal(t, 10, n) - assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err) n, err = acc.Read(b) assert.Equal(t, 0, n) assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err) diff --git a/fs/operations/operations_test.go b/fs/operations/operations_test.go index 192dd6610..e96b03741 100644 --- a/fs/operations/operations_test.go +++ b/fs/operations/operations_test.go @@ -1541,45 +1541,53 @@ func TestCopyFileMaxTransfer(t *testing.T) { accounting.Stats(context.Background()).ResetCounters() }() + ctx := context.Background() + file1 := r.WriteFile("file1", "file1 contents", t1) file2 := r.WriteFile("file2", "file2 contents...........", t2) rfile1 := file1 rfile1.Path = "sub/file1" - rfile2 := file2 - rfile2.Path = "sub/file2" + rfile2a := file2 + rfile2a.Path = "sub/file2a" + rfile2b := file2 + rfile2b.Path = "sub/file2b" + rfile2c := file2 + rfile2c.Path = "sub/file2c" fs.Config.MaxTransfer = 15 fs.Config.CutoffMode = fs.CutoffModeHard - accounting.Stats(context.Background()).ResetCounters() + accounting.Stats(ctx).ResetCounters() - err := operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile1.Path, file1.Path) + err := operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile1.Path, file1.Path) require.NoError(t, err) fstest.CheckItems(t, r.Flocal, file1, file2) fstest.CheckItems(t, r.Fremote, rfile1) - accounting.Stats(context.Background()).ResetCounters() + accounting.Stats(ctx).ResetCounters() - err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1) + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2a.Path, file2.Path) + require.NotNil(t, err) assert.Contains(t, err.Error(), "Max transfer limit reached") assert.True(t, fserrors.IsFatalError(err)) + fstest.CheckItems(t, r.Flocal, file1, file2) + fstest.CheckItems(t, r.Fremote, rfile1) fs.Config.CutoffMode = fs.CutoffModeCautious - accounting.Stats(context.Background()).ResetCounters() + accounting.Stats(ctx).ResetCounters() - err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1) + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2b.Path, file2.Path) + require.NotNil(t, err) assert.Contains(t, err.Error(), "Max transfer limit reached") assert.True(t, fserrors.IsFatalError(err)) + fstest.CheckItems(t, r.Flocal, file1, file2) + fstest.CheckItems(t, r.Fremote, rfile1) fs.Config.CutoffMode = fs.CutoffModeSoft - accounting.Stats(context.Background()).ResetCounters() + accounting.Stats(ctx).ResetCounters() - err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2c.Path, file2.Path) require.NoError(t, err) fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1, rfile2) + fstest.CheckItems(t, r.Fremote, rfile1, rfile2c) }