diff --git a/api/types.go b/api/types.go index 7fc53b370..51ddafb86 100644 --- a/api/types.go +++ b/api/types.go @@ -814,6 +814,7 @@ type ModelRecommendation struct { ContextLength int `json:"context_length,omitempty"` MaxOutputTokens int `json:"max_output_tokens,omitempty"` VRAMBytes int64 `json:"vram_bytes,omitempty"` + RequiredPlan string `json:"required_plan,omitempty"` } // ProcessResponse is the response from [Client.Process]. diff --git a/cmd/cmd.go b/cmd/cmd.go index 7448a3e7d..2ecdb6fb0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -61,28 +61,20 @@ import ( func init() { // Override default selectors to use Bubbletea TUI instead of raw terminal I/O. - launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) { - if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) { - return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode") - } - tuiItems := tui.ReorderItems(tui.ConvertItems(items)) - result, err := tui.SelectSingle(title, tuiItems, current) - if errors.Is(err, tui.ErrCancelled) { - return "", launch.ErrCancelled - } - return result, err + launch.DefaultSingleSelector = func(title string, items []launch.SelectionItem, current string) (string, error) { + return runTUISingleSelector(title, items, current, nil) } - launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) { - if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) { - return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode") - } - tuiItems := tui.ReorderItems(tui.ConvertItems(items)) - result, err := tui.SelectMultiple(title, tuiItems, preChecked) - if errors.Is(err, tui.ErrCancelled) { - return nil, launch.ErrCancelled - } - return result, err + launch.DefaultSingleSelectorWithUpdates = func(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) { + return runTUISingleSelector(title, items, current, updates) + } + + launch.DefaultMultiSelector = func(title string, items []launch.SelectionItem, preChecked []string) ([]string, error) { + return runTUIMultiSelector(title, items, preChecked, nil) + } + + launch.DefaultMultiSelectorWithUpdates = func(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) { + return runTUIMultiSelector(title, items, preChecked, updates) } launch.DefaultSignIn = func(modelName, signInURL string) (string, error) { @@ -93,9 +85,55 @@ func init() { return userName, err } + launch.DefaultUpgrade = func(modelName, requiredPlan string) (string, error) { + plan, err := tui.RunUpgrade(modelName, requiredPlan) + if errors.Is(err, tui.ErrCancelled) { + return "", launch.ErrCancelled + } + return plan, err + } + launch.DefaultConfirmPrompt = tui.RunConfirmWithOptions } +func runTUISingleSelector(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) { + if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) { + return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode") + } + tuiItems := tui.ReorderItems(tui.ConvertItems(items)) + result, err := tui.SelectSingleWithUpdates(title, tuiItems, current, convertSelectionItemUpdates(updates)) + if errors.Is(err, tui.ErrCancelled) { + return "", launch.ErrCancelled + } + return result, err +} + +func runTUIMultiSelector(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) { + if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) { + return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode") + } + tuiItems := tui.ReorderItems(tui.ConvertItems(items)) + result, err := tui.SelectMultipleWithUpdates(title, tuiItems, preChecked, convertSelectionItemUpdates(updates)) + if errors.Is(err, tui.ErrCancelled) { + return nil, launch.ErrCancelled + } + return result, err +} + +func convertSelectionItemUpdates(updates <-chan []launch.SelectionItem) <-chan []tui.SelectItem { + if updates == nil { + return nil + } + out := make(chan []tui.SelectItem, 1) + go func() { + defer close(out) + for items := range updates { + out <- tui.ReorderItems(tui.ConvertItems(items)) + } + }() + return out +} + const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n" // ensureThinkingSupport emits a warning if the model does not advertise thinking support @@ -2090,12 +2128,15 @@ func runInteractiveTUI(cmd *cobra.Command) { return } + accountPrefetch := launch.StartAccountStatePrefetch(cmd.Context()) deps := launcherDeps{ - buildState: launch.BuildLauncherState, - runMenu: tui.RunMenu, - resolveRunModel: launch.ResolveRunModel, - launchIntegration: launch.LaunchIntegration, - runModel: launchInteractiveModel, + buildState: launch.BuildLauncherState, + runMenu: tui.RunMenu, + resolveRunModel: launch.ResolveRunModel, + launchIntegration: launch.LaunchIntegration, + runModel: launchInteractiveModel, + accountState: accountPrefetch.StateIfReady, + accountStateUpdates: accountPrefetch.StateUpdates, } for { @@ -2110,11 +2151,13 @@ func runInteractiveTUI(cmd *cobra.Command) { } type launcherDeps struct { - buildState func(context.Context) (*launch.LauncherState, error) - runMenu func(*launch.LauncherState) (tui.TUIAction, error) - resolveRunModel func(context.Context, launch.RunModelRequest) (string, error) - launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error - runModel func(*cobra.Command, string) error + buildState func(context.Context) (*launch.LauncherState, error) + runMenu func(*launch.LauncherState) (tui.TUIAction, error) + resolveRunModel func(context.Context, launch.RunModelRequest) (string, error) + launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error + runModel func(*cobra.Command, string) error + accountState func() *launch.AccountState + accountStateUpdates func(context.Context) <-chan *launch.AccountState } func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) { @@ -2122,6 +2165,9 @@ func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) if err != nil { return false, fmt.Errorf("build launcher state: %w", err) } + if state != nil && deps.accountState != nil { + state.AccountState = deps.accountState() + } action, err := deps.runMenu(state) if err != nil { @@ -2142,7 +2188,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe return false, nil case tui.TUIActionRunModel: saveLauncherSelection(action) - modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest()) + req := action.RunModelRequest() + if deps.accountState != nil { + req.AccountState = deps.accountState() + req.AccountStateProvider = deps.accountState + } + req.AccountStateUpdates = deps.accountStateUpdates + modelName, err := deps.resolveRunModel(cmd.Context(), req) if errors.Is(err, launch.ErrCancelled) { return true, nil } @@ -2155,7 +2207,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe return true, nil case tui.TUIActionLaunchIntegration: saveLauncherSelection(action) - err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest()) + req := action.IntegrationLaunchRequest() + if deps.accountState != nil { + req.AccountState = deps.accountState() + req.AccountStateProvider = deps.accountState + } + req.AccountStateUpdates = deps.accountStateUpdates + err := deps.launchIntegration(cmd.Context(), req) if errors.Is(err, launch.ErrCancelled) { return true, nil } diff --git a/cmd/cmd_launcher_test.go b/cmd/cmd_launcher_test.go index a90c74e94..853520e60 100644 --- a/cmd/cmd_launcher_test.go +++ b/cmd/cmd_launcher_test.go @@ -76,11 +76,18 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) { var gotReq launch.RunModelRequest var launched string + prefetchedAccount := &launch.AccountState{} + accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil } deps := launcherDeps{ buildState: func(ctx context.Context) (*launch.LauncherState, error) { return &launch.LauncherState{}, nil }, - runMenu: runMenu, + runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) { + if state.AccountState != prefetchedAccount { + t.Fatalf("prefetched account state was not piped to menu state") + } + return runMenu(state) + }, resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) { gotReq = req return tt.wantModel, nil @@ -90,6 +97,10 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) { launched = model return nil }, + accountState: func() *launch.AccountState { + return prefetchedAccount + }, + accountStateUpdates: accountUpdates, } cmd := &cobra.Command{} @@ -107,6 +118,12 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) { if gotReq.ForcePicker != tt.wantForce { t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker) } + if gotReq.AccountState != prefetchedAccount { + t.Fatalf("expected prefetched account state to be passed to run model request") + } + if gotReq.AccountStateUpdates == nil { + t.Fatalf("expected account state updates to be passed to run model request") + } if launched != tt.wantModel { t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched) } @@ -148,17 +165,28 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) } var gotReq launch.IntegrationLaunchRequest + prefetchedAccount := &launch.AccountState{} + accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil } deps := launcherDeps{ buildState: func(ctx context.Context) (*launch.LauncherState, error) { return &launch.LauncherState{}, nil }, - runMenu: runMenu, + runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) { + if state.AccountState != prefetchedAccount { + t.Fatalf("prefetched account state was not piped to menu state") + } + return runMenu(state) + }, resolveRunModel: unexpectedRunModelResolution(t), launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error { gotReq = req return nil }, runModel: unexpectedModelLaunch(t), + accountState: func() *launch.AccountState { + return prefetchedAccount + }, + accountStateUpdates: accountUpdates, } cmd := &cobra.Command{} @@ -179,6 +207,12 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) if gotReq.ForceConfigure != tt.wantForce { t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure) } + if gotReq.AccountState != prefetchedAccount { + t.Fatalf("expected prefetched account state to be passed to integration request") + } + if gotReq.AccountStateUpdates == nil { + t.Fatalf("expected account state updates to be passed to integration request") + } if got := config.LastSelection(); got != "claude" { t.Fatalf("expected last selection to be claude, got %q", got) } diff --git a/cmd/launch/account.go b/cmd/launch/account.go new file mode 100644 index 000000000..1471e1f08 --- /dev/null +++ b/cmd/launch/account.go @@ -0,0 +1,371 @@ +package launch + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/ollama/ollama/api" +) + +const ( + // DefaultUpgradeURL is the fixed destination for subscription upgrades. + DefaultUpgradeURL = "https://ollama.com/upgrade" + + accountCheckTimeout = 3 * time.Second +) + +var ( + ErrPlanVerificationUnavailable = errors.New("Could not verify your plan. Try again in a moment.") + errUpgradeCancelled = errors.New("upgrade cancelled") +) + +type accountStateStatus int + +const ( + accountStateUnknown accountStateStatus = iota + accountStateSignedOut + accountStateSignedIn +) + +type AccountState struct { + Status accountStateStatus + Plan string +} + +type AccountStatePrefetch struct { + done chan struct{} + state AccountState +} + +func StartAccountStatePrefetch(ctx context.Context) *AccountStatePrefetch { + if ctx == nil { + ctx = context.Background() + } + p := &AccountStatePrefetch{done: make(chan struct{})} + go func() { + state := AccountState{Status: accountStateUnknown} + client, err := api.ClientFromEnvironment() + if err == nil { + prefetchCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout) + defer cancel() + if disabled, known := cloudStatusDisabled(prefetchCtx, client); !known || !disabled { + state = launchAccountState(prefetchCtx, client) + } + } + p.state = state + close(p.done) + }() + return p +} + +func (p *AccountStatePrefetch) StateIfReady() *AccountState { + if p == nil { + return nil + } + select { + case <-p.done: + state := p.state + return &state + default: + return nil + } +} + +func (p *AccountStatePrefetch) StateUpdates(ctx context.Context) <-chan *AccountState { + if p == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + out := make(chan *AccountState, 1) + go func() { + defer close(out) + select { + case <-p.done: + if p.state.Status == accountStateUnknown { + return + } + state := p.state + select { + case out <- &state: + case <-ctx.Done(): + } + case <-ctx.Done(): + } + }() + return out +} + +func launchAccountState(ctx context.Context, client *api.Client) AccountState { + if client == nil { + return AccountState{Status: accountStateUnknown} + } + + user, err := whoamiWithTimeout(ctx, client) + if err != nil { + var authErr api.AuthorizationError + if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized { + return AccountState{Status: accountStateSignedOut} + } + return AccountState{Status: accountStateUnknown} + } + if user == nil || strings.TrimSpace(user.Name) == "" { + return AccountState{Status: accountStateSignedOut} + } + return AccountState{ + Status: accountStateSignedIn, + Plan: strings.TrimSpace(user.Plan), + } +} + +func whoamiWithTimeout(ctx context.Context, client *api.Client) (*api.UserResponse, error) { + if ctx == nil { + ctx = context.Background() + } + checkCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout) + defer cancel() + return client.Whoami(checkCtx) +} + +func ApplyAccountStateToSelectionItems(items []ModelItem, state AccountState) []SelectionItem { + out := make([]SelectionItem, len(items)) + for i, item := range items { + out[i] = SelectionItem{ + Name: item.Name, + Description: item.Description, + Recommended: item.Recommended, + AvailabilityBadge: availabilityBadge(item, state), + } + } + return out +} + +func SelectionItemsWithAccountState(items []ModelItem, state *AccountState) []SelectionItem { + if state == nil || !selectionItemsNeedAccountState(items) { + return ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown}) + } + return ApplyAccountStateToSelectionItems(items, *state) +} + +func selectionItemsNeedAccountState(items []ModelItem) bool { + for _, item := range items { + if isCloudModelName(item.Name) && itemHasRecommendationMetadata(item) { + return true + } + } + return false +} + +func (c *launcherClient) selectionItemUpdates(ctx context.Context, items []ModelItem, state *AccountState) <-chan []SelectionItem { + if !selectionItemsNeedAccountState(items) || state != nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + + stateUpdates := c.accountStateUpdateSource(ctx) + if stateUpdates == nil { + return nil + } + + out := make(chan []SelectionItem, 1) + go func() { + defer close(out) + select { + case state, ok := <-stateUpdates: + if !ok || state == nil { + return + } + select { + case out <- SelectionItemsWithAccountState(items, state): + case <-ctx.Done(): + } + case <-ctx.Done(): + } + }() + return out +} + +func (c *launcherClient) accountStateUpdateSource(ctx context.Context) <-chan *AccountState { + if c.accountStateUpdates != nil { + return c.accountStateUpdates(ctx) + } + if c.apiClient == nil { + return nil + } + out := make(chan *AccountState, 1) + go func() { + defer close(out) + state := launchAccountState(ctx, c.apiClient) + if state.Status == accountStateUnknown { + return + } + select { + case out <- &state: + case <-ctx.Done(): + } + }() + return out +} + +func availabilityBadge(item ModelItem, state AccountState) string { + if !isCloudModelName(item.Name) { + return "" + } + switch state.Status { + case accountStateSignedOut: + if itemHasRecommendationMetadata(item) { + return "Sign in required" + } + case accountStateSignedIn: + if item.RequiredPlan != "" && !PlanSatisfies(state.Plan, item.RequiredPlan) { + return "Upgrade required" + } + } + return "" +} + +func itemHasRecommendationMetadata(item ModelItem) bool { + return item.Recommended || strings.TrimSpace(item.RequiredPlan) != "" +} + +func (c *launcherClient) ensureCloudModelAccess(ctx context.Context, model string) error { + item, ok := c.modelRecommendationItem(ctx, model) + if !ok || strings.TrimSpace(item.RequiredPlan) == "" { + return nil + } + + state := launchAccountState(ctx, c.apiClient) + if state.Status != accountStateUnknown { + c.accountState = &state + } + if state.Status == accountStateUnknown { + return ErrPlanVerificationUnavailable + } + + if state.Status == accountStateSignedOut { + if err := ensureCloudAuth(ctx, c.apiClient, model); err != nil { + return err + } + state = launchAccountState(ctx, c.apiClient) + if state.Status != accountStateUnknown { + c.accountState = &state + } + if state.Status == accountStateUnknown { + return ErrPlanVerificationUnavailable + } + } + + if PlanSatisfies(state.Plan, item.RequiredPlan) { + return nil + } + + if err := c.runUpgradeFlow(ctx, item); err != nil { + return err + } + state = launchAccountState(ctx, c.apiClient) + if state.Status == accountStateUnknown { + return ErrPlanVerificationUnavailable + } + if state.Status != accountStateSignedIn || !PlanSatisfies(state.Plan, item.RequiredPlan) { + return errUpgradeCancelled + } + return nil +} + +func (c *launcherClient) modelRecommendationItem(ctx context.Context, model string) (ModelItem, bool) { + for _, item := range c.recommendations(ctx) { + if item.Name == model { + return item, true + } + } + return ModelItem{}, false +} + +func (c *launcherClient) runUpgradeFlow(ctx context.Context, item ModelItem) error { + if DefaultUpgrade != nil { + if _, err := DefaultUpgrade(item.Name, item.RequiredPlan); err != nil { + if errors.Is(err, ErrCancelled) { + return errUpgradeCancelled + } + return err + } + return nil + } + + yes, err := ConfirmPrompt(fmt.Sprintf("Upgrade to use %s?", item.Name)) + if errors.Is(err, ErrCancelled) { + return errUpgradeCancelled + } + if err != nil { + return err + } + if !yes { + return errUpgradeCancelled + } + + fmt.Fprintf(os.Stderr, "\nTo upgrade, navigate to:\n %s\n\n", DefaultUpgradeURL) + openNow, err := ConfirmPrompt("Open now?") + if errors.Is(err, ErrCancelled) { + return errUpgradeCancelled + } + if err != nil { + return err + } + if openNow { + OpenBrowser(DefaultUpgradeURL) + } else { + return errUpgradeCancelled + } + + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 + fmt.Fprintf(os.Stderr, "\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[0]) + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + if frame%10 != 0 { + continue + } + state := launchAccountState(ctx, c.apiClient) + if state.Status == accountStateUnknown { + fmt.Fprintf(os.Stderr, "\r\033[K") + return ErrPlanVerificationUnavailable + } + if state.Status == accountStateSignedIn && PlanSatisfies(state.Plan, item.RequiredPlan) { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1mplan updated\033[0m\n") + return nil + } + } + } +} + +// PlanSatisfies reports whether currentPlan can use a model that has a requiredPlan. +func PlanSatisfies(currentPlan, requiredPlan string) bool { + required := normalizePlan(requiredPlan) + if required == "" || required == "free" { + return true + } + current := normalizePlan(currentPlan) + return current != "" && current != "free" +} + +func normalizePlan(plan string) string { + return strings.ToLower(strings.TrimSpace(plan)) +} diff --git a/cmd/launch/command_test.go b/cmd/launch/command_test.go index 4bae22d8f..24dc78577 100644 --- a/cmd/launch/command_test.go +++ b/cmd/launch/command_test.go @@ -319,7 +319,7 @@ func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) { var selectorCalls int var gotCurrent string - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { selectorCalls++ gotCurrent = current return "llama3.2", nil @@ -553,7 +553,7 @@ func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) defer func() { DefaultSingleSelector = oldSelector }() var gotCurrent string - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { gotCurrent = current return "qwen3:8b", nil } @@ -607,7 +607,7 @@ func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T oldSelector := DefaultSingleSelector defer func() { DefaultSingleSelector = oldSelector }() - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called for headless --yes saved-model launch") return "", nil } @@ -644,7 +644,7 @@ func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testin oldSelector := DefaultSingleSelector defer func() { DefaultSingleSelector = oldSelector }() - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called for headless --yes without saved model") return "", nil } diff --git a/cmd/launch/integrations_test.go b/cmd/launch/integrations_test.go index fdd3ce188..9e61de8b9 100644 --- a/cmd/launch/integrations_test.go +++ b/cmd/launch/integrations_test.go @@ -10,7 +10,9 @@ import ( "net/url" "slices" "strings" + "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" @@ -456,6 +458,28 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) { } } +func TestBuildModelList_PreservesRecommendationRequiredPlanForExistingCloudModel(t *testing.T) { + recommendations := []ModelItem{ + { + Name: "glm-5:cloud", + Description: "Reasoning and code generation", + Recommended: true, + RequiredPlan: "pro", + ContextLength: 202_752, + }, + } + existing := []modelInfo{{Name: "glm-5:cloud", Remote: true}} + + items, _, _, _ := buildModelListWithRecommendations(existing, recommendations, nil, "") + if len(items) != 1 { + t.Fatalf("expected one item, got %v", items) + } + item := items[0] + if item.RequiredPlan != "pro" { + t.Fatalf("RequiredPlan = %q, want pro", item.RequiredPlan) + } +} + func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) { existing := []modelInfo{ {Name: "gemma4", Remote: false}, @@ -1390,6 +1414,187 @@ func TestEnsureAuth_EmptyWhoamiRequiresSignIn(t *testing.T) { } } +func TestApplyAccountStateToSelectionItems_BadgesOnlyWhenActionRequired(t *testing.T) { + items := []ModelItem{ + {Name: "qwen3.5:cloud", Recommended: true}, + {Name: "kimi-k2.6:cloud", Recommended: true, RequiredPlan: "pro"}, + {Name: "llama3.2", RequiredPlan: "pro"}, + {Name: "glm-5:cloud"}, + {Name: "nemotron-3-super:cloud", Recommended: true, RequiredPlan: "free"}, + } + + signedOut := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedOut}) + if signedOut[0].AvailabilityBadge != "Sign in required" { + t.Fatalf("account cloud badge = %q", signedOut[0].AvailabilityBadge) + } + if signedOut[1].AvailabilityBadge != "Sign in required" { + t.Fatalf("subscription cloud signed-out badge = %q", signedOut[1].AvailabilityBadge) + } + if signedOut[4].AvailabilityBadge != "Sign in required" { + t.Fatalf("free-plan cloud signed-out badge = %q", signedOut[4].AvailabilityBadge) + } + if signedOut[2].AvailabilityBadge != "" || signedOut[3].AvailabilityBadge != "" { + t.Fatalf("unexpected badge for local or unmetadata item: %#v", signedOut) + } + + freeUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "free"}) + if freeUser[0].AvailabilityBadge != "" { + t.Fatalf("signed-in account model should not be badged, got %q", freeUser[0].AvailabilityBadge) + } + if freeUser[1].AvailabilityBadge != "Upgrade required" { + t.Fatalf("subscription cloud free-plan badge = %q", freeUser[1].AvailabilityBadge) + } + if freeUser[4].AvailabilityBadge != "" { + t.Fatalf("free required plan should be usable by free user, got %q", freeUser[4].AvailabilityBadge) + } + + proUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "pro"}) + if proUser[1].AvailabilityBadge != "" { + t.Fatalf("pro user should not see included badge, got %q", proUser[1].AvailabilityBadge) + } + + maxUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "max"}) + if maxUser[1].AvailabilityBadge != "" { + t.Fatalf("max user should not see upgrade badge, got %q", maxUser[1].AvailabilityBadge) + } + + unknown := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown}) + for _, item := range unknown { + if item.AvailabilityBadge != "" { + t.Fatalf("unknown account state should not render badges: %#v", unknown) + } + } +} + +func TestSelectionItemsWithAccountState_SkipsBadgesWithoutBadgeableCloudItems(t *testing.T) { + items := []ModelItem{ + {Name: "llama3.2"}, + {Name: "custom:cloud"}, + } + state := &AccountState{Status: accountStateSignedOut} + got := SelectionItemsWithAccountState(items, state) + if len(got) != len(items) { + t.Fatalf("got %d selection items, want %d", len(got), len(items)) + } + for _, item := range got { + if item.AvailabilityBadge != "" { + t.Fatalf("unexpected badge without account state: %#v", got) + } + } +} + +func TestSelectionItemsWithAccountState_UsesPrefetchedStateForRecommendedCloudItems(t *testing.T) { + state := &AccountState{Status: accountStateSignedOut} + got := SelectionItemsWithAccountState([]ModelItem{{Name: "qwen3.5:cloud", Recommended: true}}, state) + if got[0].AvailabilityBadge != "Sign in required" { + t.Fatalf("badge = %q, want Sign in required", got[0].AvailabilityBadge) + } +} + +func TestRecommendedModelsDoNotIncludeRequiredPlanStubs(t *testing.T) { + byName := make(map[string]ModelItem, len(recommendedModels)) + for _, item := range recommendedModels { + byName[item.Name] = item + } + + if item := byName["kimi-k2.6:cloud"]; item.RequiredPlan != "" { + t.Fatalf("kimi fallback required plan should not be stubbed: %#v", item) + } + if item := byName["minimax-m2.7:cloud"]; item.RequiredPlan != "" { + t.Fatalf("minimax fallback required plan should not be stubbed: %#v", item) + } + if item := byName["qwen3.5:cloud"]; item.RequiredPlan != "" { + t.Fatalf("qwen fallback required plan = %#v", item) + } + if item := byName["glm-5.1:cloud"]; item.RequiredPlan != "" { + t.Fatalf("glm fallback required plan = %#v", item) + } +} + +func TestLaunchAccountState(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantStatus accountStateStatus + wantPlan string + }{ + { + name: "signed in", + statusCode: http.StatusOK, + body: `{"name":"parth","plan":"pro"}`, + wantStatus: accountStateSignedIn, + wantPlan: "pro", + }, + { + name: "signed out", + statusCode: http.StatusUnauthorized, + body: `{"error":"unauthorized","signin_url":"https://example.com/signin"}`, + wantStatus: accountStateSignedOut, + }, + { + name: "unreachable", + statusCode: http.StatusInternalServerError, + body: `{"error":"temporary failure"}`, + wantStatus: accountStateUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/me" { + http.NotFound(w, r) + return + } + w.WriteHeader(tt.statusCode) + fmt.Fprint(w, tt.body) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + got := launchAccountState(context.Background(), api.NewClient(u, srv.Client())) + if got.Status != tt.wantStatus { + t.Fatalf("Status = %v, want %v", got.Status, tt.wantStatus) + } + if got.Plan != tt.wantPlan { + t.Fatalf("Plan = %q, want %q", got.Plan, tt.wantPlan) + } + }) + } +} + +func TestStartAccountStatePrefetch_SkipsWhoamiWhenCloudDisabled(t *testing.T) { + var whoamiCalled atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`) + case "/api/me": + whoamiCalled.Store(true) + http.NotFound(w, r) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + prefetch := StartAccountStatePrefetch(context.Background()) + select { + case <-prefetch.done: + case <-time.After(time.Second): + t.Fatal("account prefetch did not finish") + } + if whoamiCalled.Load() { + t.Fatal("prefetch should not call whoami when cloud is disabled") + } + state := prefetch.StateIfReady() + if state == nil || state.Status != accountStateUnknown { + t.Fatalf("prefetch state = %#v, want unknown", state) + } +} + func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) { oldSignIn := DefaultSignIn DefaultSignIn = func(modelName, signInURL string) (string, error) { diff --git a/cmd/launch/launch.go b/cmd/launch/launch.go index ee03d778b..7044ba344 100644 --- a/cmd/launch/launch.go +++ b/cmd/launch/launch.go @@ -22,6 +22,7 @@ type LauncherState struct { RunModel string RunModelUsable bool Integrations map[string]LauncherIntegrationState + AccountState *AccountState } // LauncherIntegrationState is the launch-owned status for one launcher integration. @@ -41,8 +42,11 @@ type LauncherIntegrationState struct { // RunModelRequest controls how the root launcher resolves the chat model. type RunModelRequest struct { - ForcePicker bool - Policy *LaunchPolicy + ForcePicker bool + Policy *LaunchPolicy + AccountState *AccountState + AccountStateProvider func() *AccountState + AccountStateUpdates func(context.Context) <-chan *AccountState } // LaunchConfirmMode controls confirmation behavior across launch flows. @@ -117,13 +121,16 @@ func (p LaunchPolicy) missingModelPolicy() missingModelPolicy { // IntegrationLaunchRequest controls the canonical integration launcher flow. type IntegrationLaunchRequest struct { - Name string - ModelOverride string - ForceConfigure bool - ConfigureOnly bool - Restore bool - ExtraArgs []string - Policy *LaunchPolicy + Name string + ModelOverride string + ForceConfigure bool + ConfigureOnly bool + Restore bool + ExtraArgs []string + Policy *LaunchPolicy + AccountState *AccountState + AccountStateProvider func() *AccountState + AccountStateUpdates func(context.Context) <-chan *AccountState } var isInteractiveSession = func() bool { @@ -241,7 +248,7 @@ type modelInfo struct { // ModelInfo re-exports launcher model inventory details for callers. type ModelInfo = modelInfo -// ModelItem represents a model for selection UIs. +// ModelItem represents model metadata before selector-only UI state is derived. type ModelItem struct { Name string Description string @@ -249,6 +256,15 @@ type ModelItem struct { VRAMBytes int64 ContextLength int MaxOutputTokens int + RequiredPlan string +} + +// SelectionItem represents a model row after launch has derived selector-only UI state. +type SelectionItem struct { + Name string + Description string + Recommended bool + AvailabilityBadge string } // LaunchCmd returns the cobra command for launching integrations. @@ -384,10 +400,15 @@ func launchCommandCanSkipHeartbeat(args []string) bool { } type launcherClient struct { - apiClient *api.Client - modelInventory []ModelInfo - inventoryLoaded bool - policy LaunchPolicy + apiClient *api.Client + modelInventory []ModelInfo + inventoryLoaded bool + recommendationsLoaded bool + recommendationItems []ModelItem + accountState *AccountState + accountStateProvider func() *AccountState + accountStateUpdates func(context.Context) <-chan *AccountState + policy LaunchPolicy } func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) { @@ -425,6 +446,9 @@ func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) { if err != nil { return "", err } + launchClient.accountState = req.AccountState + launchClient.accountStateProvider = req.AccountStateProvider + launchClient.accountStateUpdates = req.AccountStateUpdates return launchClient.resolveRunModel(ctx, req) } @@ -449,6 +473,9 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error if err != nil { return err } + launchClient.accountState = req.AccountState + launchClient.accountStateProvider = req.AccountStateProvider + launchClient.accountStateUpdates = req.AccountStateUpdates if autodiscovery, ok := runner.(ManagedAutodiscoveryIntegration); ok { if err := EnsureIntegrationInstalled(name, runner); err != nil { @@ -811,7 +838,10 @@ func (c *launcherClient) managedAutodiscoveryUsable(ctx context.Context, autodis if !managedAutodiscoveryUsesOllamaCloud(autodiscovery) { return true } - return c.ollamaCloudSignedIn(ctx) + if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled { + return false + } + return true } func (c *launcherClient) ensureManagedAutodiscoveryUsable(ctx context.Context, autodiscovery ManagedAutodiscoveryIntegration, label string) error { @@ -858,14 +888,6 @@ func printRestoreSuccess(integration any) { } } -func (c *launcherClient) ollamaCloudSignedIn(ctx context.Context) bool { - if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled { - return false - } - user, err := c.apiClient.Whoami(ctx) - return err == nil && user != nil && user.Name != "" -} - func (c *launcherClient) managedSingleConfigureModels(ctx context.Context, managed ManagedSingleModel, target string) ([]string, error) { models := []string{target} if _, ok := managed.(ManagedModelListConfigurer); !ok { @@ -959,8 +981,15 @@ func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, titl return c.selectSingleModelWithSelectorReady(ctx, title, current, selector, true) } +func (c *launcherClient) latestAccountState() *AccountState { + if c.accountStateProvider != nil { + return c.accountStateProvider() + } + return c.accountState +} + func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context, title, current string, selector SingleSelector, ensureReady bool) (string, error) { - if selector == nil { + if selector == nil && DefaultSingleSelectorWithUpdates == nil { return "", fmt.Errorf("no selector configured") } @@ -969,45 +998,88 @@ func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context, return "", err } - selected, err := selector(title, items, current) - if err != nil { - return "", err - } - if selected == "" { - return "", ErrCancelled - } - if ensureReady { - if err := c.ensureModelsReady(ctx, []string{selected}); err != nil { + for { + accountState := c.latestAccountState() + selectionItems := SelectionItemsWithAccountState(items, accountState) + var updates <-chan []SelectionItem + if DefaultSingleSelectorWithUpdates != nil { + updates = c.selectionItemUpdates(ctx, items, accountState) + } + selected, err := runSingleSelector(title, selectionItems, current, updates, selector) + if err != nil { return "", err } + if selected == "" { + return "", ErrCancelled + } + if ensureReady { + if err := c.ensureModelsReady(ctx, []string{selected}); err != nil { + if errors.Is(err, errUpgradeCancelled) { + current = selected + continue + } + return "", err + } + } + return selected, nil } - return selected, nil } func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) { - if DefaultMultiSelector == nil { + if DefaultMultiSelector == nil && DefaultMultiSelectorWithUpdates == nil { return nil, fmt.Errorf("no selector configured") } current := firstModel(preChecked) - items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available") if err != nil { return nil, err } - selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked) - if err != nil { - return nil, err + for { + accountState := c.latestAccountState() + selectionItems := SelectionItemsWithAccountState(items, accountState) + var updates <-chan []SelectionItem + if DefaultMultiSelectorWithUpdates != nil { + updates = c.selectionItemUpdates(ctx, items, accountState) + } + selected, err := runMultiSelector(fmt.Sprintf("Select models for %s:", runner), selectionItems, orderedChecked, updates) + if err != nil { + return nil, err + } + accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected) + if err != nil { + if errors.Is(err, errUpgradeCancelled) { + orderedChecked = append([]string(nil), selected...) + continue + } + return nil, err + } + for _, skip := range skipped { + fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason) + } + return accepted, nil } - accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected) - if err != nil { - return nil, err +} + +func runSingleSelector(title string, items []SelectionItem, current string, updates <-chan []SelectionItem, fallback SingleSelector) (string, error) { + if DefaultSingleSelectorWithUpdates != nil { + return DefaultSingleSelectorWithUpdates(title, items, current, updates) } - for _, skip := range skipped { - fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason) + if fallback == nil { + return "", fmt.Errorf("no selector configured") } - return accepted, nil + return fallback(title, items, current) +} + +func runMultiSelector(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error) { + if DefaultMultiSelectorWithUpdates != nil { + return DefaultMultiSelectorWithUpdates(title, items, preChecked, updates) + } + if DefaultMultiSelector == nil { + return nil, fmt.Errorf("no selector configured") + } + return DefaultMultiSelector(title, items, preChecked) } func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) { @@ -1029,16 +1101,24 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked [] } func (c *launcherClient) recommendations(ctx context.Context) []ModelItem { + if c.recommendationsLoaded { + return append([]ModelItem(nil), c.recommendationItems...) + } + recommendations, err := c.requestRecommendations(ctx) if err != nil || len(recommendations) == 0 { // Fail open: recommendation issues should not block launch flows. // Fall back to built-in recommendations until server data is available. fallback := append([]ModelItem(nil), recommendedModels...) setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(fallback)) - return fallback + c.recommendationItems = fallback + c.recommendationsLoaded = true + return append([]ModelItem(nil), fallback...) } setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(recommendations)) - return recommendations + c.recommendationItems = recommendations + c.recommendationsLoaded = true + return append([]ModelItem(nil), recommendations...) } func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelItem, error) { @@ -1076,6 +1156,7 @@ func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelIte VRAMBytes: rec.VRAMBytes, ContextLength: rec.ContextLength, MaxOutputTokens: rec.MaxOutputTokens, + RequiredPlan: strings.TrimSpace(rec.RequiredPlan), }) } @@ -1093,6 +1174,9 @@ func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) isCloudModel := isCloudModelName(model) if isCloudModel { cloudModels[model] = true + if err := c.ensureCloudModelAccess(ctx, model); err != nil { + return err + } } if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil { return err @@ -1126,6 +1210,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected for _, model := range selected { if err := c.ensureModelsReady(ctx, []string{model}); err != nil { + if errors.Is(err, errUpgradeCancelled) { + return nil, nil, err + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, nil, err } @@ -1142,6 +1229,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected } func skippedModelReason(model string, err error) string { + if errors.Is(err, errUpgradeCancelled) { + return "upgrade was cancelled" + } if errors.Is(err, ErrCancelled) { if isCloudModelName(model) { return "sign in was cancelled" diff --git a/cmd/launch/launch_test.go b/cmd/launch/launch_test.go index 9b63dcdf1..8333c89ca 100644 --- a/cmd/launch/launch_test.go +++ b/cmd/launch/launch_test.go @@ -12,6 +12,7 @@ import ( "slices" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/cmd/config" @@ -192,14 +193,20 @@ func withInteractiveSession(t *testing.T, interactive bool) { func withLauncherHooks(t *testing.T) { t.Helper() oldSingle := DefaultSingleSelector + oldSingleWithUpdates := DefaultSingleSelectorWithUpdates oldMulti := DefaultMultiSelector + oldMultiWithUpdates := DefaultMultiSelectorWithUpdates oldConfirm := DefaultConfirmPrompt oldSignIn := DefaultSignIn + oldUpgrade := DefaultUpgrade t.Cleanup(func() { DefaultSingleSelector = oldSingle + DefaultSingleSelectorWithUpdates = oldSingleWithUpdates DefaultMultiSelector = oldMulti + DefaultMultiSelectorWithUpdates = oldMultiWithUpdates DefaultConfirmPrompt = oldConfirm DefaultSignIn = oldSignIn + DefaultUpgrade = oldUpgrade }) } @@ -346,7 +353,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t * } withIntegrationOverride(t, "stubmanaged", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { return "gemma4", nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -501,7 +508,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches( runner := &launcherManagedRunner{} withIntegrationOverride(t, "stubmanaged", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called when saved model matches target") return "", nil } @@ -553,7 +560,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *t runner := &launcherManagedRunner{} withIntegrationOverride(t, "stubmanaged", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called when model override is provided") return "", nil } @@ -607,7 +614,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenLiveConfigDrifts( } withIntegrationOverride(t, "stubmanaged", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called when live config already provides the target") return "", nil } @@ -734,7 +741,7 @@ func TestLaunchIntegration_ManagedAutodiscoverySkipsModelPicker(t *testing.T) { runner := &launcherManagedAutodiscoveryRunner{} withIntegrationOverride(t, "stubmanaged", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("model selector should not run for autodiscovery integrations") return "", nil } @@ -987,7 +994,7 @@ func TestLaunchIntegration_CloudAutodiscoveryUsesSignInHook(t *testing.T) { } } -func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *testing.T) { +func TestBuildLauncherIntegrationState_CloudAutodiscoveryDoesNotCheckSignIn(t *testing.T) { tmpDir := t.TempDir() setLaunchTestHome(t, tmpDir) withLauncherHooks(t) @@ -1004,8 +1011,7 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes w.WriteHeader(http.StatusNotFound) fmt.Fprint(w, `{"error":"not found"}`) case "/api/me": - w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`) + t.Fatal("build launcher state should not check whoami") default: http.NotFound(w, r) } @@ -1028,8 +1034,8 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes if state.CurrentModel != "Ollama Cloud" { t.Fatalf("current model = %q, want Ollama Cloud", state.CurrentModel) } - if state.ModelUsable { - t.Fatal("expected cloud autodiscovery config to be unusable while signed out") + if !state.ModelUsable { + t.Fatal("expected cloud autodiscovery config to stay usable until launch-time auth check") } } @@ -1296,7 +1302,7 @@ func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) { } selectorCalled := false - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { selectorCalled = true return "", nil } @@ -1338,7 +1344,7 @@ func TestResolveRunModel_HeadlessYesAutoPicksLastModel(t *testing.T) { t.Fatalf("failed to save last model: %v", err) } - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called in headless --yes mode") return "", nil } @@ -1405,7 +1411,7 @@ func TestResolveRunModel_UsesRequestPolicy(t *testing.T) { t.Fatalf("failed to save last model: %v", err) } - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { t.Fatal("selector should not be called when request policy enables headless auto-pick") return "", nil } @@ -1465,7 +1471,7 @@ func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) { } var selectorCalls int - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { selectorCalls++ if current != "llama3.2" { t.Fatalf("expected current selection to be last model, got %q", current) @@ -1513,7 +1519,7 @@ func TestResolveRunModel_ForcePicker_DoesNotReorderByLastModel(t *testing.T) { } var gotNames []string - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { if current != "qwen3.5" { t.Fatalf("expected current selection to be last model, got %q", current) } @@ -1564,7 +1570,7 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) { setLaunchTestHome(t, tmpDir) withLauncherHooks(t) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { return "glm-5:cloud", nil } @@ -1610,6 +1616,241 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) { } } +func TestResolveRunModel_MetadataSignedOutUsesSignInHook(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { + return "qwen3.5:cloud", nil + } + + signedIn := false + signInCalled := false + DefaultSignIn = func(modelName, signInURL string) (string, error) { + signInCalled = true + signedIn = true + if modelName != "qwen3.5:cloud" { + t.Fatalf("unexpected model passed to sign-in: %q", modelName) + } + if signInURL != "https://example.com/signin" { + t.Fatalf("unexpected sign-in URL: %q", signInURL) + } + return "test-user", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/experimental/model-recommendations": + fmt.Fprint(w, `{"recommendations":[{"model":"qwen3.5:cloud","description":"Reasoning","context_length":262144,"max_output_tokens":32768}]}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + fmt.Fprint(w, `{"remote_model":"qwen3.5"}`) + case "/api/me": + if !signedIn { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`) + return + } + fmt.Fprint(w, `{"name":"test-user","plan":"free"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if model != "qwen3.5:cloud" { + t.Fatalf("expected selected cloud model, got %q", model) + } + if !signInCalled { + t.Fatal("expected sign-in hook to be used for account-gated cloud model") + } +} + +func TestResolveRunModel_SubscriptionModelUsesUpgradeHook(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + DefaultSingleSelectorWithUpdates = func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error) { + for _, item := range items { + if item.Name == "kimi-k2.6:cloud" && item.AvailabilityBadge != "" { + t.Fatalf("initial availability badge = %q, want empty before account update", item.AvailabilityBadge) + } + } + select { + case items = <-updates: + case <-time.After(time.Second): + t.Fatal("timed out waiting for selector item update") + } + for _, item := range items { + if item.Name == "kimi-k2.6:cloud" { + if item.AvailabilityBadge != "Upgrade required" { + t.Fatalf("availability badge = %q, want Upgrade required", item.AvailabilityBadge) + } + return "kimi-k2.6:cloud", nil + } + } + t.Fatalf("paid cloud model missing from selector items: %#v", items) + return "kimi-k2.6:cloud", nil + } + DefaultSignIn = func(modelName, signInURL string) (string, error) { + t.Fatalf("did not expect sign-in hook for signed-in user") + return "", nil + } + + plan := "free" + upgradeCalled := false + DefaultUpgrade = func(modelName, requiredPlan string) (string, error) { + upgradeCalled = true + if modelName != "kimi-k2.6:cloud" { + t.Fatalf("unexpected model passed to upgrade: %q", modelName) + } + if requiredPlan != "pro" { + t.Fatalf("unexpected required plan: %q", requiredPlan) + } + plan = "max" + return plan, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/experimental/model-recommendations": + fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + fmt.Fprint(w, `{"remote_model":"kimi-k2.6"}`) + case "/api/me": + fmt.Fprintf(w, `{"name":"test-user","plan":%q}`, plan) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if model != "kimi-k2.6:cloud" { + t.Fatalf("expected selected cloud model, got %q", model) + } + if !upgradeCalled { + t.Fatal("expected upgrade hook to be used for subscription-gated cloud model") + } +} + +func TestResolveRunModel_UpgradeCancelledReturnsToModelSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + selectorCalls := 0 + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { + selectorCalls++ + switch selectorCalls { + case 1: + return "kimi-k2.6:cloud", nil + case 2: + return "llama3.2", nil + default: + t.Fatalf("selector called too many times: %d", selectorCalls) + return "", nil + } + } + DefaultUpgrade = func(modelName, requiredPlan string) (string, error) { + return "", ErrCancelled + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/experimental/model-recommendations": + fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + case "/api/me": + fmt.Fprint(w, `{"name":"test-user","plan":"free"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if model != "llama3.2" { + t.Fatalf("model = %q, want llama3.2", model) + } + if selectorCalls != 2 { + t.Fatalf("selector calls = %d, want 2", selectorCalls) + } +} + +func TestResolveRunModel_SubscriptionModelUnavailableWhoamiFailsGracefully(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { + return "kimi-k2.6:cloud", nil + } + DefaultUpgrade = func(modelName, requiredPlan string) (string, error) { + t.Fatalf("did not expect upgrade hook when plan could not be verified") + return "", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/experimental/model-recommendations": + fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/me": + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"error":"temporary failure"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + _, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err == nil { + t.Fatal("expected plan verification error") + } + if !strings.Contains(err.Error(), "Could not verify your plan. Try again in a moment.") { + t.Fatalf("unexpected error: %v", err) + } +} + func TestLaunchIntegration_EditorForceConfigure(t *testing.T) { tmpDir := t.TempDir() setLaunchTestHome(t, tmpDir) @@ -1623,7 +1864,7 @@ func TestLaunchIntegration_EditorForceConfigure(t *testing.T) { withIntegrationOverride(t, "droid", editor) var multiCalled bool - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { multiCalled = true return []string{"llama3.2", "qwen3:8b"}, nil } @@ -1688,7 +1929,7 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t var gotItems []string var gotPreChecked []string - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { for _, item := range items { gotItems = append(gotItems, item.Name) } @@ -1809,7 +2050,7 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T) } var multiCalled bool - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { multiCalled = true return []string{"llama3.2"}, nil } @@ -1851,7 +2092,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccep editor := &launcherEditorRunner{} withIntegrationOverride(t, "droid", editor) - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { return []string{"glm-5:cloud", "missing-local"}, nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -1932,7 +2173,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce editor := &launcherEditorRunner{} withIntegrationOverride(t, "droid", editor) - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { return []string{"llama3.2", "glm-5:cloud"}, nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -2001,6 +2242,84 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce } } +func TestLaunchIntegration_EditorConfigureUpgradeCancelledReturnsToModelSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{} + withIntegrationOverride(t, "droid", editor) + + selectorCalls := 0 + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { + selectorCalls++ + switch selectorCalls { + case 1: + return []string{"kimi-k2.6:cloud"}, nil + case 2: + if diff := compareStrings(preChecked, []string{"kimi-k2.6:cloud"}); diff != "" { + t.Fatalf("second selector preChecked (-want +got):\n%s", diff) + } + return []string{"llama3.2"}, nil + default: + t.Fatalf("selector called too many times: %d", selectorCalls) + return nil, nil + } + } + DefaultUpgrade = func(modelName, requiredPlan string) (string, error) { + return "", ErrCancelled + } + DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { + if prompt == "Proceed?" { + return true, nil + } + t.Fatalf("unexpected prompt: %q", prompt) + return false, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/experimental/model-recommendations": + fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + case "/api/me": + fmt.Fprint(w, `{"name":"test-user","plan":"free"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ForceConfigure: true, + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if selectorCalls != 2 { + t.Fatalf("selector calls = %d, want 2", selectorCalls) + } + if editor.ranModel != "llama3.2" { + t.Fatalf("expected launch to use local model, got %q", editor.ranModel) + } + if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" { + t.Fatalf("unexpected edited models (-want +got):\n%s", diff) + } +} + func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *testing.T) { tmpDir := t.TempDir() setLaunchTestHome(t, tmpDir) @@ -2016,7 +2335,7 @@ func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t * if err := config.SaveIntegration("droid", []string{"glm-5:cloud", "llama3.2"}); err != nil { t.Fatalf("failed to seed config: %v", err) } - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { return append([]string(nil), preChecked...), nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -2102,7 +2421,7 @@ func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsL t.Fatalf("failed to seed config: %v", err) } - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { return []string{"missing-local-a", "missing-local-b"}, nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -2342,7 +2661,7 @@ func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T) withIntegrationOverride(t, "openclaw", editor) selectorCalled := false - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { selectorCalled = true return []string{"llama3.2"}, nil } @@ -2376,7 +2695,7 @@ func TestLaunchIntegration_PiInstallsBeforeConfigSideEffects(t *testing.T) { withIntegrationOverride(t, "pi", editor) selectorCalled := false - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { selectorCalled = true return []string{"llama3.2"}, nil } @@ -2408,7 +2727,7 @@ func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}} withIntegrationOverride(t, "droid", editor) - DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) { return []string{"llama3.2"}, nil } @@ -2519,7 +2838,7 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) { } var selectorCalls int - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { selectorCalls++ return "glm-5:cloud", nil } @@ -2572,7 +2891,7 @@ func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *te t.Fatalf("failed to seed config: %v", err) } - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { return "missing-model", nil } DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) { @@ -2633,7 +2952,7 @@ func TestLaunchIntegration_ClaudeModelOverrideSkipsSelector(t *testing.T) { t.Setenv("PATH", binDir) var selectorCalls int - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { selectorCalls++ return "", fmt.Errorf("selector should not run when --model override is set") } @@ -2698,7 +3017,7 @@ func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) { runner := &launcherSingleRunner{} withIntegrationOverride(t, "stubsingle", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { return "llama3.2", nil } @@ -2926,7 +3245,7 @@ func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T) runner := &launcherSingleRunner{} withIntegrationOverride(t, "droid", runner) - DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) { return "missing-model", nil } diff --git a/cmd/launch/models.go b/cmd/launch/models.go index 619de47f2..2283af541 100644 --- a/cmd/launch/models.go +++ b/cmd/launch/models.go @@ -188,7 +188,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string) return errors.New(internalcloud.DisabledError("remote inference is unavailable")) } - user, err := client.Whoami(ctx) + user, err := whoamiWithTimeout(ctx, client) if err == nil && user != nil && user.Name != "" { return nil } @@ -243,7 +243,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string) fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) if frame%10 == 0 { - u, err := client.Whoami(ctx) + u, err := whoamiWithTimeout(ctx, client) if err == nil && u != nil && u.Name != "" { fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) return nil @@ -348,9 +348,11 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M var hasLocalModel, hasCloudModel bool recDesc := make(map[string]string) + recByName := make(map[string]ModelItem) for _, rec := range recommendations { recommended[rec.Name] = true recDesc[rec.Name] = rec.Description + recByName[rec.Name] = rec } for _, m := range existing { @@ -364,6 +366,9 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M displayName := strings.TrimSuffix(m.Name, ":latest") existingModels[displayName] = true item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]} + if rec, ok := recByName[displayName]; ok { + item = copyModelRecommendationFields(displayName, rec) + } items = append(items, item) } @@ -472,6 +477,12 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M return items, preChecked, existingModels, cloudModels } +func copyModelRecommendationFields(name string, rec ModelItem) ModelItem { + rec.Name = name + rec.Recommended = true + return rec +} + // isCloudModelName reports whether the model name has an explicit cloud source. func isCloudModelName(name string) bool { return modelref.HasExplicitCloudSource(name) diff --git a/cmd/launch/selector_hooks.go b/cmd/launch/selector_hooks.go index 0f55aadea..d8d83ea91 100644 --- a/cmd/launch/selector_hooks.go +++ b/cmd/launch/selector_hooks.go @@ -35,22 +35,38 @@ type ConfirmOptions struct { // SingleSelector is a function type for single item selection. // current is the name of the previously selected item to highlight; empty means no pre-selection. -type SingleSelector func(title string, items []ModelItem, current string) (string, error) +type SingleSelector func(title string, items []SelectionItem, current string) (string, error) + +// SingleSelectorWithUpdates is a single item selector that can receive refreshed item state while open. +type SingleSelectorWithUpdates func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error) // MultiSelector is a function type for multi item selection. -type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error) +type MultiSelector func(title string, items []SelectionItem, preChecked []string) ([]string, error) + +// MultiSelectorWithUpdates is a multi item selector that can receive refreshed item state while open. +type MultiSelectorWithUpdates func(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error) // DefaultSingleSelector is the default single-select implementation. var DefaultSingleSelector SingleSelector +// DefaultSingleSelectorWithUpdates is the default single-select implementation with live updates. +var DefaultSingleSelectorWithUpdates SingleSelectorWithUpdates + // DefaultMultiSelector is the default multi-select implementation. var DefaultMultiSelector MultiSelector +// DefaultMultiSelectorWithUpdates is the default multi-select implementation with live updates. +var DefaultMultiSelectorWithUpdates MultiSelectorWithUpdates + // DefaultSignIn provides a TUI-based sign-in flow. // When set, ensureAuth uses it instead of plain text prompts. // Returns the signed-in username or an error. var DefaultSignIn func(modelName, signInURL string) (string, error) +// DefaultUpgrade provides a TUI-based upgrade flow. +// Returns the updated plan or an error. +var DefaultUpgrade func(modelName, requiredPlan string) (string, error) + type launchConfirmPolicy struct { yes bool requireYesMessage bool diff --git a/cmd/tui/selector.go b/cmd/tui/selector.go index b9baa3b6d..c53fe16d5 100644 --- a/cmd/tui/selector.go +++ b/cmd/tui/selector.go @@ -35,8 +35,7 @@ var ( Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"}) selectorDefaultTagStyle = lipgloss.NewStyle(). - Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}). - Italic(true) + Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}) selectorHelpStyle = lipgloss.NewStyle(). Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"}) @@ -58,16 +57,39 @@ const maxSelectorItems = 10 var ErrCancelled = launch.ErrCancelled type SelectItem struct { - Name string - Description string - Recommended bool + Name string + Description string + Recommended bool + AvailabilityBadge string } -// ConvertItems converts launch.ModelItem slice to SelectItem slice. -func ConvertItems(items []launch.ModelItem) []SelectItem { +type selectorItemsUpdatedMsg struct { + items []SelectItem +} + +func waitForSelectorItems(updates <-chan []SelectItem) tea.Cmd { + if updates == nil { + return nil + } + return func() tea.Msg { + items, ok := <-updates + if !ok { + return nil + } + return selectorItemsUpdatedMsg{items: items} + } +} + +// ConvertItems converts launch.SelectionItem slice to SelectItem slice. +func ConvertItems(items []launch.SelectionItem) []SelectItem { out := make([]SelectItem, len(items)) for i, item := range items { - out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended} + out[i] = SelectItem{ + Name: item.Name, + Description: item.Description, + Recommended: item.Recommended, + AvailabilityBadge: item.AvailabilityBadge, + } } return out } @@ -91,6 +113,7 @@ func ReorderItems(items []SelectItem) []SelectItem { type selectorModel struct { title string items []SelectItem + updates <-chan []SelectItem filter string cursor int scrollOffset int @@ -110,6 +133,33 @@ func selectorModelWithCurrent(title string, items []SelectItem, current string) return m } +func currentItemName(items []SelectItem, cursor int) string { + if cursor < 0 || cursor >= len(items) { + return "" + } + return items[cursor].Name +} + +func cursorForItemName(items []SelectItem, name string, fallback int) int { + if len(items) == 0 { + return 0 + } + if name != "" { + for i, item := range items { + if item.Name == name { + return i + } + } + } + if fallback < 0 { + return 0 + } + if fallback >= len(items) { + return len(items) - 1 + } + return fallback +} + func (m selectorModel) filteredItems() []SelectItem { if m.filter == "" { return m.items @@ -125,7 +175,7 @@ func (m selectorModel) filteredItems() []SelectItem { } func (m selectorModel) Init() tea.Cmd { - return nil + return waitForSelectorItems(m.updates) } // otherStart returns the index of the first non-recommended item in the filtered list. @@ -235,6 +285,13 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil + case selectorItemsUpdatedMsg: + current := currentItemName(m.filteredItems(), m.cursor) + m.items = msg.items + m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor) + m.updateScroll(m.otherStart()) + return m, waitForSelectorItems(m.updates) + case tea.KeyMsg: switch msg.Type { case tea.KeyCtrlC, tea.KeyEsc: @@ -260,9 +317,17 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } +func cursorItemSuffix(item SelectItem) string { + if item.AvailabilityBadge == "" { + return "" + } + return " " + selectorDefaultTagStyle.Render("("+item.AvailabilityBadge+")") +} + func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) { if idx == m.cursor { s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name)) + s.WriteString(cursorItemSuffix(item)) } else { s.WriteString(selectorItemStyle.Render(item.Name)) } @@ -402,11 +467,16 @@ func cursorForCurrent(items []SelectItem, current string) int { } func SelectSingle(title string, items []SelectItem, current string) (string, error) { + return SelectSingleWithUpdates(title, items, current, nil) +} + +func SelectSingleWithUpdates(title string, items []SelectItem, current string, updates <-chan []SelectItem) (string, error) { if len(items) == 0 { return "", fmt.Errorf("no items to select from") } m := selectorModelWithCurrent(title, items, current) + m.updates = updates p := tea.NewProgram(m) finalModel, err := p.Run() @@ -426,6 +496,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err type multiSelectorModel struct { title string items []SelectItem + updates <-chan []SelectItem itemIndex map[string]int filter string cursor int @@ -475,6 +546,36 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string return m } +func (m *multiSelectorModel) rebuildItemIndex() { + m.itemIndex = make(map[string]int, len(m.items)) + for i, item := range m.items { + m.itemIndex[item.Name] = i + } +} + +func (m *multiSelectorModel) replaceItems(items []SelectItem) { + current := currentItemName(m.filteredItems(), m.cursor) + checkedNames := make([]string, 0, len(m.checkOrder)) + for _, idx := range m.checkOrder { + if idx >= 0 && idx < len(m.items) { + checkedNames = append(checkedNames, m.items[idx].Name) + } + } + + m.items = items + m.rebuildItemIndex() + m.checked = make(map[int]bool, len(checkedNames)) + m.checkOrder = nil + for _, name := range checkedNames { + if idx, ok := m.itemIndex[name]; ok { + m.checked[idx] = true + m.checkOrder = append(m.checkOrder, idx) + } + } + m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor) + m.updateScroll(m.otherStart()) +} + func (m multiSelectorModel) filteredItems() []SelectItem { if m.filter == "" { return m.items @@ -590,7 +691,7 @@ func (m multiSelectorModel) selectedCount() int { } func (m multiSelectorModel) Init() tea.Cmd { - return nil + return waitForSelectorItems(m.updates) } func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -603,6 +704,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil + case selectorItemsUpdatedMsg: + m.replaceItems(msg.items) + return m, waitForSelectorItems(m.updates) + case tea.KeyMsg: filtered := m.filteredItems() @@ -689,6 +794,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) { if idx == m.cursor { s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name)) + s.WriteString(cursorItemSuffix(item)) } else { s.WriteString(selectorItemStyle.Render(item.Name)) } @@ -716,6 +822,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, if idx == m.cursor { s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name)) + s.WriteString(cursorItemSuffix(item)) } else { s.WriteString(selectorItemStyle.Render(check + item.Name)) } @@ -841,11 +948,16 @@ func (m multiSelectorModel) View() string { } func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) { + return SelectMultipleWithUpdates(title, items, preChecked, nil) +} + +func SelectMultipleWithUpdates(title string, items []SelectItem, preChecked []string, updates <-chan []SelectItem) ([]string, error) { if len(items) == 0 { return nil, fmt.Errorf("no items to select from") } m := newMultiSelectorModel(title, items, preChecked) + m.updates = updates p := tea.NewProgram(m) finalModel, err := p.Run() diff --git a/cmd/tui/selector_test.go b/cmd/tui/selector_test.go index 1887c3af5..ad6182134 100644 --- a/cmd/tui/selector_test.go +++ b/cmd/tui/selector_test.go @@ -311,6 +311,91 @@ func TestRenderContent_SelectedItemIndicator(t *testing.T) { } } +func TestRenderContent_AvailabilityBadgeOnlyOnCursor(t *testing.T) { + m := selectorModel{ + title: "Pick:", + items: []SelectItem{ + {Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"}, + {Name: "qwen3.5:cloud", AvailabilityBadge: "Sign in required"}, + {Name: "glm-5:cloud", AvailabilityBadge: "Included"}, + }, + cursor: 0, + } + content := m.renderContent() + + if !strings.Contains(content, "(Upgrade required)") { + t.Fatalf("cursor badge missing:\n%s", content) + } + if strings.Contains(content, "(Sign in required)") { + t.Fatalf("non-cursor badge should not render:\n%s", content) + } + if strings.Contains(content, "Included") { + t.Fatalf("included badge should not render:\n%s", content) + } +} + +func TestSelectorModel_ItemsUpdatedPreservesCursorAndRendersBadge(t *testing.T) { + m := selectorModelWithCurrent("Pick:", []SelectItem{ + {Name: "kimi-k2.6:cloud", Recommended: true}, + {Name: "llama3.2"}, + }, "kimi-k2.6:cloud") + + updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{ + {Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"}, + {Name: "llama3.2"}, + }}) + fm := updated.(selectorModel) + if fm.cursor != 0 { + t.Fatalf("cursor = %d, want 0", fm.cursor) + } + content := fm.renderContent() + if !strings.Contains(content, "(Upgrade required)") { + t.Fatalf("updated badge missing:\n%s", content) + } +} + +func TestMultiSelector_AvailabilityBadgePreservesDefaultSuffix(t *testing.T) { + m := newMultiSelectorModel("Pick:", []SelectItem{ + {Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"}, + {Name: "qwen3.5:cloud"}, + }, []string{"kimi-k2.6:cloud"}) + m.multi = true + m.cursor = 0 + + content := m.View() + if !strings.Contains(content, "(Upgrade required)") { + t.Fatalf("cursor badge missing:\n%s", content) + } + if !strings.Contains(content, "(default)") { + t.Fatalf("default suffix missing:\n%s", content) + } +} + +func TestMultiSelector_ItemsUpdatedPreservesCheckedStateAndRendersBadge(t *testing.T) { + m := newMultiSelectorModel("Pick:", []SelectItem{ + {Name: "kimi-k2.6:cloud", Recommended: true}, + {Name: "llama3.2"}, + }, []string{"kimi-k2.6:cloud"}) + m.multi = true + + updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{ + {Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"}, + {Name: "llama3.2"}, + }}) + fm := updated.(multiSelectorModel) + idx := fm.itemIndex["kimi-k2.6:cloud"] + if !fm.checked[idx] { + t.Fatalf("checked state was not preserved: %#v", fm.checked) + } + content := fm.View() + if !strings.Contains(content, "(Upgrade required)") { + t.Fatalf("updated badge missing:\n%s", content) + } + if !strings.Contains(content, "(default)") { + t.Fatalf("default suffix missing after update:\n%s", content) + } +} + func TestRenderContent_Description(t *testing.T) { m := selectorModel{ title: "Pick:", diff --git a/cmd/tui/signin.go b/cmd/tui/signin.go index 4cfea1e00..c45326ebd 100644 --- a/cmd/tui/signin.go +++ b/cmd/tui/signin.go @@ -19,6 +19,14 @@ type signInCheckMsg struct { userName string } +type upgradeTickMsg struct{} + +type upgradeCheckMsg struct { + upgraded bool + plan string + err error +} + type signInModel struct { modelName string signInURL string @@ -28,6 +36,18 @@ type signInModel struct { cancelled bool } +type upgradeModel struct { + modelName string + requiredPlan string + spinner int + width int + openNow bool + polling bool + plan string + cancelled bool + err error +} + func (m signInModel) Init() tea.Cmd { return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg { return signInTickMsg{} @@ -82,6 +102,85 @@ func (m signInModel) View() string { return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width) } +func (m upgradeModel) Init() tea.Cmd { + if m.polling { + return upgradeTickCmd() + } + return nil +} + +func (m upgradeModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + wasSet := m.width > 0 + m.width = msg.Width + if wasSet { + return m, tea.EnterAltScreen + } + return m, nil + + case tea.KeyMsg: + switch msg.Type { + case tea.KeyCtrlC, tea.KeyEsc: + m.cancelled = true + return m, tea.Quit + case tea.KeyLeft: + if !m.polling { + m.openNow = true + } + case tea.KeyRight: + if !m.polling { + m.openNow = false + } + case tea.KeyEnter: + if !m.polling { + if !m.openNow { + m.cancelled = true + return m, tea.Quit + } + launch.OpenBrowser(launch.DefaultUpgradeURL) + m.polling = true + return m, upgradeTickCmd() + } + } + + case upgradeTickMsg: + if !m.polling { + return m, nil + } + m.spinner++ + if m.spinner%5 == 0 { + return m, tea.Batch( + upgradeTickCmd(), + checkUpgrade(m.requiredPlan), + ) + } + return m, upgradeTickCmd() + + case upgradeCheckMsg: + if msg.err != nil { + m.err = msg.err + return m, tea.Quit + } + if msg.upgraded { + m.plan = msg.plan + return m, tea.Quit + } + } + + return m, nil +} + +func (m upgradeModel) View() string { + if m.plan != "" { + return "" + } + if m.err != nil { + return "" + } + return renderUpgrade(m.modelName, m.spinner, m.width, m.polling, m.openNow) +} + func renderSignIn(modelName, signInURL string, spinner, width int) string { spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} frame := spinnerFrames[spinner%len(spinnerFrames)] @@ -110,18 +209,88 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string { return lipgloss.NewStyle().PaddingLeft(2).Render(s.String()) } +func upgradeTickCmd() tea.Cmd { + return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg { + return upgradeTickMsg{} + }) +} + +func renderUpgrade(modelName string, spinner, width int, polling, openNow bool) string { + spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + frame := spinnerFrames[spinner%len(spinnerFrames)] + + urlColor := lipgloss.NewStyle(). + Foreground(lipgloss.Color("117")) + urlWrap := lipgloss.NewStyle().PaddingLeft(2) + if width > 4 { + urlWrap = urlWrap.Width(width - 4) + } + + var s strings.Builder + + fmt.Fprintf(&s, "To use %s, upgrade your Ollama plan.\n\n", selectorSelectedItemStyle.Render(modelName)) + + s.WriteString("Navigate to:\n") + s.WriteString(urlWrap.Render(urlColor.Render(launch.DefaultUpgradeURL))) + s.WriteString("\n\n") + + if !polling { + var yesBtn, noBtn string + if openNow { + yesBtn = confirmActiveStyle.Render(" Yes ") + noBtn = confirmInactiveStyle.Render(" No ") + } else { + yesBtn = confirmInactiveStyle.Render(" Yes ") + noBtn = confirmActiveStyle.Render(" No ") + } + + s.WriteString("Open now?\n") + s.WriteString(" " + yesBtn + " " + noBtn) + s.WriteString("\n\n") + s.WriteString(selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")) + } else { + s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render( + frame + " Waiting for upgrade to complete...")) + s.WriteString("\n\n") + s.WriteString(selectorHelpStyle.Render("esc cancel")) + } + + return lipgloss.NewStyle().PaddingLeft(2).Render(s.String()) +} + func checkSignIn() tea.Msg { client, err := api.ClientFromEnvironment() if err != nil { return signInCheckMsg{signedIn: false} } - user, err := client.Whoami(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + user, err := client.Whoami(ctx) if err == nil && user != nil && user.Name != "" { return signInCheckMsg{signedIn: true, userName: user.Name} } return signInCheckMsg{signedIn: false} } +func checkUpgrade(requiredPlan string) tea.Cmd { + return func() tea.Msg { + client, err := api.ClientFromEnvironment() + if err != nil { + return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable} + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + user, err := client.Whoami(ctx) + if err != nil { + return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable} + } + if err == nil && user != nil && user.Name != "" && launch.PlanSatisfies(user.Plan, requiredPlan) { + return upgradeCheckMsg{upgraded: true, plan: user.Plan} + } + return upgradeCheckMsg{upgraded: false} + } +} + // RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels. func RunSignIn(modelName, signInURL string) (string, error) { launch.OpenBrowser(signInURL) @@ -144,3 +313,28 @@ func RunSignIn(modelName, signInURL string) (string, error) { return fm.userName, nil } + +// RunUpgrade shows a bubbletea upgrade dialog and polls until the user's plan is updated or cancelled. +func RunUpgrade(modelName, requiredPlan string) (string, error) { + m := upgradeModel{ + modelName: modelName, + requiredPlan: requiredPlan, + openNow: true, + } + + p := tea.NewProgram(m) + finalModel, err := p.Run() + if err != nil { + return "", fmt.Errorf("error running upgrade: %w", err) + } + + fm := finalModel.(upgradeModel) + if fm.cancelled { + return "", ErrCancelled + } + if fm.err != nil { + return "", fm.err + } + + return fm.plan, nil +} diff --git a/cmd/tui/signin_test.go b/cmd/tui/signin_test.go index 430b4f7b3..ced5c53a9 100644 --- a/cmd/tui/signin_test.go +++ b/cmd/tui/signin_test.go @@ -5,6 +5,7 @@ import ( "testing" tea "github.com/charmbracelet/bubbletea" + "github.com/ollama/ollama/cmd/launch" ) func TestRenderSignIn_ContainsModelName(t *testing.T) { @@ -50,6 +51,35 @@ func TestRenderSignIn_ContainsEscHelp(t *testing.T) { } } +func TestRenderUpgrade_AsksBeforeOpening(t *testing.T) { + got := renderUpgrade("kimi-k2.6:cloud", 0, 80, false, true) + if !strings.Contains(got, "kimi-k2.6:cloud") { + t.Error("should contain model name") + } + if !strings.Contains(got, launch.DefaultUpgradeURL) { + t.Error("should contain upgrade URL") + } + if !strings.Contains(got, "Open now?") { + t.Error("should ask before opening") + } + if !strings.Contains(got, "Yes") || !strings.Contains(got, "No") { + t.Error("should show yes/no selector") + } + if strings.Contains(got, "Waiting for upgrade to complete") { + t.Error("should not start waiting before open choice is confirmed") + } +} + +func TestRenderUpgrade_PollingShowsWaiting(t *testing.T) { + got := renderUpgrade("kimi-k2.6:cloud", 0, 80, true, true) + if !strings.Contains(got, "Waiting for upgrade to complete") { + t.Error("should contain waiting message") + } + if strings.Contains(got, "Open now?") { + t.Error("should not show open prompt while polling") + } +} + func TestSignInModel_EscCancels(t *testing.T) { m := signInModel{ modelName: "test:cloud", @@ -66,6 +96,35 @@ func TestSignInModel_EscCancels(t *testing.T) { } } +func TestUpgradeModel_NoCancelsWithoutPolling(t *testing.T) { + m := upgradeModel{ + modelName: "kimi-k2.6:cloud", + requiredPlan: "pro", + openNow: true, + } + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight}) + fm := updated.(upgradeModel) + if fm.openNow { + t.Error("right should select no") + } + if fm.polling { + t.Error("right should not start polling") + } + + updated, cmd := fm.Update(tea.KeyMsg{Type: tea.KeyEnter}) + fm = updated.(upgradeModel) + if !fm.cancelled { + t.Error("enter on no should cancel") + } + if fm.polling { + t.Error("enter on no should not start polling") + } + if cmd == nil { + t.Error("enter on no should quit") + } +} + func TestSignInModel_CtrlCCancels(t *testing.T) { m := signInModel{ modelName: "test:cloud", diff --git a/server/model_recommendations.go b/server/model_recommendations.go index b800ad4c1..46708d2c2 100644 --- a/server/model_recommendations.go +++ b/server/model_recommendations.go @@ -20,9 +20,7 @@ import ( "github.com/ollama/ollama/format" ) -const ( - modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations" -) +const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations" var ( modelRecommendationsRefreshInterval = 4 * time.Hour @@ -323,6 +321,7 @@ func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRe for _, rec := range recs { rec.Model = strings.TrimSpace(rec.Model) rec.Description = strings.TrimSpace(rec.Description) + rec.RequiredPlan = strings.TrimSpace(rec.RequiredPlan) if rec.Model == "" { return nil, errors.New("recommendation missing model") diff --git a/server/model_recommendations_test.go b/server/model_recommendations_test.go index 1c7b389ef..a918e5172 100644 --- a/server/model_recommendations_test.go +++ b/server/model_recommendations_test.go @@ -255,7 +255,7 @@ func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) { func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) { input := []api.ModelRecommendation{ - {Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256}, + {Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: " pro "}, {Model: "bad-cloud:cloud", Description: "missing limits"}, {Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte}, } @@ -266,7 +266,7 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing } want := []api.ModelRecommendation{ - {Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256}, + {Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: "pro"}, {Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte}, } if !slices.Equal(got, want) { @@ -274,6 +274,38 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing } } +func TestValidateModelRecommendationsDoesNotSynthesizeRequiredPlans(t *testing.T) { + input := []api.ModelRecommendation{ + {Model: "kimi-k2.6:cloud", Description: "coding", ContextLength: 262_144, MaxOutputTokens: 262_144}, + {Model: "qwen3.5:cloud", Description: "reasoning", ContextLength: 262_144, MaxOutputTokens: 32_768}, + {Model: "custom:cloud", Description: "custom", ContextLength: 4096, MaxOutputTokens: 1024}, + {Model: "minimax-m2.7:cloud", Description: "custom", ContextLength: 204_800, MaxOutputTokens: 128_000, RequiredPlan: "team"}, + } + + got, err := validateModelRecommendations(input) + if err != nil { + t.Fatalf("validateModelRecommendations failed: %v", err) + } + + byName := make(map[string]api.ModelRecommendation, len(got)) + for _, rec := range got { + byName[rec.Model] = rec + } + + if rec := byName["kimi-k2.6:cloud"]; rec.RequiredPlan != "" { + t.Fatalf("kimi required plan should not be synthesized: %#v", rec) + } + if rec := byName["qwen3.5:cloud"]; rec.RequiredPlan != "" { + t.Fatalf("qwen required plan should not be synthesized: %#v", rec) + } + if rec := byName["custom:cloud"]; rec.RequiredPlan != "" { + t.Fatalf("custom required plan should not be synthesized: %#v", rec) + } + if rec := byName["minimax-m2.7:cloud"]; rec.RequiredPlan != "team" { + t.Fatalf("explicit required plan should not be overwritten: %#v", rec) + } +} + func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes.go b/server/routes.go index 646142d8b..06e3ec65b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2044,11 +2044,30 @@ func (s *Server) WhoamiHandler(c *gin.Context) { client := api.NewClient(u, http.DefaultClient) user, err := client.Whoami(c) if err != nil { + var authErr api.AuthorizationError + if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized { + // Preserve an actionable sign-in response for launch; other failures + // below mean account or plan verification is temporarily unavailable. + sURL := authErr.SigninURL + if sURL == "" { + var sErr error + sURL, sErr = signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + } + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + slog.Error(err.Error()) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "account unavailable"}) + return } - // user isn't signed in - if user != nil && user.Name == "" { + if user == nil || user.Name == "" { sURL, sErr := signinURL() if sErr != nil { slog.Error(sErr.Error()) @@ -2060,6 +2079,10 @@ func (s *Server) WhoamiHandler(c *gin.Context) { return } + if strings.TrimSpace(user.Plan) == "" { + slog.Warn("account plan was not set; defaulting to free") + user.Plan = "free" + } c.JSON(http.StatusOK, user) }