diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 07f3ff1c1..6047aacec 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -20,7 +20,7 @@ import ( type Array struct { ctx C.mlx_array name string - pinned bool + pinned int } var arrays []*Array @@ -129,7 +129,7 @@ func (t *Array) Clone() *Array { func Pin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned = true + t.pinned++ } } } @@ -138,7 +138,7 @@ func Pin(s ...*Array) { func Unpin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned = false + t.pinned-- } } } @@ -148,7 +148,7 @@ func Unpin(s ...*Array) { func Sweep() { n := 0 for _, t := range arrays { - if t.pinned && t.Valid() { + if t.pinned > 0 && t.Valid() { arrays[n] = t n++ } else if t.Valid() { @@ -175,7 +175,7 @@ func (t *Array) String() string { func (t *Array) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("name", t.name), - slog.Bool("pinned", t.pinned), + slog.Int("pinned", t.pinned), } if t.Valid() { attrs = append(attrs,