mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-05-13 05:51:36 +00:00
dns: Refactor reordered pool
This commit is contained in:
parent
f0592034a6
commit
8875b52d51
3 changed files with 94 additions and 47 deletions
|
|
@ -6,6 +6,8 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type ConnPoolMode int
|
||||
|
|
@ -16,14 +18,18 @@ const (
|
|||
)
|
||||
|
||||
type ConnPoolOptions[T comparable] struct {
|
||||
Mode ConnPoolMode
|
||||
IsAlive func(T) bool
|
||||
Close func(T, error)
|
||||
Mode ConnPoolMode
|
||||
// MaxInflight caps concurrent in-progress dials. Only honored in ConnPoolOrdered mode.
|
||||
MaxInflight int
|
||||
IsAlive func(T) bool
|
||||
Close func(T, error)
|
||||
}
|
||||
|
||||
type ConnPool[T comparable] struct {
|
||||
options ConnPoolOptions[T]
|
||||
|
||||
sem *semaphore.Weighted
|
||||
|
||||
access sync.Mutex
|
||||
closed bool
|
||||
state *connPoolState[T]
|
||||
|
|
@ -53,10 +59,14 @@ type connPoolConnect[T comparable] struct {
|
|||
}
|
||||
|
||||
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
|
||||
return &ConnPool[T]{
|
||||
p := &ConnPool[T]{
|
||||
options: options,
|
||||
state: newConnPoolState[T](options.Mode),
|
||||
}
|
||||
if options.Mode == ConnPoolOrdered && options.MaxInflight > 0 {
|
||||
p.sem = semaphore.NewWeighted(int64(options.MaxInflight))
|
||||
}
|
||||
p.state = newConnPoolState[T](options.Mode)
|
||||
return p
|
||||
}
|
||||
|
||||
func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] {
|
||||
|
|
@ -113,7 +123,7 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) {
|
|||
return
|
||||
}
|
||||
if p.options.Mode == ConnPoolOrdered {
|
||||
if _, loaded := state.idleElements[conn]; !loaded {
|
||||
if _, idle := state.idleElements[conn]; !idle {
|
||||
state.idleElements[conn] = state.idle.PushBack(conn)
|
||||
}
|
||||
}
|
||||
|
|
@ -137,6 +147,31 @@ func (p *ConnPool[T]) Invalidate(conn T, cause error) {
|
|||
p.options.Close(conn, cause)
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) acquireSlot(ctx context.Context, state *connPoolState[T]) error {
|
||||
if p.sem == nil {
|
||||
return nil
|
||||
}
|
||||
acquireCtx, cancel := context.WithCancel(ctx)
|
||||
stopStateCancel := context.AfterFunc(state.ctx, cancel)
|
||||
err := p.sem.Acquire(acquireCtx, 1)
|
||||
stopStateCancel()
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
ctxErr := ctx.Err()
|
||||
if ctxErr != nil {
|
||||
return ctxErr
|
||||
}
|
||||
return context.Cause(state.ctx)
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) releaseSlot() {
|
||||
if p.sem != nil {
|
||||
p.sem.Release(1)
|
||||
}
|
||||
}
|
||||
|
||||
// removeConn must be called with p.access held.
|
||||
func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) {
|
||||
delete(state.all, conn)
|
||||
|
|
@ -199,56 +234,65 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
|
|||
}
|
||||
current := p.state
|
||||
if element := current.idle.Front(); element != nil {
|
||||
conn := current.idle.Remove(element)
|
||||
delete(current.idleElements, conn)
|
||||
if p.options.IsAlive(conn) {
|
||||
idleConn := current.idle.Remove(element)
|
||||
delete(current.idleElements, idleConn)
|
||||
if p.options.IsAlive(idleConn) {
|
||||
p.access.Unlock()
|
||||
return conn, false, nil
|
||||
return idleConn, false, nil
|
||||
}
|
||||
delete(current.all, conn)
|
||||
delete(current.all, idleConn)
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
p.options.Close(idleConn, net.ErrClosed)
|
||||
continue
|
||||
}
|
||||
p.access.Unlock()
|
||||
return p.dialAndInstall(ctx, current, dial)
|
||||
}
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithCancelCause(ctx)
|
||||
stopStateCancel := context.AfterFunc(current.ctx, func() {
|
||||
dialCancel(context.Cause(current.ctx))
|
||||
})
|
||||
conn, err := dial(dialCtx)
|
||||
stateCancelStopped := stopStateCancel()
|
||||
dialErr := context.Cause(dialCtx)
|
||||
if dialErr == nil && !stateCancelStopped {
|
||||
dialErr = context.Cause(current.ctx)
|
||||
}
|
||||
dialCancel(nil)
|
||||
if err != nil {
|
||||
if dialErr != nil {
|
||||
return zero, false, dialErr
|
||||
}
|
||||
return zero, false, err
|
||||
}
|
||||
func (p *ConnPool[T]) dialAndInstall(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, bool, error) {
|
||||
var zero T
|
||||
err := p.acquireSlot(ctx, current)
|
||||
if err != nil {
|
||||
return zero, false, err
|
||||
}
|
||||
defer p.releaseSlot()
|
||||
dialCtx, dialCancel := context.WithCancelCause(ctx)
|
||||
stopStateCancel := context.AfterFunc(current.ctx, func() {
|
||||
dialCancel(context.Cause(current.ctx))
|
||||
})
|
||||
conn, err := dial(dialCtx)
|
||||
stateCancelStopped := stopStateCancel()
|
||||
dialErr := context.Cause(dialCtx)
|
||||
if dialErr == nil && !stateCancelStopped {
|
||||
dialErr = context.Cause(current.ctx)
|
||||
}
|
||||
dialCancel(nil)
|
||||
if err != nil {
|
||||
if dialErr != nil {
|
||||
p.options.Close(conn, dialErr)
|
||||
return zero, false, dialErr
|
||||
}
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
if p.state != current {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
current.all[conn] = struct{}{}
|
||||
p.access.Unlock()
|
||||
return conn, true, nil
|
||||
return zero, false, err
|
||||
}
|
||||
if dialErr != nil {
|
||||
p.options.Close(conn, dialErr)
|
||||
return zero, false, dialErr
|
||||
}
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
if p.state != current {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
current.all[conn] = struct{}{}
|
||||
p.access.Unlock()
|
||||
return conn, true, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ import (
|
|||
|
||||
var _ adapter.DNSTransport = (*TLSTransport)(nil)
|
||||
|
||||
const tlsDNSMaxInflight = 8
|
||||
|
||||
func RegisterTLS(registry *dns.TransportRegistry) {
|
||||
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
|
||||
}
|
||||
|
|
@ -71,7 +73,8 @@ func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer
|
|||
serverAddr: serverAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{
|
||||
Mode: ConnPoolOrdered,
|
||||
Mode: ConnPoolOrdered,
|
||||
MaxInflight: tlsDNSMaxInflight,
|
||||
IsAlive: func(conn *tlsDNSConn) bool {
|
||||
return conn != nil
|
||||
},
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -59,6 +59,7 @@ require (
|
|||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/mod v0.33.0
|
||||
golang.org/x/net v0.50.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.41.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
google.golang.org/grpc v1.79.1
|
||||
|
|
@ -159,7 +160,6 @@ require (
|
|||
go.uber.org/zap/exp v0.3.0 // indirect
|
||||
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue