From 562ba1c461baa5243a0186df87281cc9b0c452a5 Mon Sep 17 00:00:00 2001 From: dmathieu <42@dmathieu.com> Date: Fri, 13 Sep 2024 10:13:05 +0200 Subject: [PATCH] ensure that a superfluous WriteHeader call panics in httptest --- src/net/http/httptest/recorder.go | 1 + src/net/http/httptest/recorder_test.go | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index dd51901b0d3b9..0c6dbbcfe97c3 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -142,6 +142,7 @@ func checkWriteHeaderCode(code int) { // WriteHeader implements [http.ResponseWriter]. func (rw *ResponseRecorder) WriteHeader(code int) { if rw.wroteHeader { + panic(fmt.Sprintf("superfluous response.WriteHeader call")) return } diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 4782eced43e6c..21b5e46d0ace0 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -138,7 +138,6 @@ func TestRecorder(t *testing.T) { "first code only", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(201) - w.WriteHeader(202) w.Write([]byte("hi")) }, check(hasStatus(201), hasContents("hi")), @@ -147,8 +146,6 @@ func TestRecorder(t *testing.T) { "write sends 200", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hi first")) - w.WriteHeader(201) - w.WriteHeader(202) }, check(hasStatus(200), hasContents("hi first"), hasFlush(false)), }, @@ -168,7 +165,6 @@ func TestRecorder(t *testing.T) { "flush", func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() // also sends a 200 - w.WriteHeader(201) }, check(hasStatus(200), hasFlush(true), hasContentLength(-1)), }, @@ -369,3 +365,20 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { }) } } + +// Ensure that httptest.Recorder panics when using WriteHeader twice. +func TestRecorderPanicsOnSuperfluousWriteHeader(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected a panic") + } + }() + + handler := func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(200) + rw.WriteHeader(201) + } + r, _ := http.NewRequest("GET", "http://example.org/", nil) + rw := NewRecorder() + handler(rw, r) +}