mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
launch: add plan-aware model gating (#16027)
This commit is contained in:
parent
7c2c36bda2
commit
bab59072fb
17 changed files with 1747 additions and 138 deletions
|
|
@ -814,6 +814,7 @@ type ModelRecommendation struct {
|
|||
ContextLength int `json:"context_length,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
VRAMBytes int64 `json:"vram_bytes,omitempty"`
|
||||
RequiredPlan string `json:"required_plan,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessResponse is the response from [Client.Process].
|
||||
|
|
|
|||
122
cmd/cmd.go
122
cmd/cmd.go
|
|
@ -61,28 +61,20 @@ import (
|
|||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
launch.DefaultSingleSelector = func(title string, items []launch.SelectionItem, current string) (string, error) {
|
||||
return runTUISingleSelector(title, items, current, nil)
|
||||
}
|
||||
|
||||
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
launch.DefaultSingleSelectorWithUpdates = func(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) {
|
||||
return runTUISingleSelector(title, items, current, updates)
|
||||
}
|
||||
|
||||
launch.DefaultMultiSelector = func(title string, items []launch.SelectionItem, preChecked []string) ([]string, error) {
|
||||
return runTUIMultiSelector(title, items, preChecked, nil)
|
||||
}
|
||||
|
||||
launch.DefaultMultiSelectorWithUpdates = func(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) {
|
||||
return runTUIMultiSelector(title, items, preChecked, updates)
|
||||
}
|
||||
|
||||
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
|
|
@ -93,9 +85,55 @@ func init() {
|
|||
return userName, err
|
||||
}
|
||||
|
||||
launch.DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
|
||||
plan, err := tui.RunUpgrade(modelName, requiredPlan)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return plan, err
|
||||
}
|
||||
|
||||
launch.DefaultConfirmPrompt = tui.RunConfirmWithOptions
|
||||
}
|
||||
|
||||
func runTUISingleSelector(title string, items []launch.SelectionItem, current string, updates <-chan []launch.SelectionItem) (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingleWithUpdates(title, tuiItems, current, convertSelectionItemUpdates(updates))
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func runTUIMultiSelector(title string, items []launch.SelectionItem, preChecked []string, updates <-chan []launch.SelectionItem) ([]string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultipleWithUpdates(title, tuiItems, preChecked, convertSelectionItemUpdates(updates))
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func convertSelectionItemUpdates(updates <-chan []launch.SelectionItem) <-chan []tui.SelectItem {
|
||||
if updates == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(chan []tui.SelectItem, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
for items := range updates {
|
||||
out <- tui.ReorderItems(tui.ConvertItems(items))
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
|
|
@ -2090,12 +2128,15 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||
return
|
||||
}
|
||||
|
||||
accountPrefetch := launch.StartAccountStatePrefetch(cmd.Context())
|
||||
deps := launcherDeps{
|
||||
buildState: launch.BuildLauncherState,
|
||||
runMenu: tui.RunMenu,
|
||||
resolveRunModel: launch.ResolveRunModel,
|
||||
launchIntegration: launch.LaunchIntegration,
|
||||
runModel: launchInteractiveModel,
|
||||
buildState: launch.BuildLauncherState,
|
||||
runMenu: tui.RunMenu,
|
||||
resolveRunModel: launch.ResolveRunModel,
|
||||
launchIntegration: launch.LaunchIntegration,
|
||||
runModel: launchInteractiveModel,
|
||||
accountState: accountPrefetch.StateIfReady,
|
||||
accountStateUpdates: accountPrefetch.StateUpdates,
|
||||
}
|
||||
|
||||
for {
|
||||
|
|
@ -2110,11 +2151,13 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||
}
|
||||
|
||||
type launcherDeps struct {
|
||||
buildState func(context.Context) (*launch.LauncherState, error)
|
||||
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||
runModel func(*cobra.Command, string) error
|
||||
buildState func(context.Context) (*launch.LauncherState, error)
|
||||
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||
runModel func(*cobra.Command, string) error
|
||||
accountState func() *launch.AccountState
|
||||
accountStateUpdates func(context.Context) <-chan *launch.AccountState
|
||||
}
|
||||
|
||||
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||
|
|
@ -2122,6 +2165,9 @@ func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error)
|
|||
if err != nil {
|
||||
return false, fmt.Errorf("build launcher state: %w", err)
|
||||
}
|
||||
if state != nil && deps.accountState != nil {
|
||||
state.AccountState = deps.accountState()
|
||||
}
|
||||
|
||||
action, err := deps.runMenu(state)
|
||||
if err != nil {
|
||||
|
|
@ -2142,7 +2188,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe
|
|||
return false, nil
|
||||
case tui.TUIActionRunModel:
|
||||
saveLauncherSelection(action)
|
||||
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||
req := action.RunModelRequest()
|
||||
if deps.accountState != nil {
|
||||
req.AccountState = deps.accountState()
|
||||
req.AccountStateProvider = deps.accountState
|
||||
}
|
||||
req.AccountStateUpdates = deps.accountStateUpdates
|
||||
modelName, err := deps.resolveRunModel(cmd.Context(), req)
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -2155,7 +2207,13 @@ func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDe
|
|||
return true, nil
|
||||
case tui.TUIActionLaunchIntegration:
|
||||
saveLauncherSelection(action)
|
||||
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||
req := action.IntegrationLaunchRequest()
|
||||
if deps.accountState != nil {
|
||||
req.AccountState = deps.accountState()
|
||||
req.AccountStateProvider = deps.accountState
|
||||
}
|
||||
req.AccountStateUpdates = deps.accountStateUpdates
|
||||
err := deps.launchIntegration(cmd.Context(), req)
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,11 +76,18 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
|||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
prefetchedAccount := &launch.AccountState{}
|
||||
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
if state.AccountState != prefetchedAccount {
|
||||
t.Fatalf("prefetched account state was not piped to menu state")
|
||||
}
|
||||
return runMenu(state)
|
||||
},
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return tt.wantModel, nil
|
||||
|
|
@ -90,6 +97,10 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
|||
launched = model
|
||||
return nil
|
||||
},
|
||||
accountState: func() *launch.AccountState {
|
||||
return prefetchedAccount
|
||||
},
|
||||
accountStateUpdates: accountUpdates,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
|
|
@ -107,6 +118,12 @@ func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
|||
if gotReq.ForcePicker != tt.wantForce {
|
||||
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||
}
|
||||
if gotReq.AccountState != prefetchedAccount {
|
||||
t.Fatalf("expected prefetched account state to be passed to run model request")
|
||||
}
|
||||
if gotReq.AccountStateUpdates == nil {
|
||||
t.Fatalf("expected account state updates to be passed to run model request")
|
||||
}
|
||||
if launched != tt.wantModel {
|
||||
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||
}
|
||||
|
|
@ -148,17 +165,28 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T)
|
|||
}
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
prefetchedAccount := &launch.AccountState{}
|
||||
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
if state.AccountState != prefetchedAccount {
|
||||
t.Fatalf("prefetched account state was not piped to menu state")
|
||||
}
|
||||
return runMenu(state)
|
||||
},
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
accountState: func() *launch.AccountState {
|
||||
return prefetchedAccount
|
||||
},
|
||||
accountStateUpdates: accountUpdates,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
|
|
@ -179,6 +207,12 @@ func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T)
|
|||
if gotReq.ForceConfigure != tt.wantForce {
|
||||
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||
}
|
||||
if gotReq.AccountState != prefetchedAccount {
|
||||
t.Fatalf("expected prefetched account state to be passed to integration request")
|
||||
}
|
||||
if gotReq.AccountStateUpdates == nil {
|
||||
t.Fatalf("expected account state updates to be passed to integration request")
|
||||
}
|
||||
if got := config.LastSelection(); got != "claude" {
|
||||
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||
}
|
||||
|
|
|
|||
371
cmd/launch/account.go
Normal file
371
cmd/launch/account.go
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultUpgradeURL is the fixed destination for subscription upgrades.
|
||||
DefaultUpgradeURL = "https://ollama.com/upgrade"
|
||||
|
||||
accountCheckTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPlanVerificationUnavailable = errors.New("Could not verify your plan. Try again in a moment.")
|
||||
errUpgradeCancelled = errors.New("upgrade cancelled")
|
||||
)
|
||||
|
||||
type accountStateStatus int
|
||||
|
||||
const (
|
||||
accountStateUnknown accountStateStatus = iota
|
||||
accountStateSignedOut
|
||||
accountStateSignedIn
|
||||
)
|
||||
|
||||
type AccountState struct {
|
||||
Status accountStateStatus
|
||||
Plan string
|
||||
}
|
||||
|
||||
type AccountStatePrefetch struct {
|
||||
done chan struct{}
|
||||
state AccountState
|
||||
}
|
||||
|
||||
func StartAccountStatePrefetch(ctx context.Context) *AccountStatePrefetch {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
p := &AccountStatePrefetch{done: make(chan struct{})}
|
||||
go func() {
|
||||
state := AccountState{Status: accountStateUnknown}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err == nil {
|
||||
prefetchCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
|
||||
defer cancel()
|
||||
if disabled, known := cloudStatusDisabled(prefetchCtx, client); !known || !disabled {
|
||||
state = launchAccountState(prefetchCtx, client)
|
||||
}
|
||||
}
|
||||
p.state = state
|
||||
close(p.done)
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *AccountStatePrefetch) StateIfReady() *AccountState {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
state := p.state
|
||||
return &state
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountStatePrefetch) StateUpdates(ctx context.Context) <-chan *AccountState {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
out := make(chan *AccountState, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
select {
|
||||
case <-p.done:
|
||||
if p.state.Status == accountStateUnknown {
|
||||
return
|
||||
}
|
||||
state := p.state
|
||||
select {
|
||||
case out <- &state:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func launchAccountState(ctx context.Context, client *api.Client) AccountState {
|
||||
if client == nil {
|
||||
return AccountState{Status: accountStateUnknown}
|
||||
}
|
||||
|
||||
user, err := whoamiWithTimeout(ctx, client)
|
||||
if err != nil {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
return AccountState{Status: accountStateSignedOut}
|
||||
}
|
||||
return AccountState{Status: accountStateUnknown}
|
||||
}
|
||||
if user == nil || strings.TrimSpace(user.Name) == "" {
|
||||
return AccountState{Status: accountStateSignedOut}
|
||||
}
|
||||
return AccountState{
|
||||
Status: accountStateSignedIn,
|
||||
Plan: strings.TrimSpace(user.Plan),
|
||||
}
|
||||
}
|
||||
|
||||
func whoamiWithTimeout(ctx context.Context, client *api.Client) (*api.UserResponse, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
checkCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
|
||||
defer cancel()
|
||||
return client.Whoami(checkCtx)
|
||||
}
|
||||
|
||||
func ApplyAccountStateToSelectionItems(items []ModelItem, state AccountState) []SelectionItem {
|
||||
out := make([]SelectionItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectionItem{
|
||||
Name: item.Name,
|
||||
Description: item.Description,
|
||||
Recommended: item.Recommended,
|
||||
AvailabilityBadge: availabilityBadge(item, state),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func SelectionItemsWithAccountState(items []ModelItem, state *AccountState) []SelectionItem {
|
||||
if state == nil || !selectionItemsNeedAccountState(items) {
|
||||
return ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
|
||||
}
|
||||
return ApplyAccountStateToSelectionItems(items, *state)
|
||||
}
|
||||
|
||||
func selectionItemsNeedAccountState(items []ModelItem) bool {
|
||||
for _, item := range items {
|
||||
if isCloudModelName(item.Name) && itemHasRecommendationMetadata(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectionItemUpdates(ctx context.Context, items []ModelItem, state *AccountState) <-chan []SelectionItem {
|
||||
if !selectionItemsNeedAccountState(items) || state != nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
stateUpdates := c.accountStateUpdateSource(ctx)
|
||||
if stateUpdates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(chan []SelectionItem, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
select {
|
||||
case state, ok := <-stateUpdates:
|
||||
if !ok || state == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case out <- SelectionItemsWithAccountState(items, state):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *launcherClient) accountStateUpdateSource(ctx context.Context) <-chan *AccountState {
|
||||
if c.accountStateUpdates != nil {
|
||||
return c.accountStateUpdates(ctx)
|
||||
}
|
||||
if c.apiClient == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(chan *AccountState, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case out <- &state:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func availabilityBadge(item ModelItem, state AccountState) string {
|
||||
if !isCloudModelName(item.Name) {
|
||||
return ""
|
||||
}
|
||||
switch state.Status {
|
||||
case accountStateSignedOut:
|
||||
if itemHasRecommendationMetadata(item) {
|
||||
return "Sign in required"
|
||||
}
|
||||
case accountStateSignedIn:
|
||||
if item.RequiredPlan != "" && !PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return "Upgrade required"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func itemHasRecommendationMetadata(item ModelItem) bool {
|
||||
return item.Recommended || strings.TrimSpace(item.RequiredPlan) != ""
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureCloudModelAccess(ctx context.Context, model string) error {
|
||||
item, ok := c.modelRecommendationItem(ctx, model)
|
||||
if !ok || strings.TrimSpace(item.RequiredPlan) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status != accountStateUnknown {
|
||||
c.accountState = &state
|
||||
}
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
|
||||
if state.Status == accountStateSignedOut {
|
||||
if err := ensureCloudAuth(ctx, c.apiClient, model); err != nil {
|
||||
return err
|
||||
}
|
||||
state = launchAccountState(ctx, c.apiClient)
|
||||
if state.Status != accountStateUnknown {
|
||||
c.accountState = &state
|
||||
}
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.runUpgradeFlow(ctx, item); err != nil {
|
||||
return err
|
||||
}
|
||||
state = launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
if state.Status != accountStateSignedIn || !PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) modelRecommendationItem(ctx context.Context, model string) (ModelItem, bool) {
|
||||
for _, item := range c.recommendations(ctx) {
|
||||
if item.Name == model {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return ModelItem{}, false
|
||||
}
|
||||
|
||||
func (c *launcherClient) runUpgradeFlow(ctx context.Context, item ModelItem) error {
|
||||
if DefaultUpgrade != nil {
|
||||
if _, err := DefaultUpgrade(item.Name, item.RequiredPlan); err != nil {
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("Upgrade to use %s?", item.Name))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo upgrade, navigate to:\n %s\n\n", DefaultUpgradeURL)
|
||||
openNow, err := ConfirmPrompt("Open now?")
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if openNow {
|
||||
OpenBrowser(DefaultUpgradeURL)
|
||||
} else {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
if frame%10 != 0 {
|
||||
continue
|
||||
}
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
if state.Status == accountStateSignedIn && PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1mplan updated\033[0m\n")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PlanSatisfies reports whether currentPlan can use a model that has a requiredPlan.
|
||||
func PlanSatisfies(currentPlan, requiredPlan string) bool {
|
||||
required := normalizePlan(requiredPlan)
|
||||
if required == "" || required == "free" {
|
||||
return true
|
||||
}
|
||||
current := normalizePlan(currentPlan)
|
||||
return current != "" && current != "free"
|
||||
}
|
||||
|
||||
func normalizePlan(plan string) string {
|
||||
return strings.ToLower(strings.TrimSpace(plan))
|
||||
}
|
||||
|
|
@ -319,7 +319,7 @@ func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
|||
|
||||
var selectorCalls int
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
gotCurrent = current
|
||||
return "llama3.2", nil
|
||||
|
|
@ -553,7 +553,7 @@ func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T)
|
|||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
gotCurrent = current
|
||||
return "qwen3:8b", nil
|
||||
}
|
||||
|
|
@ -607,7 +607,7 @@ func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T
|
|||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -644,7 +644,7 @@ func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testin
|
|||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ import (
|
|||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
|
|
@ -456,6 +458,28 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_PreservesRecommendationRequiredPlanForExistingCloudModel(t *testing.T) {
|
||||
recommendations := []ModelItem{
|
||||
{
|
||||
Name: "glm-5:cloud",
|
||||
Description: "Reasoning and code generation",
|
||||
Recommended: true,
|
||||
RequiredPlan: "pro",
|
||||
ContextLength: 202_752,
|
||||
},
|
||||
}
|
||||
existing := []modelInfo{{Name: "glm-5:cloud", Remote: true}}
|
||||
|
||||
items, _, _, _ := buildModelListWithRecommendations(existing, recommendations, nil, "")
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected one item, got %v", items)
|
||||
}
|
||||
item := items[0]
|
||||
if item.RequiredPlan != "pro" {
|
||||
t.Fatalf("RequiredPlan = %q, want pro", item.RequiredPlan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "gemma4", Remote: false},
|
||||
|
|
@ -1390,6 +1414,187 @@ func TestEnsureAuth_EmptyWhoamiRequiresSignIn(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestApplyAccountStateToSelectionItems_BadgesOnlyWhenActionRequired(t *testing.T) {
|
||||
items := []ModelItem{
|
||||
{Name: "qwen3.5:cloud", Recommended: true},
|
||||
{Name: "kimi-k2.6:cloud", Recommended: true, RequiredPlan: "pro"},
|
||||
{Name: "llama3.2", RequiredPlan: "pro"},
|
||||
{Name: "glm-5:cloud"},
|
||||
{Name: "nemotron-3-super:cloud", Recommended: true, RequiredPlan: "free"},
|
||||
}
|
||||
|
||||
signedOut := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedOut})
|
||||
if signedOut[0].AvailabilityBadge != "Sign in required" {
|
||||
t.Fatalf("account cloud badge = %q", signedOut[0].AvailabilityBadge)
|
||||
}
|
||||
if signedOut[1].AvailabilityBadge != "Sign in required" {
|
||||
t.Fatalf("subscription cloud signed-out badge = %q", signedOut[1].AvailabilityBadge)
|
||||
}
|
||||
if signedOut[4].AvailabilityBadge != "Sign in required" {
|
||||
t.Fatalf("free-plan cloud signed-out badge = %q", signedOut[4].AvailabilityBadge)
|
||||
}
|
||||
if signedOut[2].AvailabilityBadge != "" || signedOut[3].AvailabilityBadge != "" {
|
||||
t.Fatalf("unexpected badge for local or unmetadata item: %#v", signedOut)
|
||||
}
|
||||
|
||||
freeUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "free"})
|
||||
if freeUser[0].AvailabilityBadge != "" {
|
||||
t.Fatalf("signed-in account model should not be badged, got %q", freeUser[0].AvailabilityBadge)
|
||||
}
|
||||
if freeUser[1].AvailabilityBadge != "Upgrade required" {
|
||||
t.Fatalf("subscription cloud free-plan badge = %q", freeUser[1].AvailabilityBadge)
|
||||
}
|
||||
if freeUser[4].AvailabilityBadge != "" {
|
||||
t.Fatalf("free required plan should be usable by free user, got %q", freeUser[4].AvailabilityBadge)
|
||||
}
|
||||
|
||||
proUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "pro"})
|
||||
if proUser[1].AvailabilityBadge != "" {
|
||||
t.Fatalf("pro user should not see included badge, got %q", proUser[1].AvailabilityBadge)
|
||||
}
|
||||
|
||||
maxUser := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateSignedIn, Plan: "max"})
|
||||
if maxUser[1].AvailabilityBadge != "" {
|
||||
t.Fatalf("max user should not see upgrade badge, got %q", maxUser[1].AvailabilityBadge)
|
||||
}
|
||||
|
||||
unknown := ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
|
||||
for _, item := range unknown {
|
||||
if item.AvailabilityBadge != "" {
|
||||
t.Fatalf("unknown account state should not render badges: %#v", unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectionItemsWithAccountState_SkipsBadgesWithoutBadgeableCloudItems(t *testing.T) {
|
||||
items := []ModelItem{
|
||||
{Name: "llama3.2"},
|
||||
{Name: "custom:cloud"},
|
||||
}
|
||||
state := &AccountState{Status: accountStateSignedOut}
|
||||
got := SelectionItemsWithAccountState(items, state)
|
||||
if len(got) != len(items) {
|
||||
t.Fatalf("got %d selection items, want %d", len(got), len(items))
|
||||
}
|
||||
for _, item := range got {
|
||||
if item.AvailabilityBadge != "" {
|
||||
t.Fatalf("unexpected badge without account state: %#v", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectionItemsWithAccountState_UsesPrefetchedStateForRecommendedCloudItems(t *testing.T) {
|
||||
state := &AccountState{Status: accountStateSignedOut}
|
||||
got := SelectionItemsWithAccountState([]ModelItem{{Name: "qwen3.5:cloud", Recommended: true}}, state)
|
||||
if got[0].AvailabilityBadge != "Sign in required" {
|
||||
t.Fatalf("badge = %q, want Sign in required", got[0].AvailabilityBadge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecommendedModelsDoNotIncludeRequiredPlanStubs(t *testing.T) {
|
||||
byName := make(map[string]ModelItem, len(recommendedModels))
|
||||
for _, item := range recommendedModels {
|
||||
byName[item.Name] = item
|
||||
}
|
||||
|
||||
if item := byName["kimi-k2.6:cloud"]; item.RequiredPlan != "" {
|
||||
t.Fatalf("kimi fallback required plan should not be stubbed: %#v", item)
|
||||
}
|
||||
if item := byName["minimax-m2.7:cloud"]; item.RequiredPlan != "" {
|
||||
t.Fatalf("minimax fallback required plan should not be stubbed: %#v", item)
|
||||
}
|
||||
if item := byName["qwen3.5:cloud"]; item.RequiredPlan != "" {
|
||||
t.Fatalf("qwen fallback required plan = %#v", item)
|
||||
}
|
||||
if item := byName["glm-5.1:cloud"]; item.RequiredPlan != "" {
|
||||
t.Fatalf("glm fallback required plan = %#v", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchAccountState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
wantStatus accountStateStatus
|
||||
wantPlan string
|
||||
}{
|
||||
{
|
||||
name: "signed in",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"name":"parth","plan":"pro"}`,
|
||||
wantStatus: accountStateSignedIn,
|
||||
wantPlan: "pro",
|
||||
},
|
||||
{
|
||||
name: "signed out",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
body: `{"error":"unauthorized","signin_url":"https://example.com/signin"}`,
|
||||
wantStatus: accountStateSignedOut,
|
||||
},
|
||||
{
|
||||
name: "unreachable",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: `{"error":"temporary failure"}`,
|
||||
wantStatus: accountStateUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/me" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(tt.statusCode)
|
||||
fmt.Fprint(w, tt.body)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
got := launchAccountState(context.Background(), api.NewClient(u, srv.Client()))
|
||||
if got.Status != tt.wantStatus {
|
||||
t.Fatalf("Status = %v, want %v", got.Status, tt.wantStatus)
|
||||
}
|
||||
if got.Plan != tt.wantPlan {
|
||||
t.Fatalf("Plan = %q, want %q", got.Plan, tt.wantPlan)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAccountStatePrefetch_SkipsWhoamiWhenCloudDisabled(t *testing.T) {
|
||||
var whoamiCalled atomic.Bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/me":
|
||||
whoamiCalled.Store(true)
|
||||
http.NotFound(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
prefetch := StartAccountStatePrefetch(context.Background())
|
||||
select {
|
||||
case <-prefetch.done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("account prefetch did not finish")
|
||||
}
|
||||
if whoamiCalled.Load() {
|
||||
t.Fatal("prefetch should not call whoami when cloud is disabled")
|
||||
}
|
||||
state := prefetch.StateIfReady()
|
||||
if state == nil || state.Status != accountStateUnknown {
|
||||
t.Fatalf("prefetch state = %#v, want unknown", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) {
|
||||
oldSignIn := DefaultSignIn
|
||||
DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ type LauncherState struct {
|
|||
RunModel string
|
||||
RunModelUsable bool
|
||||
Integrations map[string]LauncherIntegrationState
|
||||
AccountState *AccountState
|
||||
}
|
||||
|
||||
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||
|
|
@ -41,8 +42,11 @@ type LauncherIntegrationState struct {
|
|||
|
||||
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||
type RunModelRequest struct {
|
||||
ForcePicker bool
|
||||
Policy *LaunchPolicy
|
||||
ForcePicker bool
|
||||
Policy *LaunchPolicy
|
||||
AccountState *AccountState
|
||||
AccountStateProvider func() *AccountState
|
||||
AccountStateUpdates func(context.Context) <-chan *AccountState
|
||||
}
|
||||
|
||||
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||
|
|
@ -117,13 +121,16 @@ func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
|||
|
||||
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||
type IntegrationLaunchRequest struct {
|
||||
Name string
|
||||
ModelOverride string
|
||||
ForceConfigure bool
|
||||
ConfigureOnly bool
|
||||
Restore bool
|
||||
ExtraArgs []string
|
||||
Policy *LaunchPolicy
|
||||
Name string
|
||||
ModelOverride string
|
||||
ForceConfigure bool
|
||||
ConfigureOnly bool
|
||||
Restore bool
|
||||
ExtraArgs []string
|
||||
Policy *LaunchPolicy
|
||||
AccountState *AccountState
|
||||
AccountStateProvider func() *AccountState
|
||||
AccountStateUpdates func(context.Context) <-chan *AccountState
|
||||
}
|
||||
|
||||
var isInteractiveSession = func() bool {
|
||||
|
|
@ -241,7 +248,7 @@ type modelInfo struct {
|
|||
// ModelInfo re-exports launcher model inventory details for callers.
|
||||
type ModelInfo = modelInfo
|
||||
|
||||
// ModelItem represents a model for selection UIs.
|
||||
// ModelItem represents model metadata before selector-only UI state is derived.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
|
|
@ -249,6 +256,15 @@ type ModelItem struct {
|
|||
VRAMBytes int64
|
||||
ContextLength int
|
||||
MaxOutputTokens int
|
||||
RequiredPlan string
|
||||
}
|
||||
|
||||
// SelectionItem represents a model row after launch has derived selector-only UI state.
|
||||
type SelectionItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
AvailabilityBadge string
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
|
|
@ -384,10 +400,15 @@ func launchCommandCanSkipHeartbeat(args []string) bool {
|
|||
}
|
||||
|
||||
type launcherClient struct {
|
||||
apiClient *api.Client
|
||||
modelInventory []ModelInfo
|
||||
inventoryLoaded bool
|
||||
policy LaunchPolicy
|
||||
apiClient *api.Client
|
||||
modelInventory []ModelInfo
|
||||
inventoryLoaded bool
|
||||
recommendationsLoaded bool
|
||||
recommendationItems []ModelItem
|
||||
accountState *AccountState
|
||||
accountStateProvider func() *AccountState
|
||||
accountStateUpdates func(context.Context) <-chan *AccountState
|
||||
policy LaunchPolicy
|
||||
}
|
||||
|
||||
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||
|
|
@ -425,6 +446,9 @@ func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
launchClient.accountState = req.AccountState
|
||||
launchClient.accountStateProvider = req.AccountStateProvider
|
||||
launchClient.accountStateUpdates = req.AccountStateUpdates
|
||||
return launchClient.resolveRunModel(ctx, req)
|
||||
}
|
||||
|
||||
|
|
@ -449,6 +473,9 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
launchClient.accountState = req.AccountState
|
||||
launchClient.accountStateProvider = req.AccountStateProvider
|
||||
launchClient.accountStateUpdates = req.AccountStateUpdates
|
||||
|
||||
if autodiscovery, ok := runner.(ManagedAutodiscoveryIntegration); ok {
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
|
|
@ -811,7 +838,10 @@ func (c *launcherClient) managedAutodiscoveryUsable(ctx context.Context, autodis
|
|||
if !managedAutodiscoveryUsesOllamaCloud(autodiscovery) {
|
||||
return true
|
||||
}
|
||||
return c.ollamaCloudSignedIn(ctx)
|
||||
if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureManagedAutodiscoveryUsable(ctx context.Context, autodiscovery ManagedAutodiscoveryIntegration, label string) error {
|
||||
|
|
@ -858,14 +888,6 @@ func printRestoreSuccess(integration any) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *launcherClient) ollamaCloudSignedIn(ctx context.Context) bool {
|
||||
if disabled, known := cloudStatusDisabled(ctx, c.apiClient); known && disabled {
|
||||
return false
|
||||
}
|
||||
user, err := c.apiClient.Whoami(ctx)
|
||||
return err == nil && user != nil && user.Name != ""
|
||||
}
|
||||
|
||||
func (c *launcherClient) managedSingleConfigureModels(ctx context.Context, managed ManagedSingleModel, target string) ([]string, error) {
|
||||
models := []string{target}
|
||||
if _, ok := managed.(ManagedModelListConfigurer); !ok {
|
||||
|
|
@ -959,8 +981,15 @@ func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, titl
|
|||
return c.selectSingleModelWithSelectorReady(ctx, title, current, selector, true)
|
||||
}
|
||||
|
||||
func (c *launcherClient) latestAccountState() *AccountState {
|
||||
if c.accountStateProvider != nil {
|
||||
return c.accountStateProvider()
|
||||
}
|
||||
return c.accountState
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context, title, current string, selector SingleSelector, ensureReady bool) (string, error) {
|
||||
if selector == nil {
|
||||
if selector == nil && DefaultSingleSelectorWithUpdates == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
|
|
@ -969,45 +998,88 @@ func (c *launcherClient) selectSingleModelWithSelectorReady(ctx context.Context,
|
|||
return "", err
|
||||
}
|
||||
|
||||
selected, err := selector(title, items, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if selected == "" {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
if ensureReady {
|
||||
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||
for {
|
||||
accountState := c.latestAccountState()
|
||||
selectionItems := SelectionItemsWithAccountState(items, accountState)
|
||||
var updates <-chan []SelectionItem
|
||||
if DefaultSingleSelectorWithUpdates != nil {
|
||||
updates = c.selectionItemUpdates(ctx, items, accountState)
|
||||
}
|
||||
selected, err := runSingleSelector(title, selectionItems, current, updates, selector)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if selected == "" {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
if ensureReady {
|
||||
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||
if errors.Is(err, errUpgradeCancelled) {
|
||||
current = selected
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||
if DefaultMultiSelector == nil {
|
||||
if DefaultMultiSelector == nil && DefaultMultiSelectorWithUpdates == nil {
|
||||
return nil, fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
current := firstModel(preChecked)
|
||||
|
||||
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for {
|
||||
accountState := c.latestAccountState()
|
||||
selectionItems := SelectionItemsWithAccountState(items, accountState)
|
||||
var updates <-chan []SelectionItem
|
||||
if DefaultMultiSelectorWithUpdates != nil {
|
||||
updates = c.selectionItemUpdates(ctx, items, accountState)
|
||||
}
|
||||
selected, err := runMultiSelector(fmt.Sprintf("Select models for %s:", runner), selectionItems, orderedChecked, updates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
|
||||
if err != nil {
|
||||
if errors.Is(err, errUpgradeCancelled) {
|
||||
orderedChecked = append([]string(nil), selected...)
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
for _, skip := range skipped {
|
||||
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
|
||||
}
|
||||
return accepted, nil
|
||||
}
|
||||
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func runSingleSelector(title string, items []SelectionItem, current string, updates <-chan []SelectionItem, fallback SingleSelector) (string, error) {
|
||||
if DefaultSingleSelectorWithUpdates != nil {
|
||||
return DefaultSingleSelectorWithUpdates(title, items, current, updates)
|
||||
}
|
||||
for _, skip := range skipped {
|
||||
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
|
||||
if fallback == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
return accepted, nil
|
||||
return fallback(title, items, current)
|
||||
}
|
||||
|
||||
func runMultiSelector(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error) {
|
||||
if DefaultMultiSelectorWithUpdates != nil {
|
||||
return DefaultMultiSelectorWithUpdates(title, items, preChecked, updates)
|
||||
}
|
||||
if DefaultMultiSelector == nil {
|
||||
return nil, fmt.Errorf("no selector configured")
|
||||
}
|
||||
return DefaultMultiSelector(title, items, preChecked)
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||
|
|
@ -1029,16 +1101,24 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
|
|||
}
|
||||
|
||||
func (c *launcherClient) recommendations(ctx context.Context) []ModelItem {
|
||||
if c.recommendationsLoaded {
|
||||
return append([]ModelItem(nil), c.recommendationItems...)
|
||||
}
|
||||
|
||||
recommendations, err := c.requestRecommendations(ctx)
|
||||
if err != nil || len(recommendations) == 0 {
|
||||
// Fail open: recommendation issues should not block launch flows.
|
||||
// Fall back to built-in recommendations until server data is available.
|
||||
fallback := append([]ModelItem(nil), recommendedModels...)
|
||||
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(fallback))
|
||||
return fallback
|
||||
c.recommendationItems = fallback
|
||||
c.recommendationsLoaded = true
|
||||
return append([]ModelItem(nil), fallback...)
|
||||
}
|
||||
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(recommendations))
|
||||
return recommendations
|
||||
c.recommendationItems = recommendations
|
||||
c.recommendationsLoaded = true
|
||||
return append([]ModelItem(nil), recommendations...)
|
||||
}
|
||||
|
||||
func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelItem, error) {
|
||||
|
|
@ -1076,6 +1156,7 @@ func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelIte
|
|||
VRAMBytes: rec.VRAMBytes,
|
||||
ContextLength: rec.ContextLength,
|
||||
MaxOutputTokens: rec.MaxOutputTokens,
|
||||
RequiredPlan: strings.TrimSpace(rec.RequiredPlan),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -1093,6 +1174,9 @@ func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string)
|
|||
isCloudModel := isCloudModelName(model)
|
||||
if isCloudModel {
|
||||
cloudModels[model] = true
|
||||
if err := c.ensureCloudModelAccess(ctx, model); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||
return err
|
||||
|
|
@ -1126,6 +1210,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected
|
|||
|
||||
for _, model := range selected {
|
||||
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
|
||||
if errors.Is(err, errUpgradeCancelled) {
|
||||
return nil, nil, err
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
@ -1142,6 +1229,9 @@ func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected
|
|||
}
|
||||
|
||||
func skippedModelReason(model string, err error) string {
|
||||
if errors.Is(err, errUpgradeCancelled) {
|
||||
return "upgrade was cancelled"
|
||||
}
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
if isCloudModelName(model) {
|
||||
return "sign in was cancelled"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
|
|
@ -192,14 +193,20 @@ func withInteractiveSession(t *testing.T, interactive bool) {
|
|||
func withLauncherHooks(t *testing.T) {
|
||||
t.Helper()
|
||||
oldSingle := DefaultSingleSelector
|
||||
oldSingleWithUpdates := DefaultSingleSelectorWithUpdates
|
||||
oldMulti := DefaultMultiSelector
|
||||
oldMultiWithUpdates := DefaultMultiSelectorWithUpdates
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
oldSignIn := DefaultSignIn
|
||||
oldUpgrade := DefaultUpgrade
|
||||
t.Cleanup(func() {
|
||||
DefaultSingleSelector = oldSingle
|
||||
DefaultSingleSelectorWithUpdates = oldSingleWithUpdates
|
||||
DefaultMultiSelector = oldMulti
|
||||
DefaultMultiSelectorWithUpdates = oldMultiWithUpdates
|
||||
DefaultConfirmPrompt = oldConfirm
|
||||
DefaultSignIn = oldSignIn
|
||||
DefaultUpgrade = oldUpgrade
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -346,7 +353,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *
|
|||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "gemma4", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -501,7 +508,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(
|
|||
runner := &launcherManagedRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when saved model matches target")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -553,7 +560,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *t
|
|||
runner := &launcherManagedRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when model override is provided")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -607,7 +614,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenLiveConfigDrifts(
|
|||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when live config already provides the target")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -734,7 +741,7 @@ func TestLaunchIntegration_ManagedAutodiscoverySkipsModelPicker(t *testing.T) {
|
|||
runner := &launcherManagedAutodiscoveryRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("model selector should not run for autodiscovery integrations")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -987,7 +994,7 @@ func TestLaunchIntegration_CloudAutodiscoveryUsesSignInHook(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *testing.T) {
|
||||
func TestBuildLauncherIntegrationState_CloudAutodiscoveryDoesNotCheckSignIn(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
|
@ -1004,8 +1011,7 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes
|
|||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
t.Fatal("build launcher state should not check whoami")
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
|
@ -1028,8 +1034,8 @@ func TestBuildLauncherIntegrationState_CloudAutodiscoveryRequiresSignedIn(t *tes
|
|||
if state.CurrentModel != "Ollama Cloud" {
|
||||
t.Fatalf("current model = %q, want Ollama Cloud", state.CurrentModel)
|
||||
}
|
||||
if state.ModelUsable {
|
||||
t.Fatal("expected cloud autodiscovery config to be unusable while signed out")
|
||||
if !state.ModelUsable {
|
||||
t.Fatal("expected cloud autodiscovery config to stay usable until launch-time auth check")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1296,7 +1302,7 @@ func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) {
|
|||
}
|
||||
|
||||
selectorCalled := false
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalled = true
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -1338,7 +1344,7 @@ func TestResolveRunModel_HeadlessYesAutoPicksLastModel(t *testing.T) {
|
|||
t.Fatalf("failed to save last model: %v", err)
|
||||
}
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called in headless --yes mode")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -1405,7 +1411,7 @@ func TestResolveRunModel_UsesRequestPolicy(t *testing.T) {
|
|||
t.Fatalf("failed to save last model: %v", err)
|
||||
}
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when request policy enables headless auto-pick")
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -1465,7 +1471,7 @@ func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) {
|
|||
}
|
||||
|
||||
var selectorCalls int
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
if current != "llama3.2" {
|
||||
t.Fatalf("expected current selection to be last model, got %q", current)
|
||||
|
|
@ -1513,7 +1519,7 @@ func TestResolveRunModel_ForcePicker_DoesNotReorderByLastModel(t *testing.T) {
|
|||
}
|
||||
|
||||
var gotNames []string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
if current != "qwen3.5" {
|
||||
t.Fatalf("expected current selection to be last model, got %q", current)
|
||||
}
|
||||
|
|
@ -1564,7 +1570,7 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
|
|||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "glm-5:cloud", nil
|
||||
}
|
||||
|
||||
|
|
@ -1610,6 +1616,241 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestResolveRunModel_MetadataSignedOutUsesSignInHook(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "qwen3.5:cloud", nil
|
||||
}
|
||||
|
||||
signedIn := false
|
||||
signInCalled := false
|
||||
DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
signInCalled = true
|
||||
signedIn = true
|
||||
if modelName != "qwen3.5:cloud" {
|
||||
t.Fatalf("unexpected model passed to sign-in: %q", modelName)
|
||||
}
|
||||
if signInURL != "https://example.com/signin" {
|
||||
t.Fatalf("unexpected sign-in URL: %q", signInURL)
|
||||
}
|
||||
return "test-user", nil
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[{"model":"qwen3.5:cloud","description":"Reasoning","context_length":262144,"max_output_tokens":32768}]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"remote_model":"qwen3.5"}`)
|
||||
case "/api/me":
|
||||
if !signedIn {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
return
|
||||
}
|
||||
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveRunModel returned error: %v", err)
|
||||
}
|
||||
if model != "qwen3.5:cloud" {
|
||||
t.Fatalf("expected selected cloud model, got %q", model)
|
||||
}
|
||||
if !signInCalled {
|
||||
t.Fatal("expected sign-in hook to be used for account-gated cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRunModel_SubscriptionModelUsesUpgradeHook(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
DefaultSingleSelectorWithUpdates = func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error) {
|
||||
for _, item := range items {
|
||||
if item.Name == "kimi-k2.6:cloud" && item.AvailabilityBadge != "" {
|
||||
t.Fatalf("initial availability badge = %q, want empty before account update", item.AvailabilityBadge)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case items = <-updates:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for selector item update")
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Name == "kimi-k2.6:cloud" {
|
||||
if item.AvailabilityBadge != "Upgrade required" {
|
||||
t.Fatalf("availability badge = %q, want Upgrade required", item.AvailabilityBadge)
|
||||
}
|
||||
return "kimi-k2.6:cloud", nil
|
||||
}
|
||||
}
|
||||
t.Fatalf("paid cloud model missing from selector items: %#v", items)
|
||||
return "kimi-k2.6:cloud", nil
|
||||
}
|
||||
DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
t.Fatalf("did not expect sign-in hook for signed-in user")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
plan := "free"
|
||||
upgradeCalled := false
|
||||
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
|
||||
upgradeCalled = true
|
||||
if modelName != "kimi-k2.6:cloud" {
|
||||
t.Fatalf("unexpected model passed to upgrade: %q", modelName)
|
||||
}
|
||||
if requiredPlan != "pro" {
|
||||
t.Fatalf("unexpected required plan: %q", requiredPlan)
|
||||
}
|
||||
plan = "max"
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"remote_model":"kimi-k2.6"}`)
|
||||
case "/api/me":
|
||||
fmt.Fprintf(w, `{"name":"test-user","plan":%q}`, plan)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveRunModel returned error: %v", err)
|
||||
}
|
||||
if model != "kimi-k2.6:cloud" {
|
||||
t.Fatalf("expected selected cloud model, got %q", model)
|
||||
}
|
||||
if !upgradeCalled {
|
||||
t.Fatal("expected upgrade hook to be used for subscription-gated cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRunModel_UpgradeCancelledReturnsToModelSelector(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
selectorCalls := 0
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
switch selectorCalls {
|
||||
case 1:
|
||||
return "kimi-k2.6:cloud", nil
|
||||
case 2:
|
||||
return "llama3.2", nil
|
||||
default:
|
||||
t.Fatalf("selector called too many times: %d", selectorCalls)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/show":
|
||||
var req apiShowRequest
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
fmt.Fprintf(w, `{"model":%q}`, req.Model)
|
||||
case "/api/me":
|
||||
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveRunModel returned error: %v", err)
|
||||
}
|
||||
if model != "llama3.2" {
|
||||
t.Fatalf("model = %q, want llama3.2", model)
|
||||
}
|
||||
if selectorCalls != 2 {
|
||||
t.Fatalf("selector calls = %d, want 2", selectorCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRunModel_SubscriptionModelUnavailableWhoamiFailsGracefully(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "kimi-k2.6:cloud", nil
|
||||
}
|
||||
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
|
||||
t.Fatalf("did not expect upgrade hook when plan could not be verified")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprint(w, `{"error":"temporary failure"}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
_, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true})
|
||||
if err == nil {
|
||||
t.Fatal("expected plan verification error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Could not verify your plan. Try again in a moment.") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
|
@ -1623,7 +1864,7 @@ func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
|
|||
withIntegrationOverride(t, "droid", editor)
|
||||
|
||||
var multiCalled bool
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
multiCalled = true
|
||||
return []string{"llama3.2", "qwen3:8b"}, nil
|
||||
}
|
||||
|
|
@ -1688,7 +1929,7 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t
|
|||
|
||||
var gotItems []string
|
||||
var gotPreChecked []string
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
for _, item := range items {
|
||||
gotItems = append(gotItems, item.Name)
|
||||
}
|
||||
|
|
@ -1809,7 +2050,7 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T)
|
|||
}
|
||||
|
||||
var multiCalled bool
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
multiCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
}
|
||||
|
|
@ -1851,7 +2092,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccep
|
|||
editor := &launcherEditorRunner{}
|
||||
withIntegrationOverride(t, "droid", editor)
|
||||
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
return []string{"glm-5:cloud", "missing-local"}, nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -1932,7 +2173,7 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
|
|||
editor := &launcherEditorRunner{}
|
||||
withIntegrationOverride(t, "droid", editor)
|
||||
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
return []string{"llama3.2", "glm-5:cloud"}, nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -2001,6 +2242,84 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
|
|||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_EditorConfigureUpgradeCancelledReturnsToModelSelector(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
|
||||
binDir := t.TempDir()
|
||||
writeFakeBinary(t, binDir, "droid")
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
editor := &launcherEditorRunner{}
|
||||
withIntegrationOverride(t, "droid", editor)
|
||||
|
||||
selectorCalls := 0
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
selectorCalls++
|
||||
switch selectorCalls {
|
||||
case 1:
|
||||
return []string{"kimi-k2.6:cloud"}, nil
|
||||
case 2:
|
||||
if diff := compareStrings(preChecked, []string{"kimi-k2.6:cloud"}); diff != "" {
|
||||
t.Fatalf("second selector preChecked (-want +got):\n%s", diff)
|
||||
}
|
||||
return []string{"llama3.2"}, nil
|
||||
default:
|
||||
t.Fatalf("selector called too many times: %d", selectorCalls)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
DefaultUpgrade = func(modelName, requiredPlan string) (string, error) {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
if prompt == "Proceed?" {
|
||||
return true, nil
|
||||
}
|
||||
t.Fatalf("unexpected prompt: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[{"model":"kimi-k2.6:cloud","description":"Coding","context_length":262144,"max_output_tokens":262144,"required_plan":"pro"}]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
case "/api/show":
|
||||
var req apiShowRequest
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
fmt.Fprintf(w, `{"model":%q}`, req.Model)
|
||||
case "/api/me":
|
||||
fmt.Fprint(w, `{"name":"test-user","plan":"free"}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "droid",
|
||||
ForceConfigure: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
if selectorCalls != 2 {
|
||||
t.Fatalf("selector calls = %d, want 2", selectorCalls)
|
||||
}
|
||||
if editor.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to use local model, got %q", editor.ranModel)
|
||||
}
|
||||
if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" {
|
||||
t.Fatalf("unexpected edited models (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
|
@ -2016,7 +2335,7 @@ func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *
|
|||
if err := config.SaveIntegration("droid", []string{"glm-5:cloud", "llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed config: %v", err)
|
||||
}
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
return append([]string(nil), preChecked...), nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -2102,7 +2421,7 @@ func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsL
|
|||
t.Fatalf("failed to seed config: %v", err)
|
||||
}
|
||||
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
return []string{"missing-local-a", "missing-local-b"}, nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -2342,7 +2661,7 @@ func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T)
|
|||
withIntegrationOverride(t, "openclaw", editor)
|
||||
|
||||
selectorCalled := false
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
selectorCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
}
|
||||
|
|
@ -2376,7 +2695,7 @@ func TestLaunchIntegration_PiInstallsBeforeConfigSideEffects(t *testing.T) {
|
|||
withIntegrationOverride(t, "pi", editor)
|
||||
|
||||
selectorCalled := false
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
selectorCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
}
|
||||
|
|
@ -2408,7 +2727,7 @@ func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing
|
|||
editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}}
|
||||
withIntegrationOverride(t, "droid", editor)
|
||||
|
||||
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
DefaultMultiSelector = func(title string, items []SelectionItem, preChecked []string) ([]string, error) {
|
||||
return []string{"llama3.2"}, nil
|
||||
}
|
||||
|
||||
|
|
@ -2519,7 +2838,7 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) {
|
|||
}
|
||||
|
||||
var selectorCalls int
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
return "glm-5:cloud", nil
|
||||
}
|
||||
|
|
@ -2572,7 +2891,7 @@ func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *te
|
|||
t.Fatalf("failed to seed config: %v", err)
|
||||
}
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "missing-model", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
|
|
@ -2633,7 +2952,7 @@ func TestLaunchIntegration_ClaudeModelOverrideSkipsSelector(t *testing.T) {
|
|||
t.Setenv("PATH", binDir)
|
||||
|
||||
var selectorCalls int
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
return "", fmt.Errorf("selector should not run when --model override is set")
|
||||
}
|
||||
|
|
@ -2698,7 +3017,7 @@ func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) {
|
|||
runner := &launcherSingleRunner{}
|
||||
withIntegrationOverride(t, "stubsingle", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "llama3.2", nil
|
||||
}
|
||||
|
||||
|
|
@ -2926,7 +3245,7 @@ func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T)
|
|||
runner := &launcherSingleRunner{}
|
||||
withIntegrationOverride(t, "droid", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
return "missing-model", nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string)
|
|||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
user, err := whoamiWithTimeout(ctx, client)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -243,7 +243,7 @@ func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string)
|
|||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
u, err := whoamiWithTimeout(ctx, client)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
|
|
@ -348,9 +348,11 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
|
|||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
recByName := make(map[string]ModelItem)
|
||||
for _, rec := range recommendations {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
recByName[rec.Name] = rec
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
|
|
@ -364,6 +366,9 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
|
|||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||
if rec, ok := recByName[displayName]; ok {
|
||||
item = copyModelRecommendationFields(displayName, rec)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
|
|
@ -472,6 +477,12 @@ func buildModelListWithRecommendations(existing []modelInfo, recommendations []M
|
|||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func copyModelRecommendationFields(name string, rec ModelItem) ModelItem {
|
||||
rec.Name = name
|
||||
rec.Recommended = true
|
||||
return rec
|
||||
}
|
||||
|
||||
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||
func isCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
|
|
|
|||
|
|
@ -35,22 +35,38 @@ type ConfirmOptions struct {
|
|||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
type SingleSelector func(title string, items []SelectionItem, current string) (string, error)
|
||||
|
||||
// SingleSelectorWithUpdates is a single item selector that can receive refreshed item state while open.
|
||||
type SingleSelectorWithUpdates func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
type MultiSelector func(title string, items []SelectionItem, preChecked []string) ([]string, error)
|
||||
|
||||
// MultiSelectorWithUpdates is a multi item selector that can receive refreshed item state while open.
|
||||
type MultiSelectorWithUpdates func(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error)
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultSingleSelectorWithUpdates is the default single-select implementation with live updates.
|
||||
var DefaultSingleSelectorWithUpdates SingleSelectorWithUpdates
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultMultiSelectorWithUpdates is the default multi-select implementation with live updates.
|
||||
var DefaultMultiSelectorWithUpdates MultiSelectorWithUpdates
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, ensureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
// DefaultUpgrade provides a TUI-based upgrade flow.
|
||||
// Returns the updated plan or an error.
|
||||
var DefaultUpgrade func(modelName, requiredPlan string) (string, error)
|
||||
|
||||
type launchConfirmPolicy struct {
|
||||
yes bool
|
||||
requireYesMessage bool
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@ var (
|
|||
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
|
||||
|
|
@ -58,16 +57,39 @@ const maxSelectorItems = 10
|
|||
var ErrCancelled = launch.ErrCancelled
|
||||
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
AvailabilityBadge string
|
||||
}
|
||||
|
||||
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||
type selectorItemsUpdatedMsg struct {
|
||||
items []SelectItem
|
||||
}
|
||||
|
||||
func waitForSelectorItems(updates <-chan []SelectItem) tea.Cmd {
|
||||
if updates == nil {
|
||||
return nil
|
||||
}
|
||||
return func() tea.Msg {
|
||||
items, ok := <-updates
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return selectorItemsUpdatedMsg{items: items}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertItems converts launch.SelectionItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.SelectionItem) []SelectItem {
|
||||
out := make([]SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
out[i] = SelectItem{
|
||||
Name: item.Name,
|
||||
Description: item.Description,
|
||||
Recommended: item.Recommended,
|
||||
AvailabilityBadge: item.AvailabilityBadge,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
@ -91,6 +113,7 @@ func ReorderItems(items []SelectItem) []SelectItem {
|
|||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
updates <-chan []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
|
|
@ -110,6 +133,33 @@ func selectorModelWithCurrent(title string, items []SelectItem, current string)
|
|||
return m
|
||||
}
|
||||
|
||||
func currentItemName(items []SelectItem, cursor int) string {
|
||||
if cursor < 0 || cursor >= len(items) {
|
||||
return ""
|
||||
}
|
||||
return items[cursor].Name
|
||||
}
|
||||
|
||||
func cursorForItemName(items []SelectItem, name string, fallback int) int {
|
||||
if len(items) == 0 {
|
||||
return 0
|
||||
}
|
||||
if name != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == name {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
if fallback < 0 {
|
||||
return 0
|
||||
}
|
||||
if fallback >= len(items) {
|
||||
return len(items) - 1
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
|
|
@ -125,7 +175,7 @@ func (m selectorModel) filteredItems() []SelectItem {
|
|||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
return waitForSelectorItems(m.updates)
|
||||
}
|
||||
|
||||
// otherStart returns the index of the first non-recommended item in the filtered list.
|
||||
|
|
@ -235,6 +285,13 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
}
|
||||
return m, nil
|
||||
|
||||
case selectorItemsUpdatedMsg:
|
||||
current := currentItemName(m.filteredItems(), m.cursor)
|
||||
m.items = msg.items
|
||||
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
|
||||
m.updateScroll(m.otherStart())
|
||||
return m, waitForSelectorItems(m.updates)
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
|
|
@ -260,9 +317,17 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
func cursorItemSuffix(item SelectItem) string {
|
||||
if item.AvailabilityBadge == "" {
|
||||
return ""
|
||||
}
|
||||
return " " + selectorDefaultTagStyle.Render("("+item.AvailabilityBadge+")")
|
||||
}
|
||||
|
||||
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
|
@ -402,11 +467,16 @@ func cursorForCurrent(items []SelectItem, current string) int {
|
|||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
return SelectSingleWithUpdates(title, items, current, nil)
|
||||
}
|
||||
|
||||
func SelectSingleWithUpdates(title string, items []SelectItem, current string, updates <-chan []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModelWithCurrent(title, items, current)
|
||||
m.updates = updates
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
|
|
@ -426,6 +496,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
|||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
updates <-chan []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
|
|
@ -475,6 +546,36 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
|||
return m
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) rebuildItemIndex() {
|
||||
m.itemIndex = make(map[string]int, len(m.items))
|
||||
for i, item := range m.items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) replaceItems(items []SelectItem) {
|
||||
current := currentItemName(m.filteredItems(), m.cursor)
|
||||
checkedNames := make([]string, 0, len(m.checkOrder))
|
||||
for _, idx := range m.checkOrder {
|
||||
if idx >= 0 && idx < len(m.items) {
|
||||
checkedNames = append(checkedNames, m.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
m.items = items
|
||||
m.rebuildItemIndex()
|
||||
m.checked = make(map[int]bool, len(checkedNames))
|
||||
m.checkOrder = nil
|
||||
for _, name := range checkedNames {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
|
|
@ -590,7 +691,7 @@ func (m multiSelectorModel) selectedCount() int {
|
|||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
return waitForSelectorItems(m.updates)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
|
@ -603,6 +704,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
}
|
||||
return m, nil
|
||||
|
||||
case selectorItemsUpdatedMsg:
|
||||
m.replaceItems(msg.items)
|
||||
return m, waitForSelectorItems(m.updates)
|
||||
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
|
|
@ -689,6 +794,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
|
@ -716,6 +822,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
|||
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(check + item.Name))
|
||||
}
|
||||
|
|
@ -841,11 +948,16 @@ func (m multiSelectorModel) View() string {
|
|||
}
|
||||
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
return SelectMultipleWithUpdates(title, items, preChecked, nil)
|
||||
}
|
||||
|
||||
func SelectMultipleWithUpdates(title string, items []SelectItem, preChecked []string, updates <-chan []SelectItem) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
m.updates = updates
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
|
|
|
|||
|
|
@ -311,6 +311,91 @@ func TestRenderContent_SelectedItemIndicator(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_AvailabilityBadgeOnlyOnCursor(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"},
|
||||
{Name: "qwen3.5:cloud", AvailabilityBadge: "Sign in required"},
|
||||
{Name: "glm-5:cloud", AvailabilityBadge: "Included"},
|
||||
},
|
||||
cursor: 0,
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "(Upgrade required)") {
|
||||
t.Fatalf("cursor badge missing:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "(Sign in required)") {
|
||||
t.Fatalf("non-cursor badge should not render:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "Included") {
|
||||
t.Fatalf("included badge should not render:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorModel_ItemsUpdatedPreservesCursorAndRendersBadge(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", Recommended: true},
|
||||
{Name: "llama3.2"},
|
||||
}, "kimi-k2.6:cloud")
|
||||
|
||||
updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"},
|
||||
{Name: "llama3.2"},
|
||||
}})
|
||||
fm := updated.(selectorModel)
|
||||
if fm.cursor != 0 {
|
||||
t.Fatalf("cursor = %d, want 0", fm.cursor)
|
||||
}
|
||||
content := fm.renderContent()
|
||||
if !strings.Contains(content, "(Upgrade required)") {
|
||||
t.Fatalf("updated badge missing:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSelector_AvailabilityBadgePreservesDefaultSuffix(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", AvailabilityBadge: "Upgrade required"},
|
||||
{Name: "qwen3.5:cloud"},
|
||||
}, []string{"kimi-k2.6:cloud"})
|
||||
m.multi = true
|
||||
m.cursor = 0
|
||||
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(Upgrade required)") {
|
||||
t.Fatalf("cursor badge missing:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatalf("default suffix missing:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSelector_ItemsUpdatedPreservesCheckedStateAndRendersBadge(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", Recommended: true},
|
||||
{Name: "llama3.2"},
|
||||
}, []string{"kimi-k2.6:cloud"})
|
||||
m.multi = true
|
||||
|
||||
updated, _ := m.Update(selectorItemsUpdatedMsg{items: []SelectItem{
|
||||
{Name: "kimi-k2.6:cloud", Recommended: true, AvailabilityBadge: "Upgrade required"},
|
||||
{Name: "llama3.2"},
|
||||
}})
|
||||
fm := updated.(multiSelectorModel)
|
||||
idx := fm.itemIndex["kimi-k2.6:cloud"]
|
||||
if !fm.checked[idx] {
|
||||
t.Fatalf("checked state was not preserved: %#v", fm.checked)
|
||||
}
|
||||
content := fm.View()
|
||||
if !strings.Contains(content, "(Upgrade required)") {
|
||||
t.Fatalf("updated badge missing:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatalf("default suffix missing after update:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_Description(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,14 @@ type signInCheckMsg struct {
|
|||
userName string
|
||||
}
|
||||
|
||||
type upgradeTickMsg struct{}
|
||||
|
||||
type upgradeCheckMsg struct {
|
||||
upgraded bool
|
||||
plan string
|
||||
err error
|
||||
}
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
|
|
@ -28,6 +36,18 @@ type signInModel struct {
|
|||
cancelled bool
|
||||
}
|
||||
|
||||
type upgradeModel struct {
|
||||
modelName string
|
||||
requiredPlan string
|
||||
spinner int
|
||||
width int
|
||||
openNow bool
|
||||
polling bool
|
||||
plan string
|
||||
cancelled bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m signInModel) Init() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
|
|
@ -82,6 +102,85 @@ func (m signInModel) View() string {
|
|||
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
|
||||
}
|
||||
|
||||
func (m upgradeModel) Init() tea.Cmd {
|
||||
if m.polling {
|
||||
return upgradeTickCmd()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m upgradeModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
case tea.KeyLeft:
|
||||
if !m.polling {
|
||||
m.openNow = true
|
||||
}
|
||||
case tea.KeyRight:
|
||||
if !m.polling {
|
||||
m.openNow = false
|
||||
}
|
||||
case tea.KeyEnter:
|
||||
if !m.polling {
|
||||
if !m.openNow {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
launch.OpenBrowser(launch.DefaultUpgradeURL)
|
||||
m.polling = true
|
||||
return m, upgradeTickCmd()
|
||||
}
|
||||
}
|
||||
|
||||
case upgradeTickMsg:
|
||||
if !m.polling {
|
||||
return m, nil
|
||||
}
|
||||
m.spinner++
|
||||
if m.spinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
upgradeTickCmd(),
|
||||
checkUpgrade(m.requiredPlan),
|
||||
)
|
||||
}
|
||||
return m, upgradeTickCmd()
|
||||
|
||||
case upgradeCheckMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
return m, tea.Quit
|
||||
}
|
||||
if msg.upgraded {
|
||||
m.plan = msg.plan
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m upgradeModel) View() string {
|
||||
if m.plan != "" {
|
||||
return ""
|
||||
}
|
||||
if m.err != nil {
|
||||
return ""
|
||||
}
|
||||
return renderUpgrade(m.modelName, m.spinner, m.width, m.polling, m.openNow)
|
||||
}
|
||||
|
||||
func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
|
@ -110,18 +209,88 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
|||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func upgradeTickCmd() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return upgradeTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func renderUpgrade(modelName string, spinner, width int, polling, openNow bool) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
||||
urlColor := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("117"))
|
||||
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
|
||||
if width > 4 {
|
||||
urlWrap = urlWrap.Width(width - 4)
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, upgrade your Ollama plan.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(urlColor.Render(launch.DefaultUpgradeURL)))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
if !polling {
|
||||
var yesBtn, noBtn string
|
||||
if openNow {
|
||||
yesBtn = confirmActiveStyle.Render(" Yes ")
|
||||
noBtn = confirmInactiveStyle.Render(" No ")
|
||||
} else {
|
||||
yesBtn = confirmInactiveStyle.Render(" Yes ")
|
||||
noBtn = confirmActiveStyle.Render(" No ")
|
||||
}
|
||||
|
||||
s.WriteString("Open now?\n")
|
||||
s.WriteString(" " + yesBtn + " " + noBtn)
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
frame + " Waiting for upgrade to complete..."))
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func checkUpgrade(requiredPlan string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
user, err := client.Whoami(ctx)
|
||||
if err != nil {
|
||||
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
|
||||
}
|
||||
if err == nil && user != nil && user.Name != "" && launch.PlanSatisfies(user.Plan, requiredPlan) {
|
||||
return upgradeCheckMsg{upgraded: true, plan: user.Plan}
|
||||
}
|
||||
return upgradeCheckMsg{upgraded: false}
|
||||
}
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
launch.OpenBrowser(signInURL)
|
||||
|
|
@ -144,3 +313,28 @@ func RunSignIn(modelName, signInURL string) (string, error) {
|
|||
|
||||
return fm.userName, nil
|
||||
}
|
||||
|
||||
// RunUpgrade shows a bubbletea upgrade dialog and polls until the user's plan is updated or cancelled.
|
||||
func RunUpgrade(modelName, requiredPlan string) (string, error) {
|
||||
m := upgradeModel{
|
||||
modelName: modelName,
|
||||
requiredPlan: requiredPlan,
|
||||
openNow: true,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running upgrade: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(upgradeModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
if fm.err != nil {
|
||||
return "", fm.err
|
||||
}
|
||||
|
||||
return fm.plan, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func TestRenderSignIn_ContainsModelName(t *testing.T) {
|
||||
|
|
@ -50,6 +51,35 @@ func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRenderUpgrade_AsksBeforeOpening(t *testing.T) {
|
||||
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, false, true)
|
||||
if !strings.Contains(got, "kimi-k2.6:cloud") {
|
||||
t.Error("should contain model name")
|
||||
}
|
||||
if !strings.Contains(got, launch.DefaultUpgradeURL) {
|
||||
t.Error("should contain upgrade URL")
|
||||
}
|
||||
if !strings.Contains(got, "Open now?") {
|
||||
t.Error("should ask before opening")
|
||||
}
|
||||
if !strings.Contains(got, "Yes") || !strings.Contains(got, "No") {
|
||||
t.Error("should show yes/no selector")
|
||||
}
|
||||
if strings.Contains(got, "Waiting for upgrade to complete") {
|
||||
t.Error("should not start waiting before open choice is confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderUpgrade_PollingShowsWaiting(t *testing.T) {
|
||||
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, true, true)
|
||||
if !strings.Contains(got, "Waiting for upgrade to complete") {
|
||||
t.Error("should contain waiting message")
|
||||
}
|
||||
if strings.Contains(got, "Open now?") {
|
||||
t.Error("should not show open prompt while polling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_EscCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
|
|
@ -66,6 +96,35 @@ func TestSignInModel_EscCancels(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUpgradeModel_NoCancelsWithoutPolling(t *testing.T) {
|
||||
m := upgradeModel{
|
||||
modelName: "kimi-k2.6:cloud",
|
||||
requiredPlan: "pro",
|
||||
openNow: true,
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
fm := updated.(upgradeModel)
|
||||
if fm.openNow {
|
||||
t.Error("right should select no")
|
||||
}
|
||||
if fm.polling {
|
||||
t.Error("right should not start polling")
|
||||
}
|
||||
|
||||
updated, cmd := fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm = updated.(upgradeModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("enter on no should cancel")
|
||||
}
|
||||
if fm.polling {
|
||||
t.Error("enter on no should not start polling")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter on no should quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_CtrlCCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ import (
|
|||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
|
||||
)
|
||||
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
|
||||
|
||||
var (
|
||||
modelRecommendationsRefreshInterval = 4 * time.Hour
|
||||
|
|
@ -323,6 +321,7 @@ func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRe
|
|||
for _, rec := range recs {
|
||||
rec.Model = strings.TrimSpace(rec.Model)
|
||||
rec.Description = strings.TrimSpace(rec.Description)
|
||||
rec.RequiredPlan = strings.TrimSpace(rec.RequiredPlan)
|
||||
|
||||
if rec.Model == "" {
|
||||
return nil, errors.New("recommendation missing model")
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
|
|||
|
||||
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
|
||||
input := []api.ModelRecommendation{
|
||||
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256},
|
||||
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: " pro "},
|
||||
{Model: "bad-cloud:cloud", Description: "missing limits"},
|
||||
{Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte},
|
||||
}
|
||||
|
|
@ -266,7 +266,7 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing
|
|||
}
|
||||
|
||||
want := []api.ModelRecommendation{
|
||||
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256},
|
||||
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: "pro"},
|
||||
{Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte},
|
||||
}
|
||||
if !slices.Equal(got, want) {
|
||||
|
|
@ -274,6 +274,38 @@ func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateModelRecommendationsDoesNotSynthesizeRequiredPlans(t *testing.T) {
|
||||
input := []api.ModelRecommendation{
|
||||
{Model: "kimi-k2.6:cloud", Description: "coding", ContextLength: 262_144, MaxOutputTokens: 262_144},
|
||||
{Model: "qwen3.5:cloud", Description: "reasoning", ContextLength: 262_144, MaxOutputTokens: 32_768},
|
||||
{Model: "custom:cloud", Description: "custom", ContextLength: 4096, MaxOutputTokens: 1024},
|
||||
{Model: "minimax-m2.7:cloud", Description: "custom", ContextLength: 204_800, MaxOutputTokens: 128_000, RequiredPlan: "team"},
|
||||
}
|
||||
|
||||
got, err := validateModelRecommendations(input)
|
||||
if err != nil {
|
||||
t.Fatalf("validateModelRecommendations failed: %v", err)
|
||||
}
|
||||
|
||||
byName := make(map[string]api.ModelRecommendation, len(got))
|
||||
for _, rec := range got {
|
||||
byName[rec.Model] = rec
|
||||
}
|
||||
|
||||
if rec := byName["kimi-k2.6:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("kimi required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["qwen3.5:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("qwen required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["custom:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("custom required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["minimax-m2.7:cloud"]; rec.RequiredPlan != "team" {
|
||||
t.Fatalf("explicit required plan should not be overwritten: %#v", rec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
|
|||
|
|
@ -2044,11 +2044,30 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
|
|||
client := api.NewClient(u, http.DefaultClient)
|
||||
user, err := client.Whoami(c)
|
||||
if err != nil {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
// Preserve an actionable sign-in response for launch; other failures
|
||||
// below mean account or plan verification is temporarily unavailable.
|
||||
sURL := authErr.SigninURL
|
||||
if sURL == "" {
|
||||
var sErr error
|
||||
sURL, sErr = signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Error(err.Error())
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "account unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
// user isn't signed in
|
||||
if user != nil && user.Name == "" {
|
||||
if user == nil || user.Name == "" {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
|
|
@ -2060,6 +2079,10 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(user.Plan) == "" {
|
||||
slog.Warn("account plan was not set; defaulting to free")
|
||||
user.Plan = "free"
|
||||
}
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue