Add usbip tests and auto-load drivers

This commit is contained in:
世界 2026-04-22 01:15:58 +08:00
parent f3bbd3c07b
commit 85c9d6b4f3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 1173 additions and 97 deletions

77
option/usbip_test.go Normal file
View file

@ -0,0 +1,77 @@
package option
import (
"context"
"testing"
"github.com/sagernet/sing/common/json"
"github.com/stretchr/testify/require"
)
func TestUSBIPHexUint16UnmarshalJSON(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected USBIPHexUint16
errorSubstr string
}{
{name: "number", input: `7531`, expected: USBIPHexUint16(0x1d6b)},
{name: "hex-with-prefix", input: `"0x1d6b"`, expected: USBIPHexUint16(0x1d6b)},
{name: "hex-with-uppercase-prefix", input: `"0X1D6B"`, expected: USBIPHexUint16(0x1d6b)},
{name: "hex-without-prefix", input: `"1d6b"`, expected: USBIPHexUint16(0x1d6b)},
{name: "empty-string", input: `""`, expected: 0},
{name: "out-of-range", input: `65536`, errorSubstr: "out of uint16 range"},
{name: "invalid-hex", input: `"zzzz"`, errorSubstr: "parse usb id zzzz"},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
var value USBIPHexUint16
err := json.UnmarshalContext(context.Background(), []byte(testCase.input), &value)
if testCase.errorSubstr != "" {
require.ErrorContains(t, err, testCase.errorSubstr)
return
}
require.NoError(t, err)
require.Equal(t, testCase.expected, value)
})
}
}
func TestUSBIPHexUint16MarshalJSON(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input USBIPHexUint16
expected string
}{
{name: "zero", input: 0, expected: `""`},
{name: "non-zero", input: USBIPHexUint16(0x1d6b), expected: `"0x1d6b"`},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
content, err := json.Marshal(testCase.input)
require.NoError(t, err)
require.Equal(t, testCase.expected, string(content))
})
}
}
func TestUSBIPDeviceMatchIsZero(t *testing.T) {
t.Parallel()
require.True(t, (USBIPDeviceMatch{}).IsZero())
require.False(t, (USBIPDeviceMatch{BusID: "1-1"}).IsZero())
require.False(t, (USBIPDeviceMatch{VendorID: 0x1d6b}).IsZero())
require.False(t, (USBIPDeviceMatch{ProductID: 0x0002}).IsZero())
require.False(t, (USBIPDeviceMatch{Serial: "abc"}).IsZero())
}

View file

@ -60,6 +60,7 @@ type ClientService struct {
dialer N.Dialer
serverAddr M.Socksaddr
matches []option.USBIPDeviceMatch // empty = import all remote exports
ops usbipOps
stateMu sync.Mutex
targets []clientTarget
@ -99,6 +100,7 @@ func NewClientService(ctx context.Context, logger log.ContextLogger, tag string,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),
matches: options.Devices,
ops: systemUSBIPOps,
allWorkers: make(map[string]*clientBusIDWorker),
ports: make(map[int]struct{}),
}, nil
@ -108,7 +110,7 @@ func (c *ClientService) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if err := ensureVHCI(); err != nil {
if err := c.ops.ensureVHCI(); err != nil {
return err
}
c.initializeWorkers()
@ -319,66 +321,12 @@ func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
}
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
keysByBusID := make(map[string]DeviceKey, len(entries))
for i := range entries {
busid := entries[i].Info.BusIDString()
if busid == "" {
continue
}
keysByBusID[busid] = DeviceKey{
BusID: busid,
VendorID: entries[i].Info.IDVendor,
ProductID: entries[i].Info.IDProduct,
Serial: entries[i].Info.SerialString(),
}
}
c.stateMu.Lock()
if len(c.targets) == 0 {
c.stateMu.Unlock()
return
}
nextAssigned := make([]string, len(c.targets))
reserved := make(map[string]struct{}, len(c.targets))
for i, target := range c.targets {
if target.fixedBusID == "" {
continue
}
if _, ok := keysByBusID[target.fixedBusID]; !ok {
continue
}
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
}
for i, target := range c.targets {
if target.fixedBusID != "" {
continue
}
current := c.assigned[i]
if current == "" {
continue
}
if _, ok := reserved[current]; ok {
continue
}
key, ok := keysByBusID[current]
if !ok || !Matches(target.match, key) {
continue
}
nextAssigned[i] = current
reserved[current] = struct{}{}
}
for i, target := range c.targets {
if target.fixedBusID != "" || nextAssigned[i] != "" {
continue
}
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
if nextAssigned[i] != "" {
reserved[nextAssigned[i]] = struct{}{}
}
}
nextAssigned := assignMatchedBusIDs(c.targets, c.assigned, entries)
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
previous := append([]string(nil), c.assigned...)
c.assigned = nextAssigned
@ -585,11 +533,11 @@ func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, e
defer file.Close()
c.attachMu.Lock()
defer c.attachMu.Unlock()
port, err := vhciPickFreePort(info.Speed)
port, err := c.ops.vhciPickFreePort(info.Speed)
if err != nil {
return -1, err
}
if err := vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil {
if err := c.ops.vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil {
return -1, E.Cause(err, "vhci attach")
}
return port, nil
@ -601,12 +549,12 @@ func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
for {
select {
case <-ctx.Done():
if err := vhciDetach(port); err != nil {
if err := c.ops.vhciDetach(port); err != nil {
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
}
return
case <-ticker.C:
used, err := vhciPortUsed(port)
used, err := c.ops.vhciPortUsed(port)
if err != nil {
c.logger.Debug("poll port ", port, ": ", err)
continue
@ -632,6 +580,65 @@ func isBusIDOnlyMatch(m option.USBIPDeviceMatch) bool {
return m.BusID != "" && m.VendorID == 0 && m.ProductID == 0 && m.Serial == ""
}
func assignMatchedBusIDs(targets []clientTarget, current []string, entries []DeviceEntry) []string {
if len(targets) == 0 {
return nil
}
keysByBusID := make(map[string]DeviceKey, len(entries))
for i := range entries {
busid := entries[i].Info.BusIDString()
if busid == "" {
continue
}
keysByBusID[busid] = DeviceKey{
BusID: busid,
VendorID: entries[i].Info.IDVendor,
ProductID: entries[i].Info.IDProduct,
Serial: entries[i].Info.SerialString(),
}
}
nextAssigned := make([]string, len(targets))
reserved := make(map[string]struct{}, len(targets))
for i, target := range targets {
if target.fixedBusID == "" {
continue
}
if _, ok := keysByBusID[target.fixedBusID]; !ok {
continue
}
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
}
for i, target := range targets {
if target.fixedBusID != "" || i >= len(current) {
continue
}
if current[i] == "" {
continue
}
if _, ok := reserved[current[i]]; ok {
continue
}
key, ok := keysByBusID[current[i]]
if !ok || !Matches(target.match, key) {
continue
}
nextAssigned[i] = current[i]
reserved[current[i]] = struct{}{}
}
for i, target := range targets {
if target.fixedBusID != "" || nextAssigned[i] != "" {
continue
}
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
if nextAssigned[i] != "" {
reserved[nextAssigned[i]] = struct{}{}
}
}
return nextAssigned
}
func firstMatchingUnclaimedBusID(match option.USBIPDeviceMatch, entries []DeviceEntry, reserved map[string]struct{}) string {
for i := range entries {
key := DeviceKey{

641
service/usbip/linux_test.go Normal file
View file

@ -0,0 +1,641 @@
//go:build linux
package usbip
import (
"context"
"errors"
"net"
"os"
"slices"
"sync"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata"
"github.com/stretchr/testify/require"
)
type testDialer struct{}
func (testDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, destination.String())
}
func (testDialer) ListenPacket(context.Context, M.Socksaddr) (net.PacketConn, error) {
return nil, errors.New("unused")
}
type testDeviceStore struct {
mu sync.Mutex
devices map[string]sysfsDevice
statuses map[string]int
sockfds map[string]int
}
func newTestDeviceStore(devices ...sysfsDevice) *testDeviceStore {
store := &testDeviceStore{
devices: make(map[string]sysfsDevice),
statuses: make(map[string]int),
sockfds: make(map[string]int),
}
store.setDevices(devices...)
return store
}
func (s *testDeviceStore) setDevices(devices ...sysfsDevice) {
s.mu.Lock()
defer s.mu.Unlock()
s.devices = make(map[string]sysfsDevice, len(devices))
for _, device := range devices {
s.devices[device.BusID] = device
}
}
func (s *testDeviceStore) setStatus(busid string, status int) {
s.mu.Lock()
defer s.mu.Unlock()
s.statuses[busid] = status
}
func (s *testDeviceStore) listUSBDevices() ([]sysfsDevice, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]sysfsDevice, 0, len(s.devices))
for _, device := range s.devices {
out = append(out, device)
}
slices.SortFunc(out, func(left, right sysfsDevice) int {
switch {
case left.BusID < right.BusID:
return -1
case left.BusID > right.BusID:
return 1
default:
return 0
}
})
return out, nil
}
func (s *testDeviceStore) readSysfsDevice(busid, path string) (sysfsDevice, error) {
s.mu.Lock()
defer s.mu.Unlock()
device, ok := s.devices[busid]
if !ok {
return sysfsDevice{}, os.ErrNotExist
}
return device, nil
}
func (s *testDeviceStore) readUsbipStatus(busid string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
status, ok := s.statuses[busid]
if !ok {
return 0, os.ErrNotExist
}
return status, nil
}
func (s *testDeviceStore) writeUsbipSockfd(busid string, fd int) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sockfds[busid] = fd
return nil
}
func (s *testDeviceStore) lastSockfd(busid string) int {
s.mu.Lock()
defer s.mu.Unlock()
return s.sockfds[busid]
}
func newTestUSBIPOps(t *testing.T) usbipOps {
t.Helper()
return usbipOps{
ensureHostDriver: func() error {
t.Fatalf("unexpected ensureHostDriver")
return nil
},
ensureVHCI: func() error {
t.Fatalf("unexpected ensureVHCI")
return nil
},
listUSBDevices: func() ([]sysfsDevice, error) {
t.Fatalf("unexpected listUSBDevices")
return nil, nil
},
readSysfsDevice: func(string, string) (sysfsDevice, error) {
t.Fatalf("unexpected readSysfsDevice")
return sysfsDevice{}, nil
},
currentDriver: func(string) (string, error) {
t.Fatalf("unexpected currentDriver")
return "", nil
},
unbindFromDriver: func(string, string) error {
t.Fatalf("unexpected unbindFromDriver")
return nil
},
bindToDriver: func(string, string) error {
t.Fatalf("unexpected bindToDriver")
return nil
},
hostMatchBusID: func(string, bool) error {
t.Fatalf("unexpected hostMatchBusID")
return nil
},
hostBind: func(string) error {
t.Fatalf("unexpected hostBind")
return nil
},
hostUnbind: func(string) error {
t.Fatalf("unexpected hostUnbind")
return nil
},
readUsbipStatus: func(string) (int, error) {
t.Fatalf("unexpected readUsbipStatus")
return 0, nil
},
writeUsbipSockfd: func(string, int) error {
t.Fatalf("unexpected writeUsbipSockfd")
return nil
},
newUEventListener: func() (usbEventListener, error) {
t.Fatalf("unexpected newUEventListener")
return nil, nil
},
vhciPickFreePort: func(uint32) (int, error) {
t.Fatalf("unexpected vhciPickFreePort")
return 0, nil
},
vhciAttach: func(int, uintptr, uint32, uint32) error {
t.Fatalf("unexpected vhciAttach")
return nil
},
vhciDetach: func(int) error {
t.Fatalf("unexpected vhciDetach")
return nil
},
vhciPortUsed: func(int) (bool, error) {
t.Fatalf("unexpected vhciPortUsed")
return false, nil
},
}
}
func newTestLogger() log.ContextLogger {
return log.NewNOPFactory().NewLogger("usbip")
}
func newTestDevice(busid string, vendorID, productID uint16, serial string, speed uint32) sysfsDevice {
return sysfsDevice{
BusID: busid,
Path: sysBusDevicePath(busid),
BusNum: 3,
DevNum: 9,
Speed: speed,
VendorID: vendorID,
ProductID: productID,
DeviceClass: 0,
ConfigValue: 1,
NumConfigs: 1,
NumInterfaces: 1,
Serial: serial,
Interfaces: []DeviceInterface{{
BInterfaceClass: 0xff,
}},
}
}
func startDispatchServer(t *testing.T, server *ServerService) (M.Socksaddr, func()) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
go server.dispatchConn(conn)
}
}()
return M.SocksaddrFromNet(listener.Addr()), func() {
_ = listener.Close()
<-done
}
}
func TestBuildTargetsDedupesFixedBusID(t *testing.T) {
t.Parallel()
client := &ClientService{
matches: []option.USBIPDeviceMatch{
{BusID: "1-1"},
{VendorID: 0x1d6b, ProductID: 0x0002},
{BusID: "1-1"},
{BusID: "1-2"},
},
}
require.Equal(t, []clientTarget{
{fixedBusID: "1-1"},
{match: option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}},
{fixedBusID: "1-2"},
}, client.buildTargets())
}
func TestAssignMatchedBusIDs(t *testing.T) {
t.Parallel()
match := option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}
fixed := newTestDevice("1-1", 0x1d6b, 0x0001, "fixed", SpeedHigh)
first := newTestDevice("1-2", 0x1d6b, 0x0002, "first", SpeedHigh)
second := newTestDevice("1-3", 0x1d6b, 0x0002, "second", SpeedHigh)
entries := []DeviceEntry{
{Info: fixed.toProtocol()},
{Info: first.toProtocol()},
{Info: second.toProtocol()},
}
require.Equal(t, []string{"1-1", "1-3", "1-2"}, assignMatchedBusIDs(
[]clientTarget{
{fixedBusID: "1-1"},
{match: match},
{match: match},
},
[]string{"1-1", "1-3", ""},
entries,
))
}
func TestLinuxHelpers(t *testing.T) {
t.Parallel()
require.Equal(t, []vhciStatusRecord{
{hub: "hs", port: 0, state: 6},
{hub: "ss", port: 3, state: 4},
}, parseVHCIStatus("hub port sta spd dev sockfd local_busid\nhs 0 6 3 0 0 0\nignored line\nss 3 4 5 0 0 0\n"))
require.Equal(t, SpeedLow, speedCodeFromString("1.5"))
require.Equal(t, SpeedFull, speedCodeFromString("12"))
require.Equal(t, SpeedHigh, speedCodeFromString("480"))
require.Equal(t, SpeedSuper, speedCodeFromString("5000"))
require.Equal(t, SpeedSuperPlus, speedCodeFromString("10000"))
require.Equal(t, SpeedUnknown, speedCodeFromString("25"))
require.Equal(t, "hs", vhciHubForSpeed(SpeedHigh))
require.Equal(t, "ss", vhciHubForSpeed(SpeedSuper))
require.True(t, isUSBUEvent([]byte("ACTION=add\x00SUBSYSTEM=usb\x00")))
require.False(t, isUSBUEvent([]byte("ACTION=add\x00SUBSYSTEM=net\x00")))
}
func TestServerReconcileExportsBindsMatchesAndSkipsHub(t *testing.T) {
t.Parallel()
regular := newTestDevice("1-1", 0x1d6b, 0x0002, "regular", SpeedHigh)
hub := newTestDevice("1-2", 0x1d6b, 0x0002, "hub", SpeedHigh)
hub.DeviceClass = 0x09
store := newTestDeviceStore(regular, hub)
ops := newTestUSBIPOps(t)
var actions []string
ops.listUSBDevices = store.listUSBDevices
ops.currentDriver = func(busid string) (string, error) {
return map[string]string{
"1-1": "usbhid",
"1-2": "hubdrv",
}[busid], nil
}
ops.unbindFromDriver = func(busid, driver string) error {
actions = append(actions, "unbind "+busid+" "+driver)
return nil
}
ops.hostMatchBusID = func(busid string, add bool) error {
actions = append(actions, "match "+busid+" "+map[bool]string{true: "add", false: "del"}[add])
return nil
}
ops.hostBind = func(busid string) error {
actions = append(actions, "hostbind "+busid)
return nil
}
ops.bindToDriver = func(busid, driver string) error {
actions = append(actions, "bind "+busid+" "+driver)
return nil
}
server := &ServerService{
ctx: context.Background(),
logger: newTestLogger(),
matches: []option.USBIPDeviceMatch{{VendorID: 0x1d6b, ProductID: 0x0002}},
exports: make(map[string]serverExport),
controlSubs: make(map[uint64]*serverControlConn),
ops: ops,
}
changed, err := server.reconcileExports()
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, []string{
"unbind 1-1 usbhid",
"match 1-1 add",
"hostbind 1-1",
}, actions)
require.Equal(t, map[string]serverExport{
"1-1": {
busid: "1-1",
managed: true,
originalDriver: "usbhid",
},
}, server.snapshotExports())
}
func TestServerReconcileExportsReleasesRemovedExports(t *testing.T) {
t.Parallel()
device := newTestDevice("1-1", 0x1d6b, 0x0002, "regular", SpeedHigh)
store := newTestDeviceStore(device)
ops := newTestUSBIPOps(t)
var actions []string
ops.listUSBDevices = store.listUSBDevices
ops.writeUsbipSockfd = func(busid string, fd int) error {
actions = append(actions, "sockfd "+busid)
return nil
}
ops.hostUnbind = func(busid string) error {
actions = append(actions, "hostunbind "+busid)
return nil
}
ops.hostMatchBusID = func(busid string, add bool) error {
actions = append(actions, "match "+busid+" "+map[bool]string{true: "add", false: "del"}[add])
return nil
}
ops.bindToDriver = func(busid, driver string) error {
actions = append(actions, "bind "+busid+" "+driver)
return nil
}
ops.readSysfsDevice = store.readSysfsDevice
server := &ServerService{
ctx: context.Background(),
logger: newTestLogger(),
exports: map[string]serverExport{"1-1": {busid: "1-1", managed: true, originalDriver: "usbhid"}},
ops: ops,
}
changed, err := server.reconcileExports()
require.NoError(t, err)
require.True(t, changed)
require.Empty(t, server.snapshotExports())
require.Equal(t, []string{
"sockfd 1-1",
"hostunbind 1-1",
"match 1-1 del",
"bind 1-1 usbhid",
}, actions)
}
func TestServerBuildDevListEntriesFiltersUnavailableAndRefreshFailures(t *testing.T) {
t.Parallel()
available := newTestDevice("1-1", 0x1d6b, 0x0002, "ok", SpeedHigh)
store := newTestDeviceStore(available)
store.setStatus("1-1", usbipStatusAvailable)
store.setStatus("1-2", usbipStatusUsed)
store.setStatus("1-3", usbipStatusAvailable)
ops := newTestUSBIPOps(t)
ops.readUsbipStatus = store.readUsbipStatus
ops.readSysfsDevice = store.readSysfsDevice
server := &ServerService{
logger: newTestLogger(),
exports: map[string]serverExport{
"1-1": {busid: "1-1"},
"1-2": {busid: "1-2"},
"1-3": {busid: "1-3"},
},
ops: ops,
}
entries := server.buildDevListEntries()
require.Len(t, entries, 1)
require.Equal(t, "1-1", entries[0].Info.BusIDString())
require.Equal(t, "ok", entries[0].Info.SerialString())
}
func TestServerDispatchConnHandlesControlPingAndChanged(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := &ServerService{
ctx: ctx,
cancel: cancel,
logger: newTestLogger(),
exports: make(map[string]serverExport),
controlSubs: make(map[uint64]*serverControlConn),
ops: newTestUSBIPOps(t),
}
serverAddr, closeServer := startDispatchServer(t, server)
defer closeServer()
conn, err := net.Dial("tcp", serverAddr.String())
require.NoError(t, err)
defer conn.Close()
require.NoError(t, WriteControlPreface(conn))
require.NoError(t, WriteControlHello(conn))
ack, err := ReadControlFrame(conn)
require.NoError(t, err)
require.Equal(t, controlFrameAck, ack.Type)
require.Equal(t, controlProtocolVersion, ack.Version)
require.Equal(t, controlCapabilities, ack.Capabilities)
require.Zero(t, ack.Sequence)
require.NoError(t, WriteControlPing(conn))
pong, err := ReadControlFrame(conn)
require.NoError(t, err)
require.Equal(t, controlFramePong, pong.Type)
require.Equal(t, controlProtocolVersion, pong.Version)
server.broadcastChanged()
changed, err := ReadControlFrame(conn)
require.NoError(t, err)
require.Equal(t, controlFrameChanged, changed.Type)
require.Equal(t, uint64(1), changed.Sequence)
}
func TestClientAttemptAttachUsesImportReplyAndVHCIAttach(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedSuper)
device.BusNum = 7
device.DevNum = 11
store := newTestDeviceStore(device)
store.setStatus("1-1", usbipStatusAvailable)
serverOps := newTestUSBIPOps(t)
serverOps.readUsbipStatus = store.readUsbipStatus
serverOps.readSysfsDevice = store.readSysfsDevice
serverOps.writeUsbipSockfd = store.writeUsbipSockfd
server := &ServerService{
ctx: ctx,
cancel: cancel,
logger: newTestLogger(),
exports: map[string]serverExport{"1-1": {busid: "1-1"}},
controlSubs: make(map[uint64]*serverControlConn),
ops: serverOps,
}
serverAddr, closeServer := startDispatchServer(t, server)
defer closeServer()
clientOps := newTestUSBIPOps(t)
var attachedPort int
var attachedDevID uint32
var attachedSpeed uint32
clientOps.vhciPickFreePort = func(speed uint32) (int, error) {
require.Equal(t, SpeedSuper, speed)
return 7, nil
}
clientOps.vhciAttach = func(port int, _ uintptr, devid uint32, speed uint32) error {
attachedPort = port
attachedDevID = devid
attachedSpeed = speed
return nil
}
client := &ClientService{
ctx: ctx,
cancel: cancel,
logger: newTestLogger(),
dialer: testDialer{},
serverAddr: serverAddr,
ops: clientOps,
}
port, err := client.attemptAttach(ctx, "1-1")
require.NoError(t, err)
require.Equal(t, 7, port)
require.Equal(t, 7, attachedPort)
info := device.toProtocol()
require.Equal(t, info.DevID(), attachedDevID)
require.Equal(t, SpeedSuper, attachedSpeed)
require.Positive(t, store.lastSockfd("1-1"))
}
func TestClientRunControlSessionSyncsAssignmentsOnChanged(t *testing.T) {
t.Parallel()
serverCtx, serverCancel := context.WithCancel(context.Background())
defer serverCancel()
initialDevice := newTestDevice("1-1", 0x1d6b, 0x0002, "first", SpeedHigh)
updatedDevice := newTestDevice("1-2", 0x1d6b, 0x0002, "second", SpeedHigh)
store := newTestDeviceStore(initialDevice)
store.setStatus("1-1", usbipStatusAvailable)
store.setStatus("1-2", usbipStatusAvailable)
serverOps := newTestUSBIPOps(t)
serverOps.readUsbipStatus = store.readUsbipStatus
serverOps.readSysfsDevice = store.readSysfsDevice
server := &ServerService{
ctx: serverCtx,
cancel: serverCancel,
logger: newTestLogger(),
exports: map[string]serverExport{"1-1": {busid: "1-1"}},
controlSubs: make(map[uint64]*serverControlConn),
ops: serverOps,
}
serverAddr, closeServer := startDispatchServer(t, server)
defer closeServer()
clientCtx, clientCancel := context.WithCancel(context.Background())
defer clientCancel()
match := option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}
client := &ClientService{
ctx: clientCtx,
cancel: clientCancel,
logger: newTestLogger(),
dialer: testDialer{},
serverAddr: serverAddr,
matches: []option.USBIPDeviceMatch{match},
targets: []clientTarget{{match: match}},
assigned: make([]string, 1),
ops: newTestUSBIPOps(t),
}
errCh := make(chan error, 1)
go func() {
errCh <- client.runControlSession()
}()
require.Eventually(t, func() bool {
client.stateMu.Lock()
defer client.stateMu.Unlock()
return client.assigned[0] == "1-1"
}, 3*time.Second, 10*time.Millisecond)
store.setDevices(updatedDevice)
server.deleteExport("1-1")
server.setExport(serverExport{busid: "1-2"})
server.broadcastChanged()
require.Eventually(t, func() bool {
client.stateMu.Lock()
defer client.stateMu.Unlock()
return client.assigned[0] == "1-2"
}, 3*time.Second, 10*time.Millisecond)
clientCancel()
select {
case <-errCh:
case <-time.After(3 * time.Second):
t.Fatal("runControlSession did not exit after cancellation")
}
}
func TestUSBIPLinuxSmoke(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("usbip smoke test requires root")
}
require.NoError(t, ensureHostDriver())
require.NoError(t, ensureVHCI())
busid := os.Getenv("USBIP_TEST_BUSID")
if busid == "" {
t.Skip("USBIP_TEST_BUSID not set")
}
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
require.NoError(t, err)
require.Equal(t, busid, device.BusID)
_, err = currentDriver(busid)
require.NoError(t, err)
_, err = readUsbipStatus(busid)
require.NoError(t, err)
}

View file

@ -0,0 +1,52 @@
//go:build linux
package usbip
type usbEventListener interface {
Close() error
WaitUSBEvent() error
}
type usbipOps struct {
ensureHostDriver func() error
ensureVHCI func() error
listUSBDevices func() ([]sysfsDevice, error)
readSysfsDevice func(busid string, path string) (sysfsDevice, error)
currentDriver func(busid string) (string, error)
unbindFromDriver func(busid string, driver string) error
bindToDriver func(busid string, driver string) error
hostMatchBusID func(busid string, add bool) error
hostBind func(busid string) error
hostUnbind func(busid string) error
readUsbipStatus func(busid string) (int, error)
writeUsbipSockfd func(busid string, fd int) error
newUEventListener func() (usbEventListener, error)
vhciPickFreePort func(speed uint32) (int, error)
vhciAttach func(port int, fd uintptr, devid uint32, speed uint32) error
vhciDetach func(port int) error
vhciPortUsed func(port int) (bool, error)
}
var systemUSBIPOps = usbipOps{
ensureHostDriver: ensureHostDriver,
ensureVHCI: ensureVHCI,
listUSBDevices: listUSBDevices,
readSysfsDevice: readSysfsDevice,
currentDriver: currentDriver,
unbindFromDriver: unbindFromDriver,
bindToDriver: bindToDriver,
hostMatchBusID: hostMatchBusID,
hostBind: hostBind,
hostUnbind: hostUnbind,
readUsbipStatus: readUsbipStatus,
writeUsbipSockfd: writeUsbipSockfd,
newUEventListener: func() (usbEventListener, error) {
return newUEventListener()
},
vhciPickFreePort: vhciPickFreePort,
vhciAttach: vhciAttach,
vhciDetach: vhciDetach,
vhciPortUsed: vhciPortUsed,
}

View file

@ -0,0 +1,273 @@
package usbip
import (
"bytes"
"encoding/binary"
"io"
"strings"
"testing"
"github.com/sagernet/sing-box/option"
"github.com/stretchr/testify/require"
)
func TestControlPrefaceAndFrames(t *testing.T) {
t.Parallel()
var preface bytes.Buffer
require.NoError(t, WriteControlPreface(&preface))
require.True(t, IsControlPreface(preface.Bytes()))
corruptedPreface := append([]byte(nil), preface.Bytes()...)
corruptedPreface[len(corruptedPreface)-1] = '0'
require.False(t, IsControlPreface(corruptedPreface))
require.False(t, IsControlPreface(preface.Bytes()[:len(preface.Bytes())-1]))
testCases := []struct {
name string
write func(io.Writer) error
expected controlFrame
}{
{
name: "hello",
write: WriteControlHello,
expected: controlFrame{
Type: controlFrameHello,
Version: controlProtocolVersion,
Capabilities: controlCapabilities,
},
},
{
name: "ack",
write: func(writer io.Writer) error { return WriteControlAck(writer, 7) },
expected: controlFrame{
Type: controlFrameAck,
Version: controlProtocolVersion,
Capabilities: controlCapabilities,
Sequence: 7,
},
},
{
name: "changed",
write: func(writer io.Writer) error { return WriteControlChanged(writer, 9) },
expected: controlFrame{
Type: controlFrameChanged,
Version: controlProtocolVersion,
Sequence: 9,
},
},
{
name: "ping",
write: WriteControlPing,
expected: controlFrame{
Type: controlFramePing,
Version: controlProtocolVersion,
},
},
{
name: "pong",
write: WriteControlPong,
expected: controlFrame{
Type: controlFramePong,
Version: controlProtocolVersion,
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
require.NoError(t, testCase.write(&buffer))
frame, err := ReadControlFrame(&buffer)
require.NoError(t, err)
require.Equal(t, testCase.expected, frame)
})
}
}
func TestOpHeaderRoundTrip(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
require.NoError(t, WriteOpHeader(&buffer, OpReqDevList, OpStatusError))
header, err := ReadOpHeader(&buffer)
require.NoError(t, err)
require.Equal(t, OpHeader{
Version: ProtocolVersion,
Code: OpReqDevList,
Status: OpStatusError,
}, header)
var raw [8]byte
binary.BigEndian.PutUint16(raw[:2], ProtocolVersion)
binary.BigEndian.PutUint16(raw[2:4], OpRepImport)
binary.BigEndian.PutUint32(raw[4:8], OpStatusOK)
require.Equal(t, OpHeader{
Version: ProtocolVersion,
Code: OpRepImport,
Status: OpStatusOK,
}, ParseOpHeader(raw[:]))
}
func TestOpReqImportRoundTrip(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
require.NoError(t, WriteOpReqImport(&buffer, "1-2"))
header, err := ReadOpHeader(&buffer)
require.NoError(t, err)
require.Equal(t, OpReqImport, header.Code)
require.Equal(t, OpStatusOK, header.Status)
busid, err := ReadOpReqImportBody(&buffer)
require.NoError(t, err)
require.Equal(t, "1-2", busid)
}
func TestWriteOpReqImportRejectsLongBusID(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
err := WriteOpReqImport(&buffer, strings.Repeat("a", 32))
require.ErrorContains(t, err, "busid too long")
}
func TestWriteOpRepImportRequiresInfoOnSuccess(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
err := WriteOpRepImport(&buffer, OpStatusOK, nil)
require.ErrorContains(t, err, "success without device info")
}
func TestOpRepImportRoundTrip(t *testing.T) {
t.Parallel()
var info DeviceInfoTruncated
copy(info.BusID[:], "1-2")
info.BusNum = 1
info.DevNum = 2
info.Speed = SpeedHigh
info.IDVendor = 0x1d6b
info.IDProduct = 0x0002
var buffer bytes.Buffer
require.NoError(t, WriteOpRepImport(&buffer, OpStatusOK, &info))
header, err := ReadOpHeader(&buffer)
require.NoError(t, err)
require.Equal(t, OpRepImport, header.Code)
require.Equal(t, OpStatusOK, header.Status)
body, err := ReadOpRepImportBody(&buffer)
require.NoError(t, err)
require.Equal(t, info, body)
}
func TestOpRepDevListRoundTrip(t *testing.T) {
t.Parallel()
var path [256]byte
encodePathField(&path, "/sys/bus/usb/devices/1-2", "serial-1")
entries := []DeviceEntry{{
Info: DeviceInfoTruncated{
Path: path,
BusNum: 1,
DevNum: 2,
Speed: SpeedHigh,
IDVendor: 0x1d6b,
IDProduct: 0x0002,
BNumInterfaces: 2,
},
Interfaces: []DeviceInterface{
{BInterfaceClass: 0xff, BInterfaceSubClass: 1, BInterfaceProtocol: 2},
{BInterfaceClass: 0x03, BInterfaceSubClass: 1, BInterfaceProtocol: 1},
},
}}
copy(entries[0].Info.BusID[:], "1-2")
var buffer bytes.Buffer
require.NoError(t, WriteOpRepDevList(&buffer, entries))
header, err := ReadOpHeader(&buffer)
require.NoError(t, err)
require.Equal(t, OpRepDevList, header.Code)
require.Equal(t, OpStatusOK, header.Status)
parsed, err := ReadOpRepDevListBody(&buffer)
require.NoError(t, err)
require.Equal(t, entries, parsed)
}
func TestReadOpRepDevListBodyRejectsTooManyEntries(t *testing.T) {
t.Parallel()
var buffer bytes.Buffer
require.NoError(t, binary.Write(&buffer, binary.BigEndian, uint32(maxOpRepDevListEntries+1)))
_, err := ReadOpRepDevListBody(&buffer)
require.ErrorContains(t, err, "device count too large")
}
func TestDeviceInfoHelpers(t *testing.T) {
t.Parallel()
var info DeviceInfoTruncated
encodePathField(&info.Path, "/sys/bus/usb/devices/1-2", "serial-1")
copy(info.BusID[:], "1-2")
info.BusNum = 3
info.DevNum = 9
require.Equal(t, "/sys/bus/usb/devices/1-2", info.PathString())
require.Equal(t, "serial-1", info.SerialString())
require.Equal(t, "1-2", info.BusIDString())
require.Equal(t, uint32(0x00030009), info.DevID())
}
func TestEncodePathFieldSkipsSerialWithoutRoom(t *testing.T) {
t.Parallel()
var info DeviceInfoTruncated
encodePathField(&info.Path, strings.Repeat("a", len(info.Path)-1), "serial-1")
require.Empty(t, info.SerialString())
}
func TestMatches(t *testing.T) {
t.Parallel()
device := DeviceKey{
BusID: "1-2",
VendorID: 0x1d6b,
ProductID: 0x0002,
Serial: "serial-1",
}
testCases := []struct {
name string
match option.USBIPDeviceMatch
expected bool
}{
{name: "zero", match: option.USBIPDeviceMatch{}, expected: false},
{name: "busid", match: option.USBIPDeviceMatch{BusID: "1-2"}, expected: true},
{name: "vendor-and-product", match: option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}, expected: true},
{name: "serial", match: option.USBIPDeviceMatch{Serial: "serial-1"}, expected: true},
{name: "all-fields", match: option.USBIPDeviceMatch{BusID: "1-2", VendorID: 0x1d6b, ProductID: 0x0002, Serial: "serial-1"}, expected: true},
{name: "vendor-mismatch", match: option.USBIPDeviceMatch{VendorID: 0x1d6c}, expected: false},
{name: "serial-mismatch", match: option.USBIPDeviceMatch{Serial: "other"}, expected: false},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, testCase.expected, Matches(testCase.match, device))
})
}
}

View file

@ -41,6 +41,7 @@ type ServerService struct {
logger log.ContextLogger
listener *listener.Listener
matches []option.USBIPDeviceMatch
ops usbipOps
mu sync.Mutex
exports map[string]serverExport
@ -76,6 +77,7 @@ func NewServerService(ctx context.Context, logger log.ContextLogger, tag string,
Listen: options.ListenOptions,
}),
controlSubs: make(map[uint64]*serverControlConn),
ops: systemUSBIPOps,
}
return s, nil
}
@ -84,7 +86,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if err := ensureHostDriver(); err != nil {
if err := s.ops.ensureHostDriver(); err != nil {
return err
}
if _, err := s.reconcileExports(); err != nil {
@ -115,7 +117,7 @@ func (s *ServerService) Close() error {
}
func (s *ServerService) reconcileExports() (bool, error) {
devices, err := listUSBDevices()
devices, err := s.ops.listUSBDevices()
if err != nil {
return false, E.Cause(err, "enumerate usb devices")
}
@ -163,7 +165,7 @@ func (s *ServerService) reconcileExports() (bool, error) {
}
func (s *ServerService) bindOne(d *sysfsDevice) error {
driver, err := currentDriver(d.BusID)
driver, err := s.ops.currentDriver(d.BusID)
if err != nil {
return err
}
@ -173,20 +175,20 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
return nil
}
if driver != "" {
if err := unbindFromDriver(d.BusID, driver); err != nil {
if err := s.ops.unbindFromDriver(d.BusID, driver); err != nil {
return E.Cause(err, "unbind from ", driver)
}
}
if err := hostMatchBusID(d.BusID, true); err != nil {
if err := s.ops.hostMatchBusID(d.BusID, true); err != nil {
if driver != "" {
_ = bindToDriver(d.BusID, driver)
_ = s.ops.bindToDriver(d.BusID, driver)
}
return E.Cause(err, "match_busid add")
}
if err := hostBind(d.BusID); err != nil {
_ = hostMatchBusID(d.BusID, false)
if err := s.ops.hostBind(d.BusID); err != nil {
_ = s.ops.hostMatchBusID(d.BusID, false)
if driver != "" {
_ = bindToDriver(d.BusID, driver)
_ = s.ops.bindToDriver(d.BusID, driver)
}
return E.Cause(err, "bind to usbip-host")
}
@ -203,17 +205,17 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
s.deleteExport(export.busid)
var releaseErr error
if err := writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
if err := s.ops.writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
releaseErr = err
}
if !export.managed {
s.logger.Info("stopped tracking ", export.busid, " on usbip-host")
return releaseErr
}
if err := hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
if err := s.ops.hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
releaseErr = err
}
if err := hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
if err := s.ops.hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
releaseErr = err
}
if !restore {
@ -224,7 +226,7 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
s.logger.Info("released ", export.busid, " from usbip-host")
return releaseErr
}
if err := bindToDriver(export.busid, export.originalDriver); err != nil {
if err := s.ops.bindToDriver(export.busid, export.originalDriver); err != nil {
if releaseErr == nil {
releaseErr = err
}
@ -237,7 +239,8 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
func (s *ServerService) rollbackExports() {
exports := s.snapshotExports()
for _, export := range exports {
_, restore := currentSysfsDevice(export.busid)
_, err := s.ops.readSysfsDevice(export.busid, sysBusDevicePath(export.busid))
restore := err == nil
if err := s.releaseExport(export, restore); err != nil {
s.logger.Warn("rollback ", export.busid, ": ", err)
}
@ -409,7 +412,7 @@ func (s *ServerService) buildDevListEntries() []DeviceEntry {
}
entries := make([]DeviceEntry, 0, len(busids))
for _, busid := range busids {
status, err := readUsbipStatus(busid)
status, err := s.ops.readUsbipStatus(busid)
if err != nil {
s.logger.Debug("status ", busid, ": ", err)
continue
@ -417,7 +420,7 @@ func (s *ServerService) buildDevListEntries() []DeviceEntry {
if status != usbipStatusAvailable {
continue
}
d, err := readSysfsDevice(busid, sysBusDevicePath(busid))
d, err := s.ops.readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
s.logger.Debug("refresh ", busid, ": ", err)
continue
@ -441,13 +444,13 @@ func (s *ServerService) handleImport(conn net.Conn) {
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
status, err := readUsbipStatus(busid)
status, err := s.ops.readUsbipStatus(busid)
if err != nil || status != usbipStatusAvailable {
s.logger.Info("import rejected (busid ", busid, " status=", status, " err=", err, ")")
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
dev, err := readSysfsDevice(busid, sysBusDevicePath(busid))
dev, err := s.ops.readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
s.logger.Warn("refresh ", busid, ": ", err)
_ = WriteOpRepImport(conn, OpStatusError, nil)
@ -466,7 +469,7 @@ func (s *ServerService) handleImport(conn net.Conn) {
return
}
defer file.Close()
if err := writeUsbipSockfd(busid, int(file.Fd())); err != nil {
if err := s.ops.writeUsbipSockfd(busid, int(file.Fd())); err != nil {
s.logger.Warn("hand off ", busid, " to kernel: ", err)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
@ -474,7 +477,7 @@ func (s *ServerService) handleImport(conn net.Conn) {
info := dev.toProtocol()
if err := WriteOpRepImport(conn, OpStatusOK, &info); err != nil {
s.logger.Warn("reply import ", busid, ": ", err)
_ = writeUsbipSockfd(busid, -1)
_ = s.ops.writeUsbipSockfd(busid, -1)
return
}
s.logger.Info("attached ", busid, " to remote ", conn.RemoteAddr())
@ -489,7 +492,7 @@ func (s *ServerService) isExported(busid string) bool {
func (s *ServerService) ueventLoop() {
for {
listener, err := newUEventListener()
listener, err := s.ops.newUEventListener()
if err != nil {
if s.ctx.Err() != nil {
return
@ -595,14 +598,6 @@ func (s *ServerService) enqueueControlFrame(sub *serverControlConn, frame contro
}
}
func currentSysfsDevice(busid string) (sysfsDevice, bool) {
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
return sysfsDevice{}, false
}
return device, true
}
func sysBusDevicePath(busid string) string {
return sysBusUSBDevices + "/" + busid
}

View file

@ -6,11 +6,13 @@ import (
"bufio"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/shell"
)
const (
@ -80,18 +82,12 @@ type vhciStatusRecord struct {
// ensureHostDriver verifies the usbip-host kernel driver is loaded.
func ensureHostDriver() error {
if _, err := os.Stat(sysUsbipHostDriver); err != nil {
return E.Cause(err, "usbip-host driver not present; modprobe usbip-host")
}
return nil
return ensureKernelPath(sysUsbipHostDriver, "usbip-host", "usbip-host driver")
}
// ensureVHCI verifies the vhci_hcd controller is loaded.
func ensureVHCI() error {
if _, err := os.Stat(sysVHCIControllerV0); err != nil {
return E.Cause(err, "vhci_hcd.0 not present; modprobe vhci-hcd")
}
return nil
return ensureKernelPath(sysVHCIControllerV0, "vhci-hcd", "vhci_hcd.0")
}
// listUSBDevices enumerates /sys/bus/usb/devices, returning non-interface
@ -324,6 +320,41 @@ func vhciHubForSpeed(speed uint32) string {
}
}
func ensureKernelPath(path string, module string, description string) error {
_, err := os.Stat(path)
if err == nil {
return nil
}
if os.Getuid() != 0 {
return E.Cause(err, description, " not present; root is required to load kernel module ", module)
}
modprobePath, modprobeErr := findModprobePath()
if modprobeErr != nil {
return E.Cause(modprobeErr, "load kernel module ", module, " for ", description)
}
output, modprobeErr := shell.Exec(modprobePath, module).Read()
if modprobeErr != nil {
return E.Extend(E.Cause(modprobeErr, "load kernel module ", module, " for ", description), strings.TrimSpace(output))
}
if _, err = os.Stat(path); err != nil {
return E.Cause(err, description, " still not present after loading kernel module ", module)
}
return nil
}
func findModprobePath() (string, error) {
if path, err := exec.LookPath("modprobe"); err == nil {
return path, nil
}
for _, path := range []string{"/usr/sbin/modprobe", "/sbin/modprobe", "/usr/bin/modprobe", "/bin/modprobe"} {
info, err := os.Stat(path)
if err == nil && info.Mode().IsRegular() && info.Mode()&0o111 != 0 {
return path, nil
}
}
return "", E.New("modprobe executable not found")
}
// --- small helpers ------------------------------------------------------
func writeSysfs(path, content string) error {