From 29737fec9b94eec244ac7bea5fb895c818e08f0a Mon Sep 17 00:00:00 2001 From: tomholford Date: Wed, 22 Apr 2026 20:07:29 -0700 Subject: [PATCH] reverseproxy: track WebTransport sessions in upstream in-flight counters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bracket the pump's lifetime with Host.countRequest(±1) and incInFlightRequest/decInFlightRequest so WT sessions participate in the same accounting as the normal proxy path: - MaxRequests gating (Upstream.Full) now blocks WT sessions past the cap, instead of silently failing open. - LeastConn / FirstAvailable selection sees WT load, instead of seeing busy upstreams as idle. - Admin /reverse_proxy/upstreams reports WT sessions under num_requests. Integration test holds an upstream session open via a standalone WT server, polls the admin API to assert num_requests increments during the session and drops back to 0 after close. --- caddytest/integration/webtransport_test.go | 151 ++++++++++++++++++ .../reverseproxy/webtransport_transport.go | 11 ++ 2 files changed, 162 insertions(+) diff --git a/caddytest/integration/webtransport_test.go b/caddytest/integration/webtransport_test.go index edec51fb1..ae146a3bb 100644 --- a/caddytest/integration/webtransport_test.go +++ b/caddytest/integration/webtransport_test.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/json" "fmt" "io" "math/big" @@ -540,6 +541,156 @@ func TestWebTransport_UpstreamDialFailureSurfaces5xx(t *testing.T) { } } +// TestWebTransport_InFlightRequestsTracked proves the WT proxy path +// increments upstream.Host.NumRequests for the session's lifetime and +// decrements after it ends, so MaxRequests gating, LeastConn/FirstAvailable +// LB, and the admin /reverse_proxy/upstreams endpoint reflect WT load. +func TestWebTransport_InFlightRequestsTracked(t *testing.T) { + if testing.Short() { + t.Skip() + } + + // Upstream blocks on a release channel until the test finishes probing + // the admin API; this keeps the session alive long enough to observe + // num_requests > 0. + release := make(chan struct{}) + // t.Cleanup drains release in case the test bails early. + t.Cleanup(func() { + select { + case <-release: + default: + close(release) + } + }) + upstreamAddr, stopUpstream := startStandaloneWebTransport(t, func(sess *webtransport.Session, r *http.Request) { + <-release + _ = sess.CloseWithError(0, "") + }) + t.Cleanup(stopUpstream) + + upstreamDial := fmt.Sprintf("127.0.0.1:%d", upstreamAddr.Port) + config := fmt.Sprintf(`{ + "admin": {"listen": "localhost:2999"}, + "apps": { + "http": { + "http_port": 9080, + "https_port": 9443, + "grace_period": 1, + "servers": { + "proxy": { + "listen": [":9443"], + "protocols": ["h3"], + "routes": [ + { + "handle": [ + { + "handler": "reverse_proxy", + "transport": { + "protocol": "http", + "versions": ["3"], + "webtransport": true, + "tls": {"insecure_skip_verify": true} + }, + "upstreams": [{"dial": "%s"}] + } + ] + } + ], + "tls_connection_policies": [ + { + "certificate_selection": {"any_tag": ["cert0"]}, + "default_sni": "a.caddy.localhost" + } + ] + } + } + }, + "tls": { + "certificates": { + "load_files": [ + { + "certificate": "/a.caddy.localhost.crt", + "key": "/a.caddy.localhost.key", + "tags": ["cert0"] + } + ] + } + }, + "pki": {"certificate_authorities": {"local": {"install_trust": false}}} + } +}`, upstreamDial) + + tester := caddytest.NewTester(t) + tester.InitServer(config, "json") + + dialer := &webtransport.Dialer{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // local CA + ServerName: "a.caddy.localhost", + NextProtos: []string{http3.NextProtoH3}, + }, + QUICConfig: &quic.Config{ + EnableDatagrams: true, + EnableStreamResetPartialDelivery: true, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + _, sess, err := dialer.Dial(ctx, "https://127.0.0.1:9443/", nil) + if err != nil { + t.Fatalf("proxy Dial failed: %v", err) + } + + // Poll the admin API until we see num_requests >= 1 for our upstream. + if !waitForUpstreamRequests(t, upstreamDial, 1, 2*time.Second) { + t.Fatal("upstream num_requests never reached >= 1 while session was active") + } + + // Close the client session and release the upstream so the server-side + // handler returns; the deferred decrement in serveWebTransport should + // drop num_requests back to 0 once both sides close. + _ = sess.CloseWithError(0, "") + close(release) + + if !waitForUpstreamRequests(t, upstreamDial, 0, 2*time.Second) { + t.Fatal("upstream num_requests did not drop to 0 after session closed") + } +} + +// waitForUpstreamRequests polls the admin /reverse_proxy/upstreams endpoint +// until the entry for dial has exactly wantRequests in-flight, or timeout. +// Returns true on match. +func waitForUpstreamRequests(t *testing.T, dial string, wantRequests int, timeout time.Duration) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + rsp, err := http.Get("http://localhost:2999/reverse_proxy/upstreams") + if err != nil { + time.Sleep(50 * time.Millisecond) + continue + } + var entries []struct { + Address string `json:"address"` + NumRequests int `json:"num_requests"` + } + err = json.NewDecoder(rsp.Body).Decode(&entries) + _ = rsp.Body.Close() + if err != nil { + time.Sleep(50 * time.Millisecond) + continue + } + for _, e := range entries { + if e.Address == dial && e.NumRequests == wantRequests { + return true + } + } + time.Sleep(50 * time.Millisecond) + } + return false +} + // startStandaloneWebTransport starts a webtransport.Server on a random UDP // port with a self-signed cert. handler runs after a successful Upgrade. // Returns the listener addr and a shutdown func. diff --git a/modules/caddyhttp/reverseproxy/webtransport_transport.go b/modules/caddyhttp/reverseproxy/webtransport_transport.go index 6f9823a7e..d881cf2d2 100644 --- a/modules/caddyhttp/reverseproxy/webtransport_transport.go +++ b/modules/caddyhttp/reverseproxy/webtransport_transport.go @@ -185,6 +185,17 @@ func (h *Handler) serveWebTransport(w http.ResponseWriter, r *http.Request) erro fmt.Errorf("webtransport upgrade: %w", err)) } + // Track the session in the same upstream counters the normal proxy + // path maintains: Host.NumRequests drives MaxRequests gating and + // least-connections selection; the per-address in-flight counter + // feeds the admin API's upstream stats. + _ = dialInfo.Upstream.Host.countRequest(1) + incInFlightRequest(dialInfo.Address) + defer func() { + _ = dialInfo.Upstream.Host.countRequest(-1) + decInFlightRequest(dialInfo.Address) + }() + runWebTransportPump(clientSess, upstreamSess, h.logger) return nil }