diff --git a/client/daemon/proxy/proxy.go b/client/daemon/proxy/proxy.go index 9fc03a34892..ad6d1b2610d 100644 --- a/client/daemon/proxy/proxy.go +++ b/client/daemon/proxy/proxy.go @@ -22,6 +22,7 @@ import ( "encoding/base64" "errors" "io" + "mime" "net" "net/http" "net/http/httputil" @@ -401,6 +402,25 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { return cs[:s], cs[s+1:], true } +// flushInterval returns zero, conditionally +// overriding its value for a specific request/response. +func (proxy *Proxy) flushInterval(res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { + return -1 // negative means immediately + } + + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + + return 0 +} + func (proxy *Proxy) handleHTTP(span trace.Span, w http.ResponseWriter, req *http.Request) { resp, err := proxy.transport.RoundTrip(req) if err != nil { @@ -412,7 +432,26 @@ func (proxy *Proxy) handleHTTP(span trace.Span, w http.ResponseWriter, req *http copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode)) - if n, err := io.Copy(w, resp.Body); err != nil && err != io.EOF { + + // support event stream responses, see: https://github.com/golang/go/issues/2012 + var lw io.Writer = w + if flushInterval := proxy.flushInterval(resp); flushInterval != 0 { + mlw := &maxLatencyWriter{ + dst: w, + flush: http.NewResponseController(w).Flush, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + lw = mlw + logger.Debugf("handle event stream response: %s, url:%s", req.Host, req.URL.String()) + } + + if n, err := io.Copy(lw, resp.Body); err != nil && err != io.EOF { if peerID := resp.Header.Get(config.HeaderDragonflyPeer); peerID != "" { logger.Errorf("failed to write http body: %v, peer: %s, task: %s, written bytes: %d", err, peerID, resp.Header.Get(config.HeaderDragonflyTask), n) diff --git a/client/daemon/proxy/proxy_test.go b/client/daemon/proxy/proxy_test.go index 3a8f486991b..56d8ace7458 100644 --- a/client/daemon/proxy/proxy_test.go +++ b/client/daemon/proxy/proxy_test.go @@ -18,9 +18,13 @@ package proxy import ( "fmt" + "io" "net/http" "net/url" + "strconv" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" @@ -133,6 +137,56 @@ func (tc *testCase) TestMirror(t *testing.T) { } } +func (tc *testCase) TestEventStream(t *testing.T) { + a := assert.New(t) + if !a.Nil(tc.Error) { + return + } + tp, err := NewProxy(WithRules(tc.Rules)) + if !a.Nil(err) { + return + } + tp.transport = &mockTransport{} + for _, item := range tc.Items { + req, err := http.NewRequest("GET", item.URL, nil) + if !a.Nil(err) { + continue + } + if !a.Equal(tp.shouldUseDragonfly(req), !item.Direct) { + fmt.Println(item.URL) + } + if item.UseHTTPS { + a.Equal(req.URL.Scheme, "https") + } else { + a.Equal(req.URL.Scheme, "http") + } + if item.Redirect != "" { + a.Equal(item.Redirect, req.URL.String()) + } + if strings.Contains(req.URL.Path, "event-stream") { + batch := 10 + _, span := tp.tracer.Start(req.Context(), config.SpanProxy) + w := &mockResponseWriter{} + req.Header.Set("X-Response-Batch", strconv.Itoa(batch)) + if req.URL.Path == "/event-stream" { + req.Header.Set("X-Event-Stream", "true") + req.Header.Set("X-Response-Content-Length", "-1") + req.Header.Set("X-Response-Content-Encoding", "chunked") + req.Header.Set("X-Response-Content-Type", "text/event-stream") + tp.handleHTTP(span, w, req) + a.GreaterOrEqual(w.flushCount, batch) + } else { + req.Header.Set("X-Event-Stream", "false") + req.Header.Set("X-Response-Content-Length", strconv.Itoa(batch)) + req.Header.Set("X-Response-Content-Encoding", "") + req.Header.Set("X-Response-Content-Type", "application/octet-stream") + tp.handleHTTP(span, w, req) + a.Less(w.flushCount, batch) + } + } + } +} + func TestMatch(t *testing.T) { newTestCase(). WithRule("/blobs/sha256/", false, false, ""). @@ -235,3 +289,64 @@ func TestMatchWithRedirect(t *testing.T) { TestMirror(t) } + +func TestProxyEventStream(t *testing.T) { + newTestCase(). + WithRule("/blobs/sha256/", false, false, ""). + WithTest("http://h/event-stream", true, false, ""). + WithTest("http://h/not-event-stream", true, false, ""). + TestEventStream(t) +} + +type mockResponseWriter struct { + flushCount int +} + +func (w *mockResponseWriter) Header() http.Header { + return http.Header{} +} + +func (w *mockResponseWriter) Write(p []byte) (int, error) { + return len(string(p)), nil +} + +func (w *mockResponseWriter) WriteHeader(int) {} + +func (w *mockResponseWriter) Flush() { + w.flushCount++ +} + +type mockTransport struct{} + +func (rt *mockTransport) RoundTrip(r *http.Request) (*http.Response, error) { + batch, _ := strconv.Atoi(r.Header.Get("X-Response-Batch")) + return &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{batch: batch}, + Header: http.Header{ + "Content-Length": []string{r.Header.Get("X-Response-Content-Length")}, + "Content-Encoding": []string{r.Header.Get("X-Response-Content-Encoding")}, + "Content-Type": []string{r.Header.Get("X-Response-Content-Type")}, + }, + }, nil +} + +type mockReadCloser struct { + batch int + count int +} + +func (rc *mockReadCloser) Read(p []byte) (n int, err error) { + if rc.count == rc.batch { + return 0, io.EOF + } + time.Sleep(100 * time.Millisecond) + p[0] = '0' + p = p[:1] + rc.count++ + return len(p), nil +} + +func (rc *mockReadCloser) Close() error { + return nil +} diff --git a/client/daemon/proxy/proxy_writter.go b/client/daemon/proxy/proxy_writter.go new file mode 100644 index 00000000000..b347ddd9633 --- /dev/null +++ b/client/daemon/proxy/proxy_writter.go @@ -0,0 +1,57 @@ +package proxy + +import ( + "io" + "sync" + "time" +) + +// copy from golang library, see https://github.com/golang/go/blob/master/src/net/http/httputil/reverseproxy.go +type maxLatencyWriter struct { + dst io.Writer + flush func() error + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.flush() // nolint: errcheck + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.flush() // nolint: errcheck + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +}