launch: add plan-aware model gating (#16027)

This commit is contained in:
Parth Sareen 2026-05-06 14:34:26 -07:00 committed by GitHub
parent 7c2c36bda2
commit bab59072fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1747 additions and 138 deletions

View file

@ -814,6 +814,7 @@ type ModelRecommendation struct {
ContextLength int `json:"context_length,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
VRAMBytes int64 `json:"vram_bytes,omitempty"`
RequiredPlan string `json:"required_plan,omitempty"`
}
// ProcessResponse is the response from [Client.Process].

View file

@ -61,28 +61,20 @@ import (
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", launch.ErrCancelled
}
return result, err
launch.DefaultSingleSelector = func(title string, items []launch.SelectionItem, current string) (string, error) {
return runTUISingleSelector(title, items, current, nil)
}
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, launch.ErrCancelled
}
return result, err
launch.DefaultSingleSelectorWithUpdates = func(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) {
return runTUISingleSelector(title, items, current, updates)
}
launch.DefaultMultiSelector = func(title string, items []launch.SelectionItem, preChecked []string) ([]string, error) {
return runTUIMultiSelector(title, items, preChecked, nil)
}
launch.DefaultMultiSelectorWithUpdates = func(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) {
return runTUIMultiSelector(title, items, preChecked, updates)
}
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
@ -93,9 +85,55 @@ func init() {
return userName, err
}
launch.DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
plan, err := tui.RunUpgrade(modelName, requiredPlan)
if errors.Is(err, tui.ErrCancelled) {
return "", launch.ErrCancelled
}
return plan, err
}
launch.DefaultConfirmPrompt = tui.RunConfirmWithOptions
}
func runTUISingleSelector(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingleWithUpdates(title, tuiItems, current, convertSelectionItemUpdates(updates))
if errors.Is(err, tui.ErrCancelled) {
return "", launch.ErrCancelled
}
return result, err
}
func runTUIMultiSelector(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultipleWithUpdates(title, tuiItems, preChecked, convertSelectionItemUpdates(updates))
if errors.Is(err, tui.ErrCancelled) {
return nil, launch.ErrCancelled
}
return result, err
}
func convertSelectionItemUpdates(updates <-chan []launch.SelectionItem) <-chan []tui.SelectItem {
if updates == nil {
return nil
}
out := make(chan []tui.SelectItem, 1)
go func() {
defer close(out)
for items := range updates {
out <- tui.ReorderItems(tui.ConvertItems(items))
}
}()
return out
}
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
@ -2090,12 +2128,15 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
accountPrefetch := launch.StartAccountStatePrefetch(cmd.Context())
deps := launcherDeps{
buildState: launch.BuildLauncherState,
runMenu: tui.RunMenu,
resolveRunModel: launch.ResolveRunModel,
launchIntegration: launch.LaunchIntegration,
runModel: launchInteractiveModel,
buildState: launch.BuildLauncherState,
runMenu: tui.RunMenu,
resolveRunModel: launch.ResolveRunModel,
launchIntegration: launch.LaunchIntegration,
runModel: launchInteractiveModel,
accountState: accountPrefetch.StateIfReady,
accountStateUpdates: accountPrefetch.StateUpdates,
}
for {
@ -2110,11 +2151,13 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
type launcherDeps struct {
buildState func(context.Context) (*launch.LauncherState, error)
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
runModel func(*cobra.Command, string) error
buildState func(context.Context) (*launch.LauncherState, error)
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
runModel func(*cobra.Command, string) error
accountState func() *launch.AccountState
accountStateUpdates func(context.Context) <-chan *launch.AccountState
}
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
@ -2122,6 +2165,9 @@ func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error)
if err != nil {
return false, fmt.Errorf("build launcher state: %w", err)
}
if state != nil && deps.accountState != nil {
state.AccountState = deps.accountState()
}
action, err := deps.runMenu(state)
if err != nil {
@ -2142,7 +2188,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe
return false, nil
case tui.TUIActionRunModel:
saveLauncherSelection(action)
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
req := action.RunModelRequest()
if deps.accountState != nil {
req.AccountState = deps.accountState()
req.AccountStateProvider = deps.accountState
}
req.AccountStateUpdates = deps.accountStateUpdates
modelName, err := deps.resolveRunModel(cmd.Context(), req)
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}
@ -2155,7 +2207,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe
return true, nil
case tui.TUIActionLaunchIntegration:
saveLauncherSelection(action)
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
req := action.IntegrationLaunchRequest()
if deps.accountState != nil {
req.AccountState = deps.accountState()
req.AccountStateProvider = deps.accountState
}
req.AccountStateUpdates = deps.accountStateUpdates
err := deps.launchIntegration(cmd.Context(), req)
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}

View file

@ -76,11 +76,18 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
var gotReq launch.RunModelRequest
var launched string
prefetchedAccount := &launch.AccountState{}
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
if state.AccountState != prefetchedAccount {
t.Fatalf("prefetched account state was not piped to menu state")
}
return runMenu(state)
},
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
gotReq = req
return tt.wantModel, nil
@ -90,6 +97,10 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
launched = model
return nil
},
accountState: func() *launch.AccountState {
return prefetchedAccount
},
accountStateUpdates: accountUpdates,
}
cmd := &cobra.Command{}
@ -107,6 +118,12 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
if gotReq.ForcePicker != tt.wantForce {
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
}
if gotReq.AccountState != prefetchedAccount {
t.Fatalf("expected prefetched account state to be passed to run model request")
}
if gotReq.AccountStateUpdates == nil {
t.Fatalf("expected account state updates to be passed to run model request")
}
if launched != tt.wantModel {
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
}
@ -148,17 +165,28 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T)
}
var gotReq launch.IntegrationLaunchRequest
prefetchedAccount := &launch.AccountState{}
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
if state.AccountState != prefetchedAccount {
t.Fatalf("prefetched account state was not piped to menu state")
}
return runMenu(state)
},
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
gotReq = req
return nil
},
runModel: unexpectedModelLaunch(t),
accountState: func() *launch.AccountState {
return prefetchedAccount
},
accountStateUpdates: accountUpdates,
}
cmd := &cobra.Command{}
@ -179,6 +207,12 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T)
if gotReq.ForceConfigure != tt.wantForce {
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
}
if gotReq.AccountState != prefetchedAccount {
t.Fatalf("expected prefetched account state to be passed to integration request")
}
if gotReq.AccountStateUpdates == nil {
t.Fatalf("expected account state updates to be passed to integration request")
}
if got := config.LastSelection(); got != "claude" {
t.Fatalf("expected last selection to be claude, got %q", got)
}

371
cmd/launch/account.go Normal file
View file

@ -0,0 +1,371 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/ollama/ollama/api"
)
const (
// DefaultUpgradeURL is the fixed destination for subscription upgrades.
DefaultUpgradeURL = "https://ollama.com/upgrade"
accountCheckTimeout = 3 * time.Second
)
var (
ErrPlanVerificationUnavailable = errors.New("Could not verify your plan. Try again in a moment.")
errUpgradeCancelled = errors.New("upgrade cancelled")
)
type accountStateStatus int
const (
accountStateUnknown accountStateStatus = iota
accountStateSignedOut
accountStateSignedIn
)
type AccountState struct {
Status accountStateStatus
Plan string
}
type AccountStatePrefetch struct {
done chan struct{}
state AccountState
}
func StartAccountStatePrefetch(ctx context.Context) *AccountStatePrefetch {
if ctx == nil {
ctx = context.Background()
}
p := &AccountStatePrefetch{done: make(chan struct{})}
go func() {
state := AccountState{Status: accountStateUnknown}
client, err := api.ClientFromEnvironment()
if err == nil {
prefetchCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
defer cancel()
if disabled, known := cloudStatusDisabled(prefetchCtx, client); !known || !disabled {
state = launchAccountState(prefetchCtx, client)
}
}
p.state = state
close(p.done)
}()
return p
}
func (p *AccountStatePrefetch) StateIfReady() *AccountState {
if p == nil {
return nil
}
select {
case <-p.done:
state := p.state
return &state
default:
return nil
}
}
func (p *AccountStatePrefetch) StateUpdates(ctx context.Context) <-chan *AccountState {
if p == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
out := make(chan *AccountState, 1)
go func() {
defer close(out)
select {
case <-p.done:
if p.state.Status == accountStateUnknown {
return
}
state := p.state
select {
case out <- &state:
case <-ctx.Done():
}
case <-ctx.Done():
}
}()
return out
}
func launchAccountState(ctx context.Context, client *api.Client) AccountState {
if client == nil {
return AccountState{Status: accountStateUnknown}
}
user, err := whoamiWithTimeout(ctx, client)
if err != nil {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
return AccountState{Status: accountStateSignedOut}
}
return AccountState{Status: accountStateUnknown}
}
if user == nil || strings.TrimSpace(user.Name) == "" {
return AccountState{Status: accountStateSignedOut}
}
return AccountState{
Status: accountStateSignedIn,
Plan: strings.TrimSpace(user.Plan),
}
}
func whoamiWithTimeout(ctx context.Context, client *api.Client) (*api.UserResponse, error) {
if ctx == nil {
ctx = context.Background()
}
checkCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
defer cancel()
return client.Whoami(checkCtx)
}
func ApplyAccountStateToSelectionItems(items []ModelItem, state AccountState) []SelectionItem {
out := make([]SelectionItem, len(items))
for i, item := range items {
out[i] = SelectionItem{
Name: item.Name,
Description: item.Description,
Recommended: item.Recommended,
AvailabilityBadge: availabilityBadge(item, state),
}
}
return out
}
func SelectionItemsWithAccountState(items []ModelItem, state *AccountState) []SelectionItem {
if state == nil || !selectionItemsNeedAccountState(items) {
return ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
}
return ApplyAccountStateToSelectionItems(items, *state)
}
func selectionItemsNeedAccountState(items []ModelItem) bool {
for _, item := range items {
if isCloudModelName(item.Name) && itemHasRecommendationMetadata(item) {
return true
}
}
return false
}
func (c *launcherClient) selectionItemUpdates(ctx context.Context, items []ModelItem, state *AccountState) <-chan []SelectionItem {
if !selectionItemsNeedAccountState(items) || state != nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
stateUpdates := c.accountStateUpdateSource(ctx)
if stateUpdates == nil {
return nil
}
out := make(chan []SelectionItem, 1)
go func() {
defer close(out)
select {
case state, ok := <-stateUpdates:
if !ok || state == nil {
return
}
select {
case out <- SelectionItemsWithAccountState(items, state):
case <-ctx.Done():
}
case <-ctx.Done():
}
}()
return out
}
func (c *launcherClient) accountStateUpdateSource(ctx context.Context) <-chan *AccountState {
if c.accountStateUpdates != nil {
return c.accountStateUpdates(ctx)
}
if c.apiClient == nil {
return nil
}
out := make(chan *AccountState, 1)
go func() {
defer close(out)
state := launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
return
}
select {
case out <- &state:
case <-ctx.Done():
}
}()
return out
}
func availabilityBadge(item ModelItem, state AccountState) string {
if !isCloudModelName(item.Name) {
return ""
}
switch state.Status {
case accountStateSignedOut:
if itemHasRecommendationMetadata(item) {
return "Sign in required"
}
case accountStateSignedIn:
if item.RequiredPlan != "" && !PlanSatisfies(state.Plan, item.RequiredPlan) {
return "Upgrade required"
}
}
return ""
}
func itemHasRecommendationMetadata(item ModelItem) bool {
return item.Recommended || strings.TrimSpace(item.RequiredPlan) != ""
}
func (c *launcherClient) ensureCloudModelAccess(ctx context.Context, model string) error {
item, ok := c.modelRecommendationItem(ctx, model)
if !ok || strings.TrimSpace(item.RequiredPlan) == "" {
return nil
}
state := launchAccountState(ctx, c.apiClient)
if state.Status != accountStateUnknown {
c.accountState = &state
}
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
if state.Status == accountStateSignedOut {
if err := ensureCloudAuth(ctx, c.apiClient, model); err != nil {
return err
}
state = launchAccountState(ctx, c.apiClient)
if state.Status != accountStateUnknown {
c.accountState = &state
}
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
}
if PlanSatisfies(state.Plan, item.RequiredPlan) {
return nil
}
if err := c.runUpgradeFlow(ctx, item); err != nil {
return err
}
state = launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
if state.Status != accountStateSignedIn || !PlanSatisfies(state.Plan, item.RequiredPlan) {
return errUpgradeCancelled
}
return nil
}
func (c *launcherClient) modelRecommendationItem(ctx context.Context, model string) (ModelItem, bool) {
for _, item := range c.recommendations(ctx) {
if item.Name == model {
return item, true
}
}
return ModelItem{}, false
}
func (c *launcherClient) runUpgradeFlow(ctx context.Context, item ModelItem) error {
if DefaultUpgrade != nil {
if _, err := DefaultUpgrade(item.Name, item.RequiredPlan); err != nil {
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
return err
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("Upgrade to use %s?", item.Name))
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
if err != nil {
return err
}
if !yes {
return errUpgradeCancelled
}
fmt.Fprintf(os.Stderr, "\nTo upgrade, navigate to:\n %s\n\n", DefaultUpgradeURL)
openNow, err := ConfirmPrompt("Open now?")
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
if err != nil {
return err
}
if openNow {
OpenBrowser(DefaultUpgradeURL)
} else {
return errUpgradeCancelled
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 != 0 {
continue
}
state := launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
fmt.Fprintf(os.Stderr, "\r\033[K")
return ErrPlanVerificationUnavailable
}
if state.Status == accountStateSignedIn && PlanSatisfies(state.Plan, item.RequiredPlan) {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1mplan updated\033[0m\n")
return nil
}
}
}
}
// PlanSatisfies reports whether currentPlan can use a model that has a requiredPlan.
func PlanSatisfies(currentPlan, requiredPlan string) bool {
required := normalizePlan(requiredPlan)
if required == "" || required == "free" {
return true
}
current := normalizePlan(currentPlan)
return current != "" && current != "free"
}
func normalizePlan(plan string) string {
return strings.ToLower(strings.TrimSpace(plan))
}

View file

@ -319,7 +319,7 @@ func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
var selectorCalls int
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
gotCurrent = current
return "llama3.2", nil
@ -553,7 +553,7 @@ func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T)
defer func() { DefaultSingleSelector = oldSelector }()
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
gotCurrent = current
return "qwen3:8b", nil
}
@ -607,7 +607,7 @@ func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes saved-model launch")
return "", nil
}
@ -644,7 +644,7 @@ func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testin
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes without saved model")
return "", nil
}

View file

@ -10,7 +10,9 @@ import (
"net/url"
"slices"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
@ -456,6 +458,28 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
}
}
func TestBuildModelList_PreservesRecommendationRequiredPlanForExistingCloudModel(t *testing.T) {
recommendations := []ModelItem{
{
Name: "glm-5:cloud",
Description: "Reasoning and code generation",
Recommended: true,
RequiredPlan: "pro",
ContextLength: 202_752,
},
}
existing := []modelInfo{{Name: "glm-5:cloud", Remote: true}}
items, _, _, _ := buildModelListWithRecommendations(existing, recommendations, nil, "")
if len(items) != 1 {
t.Fatalf("expected one item, got %v", items)
}
item := items[0]
if item.RequiredPlan != "pro" {
t.Fatalf("RequiredPlan = %q, want pro", item.RequiredPlan)
}
}
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
existing := []modelInfo{
{Name: "gemma4", Remote: false},
@ -1390,6 +1414,187 @@ func TestEnsureAuth_EmptyWhoamiRequiresSignIn(t *testing.T) {
}
}
func TestApplyAccountStateToSelectionItems_BadgesOnlyWhenActionRequired(t *testing.T) {
items := []ModelItem{
{Name: "qwen3.5:cloud", Recommended: true},
{Name: "kimi-k2.6:cloud", Recommended: true, RequiredPlan: "pro"},
{Name: "llama3.2", RequiredPlan: "pro"},
{Name: "glm-5:cloud"},
{Name: "nemotron-3-super:cloud", Recommended: true, RequiredPlan: "free"},
}
signedOut := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedOut})
if signedOut[0].AvailabilityBadge != "Sign in required" {
t.Fatalf("account cloud badge = %q", signedOut[0].AvailabilityBadge)
}
if signedOut[1].AvailabilityBadge != "Sign in required" {
t.Fatalf("subscription cloud signed-out badge = %q", signedOut[1].AvailabilityBadge)
}
if signedOut[4].AvailabilityBadge != "Sign in required" {
t.Fatalf("free-plan cloud signed-out badge = %q", signedOut[4].AvailabilityBadge)
}
if signedOut[2].AvailabilityBadge != "" || signedOut[3].AvailabilityBadge != "" {
t.Fatalf("unexpected badge for local or unmetadata item: %#v", signedOut)
}
freeUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "free"})
if freeUser[0].AvailabilityBadge != "" {
t.Fatalf("signed-in account model should not be badged, got %q", freeUser[0].AvailabilityBadge)
}
if freeUser[1].AvailabilityBadge != "Upgrade required" {
t.Fatalf("subscription cloud free-plan badge = %q", freeUser[1].AvailabilityBadge)
}
if freeUser[4].AvailabilityBadge != "" {
t.Fatalf("free required plan should be usable by free user, got %q", freeUser[4].AvailabilityBadge)
}
proUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "pro"})
if proUser[1].AvailabilityBadge != "" {
t.Fatalf("pro user should not see included badge, got %q", proUser[1].AvailabilityBadge)
}
maxUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "max"})
if maxUser[1].AvailabilityBadge != "" {
t.Fatalf("max user should not see upgrade badge, got %q", maxUser[1].AvailabilityBadge)
}
unknown := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
for _, item := range unknown {
if item.AvailabilityBadge != "" {
t.Fatalf("unknown account state should not render badges: %#v", unknown)
}
}
}
func TestSelectionItemsWithAccountState_SkipsBadgesWithoutBadgeableCloudItems(t *testing.T) {
items := []ModelItem{
{Name: "llama3.2"},
{Name: "custom:cloud"},
}
state := &AccountState{Status: accountStateSignedOut}
got := SelectionItemsWithAccountState(items, state)
if len(got) != len(items) {
t.Fatalf("got %d selection items, want %d", len(got), len(items))
}
for _, item := range got {
if item.AvailabilityBadge != "" {
t.Fatalf("unexpected badge without account state: %#v", got)
}
}
}
func TestSelectionItemsWithAccountState_UsesPrefetchedStateForRecommendedCloudItems(t *testing.T) {
state := &AccountState{Status: accountStateSignedOut}
got := SelectionItemsWithAccountState([]ModelItem{{Name: "qwen3.5:cloud", Recommended: true}}, state)
if got[0].AvailabilityBadge != "Sign in required" {
t.Fatalf("badge = %q, want Sign in required", got[0].AvailabilityBadge)
}
}
func TestRecommendedModelsDoNotIncludeRequiredPlanStubs(t *testing.T) {
byName := make(map[string]ModelItem, len(recommendedModels))
for _, item := range recommendedModels {
byName[item.Name] = item
}
if item := byName["kimi-k2.6:cloud"]; item.RequiredPlan != "" {
t.Fatalf("kimi fallback required plan should not be stubbed: %#v", item)
}
if item := byName["minimax-m2.7:cloud"]; item.RequiredPlan != "" {
t.Fatalf("minimax fallback required plan should not be stubbed: %#v", item)
}
if item := byName["qwen3.5:cloud"]; item.RequiredPlan != "" {
t.Fatalf("qwen fallback required plan = %#v", item)
}
if item := byName["glm-5.1:cloud"]; item.RequiredPlan != "" {
t.Fatalf("glm fallback required plan = %#v", item)
}
}
func TestLaunchAccountState(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
wantStatus accountStateStatus
wantPlan string
}{
{
name: "signed in",
statusCode: http.StatusOK,
body: `{"name":"parth","plan":"pro"}`,
wantStatus: accountStateSignedIn,
wantPlan: "pro",
},
{
name: "signed out",
statusCode: http.StatusUnauthorized,
body: `{"error":"unauthorized","signin_url":"https://example.com/signin"}`,
wantStatus: accountStateSignedOut,
},
{
name: "unreachable",
statusCode: http.StatusInternalServerError,
body: `{"error":"temporary failure"}`,
wantStatus: accountStateUnknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/me" {
http.NotFound(w, r)
return
}
w.WriteHeader(tt.statusCode)
fmt.Fprint(w, tt.body)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
got := launchAccountState(context.Background(), api.NewClient(u, srv.Client()))
if got.Status != tt.wantStatus {
t.Fatalf("Status = %v, want %v", got.Status, tt.wantStatus)
}
if got.Plan != tt.wantPlan {
t.Fatalf("Plan = %q, want %q", got.Plan, tt.wantPlan)
}
})
}
}
func TestStartAccountStatePrefetch_SkipsWhoamiWhenCloudDisabled(t *testing.T) {
var whoamiCalled atomic.Bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/me":
whoamiCalled.Store(true)
http.NotFound(w, r)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
prefetch := StartAccountStatePrefetch(context.Background())
select {
case <-prefetch.done:
case <-time.After(time.Second):
t.Fatal("account prefetch did not finish")
}
if whoamiCalled.Load() {
t.Fatal("prefetch should not call whoami when cloud is disabled")
}
state := prefetch.StateIfReady()
if state == nil || state.Status != accountStateUnknown {
t.Fatalf("prefetch state = %#v, want unknown", state)
}
}
func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) {
oldSignIn := DefaultSignIn
DefaultSignIn = func(modelName, signInURL string) (string, error) {

View file

@ -22,6 +22,7 @@ type LauncherState struct {
RunModel string
RunModelUsable bool
Integrations map[string]LauncherIntegrationState
AccountState *AccountState
}
// LauncherIntegrationState is the launch-owned status for one launcher integration.
@ -41,8 +42,11 @@ type LauncherIntegrationState struct {
// RunModelRequest controls how the root launcher resolves the chat model.
type RunModelRequest struct {
ForcePicker bool
Policy *LaunchPolicy
ForcePicker bool
Policy *LaunchPolicy
AccountState *AccountState
AccountStateProvider func() *AccountState
AccountStateUpdates func(context.Context) <-chan *AccountState
}
// LaunchConfirmMode controls confirmation behavior across launch flows.
@ -117,13 +121,16 @@ func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
// IntegrationLaunchRequest controls the canonical integration launcher flow.
type IntegrationLaunchRequest struct {
Name string
ModelOverride string
ForceConfigure bool
ConfigureOnly bool
Restore bool
ExtraArgs []string
Policy *LaunchPolicy
Name string
ModelOverride string
ForceConfigure bool
ConfigureOnly bool
Restore bool
ExtraArgs []string
Policy *LaunchPolicy
AccountState *AccountState
AccountStateProvider func() *AccountState
AccountStateUpdates func(context.Context) <-chan *AccountState
}
var isInteractiveSession = func() bool {
@ -241,7 +248,7 @@ type modelInfo struct {
// ModelInfo re-exports launcher model inventory details for callers.
type ModelInfo = modelInfo
// ModelItem represents a model for selection UIs.
// ModelItem represents model metadata before selector-only UI state is derived.
type ModelItem struct {
Name string
Description string
@ -249,6 +256,15 @@ type ModelItem struct {
VRAMBytes int64
ContextLength int
MaxOutputTokens int
RequiredPlan string
}
// SelectionItem represents a model row after launch has derived selector-only UI state.
type SelectionItem struct {
Name string
Description string
Recommended bool
AvailabilityBadge string
}
// LaunchCmd returns the cobra command for launching integrations.
@ -384,10 +400,15 @@ func launchCommandCanSkipHeartbeat(args []string) bool {
}
type launcherClient struct {
apiClient *api.Client
modelInventory []ModelInfo
inventoryLoaded bool
policy LaunchPolicy
apiClient *api.Client
modelInventory []ModelInfo
inventoryLoaded bool
recommendationsLoaded bool
recommendationItems []ModelItem
accountState *AccountState
accountStateProvider func() *AccountState
accountStateUpdates func(context.Context) <-chan *AccountState
policy LaunchPolicy
}
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
@ -425,6 +446,9 @@ func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
if err != nil {
return "", err
}
launchClient.accountState = req.AccountState
launchClient.accountStateProvider = req.AccountStateProvider
launchClient.accountStateUpdates = req.AccountStateUpdates
return launchClient.resolveRunModel(ctx, req)
}
@ -449,6 +473,9 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
if err != nil {
return err
}
launchClient.accountState = req.AccountState
launchClient.accountStateProvider = req.AccountStateProvider
launchClient.accountStateUpdates = req.AccountStateUpdates
if autodiscovery, ok := runner.(ManagedAutodiscoveryIntegration); ok {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
@ -811,7 +838,10 @@ func (c *launcherClient) managedAutodiscoveryUsable(ctx context.Context, autodis
if !managedAutodiscoveryUsesOllamaCloud(autodiscovery) {
return true
}
return c.ollamaCloudSignedIn(ctx)
if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled {
return false
}
return true
}
func (c *launcherClient) ensureManagedAutodiscoveryUsable(ctx context.Context, autodiscovery ManagedAutodiscoveryIntegration, label string) error {
@ -858,14 +888,6 @@ func printRestoreSuccess(integration any) {
}
}
func (c *launcherClient) ollamaCloudSignedIn(ctx context.Context) bool {
if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled {
return false
}
user, err := c.apiClient.Whoami(ctx)
return err == nil && user != nil && user.Name != ""
}
func (c *launcherClient) managedSingleConfigureModels(ctx context.Context, managed ManagedSingleModel, target string) ([]string, error) {
models := []string{target}
if _, ok := managed.(ManagedModelListConfigurer); !ok {
@ -959,8 +981,15 @@ func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, titl
return c.selectSingleModelWithSelectorReady(ctx, title, current, selector, true)
}
func (c *launcherClient) latestAccountState() *AccountState {
if c.accountStateProvider != nil {
return c.accountStateProvider()
}
return c.accountState
}
func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context, title, current string, selector SingleSelector, ensureReady bool) (string, error) {
if selector == nil {
if selector == nil && DefaultSingleSelectorWithUpdates == nil {
return "", fmt.Errorf("no selector configured")
}
@ -969,45 +998,88 @@ func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context,
return "", err
}
selected, err := selector(title, items, current)
if err != nil {
return "", err
}
if selected == "" {
return "", ErrCancelled
}
if ensureReady {
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
for {
accountState := c.latestAccountState()
selectionItems := SelectionItemsWithAccountState(items, accountState)
var updates <-chan []SelectionItem
if DefaultSingleSelectorWithUpdates != nil {
updates = c.selectionItemUpdates(ctx, items, accountState)
}
selected, err := runSingleSelector(title, selectionItems, current, updates, selector)
if err != nil {
return "", err
}
if selected == "" {
return "", ErrCancelled
}
if ensureReady {
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
if errors.Is(err, errUpgradeCancelled) {
current = selected
continue
}
return "", err
}
}
return selected, nil
}
return selected, nil
}
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
if DefaultMultiSelector == nil {
if DefaultMultiSelector == nil && DefaultMultiSelectorWithUpdates == nil {
return nil, fmt.Errorf("no selector configured")
}
current := firstModel(preChecked)
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
if err != nil {
return nil, err
}
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
if err != nil {
return nil, err
for {
accountState := c.latestAccountState()
selectionItems := SelectionItemsWithAccountState(items, accountState)
var updates <-chan []SelectionItem
if DefaultMultiSelectorWithUpdates != nil {
updates = c.selectionItemUpdates(ctx, items, accountState)
}
selected, err := runMultiSelector(fmt.Sprintf("Select models for %s:", runner), selectionItems, orderedChecked, updates)
if err != nil {
return nil, err
}
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
if err != nil {
if errors.Is(err, errUpgradeCancelled) {
orderedChecked = append([]string(nil), selected...)
continue
}
return nil, err
}
for _, skip := range skipped {
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
}
return accepted, nil
}
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
if err != nil {
return nil, err
}
func runSingleSelector(title string, items []SelectionItem, current string, updates <-chan []SelectionItem, fallback SingleSelector) (string, error) {
if DefaultSingleSelectorWithUpdates != nil {
return DefaultSingleSelectorWithUpdates(title, items, current, updates)
}
for _, skip := range skipped {
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
if fallback == nil {
return "", fmt.Errorf("no selector configured")
}
return accepted, nil
return fallback(title, items, current)
}
func runMultiSelector(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error) {
if DefaultMultiSelectorWithUpdates != nil {
return DefaultMultiSelectorWithUpdates(title, items, preChecked, updates)
}
if DefaultMultiSelector == nil {
return nil, fmt.Errorf("no selector configured")
}
return DefaultMultiSelector(title, items, preChecked)
}
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
@ -1029,16 +1101,24 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
}
func (c *launcherClient) recommendations(ctx context.Context) []ModelItem {
if c.recommendationsLoaded {
return append([]ModelItem(nil), c.recommendationItems...)
}
recommendations, err := c.requestRecommendations(ctx)
if err != nil || len(recommendations) == 0 {
// Fail open: recommendation issues should not block launch flows.
// Fall back to built-in recommendations until server data is available.
fallback := append([]ModelItem(nil), recommendedModels...)
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(fallback))
return fallback
c.recommendationItems = fallback
c.recommendationsLoaded = true
return append([]ModelItem(nil), fallback...)
}
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(recommendations))
return recommendations
c.recommendationItems = recommendations
c.recommendationsLoaded = true
return append([]ModelItem(nil), recommendations...)
}
func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelItem, error) {
@ -1076,6 +1156,7 @@ func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelIte
VRAMBytes: rec.VRAMBytes,
ContextLength: rec.ContextLength,
MaxOutputTokens: rec.MaxOutputTokens,
RequiredPlan: strings.TrimSpace(rec.RequiredPlan),
})
}
@ -1093,6 +1174,9 @@ func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string)
isCloudModel := isCloudModelName(model)
if isCloudModel {
cloudModels[model] = true
if err := c.ensureCloudModelAccess(ctx, model); err != nil {
return err
}
}
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
return err
@ -1126,6 +1210,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected
for _, model := range selected {
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
if errors.Is(err, errUpgradeCancelled) {
return nil, nil, err
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, err
}
@ -1142,6 +1229,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected
}
func skippedModelReason(model string, err error) string {
if errors.Is(err, errUpgradeCancelled) {
return "upgrade was cancelled"
}
if errors.Is(err, ErrCancelled) {
if isCloudModelName(model) {
return "sign in was cancelled"

View file

@ -12,6 +12,7 @@ import (
"slices"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config"
@ -192,14 +193,20 @@ func withInteractiveSession(t *testing.T, interactive bool) {
func withLauncherHooks(t *testing.T) {
t.Helper()
oldSingle := DefaultSingleSelector
oldSingleWithUpdates := DefaultSingleSelectorWithUpdates
oldMulti := DefaultMultiSelector
oldMultiWithUpdates := DefaultMultiSelectorWithUpdates
oldConfirm := DefaultConfirmPrompt
oldSignIn := DefaultSignIn
oldUpgrade := DefaultUpgrade
t.Cleanup(func() {
DefaultSingleSelector = oldSingle
DefaultSingleSelectorWithUpdates = oldSingleWithUpdates
DefaultMultiSelector = oldMulti
DefaultMultiSelectorWithUpdates = oldMultiWithUpdates
DefaultConfirmPrompt = oldConfirm
DefaultSignIn = oldSignIn
DefaultUpgrade = oldUpgrade
})
}
@ -346,7 +353,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "gemma4", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -501,7 +508,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model matches target")
return "", nil
}
@ -553,7 +560,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *t
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called when model override is provided")
return "", nil
}
@ -607,7 +614,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenLiveConfigDrifts(
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called when live config already provides the target")
return "", nil
}
@ -734,7 +741,7 @@ func TestLaunchIntegration_ManagedAutodiscoverySkipsModelPicker(t *testing.T) {
runner := &launcherManagedAutodiscoveryRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("model selector should not run for autodiscovery integrations")
return "", nil
}
@ -987,7 +994,7 @@ func TestLaunchIntegration_CloudAutodiscoveryUsesSignInHook(t *testing.T) {
}
}
func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *testing.T) {
func TestBuildLauncherIntegrationState_CloudAutodiscoveryDoesNotCheckSignIn(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
@ -1004,8 +1011,7 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/me":
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
t.Fatal("build launcher state should not check whoami")
default:
http.NotFound(w, r)
}
@ -1028,8 +1034,8 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes
if state.CurrentModel != "Ollama Cloud" {
t.Fatalf("current model = %q, want Ollama Cloud", state.CurrentModel)
}
if state.ModelUsable {
t.Fatal("expected cloud autodiscovery config to be unusable while signed out")
if !state.ModelUsable {
t.Fatal("expected cloud autodiscovery config to stay usable until launch-time auth check")
}
}
@ -1296,7 +1302,7 @@ func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) {
}
selectorCalled := false
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalled = true
return "", nil
}
@ -1338,7 +1344,7 @@ func TestResolveRunModel_HeadlessYesAutoPicksLastModel(t *testing.T) {
t.Fatalf("failed to save last model: %v", err)
}
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called in headless --yes mode")
return "", nil
}
@ -1405,7 +1411,7 @@ func TestResolveRunModel_UsesRequestPolicy(t *testing.T) {
t.Fatalf("failed to save last model: %v", err)
}
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called when request policy enables headless auto-pick")
return "", nil
}
@ -1465,7 +1471,7 @@ func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) {
}
var selectorCalls int
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
if current != "llama3.2" {
t.Fatalf("expected current selection to be last model, got %q", current)
@ -1513,7 +1519,7 @@ func TestResolveRunModel_ForcePicker_DoesNotReorderByLastModel(t *testing.T) {
}
var gotNames []string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
if current != "qwen3.5" {
t.Fatalf("expected current selection to be last model, got %q", current)
}
@ -1564,7 +1570,7 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "glm-5:cloud", nil
}
@ -1610,6 +1616,241 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
}
}
func TestResolveRunModel_MetadataSignedOutUsesSignInHook(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "qwen3.5:cloud", nil
}
signedIn := false
signInCalled := false
DefaultSignIn = func(modelName, signInURL string) (string, error) {
signInCalled = true
signedIn = true
if modelName != "qwen3.5:cloud" {
t.Fatalf("unexpected model passed to sign-in: %q", modelName)
}
if signInURL != "https://example.com/signin" {
t.Fatalf("unexpected sign-in URL: %q", signInURL)
}
return "test-user", nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[{"model":"qwen3.5:cloud","description":"Reasoning","context_length":262144,"max_output_tokens":32768}]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
fmt.Fprint(w, `{"remote_model":"qwen3.5"}`)
case "/api/me":
if !signedIn {
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
return
}
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
if err != nil {
t.Fatalf("ResolveRunModel returned error: %v", err)
}
if model != "qwen3.5:cloud" {
t.Fatalf("expected selected cloud model, got %q", model)
}
if !signInCalled {
t.Fatal("expected sign-in hook to be used for account-gated cloud model")
}
}
func TestResolveRunModel_SubscriptionModelUsesUpgradeHook(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
DefaultSingleSelectorWithUpdates = func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error) {
for _, item := range items {
if item.Name == "kimi-k2.6:cloud" && item.AvailabilityBadge != "" {
t.Fatalf("initial availability badge = %q, want empty before account update", item.AvailabilityBadge)
}
}
select {
case items = <-updates:
case <-time.After(time.Second):
t.Fatal("timed out waiting for selector item update")
}
for _, item := range items {
if item.Name == "kimi-k2.6:cloud" {
if item.AvailabilityBadge != "Upgrade required" {
t.Fatalf("availability badge = %q, want Upgrade required", item.AvailabilityBadge)
}
return "kimi-k2.6:cloud", nil
}
}
t.Fatalf("paid cloud model missing from selector items: %#v", items)
return "kimi-k2.6:cloud", nil
}
DefaultSignIn = func(modelName, signInURL string) (string, error) {
t.Fatalf("did not expect sign-in hook for signed-in user")
return "", nil
}
plan := "free"
upgradeCalled := false
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
upgradeCalled = true
if modelName != "kimi-k2.6:cloud" {
t.Fatalf("unexpected model passed to upgrade: %q", modelName)
}
if requiredPlan != "pro" {
t.Fatalf("unexpected required plan: %q", requiredPlan)
}
plan = "max"
return plan, nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
fmt.Fprint(w, `{"remote_model":"kimi-k2.6"}`)
case "/api/me":
fmt.Fprintf(w, `{"name":"test-user","plan":%q}`, plan)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
if err != nil {
t.Fatalf("ResolveRunModel returned error: %v", err)
}
if model != "kimi-k2.6:cloud" {
t.Fatalf("expected selected cloud model, got %q", model)
}
if !upgradeCalled {
t.Fatal("expected upgrade hook to be used for subscription-gated cloud model")
}
}
func TestResolveRunModel_UpgradeCancelledReturnsToModelSelector(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
selectorCalls := 0
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
switch selectorCalls {
case 1:
return "kimi-k2.6:cloud", nil
case 2:
return "llama3.2", nil
default:
t.Fatalf("selector called too many times: %d", selectorCalls)
return "", nil
}
}
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
return "", ErrCancelled
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
fmt.Fprintf(w, `{"model":%q}`, req.Model)
case "/api/me":
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
if err != nil {
t.Fatalf("ResolveRunModel returned error: %v", err)
}
if model != "llama3.2" {
t.Fatalf("model = %q, want llama3.2", model)
}
if selectorCalls != 2 {
t.Fatalf("selector calls = %d, want 2", selectorCalls)
}
}
func TestResolveRunModel_SubscriptionModelUnavailableWhoamiFailsGracefully(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "kimi-k2.6:cloud", nil
}
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
t.Fatalf("did not expect upgrade hook when plan could not be verified")
return "", nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/me":
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"error":"temporary failure"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
_, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
if err == nil {
t.Fatal("expected plan verification error")
}
if !strings.Contains(err.Error(), "Could not verify your plan. Try again in a moment.") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
@ -1623,7 +1864,7 @@ func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
withIntegrationOverride(t, "droid", editor)
var multiCalled bool
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
multiCalled = true
return []string{"llama3.2", "qwen3:8b"}, nil
}
@ -1688,7 +1929,7 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t
var gotItems []string
var gotPreChecked []string
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
for _, item := range items {
gotItems = append(gotItems, item.Name)
}
@ -1809,7 +2050,7 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T)
}
var multiCalled bool
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
multiCalled = true
return []string{"llama3.2"}, nil
}
@ -1851,7 +2092,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccep
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
return []string{"glm-5:cloud", "missing-local"}, nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -1932,7 +2173,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
return []string{"llama3.2", "glm-5:cloud"}, nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -2001,6 +2242,84 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
}
}
func TestLaunchIntegration_EditorConfigureUpgradeCancelledReturnsToModelSelector(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
selectorCalls := 0
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
selectorCalls++
switch selectorCalls {
case 1:
return []string{"kimi-k2.6:cloud"}, nil
case 2:
if diff := compareStrings(preChecked, []string{"kimi-k2.6:cloud"}); diff != "" {
t.Fatalf("second selector preChecked (-want +got):\n%s", diff)
}
return []string{"llama3.2"}, nil
default:
t.Fatalf("selector called too many times: %d", selectorCalls)
return nil, nil
}
}
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
return "", ErrCancelled
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
if prompt == "Proceed?" {
return true, nil
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
fmt.Fprintf(w, `{"model":%q}`, req.Model)
case "/api/me":
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "droid",
ForceConfigure: true,
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if selectorCalls != 2 {
t.Fatalf("selector calls = %d, want 2", selectorCalls)
}
if editor.ranModel != "llama3.2" {
t.Fatalf("expected launch to use local model, got %q", editor.ranModel)
}
if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" {
t.Fatalf("unexpected edited models (-want +got):\n%s", diff)
}
}
func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
@ -2016,7 +2335,7 @@ func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *
if err := config.SaveIntegration("droid", []string{"glm-5:cloud", "llama3.2"}); err != nil {
t.Fatalf("failed to seed config: %v", err)
}
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
return append([]string(nil), preChecked...), nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -2102,7 +2421,7 @@ func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsL
t.Fatalf("failed to seed config: %v", err)
}
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
return []string{"missing-local-a", "missing-local-b"}, nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -2342,7 +2661,7 @@ func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T)
withIntegrationOverride(t, "openclaw", editor)
selectorCalled := false
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
selectorCalled = true
return []string{"llama3.2"}, nil
}
@ -2376,7 +2695,7 @@ func TestLaunchIntegration_PiInstallsBeforeConfigSideEffects(t *testing.T) {
withIntegrationOverride(t, "pi", editor)
selectorCalled := false
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
selectorCalled = true
return []string{"llama3.2"}, nil
}
@ -2408,7 +2727,7 @@ func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing
editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}}
withIntegrationOverride(t, "droid", editor)
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
return []string{"llama3.2"}, nil
}
@ -2519,7 +2838,7 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) {
}
var selectorCalls int
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
return "glm-5:cloud", nil
}
@ -2572,7 +2891,7 @@ func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *te
t.Fatalf("failed to seed config: %v", err)
}
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "missing-model", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@ -2633,7 +2952,7 @@ func TestLaunchIntegration_ClaudeModelOverrideSkipsSelector(t *testing.T) {
t.Setenv("PATH", binDir)
var selectorCalls int
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
return "", fmt.Errorf("selector should not run when --model override is set")
}
@ -2698,7 +3017,7 @@ func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) {
runner := &launcherSingleRunner{}
withIntegrationOverride(t, "stubsingle", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "llama3.2", nil
}
@ -2926,7 +3245,7 @@ func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T)
runner := &launcherSingleRunner{}
withIntegrationOverride(t, "droid", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
return "missing-model", nil
}

View file

@ -188,7 +188,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string)
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := client.Whoami(ctx)
user, err := whoamiWithTimeout(ctx, client)
if err == nil && user != nil && user.Name != "" {
return nil
}
@ -243,7 +243,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string)
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 == 0 {
u, err := client.Whoami(ctx)
u, err := whoamiWithTimeout(ctx, client)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
@ -348,9 +348,11 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
recByName := make(map[string]ModelItem)
for _, rec := range recommendations {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
recByName[rec.Name] = rec
}
for _, m := range existing {
@ -364,6 +366,9 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
if rec, ok := recByName[displayName]; ok {
item = copyModelRecommendationFields(displayName, rec)
}
items = append(items, item)
}
@ -472,6 +477,12 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
return items, preChecked, existingModels, cloudModels
}
func copyModelRecommendationFields(name string, rec ModelItem) ModelItem {
rec.Name = name
rec.Recommended = true
return rec
}
// isCloudModelName reports whether the model name has an explicit cloud source.
func isCloudModelName(name string) bool {
return modelref.HasExplicitCloudSource(name)

View file

@ -35,22 +35,38 @@ type ConfirmOptions struct {
// SingleSelector is a function type for single item selection.
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
type SingleSelector func(title string, items []SelectionItem, current string) (string, error)
// SingleSelectorWithUpdates is a single item selector that can receive refreshed item state while open.
type SingleSelectorWithUpdates func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
type MultiSelector func(title string, items []SelectionItem, preChecked []string) ([]string, error)
// MultiSelectorWithUpdates is a multi item selector that can receive refreshed item state while open.
type MultiSelectorWithUpdates func(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error)
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultSingleSelectorWithUpdates is the default single-select implementation with live updates.
var DefaultSingleSelectorWithUpdates SingleSelectorWithUpdates
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultMultiSelectorWithUpdates is the default multi-select implementation with live updates.
var DefaultMultiSelectorWithUpdates MultiSelectorWithUpdates
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
// DefaultUpgrade provides a TUI-based upgrade flow.
// Returns the updated plan or an error.
var DefaultUpgrade func(modelName, requiredPlan string) (string, error)
type launchConfirmPolicy struct {
yes bool
requireYesMessage bool

View file

@ -35,8 +35,7 @@ var (
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
@ -58,16 +57,39 @@ const maxSelectorItems = 10
var ErrCancelled = launch.ErrCancelled
type SelectItem struct {
Name string
Description string
Recommended bool
Name string
Description string
Recommended bool
AvailabilityBadge string
}
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
func ConvertItems(items []launch.ModelItem) []SelectItem {
type selectorItemsUpdatedMsg struct {
items []SelectItem
}
func waitForSelectorItems(updates <-chan []SelectItem) tea.Cmd {
if updates == nil {
return nil
}
return func() tea.Msg {
items, ok := <-updates
if !ok {
return nil
}
return selectorItemsUpdatedMsg{items: items}
}
}
// ConvertItems converts launch.SelectionItem slice to SelectItem slice.
func ConvertItems(items []launch.SelectionItem) []SelectItem {
out := make([]SelectItem, len(items))
for i, item := range items {
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
out[i] = SelectItem{
Name: item.Name,
Description: item.Description,
Recommended: item.Recommended,
AvailabilityBadge: item.AvailabilityBadge,
}
}
return out
}
@ -91,6 +113,7 @@ func ReorderItems(items []SelectItem) []SelectItem {
type selectorModel struct {
title string
items []SelectItem
updates <-chan []SelectItem
filter string
cursor int
scrollOffset int
@ -110,6 +133,33 @@ func selectorModelWithCurrent(title string, items []SelectItem, current string)
return m
}
func currentItemName(items []SelectItem, cursor int) string {
if cursor < 0 || cursor >= len(items) {
return ""
}
return items[cursor].Name
}
func cursorForItemName(items []SelectItem, name string, fallback int) int {
if len(items) == 0 {
return 0
}
if name != "" {
for i, item := range items {
if item.Name == name {
return i
}
}
}
if fallback < 0 {
return 0
}
if fallback >= len(items) {
return len(items) - 1
}
return fallback
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
@ -125,7 +175,7 @@ func (m selectorModel) filteredItems() []SelectItem {
}
func (m selectorModel) Init() tea.Cmd {
return nil
return waitForSelectorItems(m.updates)
}
// otherStart returns the index of the first non-recommended item in the filtered list.
@ -235,6 +285,13 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
return m, nil
case selectorItemsUpdatedMsg:
current := currentItemName(m.filteredItems(), m.cursor)
m.items = msg.items
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
m.updateScroll(m.otherStart())
return m, waitForSelectorItems(m.updates)
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
@ -260,9 +317,17 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func cursorItemSuffix(item SelectItem) string {
if item.AvailabilityBadge == "" {
return ""
}
return " " + selectorDefaultTagStyle.Render("("+item.AvailabilityBadge+")")
}
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
@ -402,11 +467,16 @@ func cursorForCurrent(items []SelectItem, current string) int {
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
return SelectSingleWithUpdates(title, items, current, nil)
}
func SelectSingleWithUpdates(title string, items []SelectItem, current string, updates <-chan []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModelWithCurrent(title, items, current)
m.updates = updates
p := tea.NewProgram(m)
finalModel, err := p.Run()
@ -426,6 +496,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
type multiSelectorModel struct {
title string
items []SelectItem
updates <-chan []SelectItem
itemIndex map[string]int
filter string
cursor int
@ -475,6 +546,36 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
return m
}
func (m *multiSelectorModel) rebuildItemIndex() {
m.itemIndex = make(map[string]int, len(m.items))
for i, item := range m.items {
m.itemIndex[item.Name] = i
}
}
func (m *multiSelectorModel) replaceItems(items []SelectItem) {
current := currentItemName(m.filteredItems(), m.cursor)
checkedNames := make([]string, 0, len(m.checkOrder))
for _, idx := range m.checkOrder {
if idx >= 0 && idx < len(m.items) {
checkedNames = append(checkedNames, m.items[idx].Name)
}
}
m.items = items
m.rebuildItemIndex()
m.checked = make(map[int]bool, len(checkedNames))
m.checkOrder = nil
for _, name := range checkedNames {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
m.updateScroll(m.otherStart())
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
@ -590,7 +691,7 @@ func (m multiSelectorModel) selectedCount() int {
}
func (m multiSelectorModel) Init() tea.Cmd {
return nil
return waitForSelectorItems(m.updates)
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@ -603,6 +704,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
return m, nil
case selectorItemsUpdatedMsg:
m.replaceItems(msg.items)
return m, waitForSelectorItems(m.updates)
case tea.KeyMsg:
filtered := m.filteredItems()
@ -689,6 +794,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
@ -716,6 +822,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(check + item.Name))
}
@ -841,11 +948,16 @@ func (m multiSelectorModel) View() string {
}
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
return SelectMultipleWithUpdates(title, items, preChecked, nil)
}
func SelectMultipleWithUpdates(title string, items []SelectItem, preChecked []string, updates <-chan []SelectItem) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
m.updates = updates
p := tea.NewProgram(m)
finalModel, err := p.Run()

View file

@ -311,6 +311,91 @@ func TestRenderContent_SelectedItemIndicator(t *testing.T) {
}
}
func TestRenderContent_AvailabilityBadgeOnlyOnCursor(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: []SelectItem{
{Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"},
{Name: "qwen3.5:cloud", AvailabilityBadge: "Sign in required"},
{Name: "glm-5:cloud", AvailabilityBadge: "Included"},
},
cursor: 0,
}
content := m.renderContent()
if !strings.Contains(content, "(Upgrade required)") {
t.Fatalf("cursor badge missing:\n%s", content)
}
if strings.Contains(content, "(Sign in required)") {
t.Fatalf("non-cursor badge should not render:\n%s", content)
}
if strings.Contains(content, "Included") {
t.Fatalf("included badge should not render:\n%s", content)
}
}
func TestSelectorModel_ItemsUpdatedPreservesCursorAndRendersBadge(t *testing.T) {
m := selectorModelWithCurrent("Pick:", []SelectItem{
{Name: "kimi-k2.6:cloud", Recommended: true},
{Name: "llama3.2"},
}, "kimi-k2.6:cloud")
updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{
{Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"},
{Name: "llama3.2"},
}})
fm := updated.(selectorModel)
if fm.cursor != 0 {
t.Fatalf("cursor = %d, want 0", fm.cursor)
}
content := fm.renderContent()
if !strings.Contains(content, "(Upgrade required)") {
t.Fatalf("updated badge missing:\n%s", content)
}
}
func TestMultiSelector_AvailabilityBadgePreservesDefaultSuffix(t *testing.T) {
m := newMultiSelectorModel("Pick:", []SelectItem{
{Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"},
{Name: "qwen3.5:cloud"},
}, []string{"kimi-k2.6:cloud"})
m.multi = true
m.cursor = 0
content := m.View()
if !strings.Contains(content, "(Upgrade required)") {
t.Fatalf("cursor badge missing:\n%s", content)
}
if !strings.Contains(content, "(default)") {
t.Fatalf("default suffix missing:\n%s", content)
}
}
func TestMultiSelector_ItemsUpdatedPreservesCheckedStateAndRendersBadge(t *testing.T) {
m := newMultiSelectorModel("Pick:", []SelectItem{
{Name: "kimi-k2.6:cloud", Recommended: true},
{Name: "llama3.2"},
}, []string{"kimi-k2.6:cloud"})
m.multi = true
updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{
{Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"},
{Name: "llama3.2"},
}})
fm := updated.(multiSelectorModel)
idx := fm.itemIndex["kimi-k2.6:cloud"]
if !fm.checked[idx] {
t.Fatalf("checked state was not preserved: %#v", fm.checked)
}
content := fm.View()
if !strings.Contains(content, "(Upgrade required)") {
t.Fatalf("updated badge missing:\n%s", content)
}
if !strings.Contains(content, "(default)") {
t.Fatalf("default suffix missing after update:\n%s", content)
}
}
func TestRenderContent_Description(t *testing.T) {
m := selectorModel{
title: "Pick:",

View file

@ -19,6 +19,14 @@ type signInCheckMsg struct {
userName string
}
type upgradeTickMsg struct{}
type upgradeCheckMsg struct {
upgraded bool
plan string
err error
}
type signInModel struct {
modelName string
signInURL string
@ -28,6 +36,18 @@ type signInModel struct {
cancelled bool
}
type upgradeModel struct {
modelName string
requiredPlan string
spinner int
width int
openNow bool
polling bool
plan string
cancelled bool
err error
}
func (m signInModel) Init() tea.Cmd {
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
@ -82,6 +102,85 @@ func (m signInModel) View() string {
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
}
func (m upgradeModel) Init() tea.Cmd {
if m.polling {
return upgradeTickCmd()
}
return nil
}
func (m upgradeModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyLeft:
if !m.polling {
m.openNow = true
}
case tea.KeyRight:
if !m.polling {
m.openNow = false
}
case tea.KeyEnter:
if !m.polling {
if !m.openNow {
m.cancelled = true
return m, tea.Quit
}
launch.OpenBrowser(launch.DefaultUpgradeURL)
m.polling = true
return m, upgradeTickCmd()
}
}
case upgradeTickMsg:
if !m.polling {
return m, nil
}
m.spinner++
if m.spinner%5 == 0 {
return m, tea.Batch(
upgradeTickCmd(),
checkUpgrade(m.requiredPlan),
)
}
return m, upgradeTickCmd()
case upgradeCheckMsg:
if msg.err != nil {
m.err = msg.err
return m, tea.Quit
}
if msg.upgraded {
m.plan = msg.plan
return m, tea.Quit
}
}
return m, nil
}
func (m upgradeModel) View() string {
if m.plan != "" {
return ""
}
if m.err != nil {
return ""
}
return renderUpgrade(m.modelName, m.spinner, m.width, m.polling, m.openNow)
}
func renderSignIn(modelName, signInURL string, spinner, width int) string {
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
frame := spinnerFrames[spinner%len(spinnerFrames)]
@ -110,18 +209,88 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
func upgradeTickCmd() tea.Cmd {
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return upgradeTickMsg{}
})
}
func renderUpgrade(modelName string, spinner, width int, polling, openNow bool) string {
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
frame := spinnerFrames[spinner%len(spinnerFrames)]
urlColor := lipgloss.NewStyle().
Foreground(lipgloss.Color("117"))
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
if width > 4 {
urlWrap = urlWrap.Width(width - 4)
}
var s strings.Builder
fmt.Fprintf(&s, "To use %s, upgrade your Ollama plan.\n\n", selectorSelectedItemStyle.Render(modelName))
s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(urlColor.Render(launch.DefaultUpgradeURL)))
s.WriteString("\n\n")
if !polling {
var yesBtn, noBtn string
if openNow {
yesBtn = confirmActiveStyle.Render(" Yes ")
noBtn = confirmInactiveStyle.Render(" No ")
} else {
yesBtn = confirmInactiveStyle.Render(" Yes ")
noBtn = confirmActiveStyle.Render(" No ")
}
s.WriteString("Open now?\n")
s.WriteString(" " + yesBtn + " " + noBtn)
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel"))
} else {
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
frame + " Waiting for upgrade to complete..."))
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("esc cancel"))
}
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
func checkUpgrade(requiredPlan string) tea.Cmd {
return func() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
user, err := client.Whoami(ctx)
if err != nil {
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
}
if err == nil && user != nil && user.Name != "" && launch.PlanSatisfies(user.Plan, requiredPlan) {
return upgradeCheckMsg{upgraded: true, plan: user.Plan}
}
return upgradeCheckMsg{upgraded: false}
}
}
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
func RunSignIn(modelName, signInURL string) (string, error) {
launch.OpenBrowser(signInURL)
@ -144,3 +313,28 @@ func RunSignIn(modelName, signInURL string) (string, error) {
return fm.userName, nil
}
// RunUpgrade shows a bubbletea upgrade dialog and polls until the user's plan is updated or cancelled.
func RunUpgrade(modelName, requiredPlan string) (string, error) {
m := upgradeModel{
modelName: modelName,
requiredPlan: requiredPlan,
openNow: true,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running upgrade: %w", err)
}
fm := finalModel.(upgradeModel)
if fm.cancelled {
return "", ErrCancelled
}
if fm.err != nil {
return "", fm.err
}
return fm.plan, nil
}

View file

@ -5,6 +5,7 @@ import (
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/ollama/ollama/cmd/launch"
)
func TestRenderSignIn_ContainsModelName(t *testing.T) {
@ -50,6 +51,35 @@ func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
}
}
func TestRenderUpgrade_AsksBeforeOpening(t *testing.T) {
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, false, true)
if !strings.Contains(got, "kimi-k2.6:cloud") {
t.Error("should contain model name")
}
if !strings.Contains(got, launch.DefaultUpgradeURL) {
t.Error("should contain upgrade URL")
}
if !strings.Contains(got, "Open now?") {
t.Error("should ask before opening")
}
if !strings.Contains(got, "Yes") || !strings.Contains(got, "No") {
t.Error("should show yes/no selector")
}
if strings.Contains(got, "Waiting for upgrade to complete") {
t.Error("should not start waiting before open choice is confirmed")
}
}
func TestRenderUpgrade_PollingShowsWaiting(t *testing.T) {
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, true, true)
if !strings.Contains(got, "Waiting for upgrade to complete") {
t.Error("should contain waiting message")
}
if strings.Contains(got, "Open now?") {
t.Error("should not show open prompt while polling")
}
}
func TestSignInModel_EscCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
@ -66,6 +96,35 @@ func TestSignInModel_EscCancels(t *testing.T) {
}
}
func TestUpgradeModel_NoCancelsWithoutPolling(t *testing.T) {
m := upgradeModel{
modelName: "kimi-k2.6:cloud",
requiredPlan: "pro",
openNow: true,
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
fm := updated.(upgradeModel)
if fm.openNow {
t.Error("right should select no")
}
if fm.polling {
t.Error("right should not start polling")
}
updated, cmd := fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm = updated.(upgradeModel)
if !fm.cancelled {
t.Error("enter on no should cancel")
}
if fm.polling {
t.Error("enter on no should not start polling")
}
if cmd == nil {
t.Error("enter on no should quit")
}
}
func TestSignInModel_CtrlCCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",

View file

@ -20,9 +20,7 @@ import (
"github.com/ollama/ollama/format"
)
const (
modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
)
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
var (
modelRecommendationsRefreshInterval = 4 * time.Hour
@ -323,6 +321,7 @@ func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRe
for _, rec := range recs {
rec.Model = strings.TrimSpace(rec.Model)
rec.Description = strings.TrimSpace(rec.Description)
rec.RequiredPlan = strings.TrimSpace(rec.RequiredPlan)
if rec.Model == "" {
return nil, errors.New("recommendation missing model")

View file

@ -255,7 +255,7 @@ func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
input := []api.ModelRecommendation{
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256},
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: " pro "},
{Model: "bad-cloud:cloud", Description: "missing limits"},
{Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte},
}
@ -266,7 +266,7 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing
}
want := []api.ModelRecommendation{
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256},
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: "pro"},
{Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte},
}
if !slices.Equal(got, want) {
@ -274,6 +274,38 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing
}
}
func TestValidateModelRecommendationsDoesNotSynthesizeRequiredPlans(t *testing.T) {
input := []api.ModelRecommendation{
{Model: "kimi-k2.6:cloud", Description: "coding", ContextLength: 262_144, MaxOutputTokens: 262_144},
{Model: "qwen3.5:cloud", Description: "reasoning", ContextLength: 262_144, MaxOutputTokens: 32_768},
{Model: "custom:cloud", Description: "custom", ContextLength: 4096, MaxOutputTokens: 1024},
{Model: "minimax-m2.7:cloud", Description: "custom", ContextLength: 204_800, MaxOutputTokens: 128_000, RequiredPlan: "team"},
}
got, err := validateModelRecommendations(input)
if err != nil {
t.Fatalf("validateModelRecommendations failed: %v", err)
}
byName := make(map[string]api.ModelRecommendation, len(got))
for _, rec := range got {
byName[rec.Model] = rec
}
if rec := byName["kimi-k2.6:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("kimi required plan should not be synthesized: %#v", rec)
}
if rec := byName["qwen3.5:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("qwen required plan should not be synthesized: %#v", rec)
}
if rec := byName["custom:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("custom required plan should not be synthesized: %#v", rec)
}
if rec := byName["minimax-m2.7:cloud"]; rec.RequiredPlan != "team" {
t.Fatalf("explicit required plan should not be overwritten: %#v", rec)
}
}
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)

View file

@ -2044,11 +2044,30 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
client := api.NewClient(u, http.DefaultClient)
user, err := client.Whoami(c)
if err != nil {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
// Preserve an actionable sign-in response for launch; other failures
// below mean account or plan verification is temporarily unavailable.
sURL := authErr.SigninURL
if sURL == "" {
var sErr error
sURL, sErr = signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
slog.Error(err.Error())
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "account unavailable"})
return
}
// user isn't signed in
if user != nil && user.Name == "" {
if user == nil || user.Name == "" {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
@ -2060,6 +2079,10 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
return
}
if strings.TrimSpace(user.Plan) == "" {
slog.Warn("account plan was not set; defaulting to free")
user.Plan = "free"
}
c.JSON(http.StatusOK, user)
}