Skip kickWriteHandshake for server first protocols

This commit is contained in:
世界 2026-05-08 12:35:14 +08:00
parent e0c137e232
commit 18f10561c9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/sniff"
"github.com/sagernet/sing-box/common/tlsfragment"
"github.com/sagernet/sing-box/common/tlsspoof"
C "github.com/sagernet/sing-box/constant"
@ -140,11 +141,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
}
remoteConn = spoofConn
}
serverFirst := sniff.Skip(&metadata)
var done atomic.Bool
if m.kickWriteHandshake(ctx, conn, remoteConn, false, &done, onClose) {
if m.kickWriteHandshake(ctx, conn, remoteConn, serverFirst, false, &done, onClose) {
return
}
if m.kickWriteHandshake(ctx, remoteConn, conn, true, &done, onClose) {
if m.kickWriteHandshake(ctx, remoteConn, conn, serverFirst, true, &done, onClose) {
return
}
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
@ -305,37 +307,43 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn,
}
}
func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool {
func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, serverFirst bool, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool {
if !N.NeedHandshakeForWrite(destination) {
return false
}
var (
cachedBuffer *buf.Buffer
err error
wrotePayload bool
)
sourceReader, readCounters := N.UnwrapCountReader(source, nil)
destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil)
if cachedReader, ok := sourceReader.(N.CachedReader); ok {
cachedBuffer = cachedReader.ReadCached()
}
var err error
if cachedBuffer != nil {
wrotePayload = true
dataLen := cachedBuffer.Len()
_, err = destinationWriter.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err == nil {
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
} else {
if serverFirst {
_ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout))
_, err = destinationWriter.Write(nil)
_, err = destination.Write(nil)
_ = destination.SetWriteDeadline(time.Time{})
} else {
var cachedBuffer *buf.Buffer
sourceReader, readCounters := N.UnwrapCountReader(source, nil)
destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil)
if cachedReader, ok := sourceReader.(N.CachedReader); ok {
cachedBuffer = cachedReader.ReadCached()
}
if cachedBuffer != nil {
wrotePayload = true
dataLen := cachedBuffer.Len()
_, err = destinationWriter.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err == nil {
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
} else {
_ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout))
_, err = destinationWriter.Write(nil)
_ = destination.SetWriteDeadline(time.Time{})
}
}
if err == nil {
return false