Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientIP middleware proposal, intended to replace RealIP #967

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions middleware/client_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package middleware

import (
"context"
"net"
"net/http"
"net/netip"
"strings"
)

var (
// clientIPCtxKey is the context key used to store the client IP address.
clientIPCtxKey = &contextKey{"clientIP"}
)

// ClientIPFromHeader parses the client IP address from a specified HTTP header
// (e.g., X-Real-IP, CF-Connecting-IP) and injects it into the request context
// if it is not already set. The parsed IP address can be retrieved using GetClientIP().
//
// The middleware validates the IP address to ignore loopback, private, and unspecified addresses.
//
// ### Important Notice:
// - Use this middleware only when your infrastructure sets a trusted header containing the client IP.
// - If the specified header is not securely set by your infrastructure, malicious clients could spoof it.
//
// Example trusted headers:
// - "X-Real-IP" - Nginx (ngx_http_realip_module)
// - "X-Client-IP" - Apache (mod_remoteip)
// - "CF-Connecting-IP" - Cloudflare
// - "True-Client-IP" - Akamai, Cloudflare Enterprise
// - "X-Azure-ClientIP" - Azure Front Door
// - "Fastly-Client-IP" - Fastly
func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Parse the IP address from the trusted header.
ip, err := netip.ParseAddr(r.Header.Get(trustedHeader))
if err != nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() {
// Ignore invalid or private IPs.
h.ServeHTTP(w, r)
return
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}

// ClientIPFromXFFHeader parses the client IP address from the X-Forwarded-For
// header and injects it into the request context if it is not already set. The
// parsed IP address can be retrieved using GetClientIP().
//
// The middleware traverses the X-Forwarded-For chain (rightmost untrusted IP)
// and excludes loopback, private, unspecified, and trusted IP ranges.
//
// ### Important Notice:
// - Use this middleware only when your infrastructure sets and validates the X-Forwarded-For header.
// - Malicious clients can spoof the header unless a trusted reverse proxy or load balancer sanitizes it.
//
// Parameters:
// - `trustedIPPrefixes`: A list of CIDR prefixes that define trusted proxy IP ranges.
//
// Example trusted IP ranges:
// - "203.0.113.0/24" - Example corporate proxy
// - "198.51.100.0/24" - Example data center or hosting provider
// - "2400:cb00::/32" - Cloudflare IPv6 range
// - "2606:4700::/32" - Cloudflare IPv6 range
// - "192.0.2.0/24" - Example VPN gateway
//
// Note: Private IP ranges (e.g., "10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12")
// are automatically excluded by netip.Addr.IsPrivate() and do not need to be added here.
func ClientIPFromXFFHeader(trustedIPPrefixes ...string) func(http.Handler) http.Handler {
// Pre-parse trusted prefixes.
trustedPrefixes := make([]netip.Prefix, len(trustedIPPrefixes))
for i, ipRange := range trustedIPPrefixes {
trustedPrefixes[i] = netip.MustParsePrefix(ipRange)
}

return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Parse and split the X-Forwarded-For header(s).
xff := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",")
nextValue:
for i := len(xff) - 1; i >= 0; i-- {
ip, err := netip.ParseAddr(strings.TrimSpace(xff[i]))
if err != nil {
continue
}

// Ignore loopback, private, or unspecified addresses.
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() {
continue
}

// Ignore trusted IPs within the given ranges.
for _, prefix := range trustedPrefixes {
if prefix.Contains(ip) {
continue nextValue
}
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
return
}

h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

// ClientIPFromRemoteAddr extracts the client IP address from the RemoteAddr
// field of the HTTP request and injects it into the request context if it is
// not already set. The parsed IP address can be retrieved using GetClientIP().
//
// The middleware ignores invalid or private IPs.
//
// ### Use Case:
// This middleware is useful when the client IP cannot be determined from headers
// such as X-Forwarded-For or X-Real-IP, and you need to fall back to RemoteAddr.
func ClientIPFromRemoteAddr(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Extract the IP from request RemoteAddr.
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
h.ServeHTTP(w, r)
return
}

ip, err := netip.ParseAddr(host)
if err != nil {
h.ServeHTTP(w, r)
return
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

// GetClientIP retrieves the client IP address from the given context.
// The IP address is set by one of the following middlewares:
// - ClientIPFromHeader
// - ClientIPFromXFFHeader
// - ClientIPFromRemoteAddr
//
// Returns an empty string if no valid IP is found.
func GetClientIP(ctx context.Context) string {
ip, ok := ctx.Value(clientIPCtxKey).(netip.Addr)
if !ok || !ip.IsValid() {
return ""
}
return ip.String()
}
141 changes: 141 additions & 0 deletions middleware/client_ip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
)

func TestClientIPFromHeader(t *testing.T) {
tt := []struct {
name string
in string
out string
}{
// Empty header.
{name: "empty", in: "", out: ""},

// Valid X-Real-IP header values.
{name: "valid_ipv4", in: "100.100.100.100", out: "100.100.100.100"},
{name: "valid_ipv4", in: "178.25.203.2", out: "178.25.203.2"},
{name: "valid_ipv6_lower", in: "2345:0425:2ca1:0000:0000:0567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_upper", in: "2345:0425:2CA1:0000:0000:0567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_lower_short", in: "2345:425:2ca1::567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_upper_short", in: "2345:425:2CA1::567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},

// Invalid X-Real-IP header values.
{name: "invalid_ip", in: "invalid", out: ""},
{name: "invalid_ip_with_port", in: "100.100.100.100:80", out: ""},
{name: "invalid_multiple_ips", in: "100.100.100.100;100.100.100.101;100.100.100.102", out: ""},
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
{name: "invalid_zeroes", in: "0.0.0.0", out: ""},
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
{name: "invalid_private_ipv4_1", in: "192.168.0.1", out: ""},
{name: "invalid_private_ipv4_2", in: "192.168.10.12", out: ""},
{name: "invalid_private_ipv4_3", in: "172.16.0.0", out: ""},
{name: "invalid_private_ipv4_4", in: "172.25.203.2", out: ""},
{name: "invalid_private_ipv4_5", in: "10.0.0.0", out: ""},
{name: "invalid_private_ipv4_6", in: "10.0.1.10", out: ""},
{name: "invalid_private_ipv6_1", in: "fc00::1", out: ""},
{name: "invalid_private_ipv6_2", in: "fc00:0425:2ca1:0000:0000:0567:5673:23b5", out: ""},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Real-IP", tc.in)
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(ClientIPFromHeader("X-Real-IP"))

var clientIP string
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

if clientIP != tc.out {
t.Errorf("expected %v, got %v", tc.out, clientIP)
}
})
}
}

func TestClientIPFromXFFHeader(t *testing.T) {
tt := []struct {
name string
xff []string
out string
}{
{name: "empty", xff: []string{""}, out: ""},

{name: "", xff: []string{"100.100.100.100"}, out: "100.100.100.100"},
{name: "", xff: []string{"100.100.100.100, 200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"100.100.100.100,200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"100.100.100.100", "200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
{name: "", xff: []string{"203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
{name: "", xff: []string{"5.5.5.5, 203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348", "7.7.7.7, 4.4.4.4"}, out: "4.4.4.4"},
}

r := chi.NewRouter()
r.Use(ClientIPFromXFFHeader())

for _, tc := range tt {
req, _ := http.NewRequest("GET", "/", nil)
for _, v := range tc.xff {
req.Header.Add("X-Forwarded-For", v)
}

w := httptest.NewRecorder()

clientIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

if clientIP != tc.out {
t.Errorf("expected %v, got %v", tc.out, clientIP)
}
}
}

func TestClientIPFromRemoteAddr(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.0.2.1:1234" // Simulate the remote address set by http.Server.

w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(ClientIPFromRemoteAddr)

var clientIP string
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

expected := "192.0.2.1"
if clientIP != expected {
t.Errorf("expected %v, got %v", expected, clientIP)
}
}
Loading