diff --git a/dns/transport/conn_pool.go b/dns/transport/conn_pool.go index ff288b773..0e7a20a8c 100644 --- a/dns/transport/conn_pool.go +++ b/dns/transport/conn_pool.go @@ -6,6 +6,8 @@ import ( "sync" "github.com/sagernet/sing/common/x/list" + + "golang.org/x/sync/semaphore" ) type ConnPoolMode int @@ -16,14 +18,18 @@ const ( ) type ConnPoolOptions[T comparable] struct { - Mode ConnPoolMode - IsAlive func(T) bool - Close func(T, error) + Mode ConnPoolMode + // MaxInflight caps concurrent in-progress dials. Only honored in ConnPoolOrdered mode. + MaxInflight int + IsAlive func(T) bool + Close func(T, error) } type ConnPool[T comparable] struct { options ConnPoolOptions[T] + sem *semaphore.Weighted + access sync.Mutex closed bool state *connPoolState[T] @@ -53,10 +59,14 @@ type connPoolConnect[T comparable] struct { } func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { - return &ConnPool[T]{ + p := &ConnPool[T]{ options: options, - state: newConnPoolState[T](options.Mode), } + if options.Mode == ConnPoolOrdered && options.MaxInflight > 0 { + p.sem = semaphore.NewWeighted(int64(options.MaxInflight)) + } + p.state = newConnPoolState[T](options.Mode) + return p } func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] { @@ -113,7 +123,7 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) { return } if p.options.Mode == ConnPoolOrdered { - if _, loaded := state.idleElements[conn]; !loaded { + if _, idle := state.idleElements[conn]; !idle { state.idleElements[conn] = state.idle.PushBack(conn) } } @@ -137,6 +147,31 @@ func (p *ConnPool[T]) Invalidate(conn T, cause error) { p.options.Close(conn, cause) } +func (p *ConnPool[T]) acquireSlot(ctx context.Context, state *connPoolState[T]) error { + if p.sem == nil { + return nil + } + acquireCtx, cancel := context.WithCancel(ctx) + stopStateCancel := context.AfterFunc(state.ctx, cancel) + err := p.sem.Acquire(acquireCtx, 1) + stopStateCancel() + cancel() + if err == nil { + return nil + } + ctxErr := ctx.Err() + if ctxErr != nil { + return ctxErr + } + return context.Cause(state.ctx) +} + +func (p *ConnPool[T]) releaseSlot() { + if p.sem != nil { + p.sem.Release(1) + } +} + // removeConn must be called with p.access held. func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) { delete(state.all, conn) @@ -199,56 +234,65 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont } current := p.state if element := current.idle.Front(); element != nil { - conn := current.idle.Remove(element) - delete(current.idleElements, conn) - if p.options.IsAlive(conn) { + idleConn := current.idle.Remove(element) + delete(current.idleElements, idleConn) + if p.options.IsAlive(idleConn) { p.access.Unlock() - return conn, false, nil + return idleConn, false, nil } - delete(current.all, conn) + delete(current.all, idleConn) p.access.Unlock() - p.options.Close(conn, net.ErrClosed) + p.options.Close(idleConn, net.ErrClosed) continue } p.access.Unlock() + return p.dialAndInstall(ctx, current, dial) + } +} - dialCtx, dialCancel := context.WithCancelCause(ctx) - stopStateCancel := context.AfterFunc(current.ctx, func() { - dialCancel(context.Cause(current.ctx)) - }) - conn, err := dial(dialCtx) - stateCancelStopped := stopStateCancel() - dialErr := context.Cause(dialCtx) - if dialErr == nil && !stateCancelStopped { - dialErr = context.Cause(current.ctx) - } - dialCancel(nil) - if err != nil { - if dialErr != nil { - return zero, false, dialErr - } - return zero, false, err - } +func (p *ConnPool[T]) dialAndInstall(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, bool, error) { + var zero T + err := p.acquireSlot(ctx, current) + if err != nil { + return zero, false, err + } + defer p.releaseSlot() + dialCtx, dialCancel := context.WithCancelCause(ctx) + stopStateCancel := context.AfterFunc(current.ctx, func() { + dialCancel(context.Cause(current.ctx)) + }) + conn, err := dial(dialCtx) + stateCancelStopped := stopStateCancel() + dialErr := context.Cause(dialCtx) + if dialErr == nil && !stateCancelStopped { + dialErr = context.Cause(current.ctx) + } + dialCancel(nil) + if err != nil { if dialErr != nil { - p.options.Close(conn, dialErr) return zero, false, dialErr } - - p.access.Lock() - if p.closed { - p.access.Unlock() - p.options.Close(conn, net.ErrClosed) - return zero, false, net.ErrClosed - } - if p.state != current { - p.access.Unlock() - p.options.Close(conn, net.ErrClosed) - return zero, false, net.ErrClosed - } - current.all[conn] = struct{}{} - p.access.Unlock() - return conn, true, nil + return zero, false, err } + if dialErr != nil { + p.options.Close(conn, dialErr) + return zero, false, dialErr + } + + p.access.Lock() + if p.closed { + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, false, net.ErrClosed + } + if p.state != current { + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, false, net.ErrClosed + } + current.all[conn] = struct{}{} + p.access.Unlock() + return conn, true, nil } func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { diff --git a/dns/transport/tls.go b/dns/transport/tls.go index b7ef25fb7..8ce415144 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -22,6 +22,8 @@ import ( var _ adapter.DNSTransport = (*TLSTransport)(nil) +const tlsDNSMaxInflight = 8 + func RegisterTLS(registry *dns.TransportRegistry) { dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS) } @@ -71,7 +73,8 @@ func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer serverAddr: serverAddr, tlsConfig: tlsConfig, connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{ - Mode: ConnPoolOrdered, + Mode: ConnPoolOrdered, + MaxInflight: tlsDNSMaxInflight, IsAlive: func(conn *tlsDNSConn) bool { return conn != nil }, diff --git a/go.mod b/go.mod index 5c5621574..923c086e8 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 golang.org/x/mod v0.33.0 golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 golang.org/x/sys v0.41.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 google.golang.org/grpc v1.79.1 @@ -159,7 +160,6 @@ require ( go.uber.org/zap/exp v0.3.0 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sync v0.19.0 // indirect golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/time v0.11.0 // indirect