Skip to content

Commit

Permalink
utils/stdcopy: fix hang in StdCopy
Browse files Browse the repository at this point in the history
Also added some automated tests. I will check whether we can import this
kind of stuff from Dockcer again. I expect more stability now that
they've released the 1.0 version.

Fix #114.
  • Loading branch information
fsouza committed Jul 9, 2014
1 parent b4230ff commit 9dba2cd
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 6 deletions.
21 changes: 15 additions & 6 deletions utils/stdcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
if nr < StdWriterPrefixLen && nr2 < StdWriterPrefixLen {
return written, nil
}
nr += nr2
break
} else if er != nil {
return 0, er
}
Expand Down Expand Up @@ -118,7 +120,7 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
// Extend it if necessary.
if frameSize+StdWriterPrefixLen > bufLen {
Debugf("Extending buffer cap.")
buf = append(buf, make([]byte, frameSize-len(buf)+1)...)
buf = append(buf, make([]byte, frameSize+StdWriterPrefixLen-len(buf)+1)...)
bufLen = len(buf)
}

Expand All @@ -127,17 +129,24 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
var nr2 int
nr2, er = src.Read(buf[nr:])
if er == io.EOF {
return written, nil
}
if er != nil {
if nr == 0 {
return written, nil
}
nr += nr2
break
} else if er != nil {
Debugf("Error reading frame: %s", er)
return 0, er
}
nr += nr2
}

// Write the retrieved frame (without header)
nw, ew = out.Write(buf[StdWriterPrefixLen : frameSize+StdWriterPrefixLen])
bound := frameSize + StdWriterPrefixLen
if bound > nr {
bound = nr
}
nw, ew = out.Write(buf[StdWriterPrefixLen:bound])
if nw > 0 {
written += int64(nw)
}
Expand All @@ -148,7 +157,7 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
// If the frame has not been fully written: error
if nw != frameSize {
Debugf("Error Short Write: (%d on %d)", nw, frameSize)
return 0, io.ErrShortWrite
return written, io.ErrShortWrite
}

// Move the rest of the buffer to the beginning
Expand Down
217 changes: 217 additions & 0 deletions utils/stdcopy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// Copyright 2014 go-dockerclient authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the DOCKER-LICENSE file.

package utils

import (
"bytes"
"errors"
"io"
"strings"
"testing"
"testing/iotest"
)

type errorWriter struct {
}

func (errorWriter) Write([]byte) (int, error) {
return 0, errors.New("something went wrong")
}

func TestStdCopy(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
input.Write([]byte("something happened!"))
input.Write([]byte{1, 0, 0, 0, 0, 0, 0, 12})
input.Write([]byte("just kidding"))
input.Write([]byte{0, 0, 0, 0, 0, 0, 0, 6})
input.Write([]byte("\nyeah!"))
n, err := StdCopy(&stdout, &stderr, &input)
if err != nil {
t.Fatal(err)
}
if expected := int64(19 + 12 + 6); n != expected {
t.Errorf("Wrong number of bytes. Want %d. Got %d.", expected, n)
}
if got := stderr.String(); got != "something happened!" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "something happened!", got)
}
if got := stdout.String(); got != "just kidding\nyeah!" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "just kidding\nyeah!", got)
}
}

func TestStdCopyStress(t *testing.T) {
var input, stdout, stderr bytes.Buffer
value := strings.Repeat("something ", 4096)
writer := NewStdWriter(&input, Stdout)
writer.Write([]byte(value))
n, err := StdCopy(&stdout, &stderr, &input)
if err != nil {
t.Fatal(err)
}
if n != 40960 {
t.Errorf("Wrong number of bytes. Want 40960. Got %d.", n)
}
if got := stderr.String(); got != "" {
t.Errorf("StdCopy: wrong stderr. Want empty string. Got %q", got)
}
if got := stdout.String(); got != value {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q", value, got)
}
}

func TestStdCopyInvalidStdHeader(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{3, 0, 0, 0, 0, 0, 0, 19})
n, err := StdCopy(&stdout, &stderr, &input)
if n != 0 {
t.Errorf("StdCopy: wrong number of bytes. Want 0. Got %d", n)
}
if err != ErrInvalidStdHeader {
t.Errorf("StdCopy: wrong error. Want ErrInvalidStdHeader. Got %#v", err)
}
}

func TestStdCopyBigFrame(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 18})
input.Write([]byte("something happened!"))
n, err := StdCopy(&stdout, &stderr, &input)
if err != nil {
t.Fatal(err)
}
if expected := int64(18); n != expected {
t.Errorf("Wrong number of bytes. Want %d. Got %d.", expected, n)
}
if got := stderr.String(); got != "something happened" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "something happened", got)
}
if got := stdout.String(); got != "" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "", got)
}
}

func TestStdCopySmallFrame(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 20})
input.Write([]byte("something happened!"))
n, err := StdCopy(&stdout, &stderr, &input)
if err != io.ErrShortWrite {
t.Errorf("StdCopy: wrong error. Want ShortWrite. Got %#v", err)
}
if expected := int64(19); n != expected {
t.Errorf("Wrong number of bytes. Want %d. Got %d.", expected, n)
}
if got := stderr.String(); got != "something happened!" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "something happened", got)
}
if got := stdout.String(); got != "" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "", got)
}
}

func TestStdCopyEmpty(t *testing.T) {
var input, stdout, stderr bytes.Buffer
n, err := StdCopy(&stdout, &stderr, &input)
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Errorf("StdCopy: wrong number of bytes. Want 0. Got %d.", n)
}
}

func TestStdCopyCorruptedHeader(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0})
n, err := StdCopy(&stdout, &stderr, &input)
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Errorf("StdCopy: wrong number of bytes. Want 0. Got %d.", n)
}
}

func TestStdCopyTruncateWriter(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
input.Write([]byte("something happened!"))
n, err := StdCopy(&stdout, iotest.TruncateWriter(&stderr, 7), &input)
if err != nil {
t.Fatal(err)
}
if expected := int64(19); n != expected {
t.Errorf("Wrong number of bytes. Want %d. Got %d.", expected, n)
}
if got := stderr.String(); got != "somethi" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "somethi", got)
}
if got := stdout.String(); got != "" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "", got)
}
}

func TestStdCopyHeaderOnly(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
n, err := StdCopy(&stdout, iotest.TruncateWriter(&stderr, 7), &input)
if err != io.ErrShortWrite {
t.Errorf("StdCopy: wrong error. Want ShortWrite. Got %#v", err)
}
if n != 0 {
t.Errorf("Wrong number of bytes. Want 0. Got %d.", n)
}
if got := stderr.String(); got != "" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "", got)
}
if got := stdout.String(); got != "" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "", got)
}
}

func TestStdCopyDataErrReader(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
input.Write([]byte("something happened!"))
n, err := StdCopy(&stdout, &stderr, iotest.DataErrReader(&input))
if err != nil {
t.Fatal(err)
}
if expected := int64(19); n != expected {
t.Errorf("Wrong number of bytes. Want %d. Got %d.", expected, n)
}
if got := stderr.String(); got != "something happened!" {
t.Errorf("StdCopy: wrong stderr. Want %q. Got %q.", "something happened!", got)
}
if got := stdout.String(); got != "" {
t.Errorf("StdCopy: wrong stdout. Want %q. Got %q.", "", got)
}
}

func TestStdCopyTimeoutReader(t *testing.T) {
var input, stdout, stderr bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
input.Write([]byte("something happened!"))
_, err := StdCopy(&stdout, &stderr, iotest.TimeoutReader(&input))
if err != iotest.ErrTimeout {
t.Errorf("StdCopy: wrong error. Want ErrTimeout. Got %#v.", err)
}
}

func TestStdCopyWriteError(t *testing.T) {
var input bytes.Buffer
input.Write([]byte{2, 0, 0, 0, 0, 0, 0, 19})
input.Write([]byte("something happened!"))
var stdout, stderr errorWriter
n, err := StdCopy(stdout, stderr, &input)
if err.Error() != "something went wrong" {
t.Errorf("StdCopy: wrong error. Want %q. Got %q", "something went wrong", err)
}
if n != 0 {
t.Errorf("StdCopy: wrong number of bytes. Want 0. Got %d.", n)
}
}

0 comments on commit 9dba2cd

Please sign in to comment.