operations: Make --max-transfer more accurate

Before this change we checked the transfer was out of range only
before the Read call. This means that we returned all the data to the
reader before declaring an error. This means that some backends wrote
the file even though an error was returned.

This fix checks the transfer after the Read as well, and chops the
excess characters off the read data if we are over the limit so that
we don't ever deliver all the data.

This fixes the tests introduced as part of 6f1766dd9e and #2672
on backends other than local.
master
Nick Craig-Wood 2020-03-13 16:20:15 +00:00
parent 6fdd7149c1
commit 6d0063d685
3 changed files with 44 additions and 25 deletions

View File

@ -171,19 +171,22 @@ func (acc *Account) averageLoop() {
} }
} }
// Check the read is valid // Check the read is valid returning the number of bytes it is over
func (acc *Account) checkRead() (err error) { func (acc *Account) checkRead() (over int64, err error) {
acc.statmu.Lock() acc.statmu.Lock()
if acc.max >= 0 && acc.stats.GetBytes() >= acc.max { if acc.max >= 0 {
acc.statmu.Unlock() over = acc.stats.GetBytes() - acc.max
return ErrorMaxTransferLimitReachedFatal if over >= 0 {
acc.statmu.Unlock()
return over, ErrorMaxTransferLimitReachedFatal
}
} }
// Set start time. // Set start time.
if acc.start.IsZero() { if acc.start.IsZero() {
acc.start = time.Now() acc.start = time.Now()
} }
acc.statmu.Unlock() acc.statmu.Unlock()
return nil return over, nil
} }
// ServerSideCopyStart should be called at the start of a server side copy // ServerSideCopyStart should be called at the start of a server side copy
@ -223,10 +226,18 @@ func (acc *Account) accountRead(n int) {
// read bytes from the io.Reader passed in and account them // read bytes from the io.Reader passed in and account them
func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { func (acc *Account) read(in io.Reader, p []byte) (n int, err error) {
err = acc.checkRead() _, err = acc.checkRead()
if err == nil { if err == nil {
n, err = in.Read(p) n, err = in.Read(p)
acc.accountRead(n) acc.accountRead(n)
if over, checkErr := acc.checkRead(); checkErr == ErrorMaxTransferLimitReachedFatal {
// chop the overage off
n -= int(over)
if n < 0 {
n = 0
}
err = checkErr
}
} }
return n, err return n, err
} }
@ -242,7 +253,7 @@ func (acc *Account) Read(p []byte) (n int, err error) {
func (acc *Account) AccountRead(n int) (err error) { func (acc *Account) AccountRead(n int) (err error) {
acc.mu.Lock() acc.mu.Lock()
defer acc.mu.Unlock() defer acc.mu.Unlock()
err = acc.checkRead() _, err = acc.checkRead()
if err == nil { if err == nil {
acc.accountRead(n) acc.accountRead(n)
} }

View File

@ -215,8 +215,8 @@ func TestAccountMaxTransfer(t *testing.T) {
assert.Equal(t, 10, n) assert.Equal(t, 10, n)
assert.NoError(t, err) assert.NoError(t, err)
n, err = acc.Read(b) n, err = acc.Read(b)
assert.Equal(t, 10, n) assert.Equal(t, 5, n)
assert.NoError(t, err) assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err)
n, err = acc.Read(b) n, err = acc.Read(b)
assert.Equal(t, 0, n) assert.Equal(t, 0, n)
assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err) assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err)

View File

@ -1541,45 +1541,53 @@ func TestCopyFileMaxTransfer(t *testing.T) {
accounting.Stats(context.Background()).ResetCounters() accounting.Stats(context.Background()).ResetCounters()
}() }()
ctx := context.Background()
file1 := r.WriteFile("file1", "file1 contents", t1) file1 := r.WriteFile("file1", "file1 contents", t1)
file2 := r.WriteFile("file2", "file2 contents...........", t2) file2 := r.WriteFile("file2", "file2 contents...........", t2)
rfile1 := file1 rfile1 := file1
rfile1.Path = "sub/file1" rfile1.Path = "sub/file1"
rfile2 := file2 rfile2a := file2
rfile2.Path = "sub/file2" rfile2a.Path = "sub/file2a"
rfile2b := file2
rfile2b.Path = "sub/file2b"
rfile2c := file2
rfile2c.Path = "sub/file2c"
fs.Config.MaxTransfer = 15 fs.Config.MaxTransfer = 15
fs.Config.CutoffMode = fs.CutoffModeHard fs.Config.CutoffMode = fs.CutoffModeHard
accounting.Stats(context.Background()).ResetCounters() accounting.Stats(ctx).ResetCounters()
err := operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile1.Path, file1.Path) err := operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile1.Path, file1.Path)
require.NoError(t, err) require.NoError(t, err)
fstest.CheckItems(t, r.Flocal, file1, file2) fstest.CheckItems(t, r.Flocal, file1, file2)
fstest.CheckItems(t, r.Fremote, rfile1) fstest.CheckItems(t, r.Fremote, rfile1)
accounting.Stats(context.Background()).ResetCounters() accounting.Stats(ctx).ResetCounters()
err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2a.Path, file2.Path)
fstest.CheckItems(t, r.Flocal, file1, file2) require.NotNil(t, err)
fstest.CheckItems(t, r.Fremote, rfile1)
assert.Contains(t, err.Error(), "Max transfer limit reached") assert.Contains(t, err.Error(), "Max transfer limit reached")
assert.True(t, fserrors.IsFatalError(err)) assert.True(t, fserrors.IsFatalError(err))
fstest.CheckItems(t, r.Flocal, file1, file2)
fstest.CheckItems(t, r.Fremote, rfile1)
fs.Config.CutoffMode = fs.CutoffModeCautious fs.Config.CutoffMode = fs.CutoffModeCautious
accounting.Stats(context.Background()).ResetCounters() accounting.Stats(ctx).ResetCounters()
err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2b.Path, file2.Path)
fstest.CheckItems(t, r.Flocal, file1, file2) require.NotNil(t, err)
fstest.CheckItems(t, r.Fremote, rfile1)
assert.Contains(t, err.Error(), "Max transfer limit reached") assert.Contains(t, err.Error(), "Max transfer limit reached")
assert.True(t, fserrors.IsFatalError(err)) assert.True(t, fserrors.IsFatalError(err))
fstest.CheckItems(t, r.Flocal, file1, file2)
fstest.CheckItems(t, r.Fremote, rfile1)
fs.Config.CutoffMode = fs.CutoffModeSoft fs.Config.CutoffMode = fs.CutoffModeSoft
accounting.Stats(context.Background()).ResetCounters() accounting.Stats(ctx).ResetCounters()
err = operations.CopyFile(context.Background(), r.Fremote, r.Flocal, rfile2.Path, file2.Path) err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2c.Path, file2.Path)
require.NoError(t, err) require.NoError(t, err)
fstest.CheckItems(t, r.Flocal, file1, file2) fstest.CheckItems(t, r.Flocal, file1, file2)
fstest.CheckItems(t, r.Fremote, rfile1, rfile2) fstest.CheckItems(t, r.Fremote, rfile1, rfile2c)
} }