Fix USB/IP import lease and attach races

This commit is contained in:
世界 2026-04-25 05:08:36 +08:00
parent 524d580cf4
commit 0c892411aa
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 405 additions and 35 deletions

View file

@ -4,6 +4,7 @@ package usbip
import (
"context"
"errors"
"fmt"
"sync"
"time"
@ -17,6 +18,8 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
type ClientService struct {
@ -215,25 +218,33 @@ func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, <
c.logger.Debug("usbip client handoff ", busid, ": ", handoff.mode())
c.portAssignAccess.Lock()
defer c.portAssignAccess.Unlock()
port, err := c.ops.vhciPickFreePort(info.Speed)
if err != nil {
return -1, nil, err
triedPorts := make(map[int]struct{})
for {
port, err := c.ops.vhciPickFreePort(info.Speed, triedPorts)
if err != nil {
return -1, nil, err
}
if !c.reservePort(port) {
triedPorts[port] = struct{}{}
continue
}
err = c.ops.vhciAttach(port, handoff.kernelFD(), info.DevID(), info.Speed)
if err != nil {
c.trackPort(port, false)
if errors.Is(err, unix.EBUSY) {
triedPorts[port] = struct{}{}
continue
}
return -1, nil, E.Cause(err, "vhci attach")
}
err = handoff.closeKernelFD()
if err != nil {
c.logger.Debug("close kernel fd ", busid, ": ", err)
}
done := handoff.startRelay(ctx, c.logger, "client", busid)
relayStarted = true
return port, done, nil
}
if !c.reservePort(port) {
return -1, nil, E.New("vhci port ", port, " already reserved")
}
err = c.ops.vhciAttach(port, handoff.kernelFD(), info.DevID(), info.Speed)
if err != nil {
c.trackPort(port, false)
return -1, nil, E.Cause(err, "vhci attach")
}
err = handoff.closeKernelFD()
if err != nil {
c.logger.Debug("close kernel fd ", busid, ": ", err)
}
done := handoff.startRelay(ctx, c.logger, "client", busid)
relayStarted = true
return port, done, nil
}
func (c *ClientService) waitPortSession(ctx context.Context, port int, busid string, done <-chan struct{}) {

View file

@ -82,7 +82,7 @@ func (s *ServerService) consumeImportLease(request ImportExtRequest) bool {
return false
}
delete(s.leasesByBusID, request.BusID)
return now.Before(lease.Expires)
return now.Before(lease.Expires) && lease.Generation == s.controlSeq
}
func (s *ServerService) cleanupExpiredImportLeasesLocked(now time.Time) {

View file

@ -0,0 +1,88 @@
//go:build linux || (darwin && cgo)
package usbip
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestConsumeImportLeaseAcceptsCurrentGeneration(t *testing.T) {
t.Parallel()
const busid = "1-1"
server := &ServerService{
controlSeq: 7,
leasesByBusID: map[string]serverImportLease{
busid: {
ID: 55,
BusID: busid,
ClientNonce: 99,
Generation: 7,
Expires: time.Now().Add(importLeaseTTL),
},
},
}
ok := server.consumeImportLease(ImportExtRequest{
BusID: busid,
LeaseID: 55,
ClientNonce: 99,
})
require.True(t, ok)
require.Empty(t, server.leasesByBusID)
}
func TestConsumeImportLeaseRejectsStaleGeneration(t *testing.T) {
t.Parallel()
const busid = "1-1"
server := &ServerService{
controlSeq: 8,
leasesByBusID: map[string]serverImportLease{
busid: {
ID: 55,
BusID: busid,
ClientNonce: 99,
Generation: 7,
Expires: time.Now().Add(importLeaseTTL),
},
},
}
ok := server.consumeImportLease(ImportExtRequest{
BusID: busid,
LeaseID: 55,
ClientNonce: 99,
})
require.False(t, ok)
require.Empty(t, server.leasesByBusID)
}
func TestConsumeImportLeaseRejectsExpiredLease(t *testing.T) {
t.Parallel()
const busid = "1-1"
server := &ServerService{
controlSeq: 7,
leasesByBusID: map[string]serverImportLease{
busid: {
ID: 55,
BusID: busid,
ClientNonce: 99,
Generation: 7,
Expires: time.Now().Add(-time.Second),
},
},
}
ok := server.consumeImportLease(ImportExtRequest{
BusID: busid,
LeaseID: 55,
ClientNonce: 99,
})
require.False(t, ok)
require.Empty(t, server.leasesByBusID)
}

View file

@ -278,7 +278,7 @@ func newTestUSBIPOps(t *testing.T) usbipOps {
t.Fatalf("unexpected newUEventListener")
return nil, nil
},
vhciPickFreePort: func(uint32) (int, error) {
vhciPickFreePort: func(uint32, map[int]struct{}) (int, error) {
t.Fatalf("unexpected vhciPickFreePort")
return 0, nil
},
@ -1946,7 +1946,8 @@ func TestClientAttemptAttachUsesImportReplyAndVHCIAttach(t *testing.T) {
var attachedPort int
var attachedDevID uint32
var attachedSpeed uint32
clientOps.vhciPickFreePort = func(speed uint32) (int, error) {
clientOps.vhciPickFreePort = func(speed uint32, skip map[int]struct{}) (int, error) {
require.Empty(t, skip)
require.Equal(t, SpeedSuper, speed)
return 7, nil
}
@ -2075,7 +2076,8 @@ func TestClientAttemptAttachUsesImportExtLease(t *testing.T) {
}()
ops := newTestUSBIPOps(t)
ops.vhciPickFreePort = func(speed uint32) (int, error) {
ops.vhciPickFreePort = func(speed uint32, skip map[int]struct{}) (int, error) {
require.Empty(t, skip)
require.Equal(t, SpeedHigh, speed)
return 4, nil
}
@ -2164,7 +2166,8 @@ func TestClientAttemptAttachWithOpaqueConnRelay(t *testing.T) {
kernelConnCh := make(chan net.Conn, 1)
ops := newTestUSBIPOps(t)
ops.vhciPickFreePort = func(speed uint32) (int, error) {
ops.vhciPickFreePort = func(speed uint32, skip map[int]struct{}) (int, error) {
require.Empty(t, skip)
require.Equal(t, SpeedHigh, speed)
return 4, nil
}
@ -2226,6 +2229,123 @@ func TestClientAttemptAttachWithOpaqueConnRelay(t *testing.T) {
}
}
func TestClientAttemptAttachRetriesNextPortOnEBUSY(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedHigh)
serverErrCh := make(chan error, 1)
go func() {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
serverErrCh <- acceptErr
return
}
defer conn.Close()
header, readErr := ReadOpHeader(conn)
if readErr != nil {
serverErrCh <- readErr
return
}
if header.Code != OpReqImport {
serverErrCh <- fmt.Errorf("unexpected request code 0x%04x", header.Code)
return
}
busid, readErr := ReadOpReqImportBody(conn)
if readErr != nil {
serverErrCh <- readErr
return
}
if busid != "1-1" {
serverErrCh <- fmt.Errorf("unexpected busid %s", busid)
return
}
info := device.toProtocol()
if writeErr := WriteOpRepImport(conn, OpStatusOK, &info); writeErr != nil {
serverErrCh <- writeErr
return
}
buffer := make([]byte, 1)
n, readErr := conn.Read(buffer)
if n != 0 {
serverErrCh <- fmt.Errorf("unexpected server read bytes after relay close: %d", n)
return
}
if !errors.Is(readErr, io.EOF) {
serverErrCh <- readErr
return
}
serverErrCh <- nil
}()
ops := newTestUSBIPOps(t)
var pickCalls int
var attachedPorts []int
ops.vhciPickFreePort = func(speed uint32, skip map[int]struct{}) (int, error) {
require.Equal(t, SpeedHigh, speed)
pickCalls++
if _, skipped := skip[4]; skipped {
return 5, nil
}
return 4, nil
}
ops.vhciAttach = func(port int, fd uintptr, devid uint32, speed uint32) error {
requireStreamSocketFD(t, fd)
info := device.toProtocol()
require.Equal(t, info.DevID(), devid)
require.Equal(t, SpeedHigh, speed)
attachedPorts = append(attachedPorts, port)
if port == 4 {
return unix.EBUSY
}
if port == 5 {
return nil
}
return fmt.Errorf("unexpected vhci port %d", port)
}
client := &ClientService{
ctx: ctx,
cancel: cancel,
logger: newTestLogger(t),
dialer: wrappingDialer{},
serverAddr: M.SocksaddrFromNet(listener.Addr()),
ops: ops,
}
port, done, err := client.attemptAttach(ctx, "1-1")
require.NoError(t, err)
require.NotNil(t, done)
require.Equal(t, 5, port)
require.Equal(t, []int{4, 5}, attachedPorts)
require.Equal(t, 2, pickCalls)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timed out waiting for client relay handoff")
}
select {
case err = <-serverErrCh:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for server side close")
}
client.portsAccess.Lock()
_, firstReserved := client.ports[4]
_, secondReserved := client.ports[5]
client.portsAccess.Unlock()
require.False(t, firstReserved)
require.True(t, secondReserved)
}
func TestClientAttemptAttachRelayClosesHandoffOnVHCIAttachFailure(t *testing.T) {
t.Parallel()
@ -2284,7 +2404,8 @@ func TestClientAttemptAttachRelayClosesHandoffOnVHCIAttachFailure(t *testing.T)
expectedErr := errors.New("vhci attach failed")
kernelConnCh := make(chan net.Conn, 1)
ops := newTestUSBIPOps(t)
ops.vhciPickFreePort = func(speed uint32) (int, error) {
ops.vhciPickFreePort = func(speed uint32, skip map[int]struct{}) (int, error) {
require.Empty(t, skip)
require.Equal(t, SpeedHigh, speed)
return 4, nil
}
@ -2574,7 +2695,7 @@ func TestClientAttemptAttachRejectsUnexpectedReplyVersion(t *testing.T) {
}()
ops := newTestUSBIPOps(t)
ops.vhciPickFreePort = func(uint32) (int, error) {
ops.vhciPickFreePort = func(uint32, map[int]struct{}) (int, error) {
return -1, errors.New("unexpected vhci attach path")
}

View file

@ -24,7 +24,7 @@ type usbipOps struct {
writeUsbipSockfd func(busid string, fd int) error
newUEventListener func() (usbEventListener, error)
vhciPickFreePort func(speed uint32) (int, error)
vhciPickFreePort func(speed uint32, skip map[int]struct{}) (int, error)
vhciAttach func(port int, fd uintptr, devid uint32, speed uint32) error
vhciDetach func(port int) error
}

View file

@ -31,6 +31,7 @@ type serverExport struct {
device *darwinUSBHostDevice
entry DeviceEntry
busy bool
stale bool
}
type darwinUSBHostDeviceWatch interface {
@ -203,6 +204,7 @@ func (s *ServerService) reconcileExports() (bool, error) {
}
if export, ok := current[busid]; ok {
if export.busy {
changed = s.markExportStale(busid) || changed
continue
}
s.deleteExport(busid)
@ -229,6 +231,7 @@ func (s *ServerService) reconcileExports() (bool, error) {
continue
}
if export.busy {
changed = s.markExportStale(busid) || changed
continue
}
s.deleteExport(busid)
@ -270,7 +273,7 @@ func (s *ServerService) currentExports() []serverExport {
defer s.access.Unlock()
out := make([]serverExport, 0, len(s.exports))
for _, export := range s.exports {
if export.busy {
if export.busy || export.stale {
continue
}
out = append(out, export)
@ -320,7 +323,7 @@ func (s *ServerService) claimExport(busid string) (serverExport, bool) {
s.access.Lock()
defer s.access.Unlock()
export, ok := s.exports[busid]
if !ok || export.busy {
if !ok || export.busy || export.stale {
return serverExport{}, false
}
export.busy = true
@ -328,14 +331,35 @@ func (s *ServerService) claimExport(busid string) (serverExport, bool) {
return export, true
}
func (s *ServerService) releaseClaim(busid string) {
func (s *ServerService) markExportStale(busid string) bool {
s.access.Lock()
defer s.access.Unlock()
export, ok := s.exports[busid]
if ok {
export.busy = false
s.exports[busid] = export
if !ok || export.stale {
return false
}
export.stale = true
s.exports[busid] = export
return true
}
func (s *ServerService) releaseClaim(claimed serverExport) bool {
s.access.Lock()
export, ok := s.exports[claimed.busid]
if !ok || export.registryID != claimed.registryID {
s.access.Unlock()
return false
}
if export.stale {
delete(s.exports, claimed.busid)
s.access.Unlock()
export.device.Close()
return true
}
export.busy = false
s.exports[claimed.busid] = export
s.access.Unlock()
return false
}
func (s *ServerService) handleStandardConn(conn net.Conn, header OpHeader) {
@ -440,7 +464,12 @@ func (s *ServerService) handleImportBusID(conn net.Conn, busid string, extended
releaseClaim := true
defer func() {
if releaseClaim {
s.releaseClaim(busid)
if s.releaseClaim(export) {
if err := s.reconcileAndBroadcast(true); err != nil {
s.logger.Warn("reconcile exports after stale release: ", err)
}
return
}
s.broadcastChanged()
}
}()
@ -456,8 +485,14 @@ func (s *ServerService) handleImportBusID(conn net.Conn, busid string, extended
if err != nil && s.ctx.Err() == nil {
s.logger.Debug("data session ", busid, ": ", err)
}
s.releaseClaim(busid)
stale := s.releaseClaim(export)
releaseClaim = false
if stale {
if err := s.reconcileAndBroadcast(true); err != nil {
s.logger.Warn("reconcile exports after stale release: ", err)
}
return
}
s.broadcastChanged()
}
@ -472,6 +507,9 @@ func (s *ServerService) buildDeviceStateV2() []DeviceInfoV2 {
}
devices := make([]DeviceInfoV2, 0, len(exports))
for _, export := range exports {
if export.stale {
continue
}
state := deviceStateAvailable
reason := deviceStateAvailable
if export.busy {
@ -490,6 +528,9 @@ func (s *ServerService) leaseAvailable(busid string) (bool, string) {
if !ok {
return false, "unknown busid"
}
if export.stale {
return false, deviceStateUnavailable
}
if export.busy {
return false, deviceStateBusy
}

View file

@ -228,6 +228,98 @@ func TestDarwinServerBuildDeviceStateIncludesBusyExports(t *testing.T) {
require.Equal(t, deviceStateBusy, devices["busy"].State)
}
func TestDarwinServerReconcileMarksBusyMissingExportStale(t *testing.T) {
t.Parallel()
const busid = "mac-00000001"
entry := standardTestDeviceEntry(busid)
export := serverExport{
busid: busid,
registryID: 1,
device: &darwinUSBHostDevice{},
entry: entry,
busy: true,
}
server := &ServerService{
ctx: context.Background(),
logger: newTestLogger(t),
matches: []option.USBIPDeviceMatch{{BusID: busid}},
exports: map[string]serverExport{busid: export},
controlSubs: make(map[uint64]*serverControlConn),
controlState: deviceInfoV2Map([]DeviceInfoV2{
deviceInfoV2FromEntry(entry, backendIDDarwinIOKit, darwinStableID(export.registryID), deviceStateBusy, 0, deviceStateBusy),
}),
ops: darwinServerOps{
copyUSBHostDevices: func() ([]darwinUSBHostDeviceInfo, error) {
return nil, nil
},
},
}
require.NoError(t, server.reconcileAndBroadcast(true))
snapshot := server.snapshotExports()
require.True(t, snapshot[busid].stale)
require.Equal(t, "", darwinServerControlState(server, busid))
require.True(t, server.releaseClaim(export))
require.NotContains(t, server.snapshotExports(), busid)
}
func TestDarwinServerReconcileCapturesReplacementAfterStaleRelease(t *testing.T) {
t.Parallel()
const busid = "mac-00000001"
oldEntry := standardTestDeviceEntry(busid)
oldExport := serverExport{
busid: busid,
registryID: 1,
device: &darwinUSBHostDevice{},
entry: oldEntry,
busy: true,
}
replacementEntry := standardTestDeviceEntry(busid)
replacementEntry.Info.DevNum = 2
replacementInfo := darwinTestDeviceInfo(2, replacementEntry)
opened := 0
server := &ServerService{
ctx: context.Background(),
logger: newTestLogger(t),
matches: []option.USBIPDeviceMatch{{BusID: busid}},
exports: map[string]serverExport{busid: oldExport},
controlSubs: make(map[uint64]*serverControlConn),
controlState: deviceInfoV2Map([]DeviceInfoV2{
deviceInfoV2FromEntry(oldEntry, backendIDDarwinIOKit, darwinStableID(oldExport.registryID), deviceStateBusy, 0, deviceStateBusy),
}),
ops: darwinServerOps{
copyUSBHostDevices: func() ([]darwinUSBHostDeviceInfo, error) {
return []darwinUSBHostDeviceInfo{replacementInfo}, nil
},
openUSBHostDevice: func(registryID uint64, capture bool) (*darwinUSBHostDevice, error) {
require.Equal(t, replacementInfo.registryID, registryID)
require.True(t, capture)
opened++
return &darwinUSBHostDevice{info: replacementInfo}, nil
},
},
}
require.NoError(t, server.reconcileAndBroadcast(true))
snapshot := server.snapshotExports()
require.True(t, snapshot[busid].stale)
require.Zero(t, opened)
require.Equal(t, "", darwinServerControlState(server, busid))
require.True(t, server.releaseClaim(oldExport))
require.NoError(t, server.reconcileAndBroadcast(true))
snapshot = server.snapshotExports()
require.Equal(t, uint64(2), snapshot[busid].registryID)
require.False(t, snapshot[busid].busy)
require.False(t, snapshot[busid].stale)
require.Equal(t, 1, opened)
require.Equal(t, deviceStateAvailable, darwinServerControlState(server, busid))
}
func TestDarwinServerRegisterControlConnQueuesSnapshotBeforeBroadcast(t *testing.T) {
t.Parallel()
@ -310,6 +402,20 @@ func darwinServerControlState(server *ServerService, busid string) string {
return server.controlState[busid].State
}
func darwinTestDeviceInfo(registryID uint64, entry DeviceEntry) darwinUSBHostDeviceInfo {
busid := entry.Info.BusIDString()
return darwinUSBHostDeviceInfo{
registryID: registryID,
entry: entry,
key: DeviceKey{
BusID: busid,
VendorID: entry.Info.IDVendor,
ProductID: entry.Info.IDProduct,
Serial: entrySerial(entry),
},
}
}
type fakeDarwinUSBHostDeviceWatch struct {
callback func()
closed bool

View file

@ -231,7 +231,7 @@ func writeUsbipSockfd(busid string, fd int) error {
return writeSysfs(filepath.Join(sysBusUSBDevices, busid, "usbip_sockfd"), strconv.Itoa(fd))
}
func vhciPickFreePort(speed uint32) (int, error) {
func vhciPickFreePort(speed uint32, skip map[int]struct{}) (int, error) {
records, err := readVHCIStatus()
if err != nil {
return -1, err
@ -241,6 +241,9 @@ func vhciPickFreePort(speed uint32) (int, error) {
if record.hub != targetHub || record.state != 4 {
continue
}
if _, skipped := skip[record.port]; skipped {
continue
}
return record.port, nil
}
return -1, E.New("no free ", targetHub, " vhci port")