From ce3340621f7d72d4c65a1246aa63fed44c36ea54 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Mon, 26 Aug 2019 12:17:53 +0100 Subject: [PATCH] lib/readers: add NoCloser to stop upgrades from io.Reader to io.ReadCloser --- lib/readers/noclose.go | 29 ++++++++++++++++++++++++ lib/readers/noclose_test.go | 44 +++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 lib/readers/noclose.go create mode 100644 lib/readers/noclose_test.go diff --git a/lib/readers/noclose.go b/lib/readers/noclose.go new file mode 100644 index 000000000..dc36e8be2 --- /dev/null +++ b/lib/readers/noclose.go @@ -0,0 +1,29 @@ +package readers + +import "io" + +// noClose is used to wrap an io.Reader to stop it being upgraded +type noClose struct { + in io.Reader +} + +// Read implements io.Closer by passing it straight on +func (nc noClose) Read(p []byte) (n int, err error) { + return nc.in.Read(p) +} + +// NoCloser makes sure that the io.Reader passed in can't upgraded to +// an io.Closer. +// +// This is for use with http.NewRequest to make sure the body doesn't +// get upgraded to an io.Closer and the body closed unexpectedly. +func NoCloser(in io.Reader) io.Reader { + if in == nil { + return in + } + // if in doesn't implement io.Closer, just return it + if _, canClose := in.(io.Closer); !canClose { + return in + } + return noClose{in: in} +} diff --git a/lib/readers/noclose_test.go b/lib/readers/noclose_test.go new file mode 100644 index 000000000..da784425c --- /dev/null +++ b/lib/readers/noclose_test.go @@ -0,0 +1,44 @@ +package readers + +import ( + "io" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var errRead = errors.New("read error") + +type readOnly struct{} + +func (readOnly) Read(p []byte) (n int, err error) { + return 0, io.EOF +} + +type readClose struct{} + +func (readClose) Read(p []byte) (n int, err error) { + return 0, errRead +} + +func (readClose) Close() (err error) { + return io.EOF +} + +func TestNoCloser(t *testing.T) { + assert.Equal(t, nil, NoCloser(nil)) + + ro := readOnly{} + assert.Equal(t, ro, NoCloser(ro)) + + rc := readClose{} + nc := NoCloser(rc) + assert.NotEqual(t, nc, rc) + + _, hasClose := nc.(io.Closer) + assert.False(t, hasClose) + + _, err := nc.Read(nil) + assert.Equal(t, errRead, err) +}