mlx: improve thread safety of array management

Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be created and pinned from multiple
goroutines without racing on those structures. Convert Array value
receivers to pointer receivers and struct fields from Array to
*Array to avoid copying the atomic.

This does not fully achieve thread safety even when building
completely independent graphs. The tracing flag and traceScratch
slice in compile.go are unprotected, so concurrent Compile calls
will race. MLX itself is not fully thread-safe either although
it is working to improve.
This commit is contained in:
Jesse Gross 2026-04-03 16:25:28 -07:00
parent fb36a01ffe
commit 04f5f0cdb4
4 changed files with 42 additions and 29 deletions

View file

@ -10,6 +10,8 @@ import (
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"unsafe"
"github.com/ollama/ollama/logutil"
@ -18,20 +20,28 @@ import (
type Array struct {
ctx C.mlx_array
name string
pinned int
pinned atomic.Int32
}
var arrays []*Array
var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
if tracing {
traceScratch = append(traceScratch, t)
} else {
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t)
}
return t
}
@ -131,7 +141,7 @@ func (t *Array) Clone() *Array {
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned++
t.pinned.Add(1)
}
}
}
@ -140,8 +150,7 @@ func Pin(s ...*Array) {
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned--
if t.pinned < 0 {
if t.pinned.Add(-1) < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
}
@ -151,9 +160,11 @@ func Unpin(s ...*Array) {
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0
for _, t := range arrays {
if t.pinned > 0 && t.Valid() {
if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
@ -180,7 +191,7 @@ func (t *Array) String() string {
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Int("pinned", t.pinned),
slog.Int("pinned", int(t.pinned.Load())),
}
if t.Valid() {
attrs = append(attrs,
@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value {
// shape utilities
func (t Array) Size() int {
func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
func (t *Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
@ -215,29 +226,29 @@ func (t Array) Dims() []int {
return dims
}
func (t Array) Dim(dim int) int {
func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
func (t *Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
func (t *Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
func (t *Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
@ -248,7 +259,7 @@ func (t Array) Ints() []int {
return ints
}
func (t Array) Floats() []float32 {
func (t *Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
@ -259,7 +270,7 @@ func (t Array) Floats() []float32 {
return floats
}
func (t Array) Save(name string) error {
func (t *Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
@ -268,6 +279,8 @@ func (t Array) Save(name string) error {
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
@ -276,7 +289,7 @@ func LogArrays() {
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}

View file

@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
if a.pinned.Load() > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {

View file

@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
}
type RMSNorm struct {
Weight Array `weight:"weight"`
Weight *Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out

View file

@ -1,12 +1,12 @@
package mlx
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
func (m *Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array {
return x.Matmul(w)
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
}
type Embedding struct {
Weight Array `weight:"weight"`
Weight *Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {