diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index c4c80dae4..425ea8fc2 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -10,6 +10,8 @@ import ( "reflect" "sort" "strings" + "sync" + "sync/atomic" "unsafe" "github.com/ollama/ollama/logutil" @@ -18,20 +20,28 @@ import ( type Array struct { ctx C.mlx_array name string - pinned int + pinned atomic.Int32 } -var arrays []*Array +var ( + arrays []*Array + arraysMu sync.Mutex +) // constructor utilities func New(name string) *Array { t := &Array{name: name} + if tracing { traceScratch = append(traceScratch, t) } else { + arraysMu.Lock() + defer arraysMu.Unlock() + arrays = append(arrays, t) } + return t } @@ -131,7 +141,7 @@ func (t *Array) Clone() *Array { func Pin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned++ + t.pinned.Add(1) } } } @@ -140,8 +150,7 @@ func Pin(s ...*Array) { func Unpin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned-- - if t.pinned < 0 { + if t.pinned.Add(-1) < 0 { panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) } } @@ -151,9 +160,11 @@ func Unpin(s ...*Array) { // Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly // free them when there are no other references, including dependencies in the graph. func Sweep() { + arraysMu.Lock() + defer arraysMu.Unlock() n := 0 for _, t := range arrays { - if t.pinned > 0 && t.Valid() { + if t.pinned.Load() > 0 && t.Valid() { arrays[n] = t n++ } else if t.Valid() { @@ -180,7 +191,7 @@ func (t *Array) String() string { func (t *Array) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("name", t.name), - slog.Int("pinned", t.pinned), + slog.Int("pinned", int(t.pinned.Load())), } if t.Valid() { attrs = append(attrs, @@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value { // shape utilities -func (t Array) Size() int { +func (t *Array) Size() int { return int(C.mlx_array_size(t.ctx)) } -func (t Array) NumBytes() int { +func (t *Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } -func (t Array) NumDims() int { +func (t *Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } -func (t Array) Dims() []int { +func (t *Array) Dims() []int { dims := make([]int, t.NumDims()) for i := range dims { dims[i] = t.Dim(i) @@ -215,29 +226,29 @@ func (t Array) Dims() []int { return dims } -func (t Array) Dim(dim int) int { +func (t *Array) Dim(dim int) int { return int(C.mlx_array_dim(t.ctx, C.int(dim))) } -func (t Array) DType() DType { +func (t *Array) DType() DType { return DType(C.mlx_array_dtype(t.ctx)) } // data utilities -func (t Array) Int() int { +func (t *Array) Int() int { var item C.int64_t C.mlx_array_item_int64(&item, t.ctx) return int(item) } -func (t Array) Float() float64 { +func (t *Array) Float() float64 { var item C.double C.mlx_array_item_float64(&item, t.ctx) return float64(item) } -func (t Array) Ints() []int { +func (t *Array) Ints() []int { if dt := t.DType(); dt != DTypeInt32 { panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt)) } @@ -248,7 +259,7 @@ func (t Array) Ints() []int { return ints } -func (t Array) Floats() []float32 { +func (t *Array) Floats() []float32 { if dt := t.DType(); dt != DTypeFloat32 { panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt)) } @@ -259,7 +270,7 @@ func (t Array) Floats() []float32 { return floats } -func (t Array) Save(name string) error { +func (t *Array) Save(name string) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) C.mlx_save(cName, t.ctx) @@ -268,6 +279,8 @@ func (t Array) Save(name string) error { // LogArrays logs all live arrays, sorted by size func LogArrays() { + arraysMu.Lock() + defer arraysMu.Unlock() sort.Slice(arrays, func(i, j int) bool { return arrays[i].NumBytes() > arrays[j].NumBytes() }) @@ -276,7 +289,7 @@ func LogArrays() { for _, t := range arrays { nb := t.NumBytes() total += nb - logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims())) } logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/compile.go b/x/mlxrunner/mlx/compile.go index 987bb7220..91e564ba0 100644 --- a/x/mlxrunner/mlx/compile.go +++ b/x/mlxrunner/mlx/compile.go @@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload traceScratch = nil defer func() { for _, a := range traceScratch { - if a.pinned > 0 { + if a.pinned.Load() > 0 { panic("mlx: traced array was pinned during compilation") } if a.Valid() { diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 7feca3b1e..d5b218d1c 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A } type LayerNorm struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } func (r *LayerNorm) Forward(x *Array, eps float32) *Array { @@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array { } type RMSNorm struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } -func (r RMSNorm) Forward(x *Array, eps float32) *Array { +func (r *RMSNorm) Forward(x *Array, eps float32) *Array { out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) return out diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index d3a99a6cd..d2e7fb4f1 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,12 +1,12 @@ package mlx type Linear struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } // Forward computes the linear transformation: x @ Weight.T + Bias -func (m Linear) Forward(x *Array) *Array { +func (m *Linear) Forward(x *Array) *Array { w := m.Weight.Transpose(1, 0) if m.Bias.Valid() { return m.Bias.Addmm(x, w, 1.0, 1.0) @@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array { return x.Matmul(w) } -func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { +func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { w := m.Weight.Transpose(0, 2, 1) // TODO: bias return x.GatherMM(w, lhs, rhs, sorted) } type Embedding struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } func (e *Embedding) Forward(indices *Array) *Array {