mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlx: Support NVIDIA TensorRT Model Optimizer import (#15566)
* mlx: Support NVIDIA TensorRT Model Optimizer import * x/create: support FP8 safetensors import Decode HF F8_E4M3 safetensors with block scale companions into MLX-importable tensor blobs, including compressed-tensors weight_scale metadata, packed NVFP4 layouts, and mixed-precision tensor headers. Use that source-precision metadata during create quantization: default FP8-sourced imports to mxfp8, allow source FP8 to target MLX low-bit formats, preserve source-quantized NVFP4 layouts, selectively keep or promote tensors based on their source precision, and detect quantized dtype from mixed-precision safetensors manifests. * review comments
This commit is contained in:
parent
ec9b4e9e47
commit
03aee88186
12 changed files with 1571 additions and 332 deletions
|
|
@ -70,9 +70,13 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
|||
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
|
||||
scaleKey := inputKey + ".scale_inv"
|
||||
scaleInv := st.Get(scaleKey)
|
||||
if scaleInv == nil {
|
||||
scaleKey = inputKey + ".scale"
|
||||
scaleInv = st.Get(scaleKey)
|
||||
}
|
||||
if scaleInv == nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
|
||||
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q or %q for fp8 source tensor %q", inputKey+".scale_inv", inputKey+".scale", inputKey)
|
||||
}
|
||||
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
|
||||
if err != nil {
|
||||
|
|
@ -560,13 +564,13 @@ func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry)
|
|||
return keys[0], nil
|
||||
}
|
||||
|
||||
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
if weight == nil || scaleInv == nil {
|
||||
func decodeSourceFP8Tensor(weight, scale *mlx.Array) (*mlx.Array, error) {
|
||||
if weight == nil || scale == nil {
|
||||
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scaleInv.Dims()
|
||||
scaleShape := scale.Dims()
|
||||
if len(weightShape) != 2 || len(scaleShape) != 2 {
|
||||
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
|
||||
}
|
||||
|
|
@ -596,7 +600,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
|||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
|
||||
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scale, 1), 3))
|
||||
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
|
||||
|
|
|
|||
24
x/create/client/quantize_test.go
Normal file
24
x/create/client/quantize_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestDecodeSourceFP8TensorAcceptsWeightScale(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX unavailable: %v", err)
|
||||
}
|
||||
|
||||
weight := mlx.FromValues([]uint8{0, 1, 2, 3}, 2, 2)
|
||||
scale := mlx.FromValues([]float32{1}, 1, 1).AsType(mlx.DTypeBFloat16)
|
||||
got, err := decodeSourceFP8Tensor(weight, scale)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mlx.Eval(got)
|
||||
if dims := got.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 2 {
|
||||
t.Fatalf("decoded dims = %v, want [2 2]", dims)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -4,7 +4,9 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
|
@ -59,6 +61,43 @@ func TestIsTensorModelDir(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateScalarFloat32TensorData(t *testing.T) {
|
||||
td := st.NewTensorDataFromBytes("linear.weight_scale_2", "F32", []int32{}, encodeFloat32s(2))
|
||||
|
||||
got, err := validateScalarFloat32TensorData(td, "linear.weight.global_scale")
|
||||
if err != nil {
|
||||
t.Fatalf("validateScalarFloat32TensorData returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Name != "linear.weight.global_scale" {
|
||||
t.Fatalf("name = %q, want %q", got.Name, "linear.weight.global_scale")
|
||||
}
|
||||
if got.Dtype != "F32" {
|
||||
t.Fatalf("dtype = %q, want F32", got.Dtype)
|
||||
}
|
||||
if len(got.Shape) != 0 {
|
||||
t.Fatalf("shape = %v, want scalar", got.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateScalarFloat32TensorDataRejectsNonScalar(t *testing.T) {
|
||||
td := st.NewTensorDataFromBytes("linear.weight_scale_2", "F32", []int32{2}, encodeFloat32s(2, 4))
|
||||
|
||||
_, err := validateScalarFloat32TensorData(td, "linear.weight.global_scale")
|
||||
if err == nil || !strings.Contains(err.Error(), "expected scalar F32 tensor") {
|
||||
t.Fatalf("validateScalarFloat32TensorData error = %v, want scalar-shape failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvertScalarFloat32TensorDataRejectsNonF32(t *testing.T) {
|
||||
td := st.NewTensorDataFromBytes("linear.weight_global_scale", "BF16", []int32{}, []byte{0, 0})
|
||||
|
||||
_, err := invertScalarFloat32TensorData(td, "linear.weight.global_scale")
|
||||
if err == nil || !strings.Contains(err.Error(), "expected F32 tensor") {
|
||||
t.Fatalf("invertScalarFloat32TensorData error = %v, want dtype failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -246,6 +285,41 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
func encodeFloat32s(vals ...float32) []byte {
|
||||
raw := make([]byte, 4*len(vals))
|
||||
for i, v := range vals {
|
||||
binary.LittleEndian.PutUint32(raw[i*4:(i+1)*4], math.Float32bits(v))
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func readPackedTensorRaw(t *testing.T, data []byte, tensorName string) []byte {
|
||||
t.Helper()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
var header map[string]struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
|
||||
info, ok := header[tensorName]
|
||||
if !ok {
|
||||
t.Fatalf("tensor %q not found in header", tensorName)
|
||||
}
|
||||
|
||||
start := 8 + int(headerSize) + info.DataOffsets[0]
|
||||
end := 8 + int(headerSize) + info.DataOffsets[1]
|
||||
return data[start:end]
|
||||
}
|
||||
|
||||
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -612,10 +686,22 @@ func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
|
|||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||
var statusMessages []string
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Fatal("no status messages received")
|
||||
}
|
||||
if got, want := statusMessages[0], "importing model.safetensors (4 tensors, converting source E4M3 block-FP8 to MLX mxfp8)"; got != want {
|
||||
t.Fatalf("status = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
|
||||
}
|
||||
|
|
@ -643,6 +729,166 @@ func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_CompressedTensorsFP8WeightScale(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"compression_config": {
|
||||
"quant_method": "compressed-tensors",
|
||||
"format": "float-quantized",
|
||||
"config_groups": {
|
||||
"group_0": {
|
||||
"format": "float-quantized",
|
||||
"weights": {
|
||||
"type": "float",
|
||||
"num_bits": 8,
|
||||
"block_structure": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
headerNamesByName := make(map[string][]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
var statusMessages []string
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
if len(statusMessages) == 0 {
|
||||
t.Fatal("no status messages received")
|
||||
}
|
||||
if got, want := statusMessages[0], "importing model.safetensors (3 tensors, converting source E4M3 block-FP8 to MLX mxfp8)"; got != want {
|
||||
t.Fatalf("status = %q, want %q", got, want)
|
||||
}
|
||||
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("linear.weight quantization = %q, want mxfp8", got)
|
||||
}
|
||||
if _, ok := quantizeByName["linear.weight_scale"]; ok {
|
||||
t.Fatal("linear.weight_scale should not be imported as a standalone tensor")
|
||||
}
|
||||
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale"}) {
|
||||
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8SourceCanConvertToNVFP4(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{128}, make([]byte, 256)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
headerNamesByName := make(map[string][]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
var statusMessages []string
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
if len(statusMessages) == 0 {
|
||||
t.Fatal("no status messages received")
|
||||
}
|
||||
if got, want := statusMessages[0], "importing model.safetensors (9 tensors, converting source E4M3 block-FP8 to MLX nvfp4)"; got != want {
|
||||
t.Fatalf("status = %q, want %q", got, want)
|
||||
}
|
||||
if got := quantizeByName["linear.weight"]; got != "nvfp4" {
|
||||
t.Fatalf("linear.weight quantization = %q, want nvfp4", got)
|
||||
}
|
||||
if got := quantizeByName["model.layers.0.mlp.experts.0.down_proj.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("source fp8 down_proj quantization = %q, want mxfp8", got)
|
||||
}
|
||||
for _, name := range []string{
|
||||
"model.layers.0.self_attn.q_proj.weight",
|
||||
"model.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
} {
|
||||
if got := quantizeByName[name]; got != "mxfp8" {
|
||||
t.Fatalf("%s quantization = %q, want mxfp8", name, got)
|
||||
}
|
||||
}
|
||||
if got := quantizeByName["model.layers.0.mlp.gate.weight"]; got != "" {
|
||||
t.Fatalf("router gate quantization = %q, want empty", got)
|
||||
}
|
||||
if got := quantizeByName["norm.weight"]; got != "" {
|
||||
t.Fatalf("norm.weight quantization = %q, want empty", got)
|
||||
}
|
||||
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
|
||||
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -670,7 +916,20 @@ func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T
|
|||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
|
||||
wantErr: `cannot convert already-quantized fp8 source model with --quantize "int4"`,
|
||||
},
|
||||
{
|
||||
name: "packed nvfp4 source",
|
||||
configJSON: `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"compression_config": {"format": "nvfp4-pack-quantized"}
|
||||
}`,
|
||||
tensors: []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -701,6 +960,317 @@ func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_PackedNVFP4PreservesSourceLayout(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"compression_config": {"format": "nvfp4-pack-quantized"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("linear.weight_global_scale", "F32", []int32{}, encodeFloat32s(4)),
|
||||
st.NewTensorDataFromBytes("linear.input_global_scale", "F32", []int32{}, encodeFloat32s(8)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{16}, make([]byte, 32)),
|
||||
})
|
||||
|
||||
var statusMessages []string
|
||||
layerHeaders := make(map[string]map[string]json.RawMessage)
|
||||
layerData := make(map[string][]byte)
|
||||
var tensorLayerNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
if mediaType == "application/vnd.ollama.image.tensor" {
|
||||
if len(data) < 8 {
|
||||
return LayerInfo{}, io.ErrUnexpectedEOF
|
||||
}
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layerHeaders[name] = header
|
||||
layerData[name] = data
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensorLayerNames = append(tensorLayerNames, name)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
progressFn := func(status string) { statusMessages = append(statusMessages, status) }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Fatal("no status messages received")
|
||||
}
|
||||
if got, want := statusMessages[0], "importing model.safetensors (5 tensors, preserving source quantization)"; got != want {
|
||||
t.Fatalf("status = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if slices.Contains(tensorLayerNames, "linear.weight_scale") || slices.Contains(tensorLayerNames, "linear.weight_global_scale") || slices.Contains(tensorLayerNames, "linear.input_global_scale") {
|
||||
t.Fatalf("packed nvfp4 companions unexpectedly emitted as standalone tensor layers: %v", tensorLayerNames)
|
||||
}
|
||||
|
||||
packedHeader := layerHeaders["linear.weight"]
|
||||
if packedHeader == nil {
|
||||
t.Fatalf("missing packed layer header for linear.weight")
|
||||
}
|
||||
for _, key := range []string{
|
||||
"linear.weight",
|
||||
"linear.weight.scale",
|
||||
"linear.weight.global_scale",
|
||||
} {
|
||||
if _, ok := packedHeader[key]; !ok {
|
||||
t.Fatalf("packed header missing %s: %v", key, packedHeader)
|
||||
}
|
||||
}
|
||||
if _, ok := packedHeader["linear.weight.input_global_scale"]; ok {
|
||||
t.Fatalf("packed header unexpectedly includes input_global_scale: %v", packedHeader)
|
||||
}
|
||||
globalRaw := readPackedTensorRaw(t, layerData["linear.weight"], "linear.weight.global_scale")
|
||||
if got := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw)); got != 0.25 {
|
||||
t.Fatalf("linear.weight.global_scale = %v, want 0.25", got)
|
||||
}
|
||||
|
||||
var metadata map[string]string
|
||||
if metaRaw, ok := packedHeader["__metadata__"]; ok {
|
||||
if err := json.Unmarshal(metaRaw, &metadata); err != nil {
|
||||
t.Fatalf("failed to parse metadata: %v", err)
|
||||
}
|
||||
}
|
||||
if metadata["quant_type"] != "nvfp4" {
|
||||
t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "nvfp4")
|
||||
}
|
||||
if metadata["group_size"] != "16" {
|
||||
t.Fatalf("group_size = %q, want %q", metadata["group_size"], "16")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_PackedNVFP4CrossShardCompanions(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"compression_config": {"format": "nvfp4-pack-quantized"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model-00001-of-00002.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{16}, make([]byte, 32)),
|
||||
})
|
||||
createTestSafetensors(t, filepath.Join(dir, "model-00002-of-00002.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("linear.weight_global_scale", "F32", []int32{}, encodeFloat32s(2)),
|
||||
st.NewTensorDataFromBytes("linear.input_global_scale", "F32", []int32{}, encodeFloat32s(8)),
|
||||
})
|
||||
indexJSON := `{
|
||||
"metadata": {"total_size": 152},
|
||||
"weight_map": {
|
||||
"linear.weight_packed": "model-00001-of-00002.safetensors",
|
||||
"norm.weight": "model-00001-of-00002.safetensors",
|
||||
"linear.weight_scale": "model-00002-of-00002.safetensors",
|
||||
"linear.weight_global_scale": "model-00002-of-00002.safetensors",
|
||||
"linear.input_global_scale": "model-00002-of-00002.safetensors"
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index: %v", err)
|
||||
}
|
||||
|
||||
layerHeaders := make(map[string]map[string]json.RawMessage)
|
||||
var tensorLayerNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
if mediaType == "application/vnd.ollama.image.tensor" {
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layerHeaders[name] = header
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensorLayerNames = append(tensorLayerNames, name)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
packedCreator := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
return LayerInfo{}, fmt.Errorf("unexpected packedCreator call for %s", groupName)
|
||||
}
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, packedCreator); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if slices.Contains(tensorLayerNames, "linear.weight_packed") || slices.Contains(tensorLayerNames, "linear.weight_scale") || slices.Contains(tensorLayerNames, "linear.weight_global_scale") || slices.Contains(tensorLayerNames, "linear.input_global_scale") {
|
||||
t.Fatalf("packed nvfp4 tensors unexpectedly emitted as standalone tensor layers: %v", tensorLayerNames)
|
||||
}
|
||||
|
||||
packedHeader := layerHeaders["linear.weight"]
|
||||
if packedHeader == nil {
|
||||
t.Fatalf("missing packed layer header for linear.weight")
|
||||
}
|
||||
for _, key := range []string{
|
||||
"linear.weight",
|
||||
"linear.weight.scale",
|
||||
"linear.weight.global_scale",
|
||||
} {
|
||||
if _, ok := packedHeader[key]; !ok {
|
||||
t.Fatalf("packed header missing %s: %v", key, packedHeader)
|
||||
}
|
||||
}
|
||||
if _, ok := packedHeader["linear.weight.input_global_scale"]; ok {
|
||||
t.Fatalf("packed header unexpectedly includes input_global_scale: %v", packedHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_PackedNVFP4StacksExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"compression_config": {"format": "nvfp4-pack-quantized"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_packed", "U8", []int32{2, 8}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_scale", "F8_E4M3", []int32{2, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_global_scale", "F32", []int32{1}, encodeFloat32s(2)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.input_global_scale", "F32", []int32{1}, encodeFloat32s(32)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_packed", "U8", []int32{2, 8}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_scale", "F8_E4M3", []int32{2, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_global_scale", "F32", []int32{1}, encodeFloat32s(4)),
|
||||
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.input_global_scale", "F32", []int32{1}, encodeFloat32s(64)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
|
||||
})
|
||||
|
||||
layerHeaders := make(map[string]map[string]json.RawMessage)
|
||||
layerData := make(map[string][]byte)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
if mediaType == "application/vnd.ollama.image.tensor" {
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layerHeaders[name] = header
|
||||
layerData[name] = data
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
packedCreator := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
return LayerInfo{}, fmt.Errorf("unexpected packedCreator call for %s", groupName)
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, packedCreator); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
header := layerHeaders["model.layers.1.mlp.experts"]
|
||||
if header == nil {
|
||||
t.Fatalf("missing packed expert layer header")
|
||||
}
|
||||
for _, key := range []string{
|
||||
"model.layers.1.mlp.switch_mlp.gate_proj.weight",
|
||||
"model.layers.1.mlp.switch_mlp.gate_proj.weight.scale",
|
||||
"model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale",
|
||||
} {
|
||||
if _, ok := header[key]; !ok {
|
||||
t.Fatalf("stacked header missing %s: %v", key, header)
|
||||
}
|
||||
}
|
||||
if _, ok := header["model.layers.1.mlp.switch_mlp.gate_proj.weight.input_global_scale"]; ok {
|
||||
t.Fatalf("stacked header unexpectedly includes input_global_scale: %v", header)
|
||||
}
|
||||
if _, ok := header["model.layers.1.mlp.experts.0.gate_proj.weight"]; ok {
|
||||
t.Fatalf("unexpected per-expert tensor left in packed header: %v", header)
|
||||
}
|
||||
|
||||
var weightInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
if err := json.Unmarshal(header["model.layers.1.mlp.switch_mlp.gate_proj.weight"], &weightInfo); err != nil {
|
||||
t.Fatalf("failed to unmarshal stacked weight info: %v", err)
|
||||
}
|
||||
if weightInfo.Dtype != "U32" || !slices.Equal(weightInfo.Shape, []int32{2, 2, 2}) {
|
||||
t.Fatalf("stacked weight = dtype %s shape %v, want U32 [2 2 2]", weightInfo.Dtype, weightInfo.Shape)
|
||||
}
|
||||
|
||||
var globalInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
if err := json.Unmarshal(header["model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale"], &globalInfo); err != nil {
|
||||
t.Fatalf("failed to unmarshal stacked global scale info: %v", err)
|
||||
}
|
||||
if globalInfo.Dtype != "F32" || !slices.Equal(globalInfo.Shape, []int32{2, 1, 1}) {
|
||||
t.Fatalf("stacked global scale = dtype %s shape %v, want F32 [2 1 1]", globalInfo.Dtype, globalInfo.Shape)
|
||||
}
|
||||
globalRaw := readPackedTensorRaw(t, layerData["model.layers.1.mlp.experts"], "model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale")
|
||||
if got0 := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw[0:4])); got0 != 0.5 {
|
||||
t.Fatalf("stacked global scale[0] = %v, want 0.5", got0)
|
||||
}
|
||||
if got1 := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw[4:8])); got1 != 0.25 {
|
||||
t.Fatalf("stacked global scale[1] = %v, want 0.25", got1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
|
@ -777,6 +1347,26 @@ func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
|
|||
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
|
||||
}
|
||||
}
|
||||
|
||||
packedLayerNames = nil
|
||||
packedLayerTensors = nil
|
||||
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel nvfp4 failed: %v", err)
|
||||
}
|
||||
|
||||
if len(packedLayerNames) != 1 {
|
||||
t.Fatalf("expected 1 packed layer for nvfp4, got %d: %v", len(packedLayerNames), packedLayerNames)
|
||||
}
|
||||
|
||||
for _, tensor := range packedLayerTensors[0] {
|
||||
want := "nvfp4"
|
||||
if strings.Contains(tensor.Name, "down_proj") {
|
||||
want = "mxfp8"
|
||||
}
|
||||
if tensor.Quantize != want {
|
||||
t.Fatalf("nvfp4 packed tensor %s quantize = %q, want %q", tensor.Name, tensor.Quantize, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ func DTypeSize(dtype string) (int, error) {
|
|||
return 4, nil
|
||||
case "F64":
|
||||
return 8, nil
|
||||
case "U8", "I8":
|
||||
return 1, nil
|
||||
case "F8_E4M3", "F8_E5M2", "F8_E4M3FN", "F8_E5M2FNUZ":
|
||||
return 1, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported dtype %q", dtype)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ func dtypeFromString(s string) mlx.Dtype {
|
|||
return mlx.DtypeInt64
|
||||
case "U8", "UINT8":
|
||||
return mlx.DtypeUint8
|
||||
case "F8_E4M3", "F8_E5M2", "F8_E4M3FN", "F8_E5M2FNUZ":
|
||||
return mlx.DtypeUint8 // FP8 types stored as raw uint8 bytes
|
||||
default:
|
||||
return mlx.DtypeFloat32
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
|
@ -121,10 +122,17 @@ func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata
|
|||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
arrayNames := make([]string, 0, len(arrays))
|
||||
for name, arr := range arrays {
|
||||
if arr == nil {
|
||||
continue
|
||||
}
|
||||
arrayNames = append(arrayNames, name)
|
||||
}
|
||||
sort.Strings(arrayNames)
|
||||
|
||||
for _, name := range arrayNames {
|
||||
arr := arrays[name]
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
|
|
@ -133,7 +141,14 @@ func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata
|
|||
cMetadata := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||
|
||||
for key, value := range metadata {
|
||||
metadataKeys := make([]string, 0, len(metadata))
|
||||
for key := range metadata {
|
||||
metadataKeys = append(metadataKeys, key)
|
||||
}
|
||||
sort.Strings(metadataKeys)
|
||||
|
||||
for _, key := range metadataKeys {
|
||||
value := metadata[key]
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||
|
|
|
|||
|
|
@ -74,14 +74,23 @@ func MakeLinearLayer(
|
|||
scales,
|
||||
)
|
||||
|
||||
// Check for per-tensor global scale (NVIDIA double-scale nvfp4).
|
||||
// NVIDIA ModelOpt stores this as "weight_scale_2"; our import
|
||||
// pipeline maps it to "weight.global_scale".
|
||||
globalScale := tensors[path+".weight.global_scale"]
|
||||
if globalScale == nil {
|
||||
globalScale = tensors[path+".weight_scale_2"]
|
||||
}
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GlobalScale: globalScale,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,13 +78,14 @@ func (l *Linear) OutputDim() int32 {
|
|||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GlobalScale *mlx.Array // Per-tensor global scale for double-scale nvfp4 (nil for standard)
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
|
|
@ -106,7 +107,18 @@ func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int,
|
|||
}
|
||||
|
||||
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
var out *mlx.Array
|
||||
if ql.GlobalScale != nil {
|
||||
// Double-scale nvfp4 (e.g., NVIDIA ModelOpt): standard quantized_matmul
|
||||
// followed by global_scale multiply. The global_scale is a per-tensor
|
||||
// F32 scalar (weight_scale_2 in NVIDIA's format).
|
||||
// TODO: switch to a fused double-scale matmul once MLX has kernel
|
||||
// coverage for this path.
|
||||
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
out = mlx.Mul(out, ql.GlobalScale)
|
||||
} else {
|
||||
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
}
|
||||
if ql.Bias != nil && ql.Bias.Valid() {
|
||||
out = out.Add(ql.Bias)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,6 +110,19 @@ func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *
|
|||
}
|
||||
}
|
||||
|
||||
// NewTensorDataFromReaderAt creates a TensorData backed by an arbitrary
|
||||
// io.ReaderAt. This is useful for constructing large synthetic tensors from
|
||||
// temporary files without loading the full payload into memory.
|
||||
func NewTensorDataFromReaderAt(name, dtype string, shape []int32, readerAt io.ReaderAt, size int64) *TensorData {
|
||||
return &TensorData{
|
||||
Name: name,
|
||||
Dtype: dtype,
|
||||
Shape: shape,
|
||||
Size: size,
|
||||
reader: io.NewSectionReader(readerAt, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts
|
||||
// the raw tensor data bytes (stripping the header).
|
||||
func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {
|
||||
|
|
|
|||
102
x/server/show.go
102
x/server/show.go
|
|
@ -306,15 +306,16 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
|||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// Reads quant_type from the first tensor blob's __metadata__.
|
||||
// Falls back to torch_dtype from config.json if no quant metadata.
|
||||
// Reads tensor headers until quantized weights are found.
|
||||
// Falls back to torch_dtype from config.json if no quant metadata exists.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check first tensor blob for quant_type metadata
|
||||
// Mixed models can start with unquantized embeddings or heads, so scan until
|
||||
// any tensor blob reports quantized weight metadata.
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
|
|
@ -323,15 +324,20 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if quantType := canonicalQuantType(info.QuantType); quantType != "" {
|
||||
return quantType, nil
|
||||
infos, err := parseSafetensorsAllHeaders(f)
|
||||
_ = f.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, info := range infos {
|
||||
if quantType := canonicalQuantType(info.QuantType); quantType != "" {
|
||||
return quantType, nil
|
||||
}
|
||||
}
|
||||
// Only check the first tensor blob
|
||||
break
|
||||
}
|
||||
|
||||
// Not quantized - return torch_dtype from config.json
|
||||
|
|
@ -354,86 +360,6 @@ type safetensorsTensorInfo struct {
|
|||
GroupSize string // from __metadata__.group_size (e.g., "32", "64")
|
||||
}
|
||||
|
||||
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
|
||||
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
|
||||
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return parseSafetensorsHeader(f)
|
||||
}
|
||||
|
||||
// parseSafetensorsHeader parses a safetensors header from a reader.
|
||||
// This is separated for testability.
|
||||
// Parses __metadata__ for quant_type and group_size if present.
|
||||
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
// Read header size (8 bytes, little endian)
|
||||
var headerSize uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check - header shouldn't be too large
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header size too large: %d", headerSize)
|
||||
}
|
||||
|
||||
// Read header JSON
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
// Parse as map of tensor name -> info
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
// Parse metadata if present
|
||||
var quantType, groupSize string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
var meta map[string]string
|
||||
if json.Unmarshal(metaRaw, &meta) == nil {
|
||||
quantType = meta["quant_type"]
|
||||
groupSize = meta["group_size"]
|
||||
}
|
||||
}
|
||||
|
||||
// Find the main tensor entry (not __metadata__, .scale, or .bias)
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
info.QuantType = quantType
|
||||
info.GroupSize = groupSize
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// Fall back to first non-metadata tensor entry
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
info.QuantType = quantType
|
||||
info.GroupSize = groupSize
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no tensor found in header")
|
||||
}
|
||||
|
||||
// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header.
|
||||
// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias).
|
||||
// For packed blobs this returns multiple entries; for single-tensor blobs, one entry.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
|
|
@ -286,168 +287,7 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]any
|
||||
wantDtype string
|
||||
wantShape []int64
|
||||
wantQuantType string
|
||||
wantGroupSize string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple tensor",
|
||||
header: map[string]any{
|
||||
"weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 262144},
|
||||
"data_offsets": []int64{0, 1342177280},
|
||||
},
|
||||
},
|
||||
wantDtype: "BF16",
|
||||
wantShape: []int64{2560, 262144},
|
||||
},
|
||||
{
|
||||
name: "tensor keyed by name",
|
||||
header: map[string]any{
|
||||
"model.layers.0.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 2560},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
},
|
||||
wantDtype: "BF16",
|
||||
wantShape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "with int4 quant metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 320},
|
||||
"data_offsets": []int64{0, 3276800},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80},
|
||||
"data_offsets": []int64{3276800, 3686400},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80},
|
||||
"data_offsets": []int64{3686400, 4096000},
|
||||
},
|
||||
},
|
||||
wantDtype: "U32",
|
||||
wantShape: []int64{2560, 320},
|
||||
wantQuantType: "int4",
|
||||
wantGroupSize: "32",
|
||||
},
|
||||
{
|
||||
name: "int8 quant metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"quant_type": "int8",
|
||||
"group_size": "64",
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 640},
|
||||
"data_offsets": []int64{0, 6553600},
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 40},
|
||||
"data_offsets": []int64{6553600, 6963200},
|
||||
},
|
||||
},
|
||||
wantDtype: "U32",
|
||||
wantShape: []int64{2560, 640},
|
||||
wantQuantType: "int8",
|
||||
wantGroupSize: "64",
|
||||
},
|
||||
{
|
||||
name: "with old-style format metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"format": "pt",
|
||||
},
|
||||
"bias": map[string]any{
|
||||
"dtype": "F32",
|
||||
"shape": []int64{1024},
|
||||
"data_offsets": []int64{0, 4096},
|
||||
},
|
||||
},
|
||||
wantDtype: "F32",
|
||||
wantShape: []int64{1024},
|
||||
},
|
||||
{
|
||||
name: "float16 tensor",
|
||||
header: map[string]any{
|
||||
"layer.weight": map[string]any{
|
||||
"dtype": "F16",
|
||||
"shape": []int64{512, 512, 3, 3},
|
||||
"data_offsets": []int64{0, 4718592},
|
||||
},
|
||||
},
|
||||
wantDtype: "F16",
|
||||
wantShape: []int64{512, 512, 3, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create safetensors format: 8-byte size + JSON header
|
||||
headerJSON, err := json.Marshal(tt.header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
buf.Write(headerJSON)
|
||||
|
||||
info, err := parseSafetensorsHeader(&buf)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if info.Dtype != tt.wantDtype {
|
||||
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
|
||||
}
|
||||
|
||||
if len(info.Shape) != len(tt.wantShape) {
|
||||
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
|
||||
} else {
|
||||
for i, s := range info.Shape {
|
||||
if s != tt.wantShape[i] {
|
||||
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.QuantType != tt.wantQuantType {
|
||||
t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType)
|
||||
}
|
||||
if info.GroupSize != tt.wantGroupSize {
|
||||
t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
func TestParseSafetensorsAllHeaders_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
|
|
@ -467,7 +307,7 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
|||
name: "header size too large",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(200*1024*1024)) // 200 MiB
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "header size too large",
|
||||
|
|
@ -510,7 +350,7 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
|
||||
_, err := parseSafetensorsAllHeaders(bytes.NewReader(tt.data))
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
|
|
@ -1209,44 +1049,77 @@ func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader(t *testing.T) {
|
||||
// Create a temp file with a valid safetensors header
|
||||
tempDir := t.TempDir()
|
||||
func TestGetSafetensorsDtypeScansPastUnquantizedFirstBlob(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
header := map[string]any{
|
||||
"test_tensor": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{1024, 768},
|
||||
"data_offsets": []int64{0, 1572864},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
writeSafetensorsLayer := func(t *testing.T, header map[string]any, name string) manifest.Layer {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(tempDir, "test.safetensors")
|
||||
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
var buf bytes.Buffer
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
buf.Write(headerJSON)
|
||||
|
||||
layer, err := manifest.NewLayer(&buf, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tensor layer: %v", err)
|
||||
}
|
||||
layer.Name = name
|
||||
return layer
|
||||
}
|
||||
|
||||
info, err := readSafetensorsHeader(filePath)
|
||||
configData, err := json.Marshal(map[string]any{
|
||||
"model_format": "safetensors",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("readSafetensorsHeader() error = %v", err)
|
||||
t.Fatalf("failed to marshal config: %v", err)
|
||||
}
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configData), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config layer: %v", err)
|
||||
}
|
||||
|
||||
if info.Dtype != "BF16" {
|
||||
t.Errorf("Dtype = %v, want BF16", info.Dtype)
|
||||
}
|
||||
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
|
||||
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
|
||||
}
|
||||
}
|
||||
unquantized := writeSafetensorsLayer(t, map[string]any{
|
||||
"model.embed_tokens.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{16, 8},
|
||||
"data_offsets": []int64{0, 256},
|
||||
},
|
||||
}, "model.embed_tokens.weight")
|
||||
|
||||
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
|
||||
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
quantized := writeSafetensorsLayer(t, map[string]any{
|
||||
"__metadata__": map[string]string{
|
||||
"quant_type": "mxfp8",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{16, 4},
|
||||
"data_offsets": []int64{0, 256},
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{16, 1},
|
||||
"data_offsets": []int64{256, 288},
|
||||
},
|
||||
}, "model.layers.0.mlp.down_proj.weight")
|
||||
|
||||
name := model.ParseName("mixed-fp8-safetensors")
|
||||
if err := manifest.WriteManifest(name, configLayer, []manifest.Layer{unquantized, quantized}); err != nil {
|
||||
t.Fatalf("failed to write manifest: %v", err)
|
||||
}
|
||||
|
||||
got, err := GetSafetensorsDtype(name)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSafetensorsDtype() error = %v", err)
|
||||
}
|
||||
if got != "mxfp8" {
|
||||
t.Fatalf("GetSafetensorsDtype() = %q, want mxfp8", got)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue