mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-05-13 13:57:05 +00:00
Add usbip tests and auto-load drivers
This commit is contained in:
parent
f3bbd3c07b
commit
85c9d6b4f3
7 changed files with 1173 additions and 97 deletions
77
option/usbip_test.go
Normal file
77
option/usbip_test.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUSBIPHexUint16UnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected USBIPHexUint16
|
||||
errorSubstr string
|
||||
}{
|
||||
{name: "number", input: `7531`, expected: USBIPHexUint16(0x1d6b)},
|
||||
{name: "hex-with-prefix", input: `"0x1d6b"`, expected: USBIPHexUint16(0x1d6b)},
|
||||
{name: "hex-with-uppercase-prefix", input: `"0X1D6B"`, expected: USBIPHexUint16(0x1d6b)},
|
||||
{name: "hex-without-prefix", input: `"1d6b"`, expected: USBIPHexUint16(0x1d6b)},
|
||||
{name: "empty-string", input: `""`, expected: 0},
|
||||
{name: "out-of-range", input: `65536`, errorSubstr: "out of uint16 range"},
|
||||
{name: "invalid-hex", input: `"zzzz"`, errorSubstr: "parse usb id zzzz"},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var value USBIPHexUint16
|
||||
err := json.UnmarshalContext(context.Background(), []byte(testCase.input), &value)
|
||||
if testCase.errorSubstr != "" {
|
||||
require.ErrorContains(t, err, testCase.errorSubstr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testCase.expected, value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUSBIPHexUint16MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input USBIPHexUint16
|
||||
expected string
|
||||
}{
|
||||
{name: "zero", input: 0, expected: `""`},
|
||||
{name: "non-zero", input: USBIPHexUint16(0x1d6b), expected: `"0x1d6b"`},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal(testCase.input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testCase.expected, string(content))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUSBIPDeviceMatchIsZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, (USBIPDeviceMatch{}).IsZero())
|
||||
require.False(t, (USBIPDeviceMatch{BusID: "1-1"}).IsZero())
|
||||
require.False(t, (USBIPDeviceMatch{VendorID: 0x1d6b}).IsZero())
|
||||
require.False(t, (USBIPDeviceMatch{ProductID: 0x0002}).IsZero())
|
||||
require.False(t, (USBIPDeviceMatch{Serial: "abc"}).IsZero())
|
||||
}
|
||||
|
|
@ -60,6 +60,7 @@ type ClientService struct {
|
|||
dialer N.Dialer
|
||||
serverAddr M.Socksaddr
|
||||
matches []option.USBIPDeviceMatch // empty = import all remote exports
|
||||
ops usbipOps
|
||||
|
||||
stateMu sync.Mutex
|
||||
targets []clientTarget
|
||||
|
|
@ -99,6 +100,7 @@ func NewClientService(ctx context.Context, logger log.ContextLogger, tag string,
|
|||
dialer: outboundDialer,
|
||||
serverAddr: options.ServerOptions.Build(),
|
||||
matches: options.Devices,
|
||||
ops: systemUSBIPOps,
|
||||
allWorkers: make(map[string]*clientBusIDWorker),
|
||||
ports: make(map[int]struct{}),
|
||||
}, nil
|
||||
|
|
@ -108,7 +110,7 @@ func (c *ClientService) Start(stage adapter.StartStage) error {
|
|||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if err := ensureVHCI(); err != nil {
|
||||
if err := c.ops.ensureVHCI(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.initializeWorkers()
|
||||
|
|
@ -319,66 +321,12 @@ func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
|
|||
}
|
||||
|
||||
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
|
||||
keysByBusID := make(map[string]DeviceKey, len(entries))
|
||||
for i := range entries {
|
||||
busid := entries[i].Info.BusIDString()
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
keysByBusID[busid] = DeviceKey{
|
||||
BusID: busid,
|
||||
VendorID: entries[i].Info.IDVendor,
|
||||
ProductID: entries[i].Info.IDProduct,
|
||||
Serial: entries[i].Info.SerialString(),
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMu.Lock()
|
||||
if len(c.targets) == 0 {
|
||||
c.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
nextAssigned := make([]string, len(c.targets))
|
||||
reserved := make(map[string]struct{}, len(c.targets))
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := keysByBusID[target.fixedBusID]; !ok {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = target.fixedBusID
|
||||
reserved[target.fixedBusID] = struct{}{}
|
||||
}
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID != "" {
|
||||
continue
|
||||
}
|
||||
current := c.assigned[i]
|
||||
if current == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := reserved[current]; ok {
|
||||
continue
|
||||
}
|
||||
key, ok := keysByBusID[current]
|
||||
if !ok || !Matches(target.match, key) {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = current
|
||||
reserved[current] = struct{}{}
|
||||
}
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID != "" || nextAssigned[i] != "" {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
|
||||
if nextAssigned[i] != "" {
|
||||
reserved[nextAssigned[i]] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
nextAssigned := assignMatchedBusIDs(c.targets, c.assigned, entries)
|
||||
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
|
||||
previous := append([]string(nil), c.assigned...)
|
||||
c.assigned = nextAssigned
|
||||
|
|
@ -585,11 +533,11 @@ func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, e
|
|||
defer file.Close()
|
||||
c.attachMu.Lock()
|
||||
defer c.attachMu.Unlock()
|
||||
port, err := vhciPickFreePort(info.Speed)
|
||||
port, err := c.ops.vhciPickFreePort(info.Speed)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
if err := vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil {
|
||||
if err := c.ops.vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil {
|
||||
return -1, E.Cause(err, "vhci attach")
|
||||
}
|
||||
return port, nil
|
||||
|
|
@ -601,12 +549,12 @@ func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
|
|||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := vhciDetach(port); err != nil {
|
||||
if err := c.ops.vhciDetach(port); err != nil {
|
||||
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
used, err := vhciPortUsed(port)
|
||||
used, err := c.ops.vhciPortUsed(port)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll port ", port, ": ", err)
|
||||
continue
|
||||
|
|
@ -632,6 +580,65 @@ func isBusIDOnlyMatch(m option.USBIPDeviceMatch) bool {
|
|||
return m.BusID != "" && m.VendorID == 0 && m.ProductID == 0 && m.Serial == ""
|
||||
}
|
||||
|
||||
func assignMatchedBusIDs(targets []clientTarget, current []string, entries []DeviceEntry) []string {
|
||||
if len(targets) == 0 {
|
||||
return nil
|
||||
}
|
||||
keysByBusID := make(map[string]DeviceKey, len(entries))
|
||||
for i := range entries {
|
||||
busid := entries[i].Info.BusIDString()
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
keysByBusID[busid] = DeviceKey{
|
||||
BusID: busid,
|
||||
VendorID: entries[i].Info.IDVendor,
|
||||
ProductID: entries[i].Info.IDProduct,
|
||||
Serial: entries[i].Info.SerialString(),
|
||||
}
|
||||
}
|
||||
|
||||
nextAssigned := make([]string, len(targets))
|
||||
reserved := make(map[string]struct{}, len(targets))
|
||||
for i, target := range targets {
|
||||
if target.fixedBusID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := keysByBusID[target.fixedBusID]; !ok {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = target.fixedBusID
|
||||
reserved[target.fixedBusID] = struct{}{}
|
||||
}
|
||||
for i, target := range targets {
|
||||
if target.fixedBusID != "" || i >= len(current) {
|
||||
continue
|
||||
}
|
||||
if current[i] == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := reserved[current[i]]; ok {
|
||||
continue
|
||||
}
|
||||
key, ok := keysByBusID[current[i]]
|
||||
if !ok || !Matches(target.match, key) {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = current[i]
|
||||
reserved[current[i]] = struct{}{}
|
||||
}
|
||||
for i, target := range targets {
|
||||
if target.fixedBusID != "" || nextAssigned[i] != "" {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
|
||||
if nextAssigned[i] != "" {
|
||||
reserved[nextAssigned[i]] = struct{}{}
|
||||
}
|
||||
}
|
||||
return nextAssigned
|
||||
}
|
||||
|
||||
func firstMatchingUnclaimedBusID(match option.USBIPDeviceMatch, entries []DeviceEntry, reserved map[string]struct{}) string {
|
||||
for i := range entries {
|
||||
key := DeviceKey{
|
||||
|
|
|
|||
641
service/usbip/linux_test.go
Normal file
641
service/usbip/linux_test.go
Normal file
|
|
@ -0,0 +1,641 @@
|
|||
//go:build linux
|
||||
|
||||
package usbip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testDialer struct{}
|
||||
|
||||
func (testDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, destination.String())
|
||||
}
|
||||
|
||||
func (testDialer) ListenPacket(context.Context, M.Socksaddr) (net.PacketConn, error) {
|
||||
return nil, errors.New("unused")
|
||||
}
|
||||
|
||||
type testDeviceStore struct {
|
||||
mu sync.Mutex
|
||||
devices map[string]sysfsDevice
|
||||
statuses map[string]int
|
||||
sockfds map[string]int
|
||||
}
|
||||
|
||||
func newTestDeviceStore(devices ...sysfsDevice) *testDeviceStore {
|
||||
store := &testDeviceStore{
|
||||
devices: make(map[string]sysfsDevice),
|
||||
statuses: make(map[string]int),
|
||||
sockfds: make(map[string]int),
|
||||
}
|
||||
store.setDevices(devices...)
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) setDevices(devices ...sysfsDevice) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.devices = make(map[string]sysfsDevice, len(devices))
|
||||
for _, device := range devices {
|
||||
s.devices[device.BusID] = device
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) setStatus(busid string, status int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.statuses[busid] = status
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) listUSBDevices() ([]sysfsDevice, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
out := make([]sysfsDevice, 0, len(s.devices))
|
||||
for _, device := range s.devices {
|
||||
out = append(out, device)
|
||||
}
|
||||
slices.SortFunc(out, func(left, right sysfsDevice) int {
|
||||
switch {
|
||||
case left.BusID < right.BusID:
|
||||
return -1
|
||||
case left.BusID > right.BusID:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) readSysfsDevice(busid, path string) (sysfsDevice, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
device, ok := s.devices[busid]
|
||||
if !ok {
|
||||
return sysfsDevice{}, os.ErrNotExist
|
||||
}
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) readUsbipStatus(busid string) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, ok := s.statuses[busid]
|
||||
if !ok {
|
||||
return 0, os.ErrNotExist
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) writeUsbipSockfd(busid string, fd int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sockfds[busid] = fd
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testDeviceStore) lastSockfd(busid string) int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.sockfds[busid]
|
||||
}
|
||||
|
||||
func newTestUSBIPOps(t *testing.T) usbipOps {
|
||||
t.Helper()
|
||||
|
||||
return usbipOps{
|
||||
ensureHostDriver: func() error {
|
||||
t.Fatalf("unexpected ensureHostDriver")
|
||||
return nil
|
||||
},
|
||||
ensureVHCI: func() error {
|
||||
t.Fatalf("unexpected ensureVHCI")
|
||||
return nil
|
||||
},
|
||||
listUSBDevices: func() ([]sysfsDevice, error) {
|
||||
t.Fatalf("unexpected listUSBDevices")
|
||||
return nil, nil
|
||||
},
|
||||
readSysfsDevice: func(string, string) (sysfsDevice, error) {
|
||||
t.Fatalf("unexpected readSysfsDevice")
|
||||
return sysfsDevice{}, nil
|
||||
},
|
||||
currentDriver: func(string) (string, error) {
|
||||
t.Fatalf("unexpected currentDriver")
|
||||
return "", nil
|
||||
},
|
||||
unbindFromDriver: func(string, string) error {
|
||||
t.Fatalf("unexpected unbindFromDriver")
|
||||
return nil
|
||||
},
|
||||
bindToDriver: func(string, string) error {
|
||||
t.Fatalf("unexpected bindToDriver")
|
||||
return nil
|
||||
},
|
||||
hostMatchBusID: func(string, bool) error {
|
||||
t.Fatalf("unexpected hostMatchBusID")
|
||||
return nil
|
||||
},
|
||||
hostBind: func(string) error {
|
||||
t.Fatalf("unexpected hostBind")
|
||||
return nil
|
||||
},
|
||||
hostUnbind: func(string) error {
|
||||
t.Fatalf("unexpected hostUnbind")
|
||||
return nil
|
||||
},
|
||||
readUsbipStatus: func(string) (int, error) {
|
||||
t.Fatalf("unexpected readUsbipStatus")
|
||||
return 0, nil
|
||||
},
|
||||
writeUsbipSockfd: func(string, int) error {
|
||||
t.Fatalf("unexpected writeUsbipSockfd")
|
||||
return nil
|
||||
},
|
||||
newUEventListener: func() (usbEventListener, error) {
|
||||
t.Fatalf("unexpected newUEventListener")
|
||||
return nil, nil
|
||||
},
|
||||
vhciPickFreePort: func(uint32) (int, error) {
|
||||
t.Fatalf("unexpected vhciPickFreePort")
|
||||
return 0, nil
|
||||
},
|
||||
vhciAttach: func(int, uintptr, uint32, uint32) error {
|
||||
t.Fatalf("unexpected vhciAttach")
|
||||
return nil
|
||||
},
|
||||
vhciDetach: func(int) error {
|
||||
t.Fatalf("unexpected vhciDetach")
|
||||
return nil
|
||||
},
|
||||
vhciPortUsed: func(int) (bool, error) {
|
||||
t.Fatalf("unexpected vhciPortUsed")
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTestLogger() log.ContextLogger {
|
||||
return log.NewNOPFactory().NewLogger("usbip")
|
||||
}
|
||||
|
||||
func newTestDevice(busid string, vendorID, productID uint16, serial string, speed uint32) sysfsDevice {
|
||||
return sysfsDevice{
|
||||
BusID: busid,
|
||||
Path: sysBusDevicePath(busid),
|
||||
BusNum: 3,
|
||||
DevNum: 9,
|
||||
Speed: speed,
|
||||
VendorID: vendorID,
|
||||
ProductID: productID,
|
||||
DeviceClass: 0,
|
||||
ConfigValue: 1,
|
||||
NumConfigs: 1,
|
||||
NumInterfaces: 1,
|
||||
Serial: serial,
|
||||
Interfaces: []DeviceInterface{{
|
||||
BInterfaceClass: 0xff,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func startDispatchServer(t *testing.T, server *ServerService) (M.Socksaddr, func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
go server.dispatchConn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return M.SocksaddrFromNet(listener.Addr()), func() {
|
||||
_ = listener.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTargetsDedupesFixedBusID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &ClientService{
|
||||
matches: []option.USBIPDeviceMatch{
|
||||
{BusID: "1-1"},
|
||||
{VendorID: 0x1d6b, ProductID: 0x0002},
|
||||
{BusID: "1-1"},
|
||||
{BusID: "1-2"},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, []clientTarget{
|
||||
{fixedBusID: "1-1"},
|
||||
{match: option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}},
|
||||
{fixedBusID: "1-2"},
|
||||
}, client.buildTargets())
|
||||
}
|
||||
|
||||
func TestAssignMatchedBusIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
match := option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}
|
||||
fixed := newTestDevice("1-1", 0x1d6b, 0x0001, "fixed", SpeedHigh)
|
||||
first := newTestDevice("1-2", 0x1d6b, 0x0002, "first", SpeedHigh)
|
||||
second := newTestDevice("1-3", 0x1d6b, 0x0002, "second", SpeedHigh)
|
||||
entries := []DeviceEntry{
|
||||
{Info: fixed.toProtocol()},
|
||||
{Info: first.toProtocol()},
|
||||
{Info: second.toProtocol()},
|
||||
}
|
||||
|
||||
require.Equal(t, []string{"1-1", "1-3", "1-2"}, assignMatchedBusIDs(
|
||||
[]clientTarget{
|
||||
{fixedBusID: "1-1"},
|
||||
{match: match},
|
||||
{match: match},
|
||||
},
|
||||
[]string{"1-1", "1-3", ""},
|
||||
entries,
|
||||
))
|
||||
}
|
||||
|
||||
func TestLinuxHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, []vhciStatusRecord{
|
||||
{hub: "hs", port: 0, state: 6},
|
||||
{hub: "ss", port: 3, state: 4},
|
||||
}, parseVHCIStatus("hub port sta spd dev sockfd local_busid\nhs 0 6 3 0 0 0\nignored line\nss 3 4 5 0 0 0\n"))
|
||||
|
||||
require.Equal(t, SpeedLow, speedCodeFromString("1.5"))
|
||||
require.Equal(t, SpeedFull, speedCodeFromString("12"))
|
||||
require.Equal(t, SpeedHigh, speedCodeFromString("480"))
|
||||
require.Equal(t, SpeedSuper, speedCodeFromString("5000"))
|
||||
require.Equal(t, SpeedSuperPlus, speedCodeFromString("10000"))
|
||||
require.Equal(t, SpeedUnknown, speedCodeFromString("25"))
|
||||
|
||||
require.Equal(t, "hs", vhciHubForSpeed(SpeedHigh))
|
||||
require.Equal(t, "ss", vhciHubForSpeed(SpeedSuper))
|
||||
require.True(t, isUSBUEvent([]byte("ACTION=add\x00SUBSYSTEM=usb\x00")))
|
||||
require.False(t, isUSBUEvent([]byte("ACTION=add\x00SUBSYSTEM=net\x00")))
|
||||
}
|
||||
|
||||
func TestServerReconcileExportsBindsMatchesAndSkipsHub(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
regular := newTestDevice("1-1", 0x1d6b, 0x0002, "regular", SpeedHigh)
|
||||
hub := newTestDevice("1-2", 0x1d6b, 0x0002, "hub", SpeedHigh)
|
||||
hub.DeviceClass = 0x09
|
||||
store := newTestDeviceStore(regular, hub)
|
||||
ops := newTestUSBIPOps(t)
|
||||
var actions []string
|
||||
ops.listUSBDevices = store.listUSBDevices
|
||||
ops.currentDriver = func(busid string) (string, error) {
|
||||
return map[string]string{
|
||||
"1-1": "usbhid",
|
||||
"1-2": "hubdrv",
|
||||
}[busid], nil
|
||||
}
|
||||
ops.unbindFromDriver = func(busid, driver string) error {
|
||||
actions = append(actions, "unbind "+busid+" "+driver)
|
||||
return nil
|
||||
}
|
||||
ops.hostMatchBusID = func(busid string, add bool) error {
|
||||
actions = append(actions, "match "+busid+" "+map[bool]string{true: "add", false: "del"}[add])
|
||||
return nil
|
||||
}
|
||||
ops.hostBind = func(busid string) error {
|
||||
actions = append(actions, "hostbind "+busid)
|
||||
return nil
|
||||
}
|
||||
ops.bindToDriver = func(busid, driver string) error {
|
||||
actions = append(actions, "bind "+busid+" "+driver)
|
||||
return nil
|
||||
}
|
||||
|
||||
server := &ServerService{
|
||||
ctx: context.Background(),
|
||||
logger: newTestLogger(),
|
||||
matches: []option.USBIPDeviceMatch{{VendorID: 0x1d6b, ProductID: 0x0002}},
|
||||
exports: make(map[string]serverExport),
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
ops: ops,
|
||||
}
|
||||
|
||||
changed, err := server.reconcileExports()
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, []string{
|
||||
"unbind 1-1 usbhid",
|
||||
"match 1-1 add",
|
||||
"hostbind 1-1",
|
||||
}, actions)
|
||||
require.Equal(t, map[string]serverExport{
|
||||
"1-1": {
|
||||
busid: "1-1",
|
||||
managed: true,
|
||||
originalDriver: "usbhid",
|
||||
},
|
||||
}, server.snapshotExports())
|
||||
}
|
||||
|
||||
func TestServerReconcileExportsReleasesRemovedExports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
device := newTestDevice("1-1", 0x1d6b, 0x0002, "regular", SpeedHigh)
|
||||
store := newTestDeviceStore(device)
|
||||
ops := newTestUSBIPOps(t)
|
||||
var actions []string
|
||||
ops.listUSBDevices = store.listUSBDevices
|
||||
ops.writeUsbipSockfd = func(busid string, fd int) error {
|
||||
actions = append(actions, "sockfd "+busid)
|
||||
return nil
|
||||
}
|
||||
ops.hostUnbind = func(busid string) error {
|
||||
actions = append(actions, "hostunbind "+busid)
|
||||
return nil
|
||||
}
|
||||
ops.hostMatchBusID = func(busid string, add bool) error {
|
||||
actions = append(actions, "match "+busid+" "+map[bool]string{true: "add", false: "del"}[add])
|
||||
return nil
|
||||
}
|
||||
ops.bindToDriver = func(busid, driver string) error {
|
||||
actions = append(actions, "bind "+busid+" "+driver)
|
||||
return nil
|
||||
}
|
||||
ops.readSysfsDevice = store.readSysfsDevice
|
||||
|
||||
server := &ServerService{
|
||||
ctx: context.Background(),
|
||||
logger: newTestLogger(),
|
||||
exports: map[string]serverExport{"1-1": {busid: "1-1", managed: true, originalDriver: "usbhid"}},
|
||||
ops: ops,
|
||||
}
|
||||
|
||||
changed, err := server.reconcileExports()
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Empty(t, server.snapshotExports())
|
||||
require.Equal(t, []string{
|
||||
"sockfd 1-1",
|
||||
"hostunbind 1-1",
|
||||
"match 1-1 del",
|
||||
"bind 1-1 usbhid",
|
||||
}, actions)
|
||||
}
|
||||
|
||||
func TestServerBuildDevListEntriesFiltersUnavailableAndRefreshFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
available := newTestDevice("1-1", 0x1d6b, 0x0002, "ok", SpeedHigh)
|
||||
store := newTestDeviceStore(available)
|
||||
store.setStatus("1-1", usbipStatusAvailable)
|
||||
store.setStatus("1-2", usbipStatusUsed)
|
||||
store.setStatus("1-3", usbipStatusAvailable)
|
||||
|
||||
ops := newTestUSBIPOps(t)
|
||||
ops.readUsbipStatus = store.readUsbipStatus
|
||||
ops.readSysfsDevice = store.readSysfsDevice
|
||||
|
||||
server := &ServerService{
|
||||
logger: newTestLogger(),
|
||||
exports: map[string]serverExport{
|
||||
"1-1": {busid: "1-1"},
|
||||
"1-2": {busid: "1-2"},
|
||||
"1-3": {busid: "1-3"},
|
||||
},
|
||||
ops: ops,
|
||||
}
|
||||
|
||||
entries := server.buildDevListEntries()
|
||||
require.Len(t, entries, 1)
|
||||
require.Equal(t, "1-1", entries[0].Info.BusIDString())
|
||||
require.Equal(t, "ok", entries[0].Info.SerialString())
|
||||
}
|
||||
|
||||
func TestServerDispatchConnHandlesControlPingAndChanged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server := &ServerService{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: newTestLogger(),
|
||||
exports: make(map[string]serverExport),
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
ops: newTestUSBIPOps(t),
|
||||
}
|
||||
serverAddr, closeServer := startDispatchServer(t, server)
|
||||
defer closeServer()
|
||||
|
||||
conn, err := net.Dial("tcp", serverAddr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
require.NoError(t, WriteControlPreface(conn))
|
||||
require.NoError(t, WriteControlHello(conn))
|
||||
|
||||
ack, err := ReadControlFrame(conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, controlFrameAck, ack.Type)
|
||||
require.Equal(t, controlProtocolVersion, ack.Version)
|
||||
require.Equal(t, controlCapabilities, ack.Capabilities)
|
||||
require.Zero(t, ack.Sequence)
|
||||
|
||||
require.NoError(t, WriteControlPing(conn))
|
||||
pong, err := ReadControlFrame(conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, controlFramePong, pong.Type)
|
||||
require.Equal(t, controlProtocolVersion, pong.Version)
|
||||
|
||||
server.broadcastChanged()
|
||||
changed, err := ReadControlFrame(conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, controlFrameChanged, changed.Type)
|
||||
require.Equal(t, uint64(1), changed.Sequence)
|
||||
}
|
||||
|
||||
func TestClientAttemptAttachUsesImportReplyAndVHCIAttach(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedSuper)
|
||||
device.BusNum = 7
|
||||
device.DevNum = 11
|
||||
store := newTestDeviceStore(device)
|
||||
store.setStatus("1-1", usbipStatusAvailable)
|
||||
|
||||
serverOps := newTestUSBIPOps(t)
|
||||
serverOps.readUsbipStatus = store.readUsbipStatus
|
||||
serverOps.readSysfsDevice = store.readSysfsDevice
|
||||
serverOps.writeUsbipSockfd = store.writeUsbipSockfd
|
||||
|
||||
server := &ServerService{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: newTestLogger(),
|
||||
exports: map[string]serverExport{"1-1": {busid: "1-1"}},
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
ops: serverOps,
|
||||
}
|
||||
serverAddr, closeServer := startDispatchServer(t, server)
|
||||
defer closeServer()
|
||||
|
||||
clientOps := newTestUSBIPOps(t)
|
||||
var attachedPort int
|
||||
var attachedDevID uint32
|
||||
var attachedSpeed uint32
|
||||
clientOps.vhciPickFreePort = func(speed uint32) (int, error) {
|
||||
require.Equal(t, SpeedSuper, speed)
|
||||
return 7, nil
|
||||
}
|
||||
clientOps.vhciAttach = func(port int, _ uintptr, devid uint32, speed uint32) error {
|
||||
attachedPort = port
|
||||
attachedDevID = devid
|
||||
attachedSpeed = speed
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ClientService{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: newTestLogger(),
|
||||
dialer: testDialer{},
|
||||
serverAddr: serverAddr,
|
||||
ops: clientOps,
|
||||
}
|
||||
|
||||
port, err := client.attemptAttach(ctx, "1-1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 7, port)
|
||||
require.Equal(t, 7, attachedPort)
|
||||
info := device.toProtocol()
|
||||
require.Equal(t, info.DevID(), attachedDevID)
|
||||
require.Equal(t, SpeedSuper, attachedSpeed)
|
||||
require.Positive(t, store.lastSockfd("1-1"))
|
||||
}
|
||||
|
||||
func TestClientRunControlSessionSyncsAssignmentsOnChanged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverCtx, serverCancel := context.WithCancel(context.Background())
|
||||
defer serverCancel()
|
||||
|
||||
initialDevice := newTestDevice("1-1", 0x1d6b, 0x0002, "first", SpeedHigh)
|
||||
updatedDevice := newTestDevice("1-2", 0x1d6b, 0x0002, "second", SpeedHigh)
|
||||
store := newTestDeviceStore(initialDevice)
|
||||
store.setStatus("1-1", usbipStatusAvailable)
|
||||
store.setStatus("1-2", usbipStatusAvailable)
|
||||
|
||||
serverOps := newTestUSBIPOps(t)
|
||||
serverOps.readUsbipStatus = store.readUsbipStatus
|
||||
serverOps.readSysfsDevice = store.readSysfsDevice
|
||||
|
||||
server := &ServerService{
|
||||
ctx: serverCtx,
|
||||
cancel: serverCancel,
|
||||
logger: newTestLogger(),
|
||||
exports: map[string]serverExport{"1-1": {busid: "1-1"}},
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
ops: serverOps,
|
||||
}
|
||||
serverAddr, closeServer := startDispatchServer(t, server)
|
||||
defer closeServer()
|
||||
|
||||
clientCtx, clientCancel := context.WithCancel(context.Background())
|
||||
defer clientCancel()
|
||||
|
||||
match := option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}
|
||||
client := &ClientService{
|
||||
ctx: clientCtx,
|
||||
cancel: clientCancel,
|
||||
logger: newTestLogger(),
|
||||
dialer: testDialer{},
|
||||
serverAddr: serverAddr,
|
||||
matches: []option.USBIPDeviceMatch{match},
|
||||
targets: []clientTarget{{match: match}},
|
||||
assigned: make([]string, 1),
|
||||
ops: newTestUSBIPOps(t),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- client.runControlSession()
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
client.stateMu.Lock()
|
||||
defer client.stateMu.Unlock()
|
||||
return client.assigned[0] == "1-1"
|
||||
}, 3*time.Second, 10*time.Millisecond)
|
||||
|
||||
store.setDevices(updatedDevice)
|
||||
server.deleteExport("1-1")
|
||||
server.setExport(serverExport{busid: "1-2"})
|
||||
server.broadcastChanged()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
client.stateMu.Lock()
|
||||
defer client.stateMu.Unlock()
|
||||
return client.assigned[0] == "1-2"
|
||||
}, 3*time.Second, 10*time.Millisecond)
|
||||
|
||||
clientCancel()
|
||||
select {
|
||||
case <-errCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("runControlSession did not exit after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUSBIPLinuxSmoke(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("usbip smoke test requires root")
|
||||
}
|
||||
require.NoError(t, ensureHostDriver())
|
||||
require.NoError(t, ensureVHCI())
|
||||
|
||||
busid := os.Getenv("USBIP_TEST_BUSID")
|
||||
if busid == "" {
|
||||
t.Skip("USBIP_TEST_BUSID not set")
|
||||
}
|
||||
|
||||
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, busid, device.BusID)
|
||||
|
||||
_, err = currentDriver(busid)
|
||||
require.NoError(t, err)
|
||||
_, err = readUsbipStatus(busid)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
52
service/usbip/ops_linux.go
Normal file
52
service/usbip/ops_linux.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
//go:build linux
|
||||
|
||||
package usbip
|
||||
|
||||
type usbEventListener interface {
|
||||
Close() error
|
||||
WaitUSBEvent() error
|
||||
}
|
||||
|
||||
type usbipOps struct {
|
||||
ensureHostDriver func() error
|
||||
ensureVHCI func() error
|
||||
|
||||
listUSBDevices func() ([]sysfsDevice, error)
|
||||
readSysfsDevice func(busid string, path string) (sysfsDevice, error)
|
||||
currentDriver func(busid string) (string, error)
|
||||
unbindFromDriver func(busid string, driver string) error
|
||||
bindToDriver func(busid string, driver string) error
|
||||
hostMatchBusID func(busid string, add bool) error
|
||||
hostBind func(busid string) error
|
||||
hostUnbind func(busid string) error
|
||||
readUsbipStatus func(busid string) (int, error)
|
||||
writeUsbipSockfd func(busid string, fd int) error
|
||||
newUEventListener func() (usbEventListener, error)
|
||||
|
||||
vhciPickFreePort func(speed uint32) (int, error)
|
||||
vhciAttach func(port int, fd uintptr, devid uint32, speed uint32) error
|
||||
vhciDetach func(port int) error
|
||||
vhciPortUsed func(port int) (bool, error)
|
||||
}
|
||||
|
||||
var systemUSBIPOps = usbipOps{
|
||||
ensureHostDriver: ensureHostDriver,
|
||||
ensureVHCI: ensureVHCI,
|
||||
listUSBDevices: listUSBDevices,
|
||||
readSysfsDevice: readSysfsDevice,
|
||||
currentDriver: currentDriver,
|
||||
unbindFromDriver: unbindFromDriver,
|
||||
bindToDriver: bindToDriver,
|
||||
hostMatchBusID: hostMatchBusID,
|
||||
hostBind: hostBind,
|
||||
hostUnbind: hostUnbind,
|
||||
readUsbipStatus: readUsbipStatus,
|
||||
writeUsbipSockfd: writeUsbipSockfd,
|
||||
newUEventListener: func() (usbEventListener, error) {
|
||||
return newUEventListener()
|
||||
},
|
||||
vhciPickFreePort: vhciPickFreePort,
|
||||
vhciAttach: vhciAttach,
|
||||
vhciDetach: vhciDetach,
|
||||
vhciPortUsed: vhciPortUsed,
|
||||
}
|
||||
273
service/usbip/protocol_test.go
Normal file
273
service/usbip/protocol_test.go
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
package usbip
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestControlPrefaceAndFrames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var preface bytes.Buffer
|
||||
require.NoError(t, WriteControlPreface(&preface))
|
||||
require.True(t, IsControlPreface(preface.Bytes()))
|
||||
|
||||
corruptedPreface := append([]byte(nil), preface.Bytes()...)
|
||||
corruptedPreface[len(corruptedPreface)-1] = '0'
|
||||
require.False(t, IsControlPreface(corruptedPreface))
|
||||
require.False(t, IsControlPreface(preface.Bytes()[:len(preface.Bytes())-1]))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
write func(io.Writer) error
|
||||
expected controlFrame
|
||||
}{
|
||||
{
|
||||
name: "hello",
|
||||
write: WriteControlHello,
|
||||
expected: controlFrame{
|
||||
Type: controlFrameHello,
|
||||
Version: controlProtocolVersion,
|
||||
Capabilities: controlCapabilities,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ack",
|
||||
write: func(writer io.Writer) error { return WriteControlAck(writer, 7) },
|
||||
expected: controlFrame{
|
||||
Type: controlFrameAck,
|
||||
Version: controlProtocolVersion,
|
||||
Capabilities: controlCapabilities,
|
||||
Sequence: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "changed",
|
||||
write: func(writer io.Writer) error { return WriteControlChanged(writer, 9) },
|
||||
expected: controlFrame{
|
||||
Type: controlFrameChanged,
|
||||
Version: controlProtocolVersion,
|
||||
Sequence: 9,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ping",
|
||||
write: WriteControlPing,
|
||||
expected: controlFrame{
|
||||
Type: controlFramePing,
|
||||
Version: controlProtocolVersion,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pong",
|
||||
write: WriteControlPong,
|
||||
expected: controlFrame{
|
||||
Type: controlFramePong,
|
||||
Version: controlProtocolVersion,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, testCase.write(&buffer))
|
||||
|
||||
frame, err := ReadControlFrame(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testCase.expected, frame)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpHeaderRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, WriteOpHeader(&buffer, OpReqDevList, OpStatusError))
|
||||
|
||||
header, err := ReadOpHeader(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, OpHeader{
|
||||
Version: ProtocolVersion,
|
||||
Code: OpReqDevList,
|
||||
Status: OpStatusError,
|
||||
}, header)
|
||||
|
||||
var raw [8]byte
|
||||
binary.BigEndian.PutUint16(raw[:2], ProtocolVersion)
|
||||
binary.BigEndian.PutUint16(raw[2:4], OpRepImport)
|
||||
binary.BigEndian.PutUint32(raw[4:8], OpStatusOK)
|
||||
require.Equal(t, OpHeader{
|
||||
Version: ProtocolVersion,
|
||||
Code: OpRepImport,
|
||||
Status: OpStatusOK,
|
||||
}, ParseOpHeader(raw[:]))
|
||||
}
|
||||
|
||||
func TestOpReqImportRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, WriteOpReqImport(&buffer, "1-2"))
|
||||
|
||||
header, err := ReadOpHeader(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, OpReqImport, header.Code)
|
||||
require.Equal(t, OpStatusOK, header.Status)
|
||||
|
||||
busid, err := ReadOpReqImportBody(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1-2", busid)
|
||||
}
|
||||
|
||||
func TestWriteOpReqImportRejectsLongBusID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := WriteOpReqImport(&buffer, strings.Repeat("a", 32))
|
||||
require.ErrorContains(t, err, "busid too long")
|
||||
}
|
||||
|
||||
func TestWriteOpRepImportRequiresInfoOnSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := WriteOpRepImport(&buffer, OpStatusOK, nil)
|
||||
require.ErrorContains(t, err, "success without device info")
|
||||
}
|
||||
|
||||
func TestOpRepImportRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var info DeviceInfoTruncated
|
||||
copy(info.BusID[:], "1-2")
|
||||
info.BusNum = 1
|
||||
info.DevNum = 2
|
||||
info.Speed = SpeedHigh
|
||||
info.IDVendor = 0x1d6b
|
||||
info.IDProduct = 0x0002
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, WriteOpRepImport(&buffer, OpStatusOK, &info))
|
||||
|
||||
header, err := ReadOpHeader(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, OpRepImport, header.Code)
|
||||
require.Equal(t, OpStatusOK, header.Status)
|
||||
|
||||
body, err := ReadOpRepImportBody(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info, body)
|
||||
}
|
||||
|
||||
func TestOpRepDevListRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var path [256]byte
|
||||
encodePathField(&path, "/sys/bus/usb/devices/1-2", "serial-1")
|
||||
|
||||
entries := []DeviceEntry{{
|
||||
Info: DeviceInfoTruncated{
|
||||
Path: path,
|
||||
BusNum: 1,
|
||||
DevNum: 2,
|
||||
Speed: SpeedHigh,
|
||||
IDVendor: 0x1d6b,
|
||||
IDProduct: 0x0002,
|
||||
BNumInterfaces: 2,
|
||||
},
|
||||
Interfaces: []DeviceInterface{
|
||||
{BInterfaceClass: 0xff, BInterfaceSubClass: 1, BInterfaceProtocol: 2},
|
||||
{BInterfaceClass: 0x03, BInterfaceSubClass: 1, BInterfaceProtocol: 1},
|
||||
},
|
||||
}}
|
||||
copy(entries[0].Info.BusID[:], "1-2")
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, WriteOpRepDevList(&buffer, entries))
|
||||
|
||||
header, err := ReadOpHeader(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, OpRepDevList, header.Code)
|
||||
require.Equal(t, OpStatusOK, header.Status)
|
||||
|
||||
parsed, err := ReadOpRepDevListBody(&buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, entries, parsed)
|
||||
}
|
||||
|
||||
func TestReadOpRepDevListBodyRejectsTooManyEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
require.NoError(t, binary.Write(&buffer, binary.BigEndian, uint32(maxOpRepDevListEntries+1)))
|
||||
|
||||
_, err := ReadOpRepDevListBody(&buffer)
|
||||
require.ErrorContains(t, err, "device count too large")
|
||||
}
|
||||
|
||||
func TestDeviceInfoHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var info DeviceInfoTruncated
|
||||
encodePathField(&info.Path, "/sys/bus/usb/devices/1-2", "serial-1")
|
||||
copy(info.BusID[:], "1-2")
|
||||
info.BusNum = 3
|
||||
info.DevNum = 9
|
||||
|
||||
require.Equal(t, "/sys/bus/usb/devices/1-2", info.PathString())
|
||||
require.Equal(t, "serial-1", info.SerialString())
|
||||
require.Equal(t, "1-2", info.BusIDString())
|
||||
require.Equal(t, uint32(0x00030009), info.DevID())
|
||||
}
|
||||
|
||||
func TestEncodePathFieldSkipsSerialWithoutRoom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var info DeviceInfoTruncated
|
||||
encodePathField(&info.Path, strings.Repeat("a", len(info.Path)-1), "serial-1")
|
||||
|
||||
require.Empty(t, info.SerialString())
|
||||
}
|
||||
|
||||
func TestMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
device := DeviceKey{
|
||||
BusID: "1-2",
|
||||
VendorID: 0x1d6b,
|
||||
ProductID: 0x0002,
|
||||
Serial: "serial-1",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
match option.USBIPDeviceMatch
|
||||
expected bool
|
||||
}{
|
||||
{name: "zero", match: option.USBIPDeviceMatch{}, expected: false},
|
||||
{name: "busid", match: option.USBIPDeviceMatch{BusID: "1-2"}, expected: true},
|
||||
{name: "vendor-and-product", match: option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}, expected: true},
|
||||
{name: "serial", match: option.USBIPDeviceMatch{Serial: "serial-1"}, expected: true},
|
||||
{name: "all-fields", match: option.USBIPDeviceMatch{BusID: "1-2", VendorID: 0x1d6b, ProductID: 0x0002, Serial: "serial-1"}, expected: true},
|
||||
{name: "vendor-mismatch", match: option.USBIPDeviceMatch{VendorID: 0x1d6c}, expected: false},
|
||||
{name: "serial-mismatch", match: option.USBIPDeviceMatch{Serial: "other"}, expected: false},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, testCase.expected, Matches(testCase.match, device))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -41,6 +41,7 @@ type ServerService struct {
|
|||
logger log.ContextLogger
|
||||
listener *listener.Listener
|
||||
matches []option.USBIPDeviceMatch
|
||||
ops usbipOps
|
||||
|
||||
mu sync.Mutex
|
||||
exports map[string]serverExport
|
||||
|
|
@ -76,6 +77,7 @@ func NewServerService(ctx context.Context, logger log.ContextLogger, tag string,
|
|||
Listen: options.ListenOptions,
|
||||
}),
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
ops: systemUSBIPOps,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
|
@ -84,7 +86,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
|
|||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if err := ensureHostDriver(); err != nil {
|
||||
if err := s.ops.ensureHostDriver(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.reconcileExports(); err != nil {
|
||||
|
|
@ -115,7 +117,7 @@ func (s *ServerService) Close() error {
|
|||
}
|
||||
|
||||
func (s *ServerService) reconcileExports() (bool, error) {
|
||||
devices, err := listUSBDevices()
|
||||
devices, err := s.ops.listUSBDevices()
|
||||
if err != nil {
|
||||
return false, E.Cause(err, "enumerate usb devices")
|
||||
}
|
||||
|
|
@ -163,7 +165,7 @@ func (s *ServerService) reconcileExports() (bool, error) {
|
|||
}
|
||||
|
||||
func (s *ServerService) bindOne(d *sysfsDevice) error {
|
||||
driver, err := currentDriver(d.BusID)
|
||||
driver, err := s.ops.currentDriver(d.BusID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -173,20 +175,20 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
|
|||
return nil
|
||||
}
|
||||
if driver != "" {
|
||||
if err := unbindFromDriver(d.BusID, driver); err != nil {
|
||||
if err := s.ops.unbindFromDriver(d.BusID, driver); err != nil {
|
||||
return E.Cause(err, "unbind from ", driver)
|
||||
}
|
||||
}
|
||||
if err := hostMatchBusID(d.BusID, true); err != nil {
|
||||
if err := s.ops.hostMatchBusID(d.BusID, true); err != nil {
|
||||
if driver != "" {
|
||||
_ = bindToDriver(d.BusID, driver)
|
||||
_ = s.ops.bindToDriver(d.BusID, driver)
|
||||
}
|
||||
return E.Cause(err, "match_busid add")
|
||||
}
|
||||
if err := hostBind(d.BusID); err != nil {
|
||||
_ = hostMatchBusID(d.BusID, false)
|
||||
if err := s.ops.hostBind(d.BusID); err != nil {
|
||||
_ = s.ops.hostMatchBusID(d.BusID, false)
|
||||
if driver != "" {
|
||||
_ = bindToDriver(d.BusID, driver)
|
||||
_ = s.ops.bindToDriver(d.BusID, driver)
|
||||
}
|
||||
return E.Cause(err, "bind to usbip-host")
|
||||
}
|
||||
|
|
@ -203,17 +205,17 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
|
|||
s.deleteExport(export.busid)
|
||||
|
||||
var releaseErr error
|
||||
if err := writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
|
||||
if err := s.ops.writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
|
||||
releaseErr = err
|
||||
}
|
||||
if !export.managed {
|
||||
s.logger.Info("stopped tracking ", export.busid, " on usbip-host")
|
||||
return releaseErr
|
||||
}
|
||||
if err := hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
|
||||
if err := s.ops.hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
if err := hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
|
||||
if err := s.ops.hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
if !restore {
|
||||
|
|
@ -224,7 +226,7 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
|
|||
s.logger.Info("released ", export.busid, " from usbip-host")
|
||||
return releaseErr
|
||||
}
|
||||
if err := bindToDriver(export.busid, export.originalDriver); err != nil {
|
||||
if err := s.ops.bindToDriver(export.busid, export.originalDriver); err != nil {
|
||||
if releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
|
|
@ -237,7 +239,8 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
|
|||
func (s *ServerService) rollbackExports() {
|
||||
exports := s.snapshotExports()
|
||||
for _, export := range exports {
|
||||
_, restore := currentSysfsDevice(export.busid)
|
||||
_, err := s.ops.readSysfsDevice(export.busid, sysBusDevicePath(export.busid))
|
||||
restore := err == nil
|
||||
if err := s.releaseExport(export, restore); err != nil {
|
||||
s.logger.Warn("rollback ", export.busid, ": ", err)
|
||||
}
|
||||
|
|
@ -409,7 +412,7 @@ func (s *ServerService) buildDevListEntries() []DeviceEntry {
|
|||
}
|
||||
entries := make([]DeviceEntry, 0, len(busids))
|
||||
for _, busid := range busids {
|
||||
status, err := readUsbipStatus(busid)
|
||||
status, err := s.ops.readUsbipStatus(busid)
|
||||
if err != nil {
|
||||
s.logger.Debug("status ", busid, ": ", err)
|
||||
continue
|
||||
|
|
@ -417,7 +420,7 @@ func (s *ServerService) buildDevListEntries() []DeviceEntry {
|
|||
if status != usbipStatusAvailable {
|
||||
continue
|
||||
}
|
||||
d, err := readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
d, err := s.ops.readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
if err != nil {
|
||||
s.logger.Debug("refresh ", busid, ": ", err)
|
||||
continue
|
||||
|
|
@ -441,13 +444,13 @@ func (s *ServerService) handleImport(conn net.Conn) {
|
|||
_ = WriteOpRepImport(conn, OpStatusError, nil)
|
||||
return
|
||||
}
|
||||
status, err := readUsbipStatus(busid)
|
||||
status, err := s.ops.readUsbipStatus(busid)
|
||||
if err != nil || status != usbipStatusAvailable {
|
||||
s.logger.Info("import rejected (busid ", busid, " status=", status, " err=", err, ")")
|
||||
_ = WriteOpRepImport(conn, OpStatusError, nil)
|
||||
return
|
||||
}
|
||||
dev, err := readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
dev, err := s.ops.readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
if err != nil {
|
||||
s.logger.Warn("refresh ", busid, ": ", err)
|
||||
_ = WriteOpRepImport(conn, OpStatusError, nil)
|
||||
|
|
@ -466,7 +469,7 @@ func (s *ServerService) handleImport(conn net.Conn) {
|
|||
return
|
||||
}
|
||||
defer file.Close()
|
||||
if err := writeUsbipSockfd(busid, int(file.Fd())); err != nil {
|
||||
if err := s.ops.writeUsbipSockfd(busid, int(file.Fd())); err != nil {
|
||||
s.logger.Warn("hand off ", busid, " to kernel: ", err)
|
||||
_ = WriteOpRepImport(conn, OpStatusError, nil)
|
||||
return
|
||||
|
|
@ -474,7 +477,7 @@ func (s *ServerService) handleImport(conn net.Conn) {
|
|||
info := dev.toProtocol()
|
||||
if err := WriteOpRepImport(conn, OpStatusOK, &info); err != nil {
|
||||
s.logger.Warn("reply import ", busid, ": ", err)
|
||||
_ = writeUsbipSockfd(busid, -1)
|
||||
_ = s.ops.writeUsbipSockfd(busid, -1)
|
||||
return
|
||||
}
|
||||
s.logger.Info("attached ", busid, " to remote ", conn.RemoteAddr())
|
||||
|
|
@ -489,7 +492,7 @@ func (s *ServerService) isExported(busid string) bool {
|
|||
|
||||
func (s *ServerService) ueventLoop() {
|
||||
for {
|
||||
listener, err := newUEventListener()
|
||||
listener, err := s.ops.newUEventListener()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
|
|
@ -595,14 +598,6 @@ func (s *ServerService) enqueueControlFrame(sub *serverControlConn, frame contro
|
|||
}
|
||||
}
|
||||
|
||||
func currentSysfsDevice(busid string) (sysfsDevice, bool) {
|
||||
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
if err != nil {
|
||||
return sysfsDevice{}, false
|
||||
}
|
||||
return device, true
|
||||
}
|
||||
|
||||
func sysBusDevicePath(busid string) string {
|
||||
return sysBusUSBDevices + "/" + busid
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@ import (
|
|||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -80,18 +82,12 @@ type vhciStatusRecord struct {
|
|||
|
||||
// ensureHostDriver verifies the usbip-host kernel driver is loaded.
|
||||
func ensureHostDriver() error {
|
||||
if _, err := os.Stat(sysUsbipHostDriver); err != nil {
|
||||
return E.Cause(err, "usbip-host driver not present; modprobe usbip-host")
|
||||
}
|
||||
return nil
|
||||
return ensureKernelPath(sysUsbipHostDriver, "usbip-host", "usbip-host driver")
|
||||
}
|
||||
|
||||
// ensureVHCI verifies the vhci_hcd controller is loaded.
|
||||
func ensureVHCI() error {
|
||||
if _, err := os.Stat(sysVHCIControllerV0); err != nil {
|
||||
return E.Cause(err, "vhci_hcd.0 not present; modprobe vhci-hcd")
|
||||
}
|
||||
return nil
|
||||
return ensureKernelPath(sysVHCIControllerV0, "vhci-hcd", "vhci_hcd.0")
|
||||
}
|
||||
|
||||
// listUSBDevices enumerates /sys/bus/usb/devices, returning non-interface
|
||||
|
|
@ -324,6 +320,41 @@ func vhciHubForSpeed(speed uint32) string {
|
|||
}
|
||||
}
|
||||
|
||||
func ensureKernelPath(path string, module string, description string) error {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if os.Getuid() != 0 {
|
||||
return E.Cause(err, description, " not present; root is required to load kernel module ", module)
|
||||
}
|
||||
modprobePath, modprobeErr := findModprobePath()
|
||||
if modprobeErr != nil {
|
||||
return E.Cause(modprobeErr, "load kernel module ", module, " for ", description)
|
||||
}
|
||||
output, modprobeErr := shell.Exec(modprobePath, module).Read()
|
||||
if modprobeErr != nil {
|
||||
return E.Extend(E.Cause(modprobeErr, "load kernel module ", module, " for ", description), strings.TrimSpace(output))
|
||||
}
|
||||
if _, err = os.Stat(path); err != nil {
|
||||
return E.Cause(err, description, " still not present after loading kernel module ", module)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findModprobePath() (string, error) {
|
||||
if path, err := exec.LookPath("modprobe"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
for _, path := range []string{"/usr/sbin/modprobe", "/sbin/modprobe", "/usr/bin/modprobe", "/bin/modprobe"} {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Mode().IsRegular() && info.Mode()&0o111 != 0 {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
return "", E.New("modprobe executable not found")
|
||||
}
|
||||
|
||||
// --- small helpers ------------------------------------------------------
|
||||
|
||||
func writeSysfs(path, content string) error {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue