From 04f5f0cdb4b68c51ea08defe260351bb3915ccc8 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 3 Apr 2026 16:25:28 -0700 Subject: [PATCH] mlx: improve thread safety of array management Use atomic.Int32 for Array.pinned and a sync.Mutex for the global arrays slice so MLX arrays can be created and pinned from multiple goroutines without racing on those structures. Convert Array value receivers to pointer receivers and struct fields from Array to *Array to avoid copying the atomic. This does not fully achieve thread safety even when building completely independent graphs. The tracing flag and traceScratch slice in compile.go are unprotected, so concurrent Compile calls will race. MLX itself is not fully thread-safe either although it is working to improve. --- x/mlxrunner/mlx/array.go | 51 ++++++++++++++++++++++++-------------- x/mlxrunner/mlx/compile.go | 2 +- x/mlxrunner/mlx/fast.go | 8 +++--- x/mlxrunner/mlx/nn.go | 10 ++++---- 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index c4c80dae4..425ea8fc2 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -10,6 +10,8 @@ import ( "reflect" "sort" "strings" + "sync" + "sync/atomic" "unsafe" "github.com/ollama/ollama/logutil" @@ -18,20 +20,28 @@ import ( type Array struct { ctx C.mlx_array name string - pinned int + pinned atomic.Int32 } -var arrays []*Array +var ( + arrays []*Array + arraysMu sync.Mutex +) // constructor utilities func New(name string) *Array { t := &Array{name: name} + if tracing { traceScratch = append(traceScratch, t) } else { + arraysMu.Lock() + defer arraysMu.Unlock() + arrays = append(arrays, t) } + return t } @@ -131,7 +141,7 @@ func (t *Array) Clone() *Array { func Pin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned++ + t.pinned.Add(1) } } } @@ -140,8 +150,7 @@ func Pin(s ...*Array) { func Unpin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned-- - if t.pinned < 0 { + if t.pinned.Add(-1) < 0 { panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) } } @@ -151,9 +160,11 @@ func Unpin(s ...*Array) { // Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly // free them when there are no other references, including dependencies in the graph. func Sweep() { + arraysMu.Lock() + defer arraysMu.Unlock() n := 0 for _, t := range arrays { - if t.pinned > 0 && t.Valid() { + if t.pinned.Load() > 0 && t.Valid() { arrays[n] = t n++ } else if t.Valid() { @@ -180,7 +191,7 @@ func (t *Array) String() string { func (t *Array) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("name", t.name), - slog.Int("pinned", t.pinned), + slog.Int("pinned", int(t.pinned.Load())), } if t.Valid() { attrs = append(attrs, @@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value { // shape utilities -func (t Array) Size() int { +func (t *Array) Size() int { return int(C.mlx_array_size(t.ctx)) } -func (t Array) NumBytes() int { +func (t *Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } -func (t Array) NumDims() int { +func (t *Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } -func (t Array) Dims() []int { +func (t *Array) Dims() []int { dims := make([]int, t.NumDims()) for i := range dims { dims[i] = t.Dim(i) @@ -215,29 +226,29 @@ func (t Array) Dims() []int { return dims } -func (t Array) Dim(dim int) int { +func (t *Array) Dim(dim int) int { return int(C.mlx_array_dim(t.ctx, C.int(dim))) } -func (t Array) DType() DType { +func (t *Array) DType() DType { return DType(C.mlx_array_dtype(t.ctx)) } // data utilities -func (t Array) Int() int { +func (t *Array) Int() int { var item C.int64_t C.mlx_array_item_int64(&item, t.ctx) return int(item) } -func (t Array) Float() float64 { +func (t *Array) Float() float64 { var item C.double C.mlx_array_item_float64(&item, t.ctx) return float64(item) } -func (t Array) Ints() []int { +func (t *Array) Ints() []int { if dt := t.DType(); dt != DTypeInt32 { panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt)) } @@ -248,7 +259,7 @@ func (t Array) Ints() []int { return ints } -func (t Array) Floats() []float32 { +func (t *Array) Floats() []float32 { if dt := t.DType(); dt != DTypeFloat32 { panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt)) } @@ -259,7 +270,7 @@ func (t Array) Floats() []float32 { return floats } -func (t Array) Save(name string) error { +func (t *Array) Save(name string) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) C.mlx_save(cName, t.ctx) @@ -268,6 +279,8 @@ func (t Array) Save(name string) error { // LogArrays logs all live arrays, sorted by size func LogArrays() { + arraysMu.Lock() + defer arraysMu.Unlock() sort.Slice(arrays, func(i, j int) bool { return arrays[i].NumBytes() > arrays[j].NumBytes() }) @@ -276,7 +289,7 @@ func LogArrays() { for _, t := range arrays { nb := t.NumBytes() total += nb - logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims())) } logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/compile.go b/x/mlxrunner/mlx/compile.go index 987bb7220..91e564ba0 100644 --- a/x/mlxrunner/mlx/compile.go +++ b/x/mlxrunner/mlx/compile.go @@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload traceScratch = nil defer func() { for _, a := range traceScratch { - if a.pinned > 0 { + if a.pinned.Load() > 0 { panic("mlx: traced array was pinned during compilation") } if a.Valid() { diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 7feca3b1e..d5b218d1c 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A } type LayerNorm struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } func (r *LayerNorm) Forward(x *Array, eps float32) *Array { @@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array { } type RMSNorm struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } -func (r RMSNorm) Forward(x *Array, eps float32) *Array { +func (r *RMSNorm) Forward(x *Array, eps float32) *Array { out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) return out diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index d3a99a6cd..d2e7fb4f1 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,12 +1,12 @@ package mlx type Linear struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } // Forward computes the linear transformation: x @ Weight.T + Bias -func (m Linear) Forward(x *Array) *Array { +func (m *Linear) Forward(x *Array) *Array { w := m.Weight.Transpose(1, 0) if m.Bias.Valid() { return m.Bias.Addmm(x, w, 1.0, 1.0) @@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array { return x.Matmul(w) } -func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { +func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { w := m.Weight.Transpose(0, 2, 1) // TODO: bias return x.GatherMM(w, lhs, rhs, sorted) } type Embedding struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } func (e *Embedding) Forward(indices *Array) *Array {