dns: Refactor reordered pool

This commit is contained in:
世界 2026-05-12 00:10:40 +08:00
parent f0592034a6
commit 8875b52d51
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 94 additions and 47 deletions

View file

@ -6,6 +6,8 @@ import (
"sync" "sync"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"golang.org/x/sync/semaphore"
) )
type ConnPoolMode int type ConnPoolMode int
@ -16,14 +18,18 @@ const (
) )
type ConnPoolOptions[T comparable] struct { type ConnPoolOptions[T comparable] struct {
Mode ConnPoolMode Mode ConnPoolMode
IsAlive func(T) bool // MaxInflight caps concurrent in-progress dials. Only honored in ConnPoolOrdered mode.
Close func(T, error) MaxInflight int
IsAlive func(T) bool
Close func(T, error)
} }
type ConnPool[T comparable] struct { type ConnPool[T comparable] struct {
options ConnPoolOptions[T] options ConnPoolOptions[T]
sem *semaphore.Weighted
access sync.Mutex access sync.Mutex
closed bool closed bool
state *connPoolState[T] state *connPoolState[T]
@ -53,10 +59,14 @@ type connPoolConnect[T comparable] struct {
} }
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
return &ConnPool[T]{ p := &ConnPool[T]{
options: options, options: options,
state: newConnPoolState[T](options.Mode),
} }
if options.Mode == ConnPoolOrdered && options.MaxInflight > 0 {
p.sem = semaphore.NewWeighted(int64(options.MaxInflight))
}
p.state = newConnPoolState[T](options.Mode)
return p
} }
func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] { func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] {
@ -113,7 +123,7 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) {
return return
} }
if p.options.Mode == ConnPoolOrdered { if p.options.Mode == ConnPoolOrdered {
if _, loaded := state.idleElements[conn]; !loaded { if _, idle := state.idleElements[conn]; !idle {
state.idleElements[conn] = state.idle.PushBack(conn) state.idleElements[conn] = state.idle.PushBack(conn)
} }
} }
@ -137,6 +147,31 @@ func (p *ConnPool[T]) Invalidate(conn T, cause error) {
p.options.Close(conn, cause) p.options.Close(conn, cause)
} }
func (p *ConnPool[T]) acquireSlot(ctx context.Context, state *connPoolState[T]) error {
if p.sem == nil {
return nil
}
acquireCtx, cancel := context.WithCancel(ctx)
stopStateCancel := context.AfterFunc(state.ctx, cancel)
err := p.sem.Acquire(acquireCtx, 1)
stopStateCancel()
cancel()
if err == nil {
return nil
}
ctxErr := ctx.Err()
if ctxErr != nil {
return ctxErr
}
return context.Cause(state.ctx)
}
func (p *ConnPool[T]) releaseSlot() {
if p.sem != nil {
p.sem.Release(1)
}
}
// removeConn must be called with p.access held. // removeConn must be called with p.access held.
func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) { func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) {
delete(state.all, conn) delete(state.all, conn)
@ -199,56 +234,65 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
} }
current := p.state current := p.state
if element := current.idle.Front(); element != nil { if element := current.idle.Front(); element != nil {
conn := current.idle.Remove(element) idleConn := current.idle.Remove(element)
delete(current.idleElements, conn) delete(current.idleElements, idleConn)
if p.options.IsAlive(conn) { if p.options.IsAlive(idleConn) {
p.access.Unlock() p.access.Unlock()
return conn, false, nil return idleConn, false, nil
} }
delete(current.all, conn) delete(current.all, idleConn)
p.access.Unlock() p.access.Unlock()
p.options.Close(conn, net.ErrClosed) p.options.Close(idleConn, net.ErrClosed)
continue continue
} }
p.access.Unlock() p.access.Unlock()
return p.dialAndInstall(ctx, current, dial)
}
}
dialCtx, dialCancel := context.WithCancelCause(ctx) func (p *ConnPool[T]) dialAndInstall(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, bool, error) {
stopStateCancel := context.AfterFunc(current.ctx, func() { var zero T
dialCancel(context.Cause(current.ctx)) err := p.acquireSlot(ctx, current)
}) if err != nil {
conn, err := dial(dialCtx) return zero, false, err
stateCancelStopped := stopStateCancel() }
dialErr := context.Cause(dialCtx) defer p.releaseSlot()
if dialErr == nil && !stateCancelStopped { dialCtx, dialCancel := context.WithCancelCause(ctx)
dialErr = context.Cause(current.ctx) stopStateCancel := context.AfterFunc(current.ctx, func() {
} dialCancel(context.Cause(current.ctx))
dialCancel(nil) })
if err != nil { conn, err := dial(dialCtx)
if dialErr != nil { stateCancelStopped := stopStateCancel()
return zero, false, dialErr dialErr := context.Cause(dialCtx)
} if dialErr == nil && !stateCancelStopped {
return zero, false, err dialErr = context.Cause(current.ctx)
} }
dialCancel(nil)
if err != nil {
if dialErr != nil { if dialErr != nil {
p.options.Close(conn, dialErr)
return zero, false, dialErr return zero, false, dialErr
} }
return zero, false, err
p.access.Lock()
if p.closed {
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed
}
if p.state != current {
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed
}
current.all[conn] = struct{}{}
p.access.Unlock()
return conn, true, nil
} }
if dialErr != nil {
p.options.Close(conn, dialErr)
return zero, false, dialErr
}
p.access.Lock()
if p.closed {
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed
}
if p.state != current {
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed
}
current.all[conn] = struct{}{}
p.access.Unlock()
return conn, true, nil
} }
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {

View file

@ -22,6 +22,8 @@ import (
var _ adapter.DNSTransport = (*TLSTransport)(nil) var _ adapter.DNSTransport = (*TLSTransport)(nil)
const tlsDNSMaxInflight = 8
func RegisterTLS(registry *dns.TransportRegistry) { func RegisterTLS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS) dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
} }
@ -71,7 +73,8 @@ func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer
serverAddr: serverAddr, serverAddr: serverAddr,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{ connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{
Mode: ConnPoolOrdered, Mode: ConnPoolOrdered,
MaxInflight: tlsDNSMaxInflight,
IsAlive: func(conn *tlsDNSConn) bool { IsAlive: func(conn *tlsDNSConn) bool {
return conn != nil return conn != nil
}, },

2
go.mod
View file

@ -59,6 +59,7 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
golang.org/x/mod v0.33.0 golang.org/x/mod v0.33.0
golang.org/x/net v0.50.0 golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.41.0 golang.org/x/sys v0.41.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
google.golang.org/grpc v1.79.1 google.golang.org/grpc v1.79.1
@ -159,7 +160,6 @@ require (
go.uber.org/zap/exp v0.3.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/term v0.40.0 // indirect golang.org/x/term v0.40.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.11.0 // indirect golang.org/x/time v0.11.0 // indirect