reverseproxy: track WebTransport sessions in upstream in-flight counters

Bracket the pump's lifetime with Host.countRequest(±1) and
incInFlightRequest/decInFlightRequest so WT sessions participate in the
same accounting as the normal proxy path:

- MaxRequests gating (Upstream.Full) now blocks WT sessions past the
  cap, instead of silently failing open.
- LeastConn / FirstAvailable selection sees WT load, instead of seeing
  busy upstreams as idle.
- Admin /reverse_proxy/upstreams reports WT sessions under num_requests.

Integration test holds an upstream session open via a standalone WT
server, polls the admin API to assert num_requests increments during
the session and drops back to 0 after close.
This commit is contained in:
tomholford 2026-04-22 20:07:29 -07:00
parent 60f93f1a87
commit 29737fec9b
2 changed files with 162 additions and 0 deletions

View file

@ -22,6 +22,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"io"
"math/big"
@ -540,6 +541,156 @@ func TestWebTransport_UpstreamDialFailureSurfaces5xx(t *testing.T) {
}
}
// TestWebTransport_InFlightRequestsTracked proves the WT proxy path
// increments upstream.Host.NumRequests for the session's lifetime and
// decrements after it ends, so MaxRequests gating, LeastConn/FirstAvailable
// LB, and the admin /reverse_proxy/upstreams endpoint reflect WT load.
func TestWebTransport_InFlightRequestsTracked(t *testing.T) {
if testing.Short() {
t.Skip()
}
// Upstream blocks on a release channel until the test finishes probing
// the admin API; this keeps the session alive long enough to observe
// num_requests > 0.
release := make(chan struct{})
// t.Cleanup drains release in case the test bails early.
t.Cleanup(func() {
select {
case <-release:
default:
close(release)
}
})
upstreamAddr, stopUpstream := startStandaloneWebTransport(t, func(sess *webtransport.Session, r *http.Request) {
<-release
_ = sess.CloseWithError(0, "")
})
t.Cleanup(stopUpstream)
upstreamDial := fmt.Sprintf("127.0.0.1:%d", upstreamAddr.Port)
config := fmt.Sprintf(`{
"admin": {"listen": "localhost:2999"},
"apps": {
"http": {
"http_port": 9080,
"https_port": 9443,
"grace_period": 1,
"servers": {
"proxy": {
"listen": [":9443"],
"protocols": ["h3"],
"routes": [
{
"handle": [
{
"handler": "reverse_proxy",
"transport": {
"protocol": "http",
"versions": ["3"],
"webtransport": true,
"tls": {"insecure_skip_verify": true}
},
"upstreams": [{"dial": "%s"}]
}
]
}
],
"tls_connection_policies": [
{
"certificate_selection": {"any_tag": ["cert0"]},
"default_sni": "a.caddy.localhost"
}
]
}
}
},
"tls": {
"certificates": {
"load_files": [
{
"certificate": "/a.caddy.localhost.crt",
"key": "/a.caddy.localhost.key",
"tags": ["cert0"]
}
]
}
},
"pki": {"certificate_authorities": {"local": {"install_trust": false}}}
}
}`, upstreamDial)
tester := caddytest.NewTester(t)
tester.InitServer(config, "json")
dialer := &webtransport.Dialer{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, //nolint:gosec // local CA
ServerName: "a.caddy.localhost",
NextProtos: []string{http3.NextProtoH3},
},
QUICConfig: &quic.Config{
EnableDatagrams: true,
EnableStreamResetPartialDelivery: true,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
defer cancel()
_, sess, err := dialer.Dial(ctx, "https://127.0.0.1:9443/", nil)
if err != nil {
t.Fatalf("proxy Dial failed: %v", err)
}
// Poll the admin API until we see num_requests >= 1 for our upstream.
if !waitForUpstreamRequests(t, upstreamDial, 1, 2*time.Second) {
t.Fatal("upstream num_requests never reached >= 1 while session was active")
}
// Close the client session and release the upstream so the server-side
// handler returns; the deferred decrement in serveWebTransport should
// drop num_requests back to 0 once both sides close.
_ = sess.CloseWithError(0, "")
close(release)
if !waitForUpstreamRequests(t, upstreamDial, 0, 2*time.Second) {
t.Fatal("upstream num_requests did not drop to 0 after session closed")
}
}
// waitForUpstreamRequests polls the admin /reverse_proxy/upstreams endpoint
// until the entry for dial has exactly wantRequests in-flight, or timeout.
// Returns true on match.
func waitForUpstreamRequests(t *testing.T, dial string, wantRequests int, timeout time.Duration) bool {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
rsp, err := http.Get("http://localhost:2999/reverse_proxy/upstreams")
if err != nil {
time.Sleep(50 * time.Millisecond)
continue
}
var entries []struct {
Address string `json:"address"`
NumRequests int `json:"num_requests"`
}
err = json.NewDecoder(rsp.Body).Decode(&entries)
_ = rsp.Body.Close()
if err != nil {
time.Sleep(50 * time.Millisecond)
continue
}
for _, e := range entries {
if e.Address == dial && e.NumRequests == wantRequests {
return true
}
}
time.Sleep(50 * time.Millisecond)
}
return false
}
// startStandaloneWebTransport starts a webtransport.Server on a random UDP
// port with a self-signed cert. handler runs after a successful Upgrade.
// Returns the listener addr and a shutdown func.

View file

@ -185,6 +185,17 @@ func (h *Handler) serveWebTransport(w http.ResponseWriter, r *http.Request) erro
fmt.Errorf("webtransport upgrade: %w", err))
}
// Track the session in the same upstream counters the normal proxy
// path maintains: Host.NumRequests drives MaxRequests gating and
// least-connections selection; the per-address in-flight counter
// feeds the admin API's upstream stats.
_ = dialInfo.Upstream.Host.countRequest(1)
incInFlightRequest(dialInfo.Address)
defer func() {
_ = dialInfo.Upstream.Host.countRequest(-1)
decInFlightRequest(dialInfo.Address)
}()
runWebTransportPump(clientSess, upstreamSess, h.logger)
return nil
}