mlx: Support NVIDIA TensorRT Model Optimizer import (#15566)

* mlx: Support NVIDIA TensorRT Model Optimizer import

* x/create: support FP8 safetensors import

Decode HF F8_E4M3 safetensors with block scale companions into MLX-importable tensor blobs, including compressed-tensors weight_scale metadata, packed NVFP4 layouts, and mixed-precision tensor headers.

Use that source-precision metadata during create quantization: default FP8-sourced imports to mxfp8, allow source FP8 to target MLX low-bit formats, preserve source-quantized NVFP4 layouts, selectively keep or promote tensors based on their source precision, and detect quantized dtype from mixed-precision safetensors manifests.

* review comments
This commit is contained in:
Daniel Hiltgen 2026-04-27 18:28:10 -07:00 committed by GitHub
parent ec9b4e9e47
commit 03aee88186
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1571 additions and 332 deletions

View file

@ -70,9 +70,13 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
scaleKey := inputKey + ".scale_inv"
scaleInv := st.Get(scaleKey)
if scaleInv == nil {
scaleKey = inputKey + ".scale"
scaleInv = st.Get(scaleKey)
}
if scaleInv == nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q or %q for fp8 source tensor %q", inputKey+".scale_inv", inputKey+".scale", inputKey)
}
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
if err != nil {
@ -560,13 +564,13 @@ func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry)
return keys[0], nil
}
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
if weight == nil || scaleInv == nil {
func decodeSourceFP8Tensor(weight, scale *mlx.Array) (*mlx.Array, error) {
if weight == nil || scale == nil {
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
}
weightShape := weight.Dims()
scaleShape := scaleInv.Dims()
scaleShape := scale.Dims()
if len(weightShape) != 2 || len(scaleShape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
}
@ -596,7 +600,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scale, 1), 3))
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
if padBottom > 0 || padSide > 0 {
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})

View file

@ -0,0 +1,24 @@
package client
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestDecodeSourceFP8TensorAcceptsWeightScale(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX unavailable: %v", err)
}
weight := mlx.FromValues([]uint8{0, 1, 2, 3}, 2, 2)
scale := mlx.FromValues([]float32{1}, 1, 1).AsType(mlx.DTypeBFloat16)
got, err := decodeSourceFP8Tensor(weight, scale)
if err != nil {
t.Fatal(err)
}
mlx.Eval(got)
if dims := got.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 2 {
t.Fatalf("decoded dims = %v, want [2 2]", dims)
}
}

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,9 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"slices"
@ -59,6 +61,43 @@ func TestIsTensorModelDir(t *testing.T) {
}
}
func TestValidateScalarFloat32TensorData(t *testing.T) {
td := st.NewTensorDataFromBytes("linear.weight_scale_2", "F32", []int32{}, encodeFloat32s(2))
got, err := validateScalarFloat32TensorData(td, "linear.weight.global_scale")
if err != nil {
t.Fatalf("validateScalarFloat32TensorData returned error: %v", err)
}
if got.Name != "linear.weight.global_scale" {
t.Fatalf("name = %q, want %q", got.Name, "linear.weight.global_scale")
}
if got.Dtype != "F32" {
t.Fatalf("dtype = %q, want F32", got.Dtype)
}
if len(got.Shape) != 0 {
t.Fatalf("shape = %v, want scalar", got.Shape)
}
}
func TestValidateScalarFloat32TensorDataRejectsNonScalar(t *testing.T) {
td := st.NewTensorDataFromBytes("linear.weight_scale_2", "F32", []int32{2}, encodeFloat32s(2, 4))
_, err := validateScalarFloat32TensorData(td, "linear.weight.global_scale")
if err == nil || !strings.Contains(err.Error(), "expected scalar F32 tensor") {
t.Fatalf("validateScalarFloat32TensorData error = %v, want scalar-shape failure", err)
}
}
func TestInvertScalarFloat32TensorDataRejectsNonF32(t *testing.T) {
td := st.NewTensorDataFromBytes("linear.weight_global_scale", "BF16", []int32{}, []byte{0, 0})
_, err := invertScalarFloat32TensorData(td, "linear.weight.global_scale")
if err == nil || !strings.Contains(err.Error(), "expected F32 tensor") {
t.Fatalf("invertScalarFloat32TensorData error = %v, want dtype failure", err)
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
@ -246,6 +285,41 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
return nil
}
func encodeFloat32s(vals ...float32) []byte {
raw := make([]byte, 4*len(vals))
for i, v := range vals {
binary.LittleEndian.PutUint32(raw[i*4:(i+1)*4], math.Float32bits(v))
}
return raw
}
func readPackedTensorRaw(t *testing.T, data []byte, tensorName string) []byte {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
info, ok := header[tensorName]
if !ok {
t.Fatalf("tensor %q not found in header", tensorName)
}
start := 8 + int(headerSize) + info.DataOffsets[0]
end := 8 + int(headerSize) + info.DataOffsets[1]
return data[start:end]
}
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
t.Helper()
@ -612,10 +686,22 @@ func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
var statusMessages []string
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(statusMessages) == 0 {
t.Fatal("no status messages received")
}
if got, want := statusMessages[0], "importing model.safetensors (4 tensors, converting source E4M3 block-FP8 to MLX mxfp8)"; got != want {
t.Fatalf("status = %q, want %q", got, want)
}
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
}
@ -643,6 +729,166 @@ func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
}
}
func TestCreateSafetensorsModel_CompressedTensorsFP8WeightScale(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"compression_config": {
"quant_method": "compressed-tensors",
"format": "float-quantized",
"config_groups": {
"group_0": {
"format": "float-quantized",
"weights": {
"type": "float",
"num_bits": 8,
"block_structure": [128, 128]
}
}
}
}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
})
quantizeByName := make(map[string]string)
headerNamesByName := make(map[string][]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
var statusMessages []string
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(statusMessages) == 0 {
t.Fatal("no status messages received")
}
if got, want := statusMessages[0], "importing model.safetensors (3 tensors, converting source E4M3 block-FP8 to MLX mxfp8)"; got != want {
t.Fatalf("status = %q, want %q", got, want)
}
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
t.Fatalf("linear.weight quantization = %q, want mxfp8", got)
}
if _, ok := quantizeByName["linear.weight_scale"]; ok {
t.Fatal("linear.weight_scale should not be imported as a standalone tensor")
}
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale"}) {
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale"})
}
}
func TestCreateSafetensorsModel_HFFP8SourceCanConvertToNVFP4(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{128}, make([]byte, 256)),
})
quantizeByName := make(map[string]string)
headerNamesByName := make(map[string][]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
var statusMessages []string
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(statusMessages) == 0 {
t.Fatal("no status messages received")
}
if got, want := statusMessages[0], "importing model.safetensors (9 tensors, converting source E4M3 block-FP8 to MLX nvfp4)"; got != want {
t.Fatalf("status = %q, want %q", got, want)
}
if got := quantizeByName["linear.weight"]; got != "nvfp4" {
t.Fatalf("linear.weight quantization = %q, want nvfp4", got)
}
if got := quantizeByName["model.layers.0.mlp.experts.0.down_proj.weight"]; got != "mxfp8" {
t.Fatalf("source fp8 down_proj quantization = %q, want mxfp8", got)
}
for _, name := range []string{
"model.layers.0.self_attn.q_proj.weight",
"model.embed_tokens.weight",
"lm_head.weight",
} {
if got := quantizeByName[name]; got != "mxfp8" {
t.Fatalf("%s quantization = %q, want mxfp8", name, got)
}
}
if got := quantizeByName["model.layers.0.mlp.gate.weight"]; got != "" {
t.Fatalf("router gate quantization = %q, want empty", got)
}
if got := quantizeByName["norm.weight"]; got != "" {
t.Fatalf("norm.weight quantization = %q, want empty", got)
}
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
}
}
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
tests := []struct {
name string
@ -670,7 +916,20 @@ func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
},
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
wantErr: `cannot convert already-quantized fp8 source model with --quantize "int4"`,
},
{
name: "packed nvfp4 source",
configJSON: `{
"model_type": "test",
"architectures": ["TestModel"],
"compression_config": {"format": "nvfp4-pack-quantized"}
}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
},
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
},
}
@ -701,6 +960,317 @@ func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T
}
}
func TestCreateSafetensorsModel_PackedNVFP4PreservesSourceLayout(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"compression_config": {"format": "nvfp4-pack-quantized"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.weight_global_scale", "F32", []int32{}, encodeFloat32s(4)),
st.NewTensorDataFromBytes("linear.input_global_scale", "F32", []int32{}, encodeFloat32s(8)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{16}, make([]byte, 32)),
})
var statusMessages []string
layerHeaders := make(map[string]map[string]json.RawMessage)
layerData := make(map[string][]byte)
var tensorLayerNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
if mediaType == "application/vnd.ollama.image.tensor" {
if len(data) < 8 {
return LayerInfo{}, io.ErrUnexpectedEOF
}
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
return LayerInfo{}, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
return LayerInfo{}, err
}
layerHeaders[name] = header
layerData[name] = data
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
tensorLayerNames = append(tensorLayerNames, name)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
progressFn := func(status string) { statusMessages = append(statusMessages, status) }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(statusMessages) == 0 {
t.Fatal("no status messages received")
}
if got, want := statusMessages[0], "importing model.safetensors (5 tensors, preserving source quantization)"; got != want {
t.Fatalf("status = %q, want %q", got, want)
}
if slices.Contains(tensorLayerNames, "linear.weight_scale") || slices.Contains(tensorLayerNames, "linear.weight_global_scale") || slices.Contains(tensorLayerNames, "linear.input_global_scale") {
t.Fatalf("packed nvfp4 companions unexpectedly emitted as standalone tensor layers: %v", tensorLayerNames)
}
packedHeader := layerHeaders["linear.weight"]
if packedHeader == nil {
t.Fatalf("missing packed layer header for linear.weight")
}
for _, key := range []string{
"linear.weight",
"linear.weight.scale",
"linear.weight.global_scale",
} {
if _, ok := packedHeader[key]; !ok {
t.Fatalf("packed header missing %s: %v", key, packedHeader)
}
}
if _, ok := packedHeader["linear.weight.input_global_scale"]; ok {
t.Fatalf("packed header unexpectedly includes input_global_scale: %v", packedHeader)
}
globalRaw := readPackedTensorRaw(t, layerData["linear.weight"], "linear.weight.global_scale")
if got := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw)); got != 0.25 {
t.Fatalf("linear.weight.global_scale = %v, want 0.25", got)
}
var metadata map[string]string
if metaRaw, ok := packedHeader["__metadata__"]; ok {
if err := json.Unmarshal(metaRaw, &metadata); err != nil {
t.Fatalf("failed to parse metadata: %v", err)
}
}
if metadata["quant_type"] != "nvfp4" {
t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "nvfp4")
}
if metadata["group_size"] != "16" {
t.Fatalf("group_size = %q, want %q", metadata["group_size"], "16")
}
}
func TestCreateSafetensorsModel_PackedNVFP4CrossShardCompanions(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"compression_config": {"format": "nvfp4-pack-quantized"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model-00001-of-00002.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight_packed", "U8", []int32{16, 8}, make([]byte, 128)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{16}, make([]byte, 32)),
})
createTestSafetensors(t, filepath.Join(dir, "model-00002-of-00002.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight_scale", "F8_E4M3", []int32{16, 1}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.weight_global_scale", "F32", []int32{}, encodeFloat32s(2)),
st.NewTensorDataFromBytes("linear.input_global_scale", "F32", []int32{}, encodeFloat32s(8)),
})
indexJSON := `{
"metadata": {"total_size": 152},
"weight_map": {
"linear.weight_packed": "model-00001-of-00002.safetensors",
"norm.weight": "model-00001-of-00002.safetensors",
"linear.weight_scale": "model-00002-of-00002.safetensors",
"linear.weight_global_scale": "model-00002-of-00002.safetensors",
"linear.input_global_scale": "model-00002-of-00002.safetensors"
}
}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
t.Fatalf("failed to write index: %v", err)
}
layerHeaders := make(map[string]map[string]json.RawMessage)
var tensorLayerNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
if mediaType == "application/vnd.ollama.image.tensor" {
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
return LayerInfo{}, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
return LayerInfo{}, err
}
layerHeaders[name] = header
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
tensorLayerNames = append(tensorLayerNames, name)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
packedCreator := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
return LayerInfo{}, fmt.Errorf("unexpected packedCreator call for %s", groupName)
}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, packedCreator); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if slices.Contains(tensorLayerNames, "linear.weight_packed") || slices.Contains(tensorLayerNames, "linear.weight_scale") || slices.Contains(tensorLayerNames, "linear.weight_global_scale") || slices.Contains(tensorLayerNames, "linear.input_global_scale") {
t.Fatalf("packed nvfp4 tensors unexpectedly emitted as standalone tensor layers: %v", tensorLayerNames)
}
packedHeader := layerHeaders["linear.weight"]
if packedHeader == nil {
t.Fatalf("missing packed layer header for linear.weight")
}
for _, key := range []string{
"linear.weight",
"linear.weight.scale",
"linear.weight.global_scale",
} {
if _, ok := packedHeader[key]; !ok {
t.Fatalf("packed header missing %s: %v", key, packedHeader)
}
}
if _, ok := packedHeader["linear.weight.input_global_scale"]; ok {
t.Fatalf("packed header unexpectedly includes input_global_scale: %v", packedHeader)
}
}
func TestCreateSafetensorsModel_PackedNVFP4StacksExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"compression_config": {"format": "nvfp4-pack-quantized"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_packed", "U8", []int32{2, 8}, make([]byte, 16)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_scale", "F8_E4M3", []int32{2, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.weight_global_scale", "F32", []int32{1}, encodeFloat32s(2)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.0.gate_proj.input_global_scale", "F32", []int32{1}, encodeFloat32s(32)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_packed", "U8", []int32{2, 8}, make([]byte, 16)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_scale", "F8_E4M3", []int32{2, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.weight_global_scale", "F32", []int32{1}, encodeFloat32s(4)),
st.NewTensorDataFromBytes("model.layers.1.mlp.experts.1.gate_proj.input_global_scale", "F32", []int32{1}, encodeFloat32s(64)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
})
layerHeaders := make(map[string]map[string]json.RawMessage)
layerData := make(map[string][]byte)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
if mediaType == "application/vnd.ollama.image.tensor" {
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
return LayerInfo{}, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
return LayerInfo{}, err
}
layerHeaders[name] = header
layerData[name] = data
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
packedCreator := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
return LayerInfo{}, fmt.Errorf("unexpected packedCreator call for %s", groupName)
}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, packedCreator); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
header := layerHeaders["model.layers.1.mlp.experts"]
if header == nil {
t.Fatalf("missing packed expert layer header")
}
for _, key := range []string{
"model.layers.1.mlp.switch_mlp.gate_proj.weight",
"model.layers.1.mlp.switch_mlp.gate_proj.weight.scale",
"model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale",
} {
if _, ok := header[key]; !ok {
t.Fatalf("stacked header missing %s: %v", key, header)
}
}
if _, ok := header["model.layers.1.mlp.switch_mlp.gate_proj.weight.input_global_scale"]; ok {
t.Fatalf("stacked header unexpectedly includes input_global_scale: %v", header)
}
if _, ok := header["model.layers.1.mlp.experts.0.gate_proj.weight"]; ok {
t.Fatalf("unexpected per-expert tensor left in packed header: %v", header)
}
var weightInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(header["model.layers.1.mlp.switch_mlp.gate_proj.weight"], &weightInfo); err != nil {
t.Fatalf("failed to unmarshal stacked weight info: %v", err)
}
if weightInfo.Dtype != "U32" || !slices.Equal(weightInfo.Shape, []int32{2, 2, 2}) {
t.Fatalf("stacked weight = dtype %s shape %v, want U32 [2 2 2]", weightInfo.Dtype, weightInfo.Shape)
}
var globalInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(header["model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale"], &globalInfo); err != nil {
t.Fatalf("failed to unmarshal stacked global scale info: %v", err)
}
if globalInfo.Dtype != "F32" || !slices.Equal(globalInfo.Shape, []int32{2, 1, 1}) {
t.Fatalf("stacked global scale = dtype %s shape %v, want F32 [2 1 1]", globalInfo.Dtype, globalInfo.Shape)
}
globalRaw := readPackedTensorRaw(t, layerData["model.layers.1.mlp.experts"], "model.layers.1.mlp.switch_mlp.gate_proj.weight.global_scale")
if got0 := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw[0:4])); got0 != 0.5 {
t.Fatalf("stacked global scale[0] = %v, want 0.5", got0)
}
if got1 := math.Float32frombits(binary.LittleEndian.Uint32(globalRaw[4:8])); got1 != 0.25 {
t.Fatalf("stacked global scale[1] = %v, want 0.25", got1)
}
}
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
dir := t.TempDir()
@ -777,6 +1347,26 @@ func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
}
}
packedLayerNames = nil
packedLayerTensors = nil
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel nvfp4 failed: %v", err)
}
if len(packedLayerNames) != 1 {
t.Fatalf("expected 1 packed layer for nvfp4, got %d: %v", len(packedLayerNames), packedLayerNames)
}
for _, tensor := range packedLayerTensors[0] {
want := "nvfp4"
if strings.Contains(tensor.Name, "down_proj") {
want = "mxfp8"
}
if tensor.Quantize != want {
t.Fatalf("nvfp4 packed tensor %s quantize = %q, want %q", tensor.Name, tensor.Quantize, want)
}
}
}
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {

View file

@ -19,6 +19,10 @@ func DTypeSize(dtype string) (int, error) {
return 4, nil
case "F64":
return 8, nil
case "U8", "I8":
return 1, nil
case "F8_E4M3", "F8_E5M2", "F8_E4M3FN", "F8_E5M2FNUZ":
return 1, nil
default:
return 0, fmt.Errorf("unsupported dtype %q", dtype)
}

View file

@ -64,6 +64,8 @@ func dtypeFromString(s string) mlx.Dtype {
return mlx.DtypeInt64
case "U8", "UINT8":
return mlx.DtypeUint8
case "F8_E4M3", "F8_E5M2", "F8_E4M3FN", "F8_E5M2FNUZ":
return mlx.DtypeUint8 // FP8 types stored as raw uint8 bytes
default:
return mlx.DtypeFloat32
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
"iter"
"runtime"
"sort"
"unsafe"
)
@ -121,10 +122,17 @@ func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
arrayNames := make([]string, 0, len(arrays))
for name, arr := range arrays {
if arr == nil {
continue
}
arrayNames = append(arrayNames, name)
}
sort.Strings(arrayNames)
for _, name := range arrayNames {
arr := arrays[name]
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
@ -133,7 +141,14 @@ func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata
cMetadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMetadata)
for key, value := range metadata {
metadataKeys := make([]string, 0, len(metadata))
for key := range metadata {
metadataKeys = append(metadataKeys, key)
}
sort.Strings(metadataKeys)
for _, key := range metadataKeys {
value := metadata[key]
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)

View file

@ -74,14 +74,23 @@ func MakeLinearLayer(
scales,
)
// Check for per-tensor global scale (NVIDIA double-scale nvfp4).
// NVIDIA ModelOpt stores this as "weight_scale_2"; our import
// pipeline maps it to "weight.global_scale".
globalScale := tensors[path+".weight.global_scale"]
if globalScale == nil {
globalScale = tensors[path+".weight_scale_2"]
}
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GlobalScale: globalScale,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}

View file

@ -78,13 +78,14 @@ func (l *Linear) OutputDim() int32 {
// QuantizedLinear applies an affine transformation using quantized weights.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
Mode string
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
Bias *mlx.Array // Layer bias [output_dims] or nil
GlobalScale *mlx.Array // Per-tensor global scale for double-scale nvfp4 (nil for standard)
GroupSize int
Bits int
Mode string
}
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
@ -106,7 +107,18 @@ func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int,
}
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
var out *mlx.Array
if ql.GlobalScale != nil {
// Double-scale nvfp4 (e.g., NVIDIA ModelOpt): standard quantized_matmul
// followed by global_scale multiply. The global_scale is a per-tensor
// F32 scalar (weight_scale_2 in NVIDIA's format).
// TODO: switch to a fused double-scale matmul once MLX has kernel
// coverage for this path.
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
out = mlx.Mul(out, ql.GlobalScale)
} else {
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
}
if ql.Bias != nil && ql.Bias.Valid() {
out = out.Add(ql.Bias)
}

View file

@ -110,6 +110,19 @@ func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *
}
}
// NewTensorDataFromReaderAt creates a TensorData backed by an arbitrary
// io.ReaderAt. This is useful for constructing large synthetic tensors from
// temporary files without loading the full payload into memory.
func NewTensorDataFromReaderAt(name, dtype string, shape []int32, readerAt io.ReaderAt, size int64) *TensorData {
return &TensorData{
Name: name,
Dtype: dtype,
Shape: shape,
Size: size,
reader: io.NewSectionReader(readerAt, 0, size),
}
}
// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts
// the raw tensor data bytes (stripping the header).
func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {

View file

@ -306,15 +306,16 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// Reads quant_type from the first tensor blob's __metadata__.
// Falls back to torch_dtype from config.json if no quant metadata.
// Reads tensor headers until quantized weights are found.
// Falls back to torch_dtype from config.json if no quant metadata exists.
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check first tensor blob for quant_type metadata
// Mixed models can start with unquantized embeddings or heads, so scan until
// any tensor blob reports quantized weight metadata.
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
@ -323,15 +324,20 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
if err != nil {
continue
}
info, err := readSafetensorsHeader(blobPath)
f, err := os.Open(blobPath)
if err != nil {
continue
}
if quantType := canonicalQuantType(info.QuantType); quantType != "" {
return quantType, nil
infos, err := parseSafetensorsAllHeaders(f)
_ = f.Close()
if err != nil {
continue
}
for _, info := range infos {
if quantType := canonicalQuantType(info.QuantType); quantType != "" {
return quantType, nil
}
}
// Only check the first tensor blob
break
}
// Not quantized - return torch_dtype from config.json
@ -354,86 +360,6 @@ type safetensorsTensorInfo struct {
GroupSize string // from __metadata__.group_size (e.g., "32", "64")
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return parseSafetensorsHeader(f)
}
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
// Parses __metadata__ for quant_type and group_size if present.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
// Sanity check - header shouldn't be too large
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header size too large: %d", headerSize)
}
// Read header JSON
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
// Parse as map of tensor name -> info
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Parse metadata if present
var quantType, groupSize string
if metaRaw, ok := header["__metadata__"]; ok {
var meta map[string]string
if json.Unmarshal(metaRaw, &meta) == nil {
quantType = meta["quant_type"]
groupSize = meta["group_size"]
}
}
// Find the main tensor entry (not __metadata__, .scale, or .bias)
for name, raw := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
info.QuantType = quantType
info.GroupSize = groupSize
return &info, nil
}
// Fall back to first non-metadata tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
info.QuantType = quantType
info.GroupSize = groupSize
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}
// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header.
// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias).
// For packed blobs this returns multiple entries; for single-tensor blobs, one entry.

View file

@ -9,6 +9,7 @@ import (
"testing"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
func TestBuildModelInfo(t *testing.T) {
@ -286,168 +287,7 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
}
}
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantDtype string
wantShape []int64
wantQuantType string
wantGroupSize string
wantErr bool
}{
{
name: "simple tensor",
header: map[string]any{
"weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 262144},
"data_offsets": []int64{0, 1342177280},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 262144},
},
{
name: "tensor keyed by name",
header: map[string]any{
"model.layers.0.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 2560},
"data_offsets": []int64{0, 13107200},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 2560},
},
{
name: "with int4 quant metadata",
header: map[string]any{
"__metadata__": map[string]any{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 320},
"data_offsets": []int64{0, 3276800},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80},
"data_offsets": []int64{3276800, 3686400},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80},
"data_offsets": []int64{3686400, 4096000},
},
},
wantDtype: "U32",
wantShape: []int64{2560, 320},
wantQuantType: "int4",
wantGroupSize: "32",
},
{
name: "int8 quant metadata",
header: map[string]any{
"__metadata__": map[string]any{
"quant_type": "int8",
"group_size": "64",
},
"model.layers.0.mlp.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 640},
"data_offsets": []int64{0, 6553600},
},
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 40},
"data_offsets": []int64{6553600, 6963200},
},
},
wantDtype: "U32",
wantShape: []int64{2560, 640},
wantQuantType: "int8",
wantGroupSize: "64",
},
{
name: "with old-style format metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
},
"bias": map[string]any{
"dtype": "F32",
"shape": []int64{1024},
"data_offsets": []int64{0, 4096},
},
},
wantDtype: "F32",
wantShape: []int64{1024},
},
{
name: "float16 tensor",
header: map[string]any{
"layer.weight": map[string]any{
"dtype": "F16",
"shape": []int64{512, 512, 3, 3},
"data_offsets": []int64{0, 4718592},
},
},
wantDtype: "F16",
wantShape: []int64{512, 512, 3, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create safetensors format: 8-byte size + JSON header
headerJSON, err := json.Marshal(tt.header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
info, err := parseSafetensorsHeader(&buf)
if (err != nil) != tt.wantErr {
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if info.Dtype != tt.wantDtype {
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
}
if len(info.Shape) != len(tt.wantShape) {
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
} else {
for i, s := range info.Shape {
if s != tt.wantShape[i] {
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
}
}
}
if info.QuantType != tt.wantQuantType {
t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType)
}
if info.GroupSize != tt.wantGroupSize {
t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize)
}
})
}
}
func TestParseSafetensorsHeader_Errors(t *testing.T) {
func TestParseSafetensorsAllHeaders_Errors(t *testing.T) {
tests := []struct {
name string
data []byte
@ -467,7 +307,7 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
name: "header size too large",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
binary.Write(&buf, binary.LittleEndian, uint64(200*1024*1024)) // 200 MiB
return buf.Bytes()
}(),
wantErr: "header size too large",
@ -510,7 +350,7 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
_, err := parseSafetensorsAllHeaders(bytes.NewReader(tt.data))
if err == nil {
t.Error("expected error, got nil")
return
@ -1209,44 +1049,77 @@ func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
}
}
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir := t.TempDir()
func TestGetSafetensorsDtypeScansPastUnquantizedFirstBlob(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
header := map[string]any{
"test_tensor": map[string]any{
"dtype": "BF16",
"shape": []int64{1024, 768},
"data_offsets": []int64{0, 1572864},
},
}
headerJSON, _ := json.Marshal(header)
writeSafetensorsLayer := func(t *testing.T, header map[string]any, name string) manifest.Layer {
t.Helper()
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
filePath := filepath.Join(tempDir, "test.safetensors")
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write test file: %v", err)
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
layer, err := manifest.NewLayer(&buf, manifest.MediaTypeImageTensor)
if err != nil {
t.Fatalf("failed to create tensor layer: %v", err)
}
layer.Name = name
return layer
}
info, err := readSafetensorsHeader(filePath)
configData, err := json.Marshal(map[string]any{
"model_format": "safetensors",
})
if err != nil {
t.Fatalf("readSafetensorsHeader() error = %v", err)
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(configData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
if info.Dtype != "BF16" {
t.Errorf("Dtype = %v, want BF16", info.Dtype)
}
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
}
}
unquantized := writeSafetensorsLayer(t, map[string]any{
"model.embed_tokens.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{16, 8},
"data_offsets": []int64{0, 256},
},
}, "model.embed_tokens.weight")
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
if err == nil {
t.Error("expected error for nonexistent file")
quantized := writeSafetensorsLayer(t, map[string]any{
"__metadata__": map[string]string{
"quant_type": "mxfp8",
"group_size": "32",
},
"model.layers.0.mlp.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{16, 4},
"data_offsets": []int64{0, 256},
},
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{16, 1},
"data_offsets": []int64{256, 288},
},
}, "model.layers.0.mlp.down_proj.weight")
name := model.ParseName("mixed-fp8-safetensors")
if err := manifest.WriteManifest(name, configLayer, []manifest.Layer{unquantized, quantized}); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
got, err := GetSafetensorsDtype(name)
if err != nil {
t.Fatalf("GetSafetensorsDtype() error = %v", err)
}
if got != "mxfp8" {
t.Fatalf("GetSafetensorsDtype() = %q, want mxfp8", got)
}
}