diff --git a/internal/config/config.go b/internal/config/config.go index a34568a7ec081441b088c1680c09a14336d81bd6..381a7ab384f3f78fbd2504db60b51b01e216c7db 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -289,6 +289,8 @@ type Config struct { // We currently only support large/small as values here. Models map[SelectedModelType]SelectedModel `json:"models,omitempty" jsonschema:"description=Model configurations for different model types,example={\"large\":{\"model\":\"gpt-4o\",\"provider\":\"openai\"}}"` + // Recently used models stored in the data directory config. + RecentModels map[SelectedModelType][]SelectedModel `json:"recent_models,omitempty" jsonschema:"description=Recently used models sorted by most recent first"` // The providers that are configured Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"` @@ -398,6 +400,9 @@ func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model Selecte if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { return fmt.Errorf("failed to update preferred model: %w", err) } + if err := c.recordRecentModel(modelType, model); err != nil { + return err + } return nil } @@ -465,6 +470,49 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return nil } +const maxRecentModelsPerType = 5 + +func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { + if model.Provider == "" || model.Model == "" { + return nil + } + + if c.RecentModels == nil { + c.RecentModels = make(map[SelectedModelType][]SelectedModel) + } + + eq := func(a, b SelectedModel) bool { + return a.Provider == b.Provider && a.Model == b.Model + } + + entry := SelectedModel{ + Provider: model.Provider, + Model: model.Model, + } + + current := c.RecentModels[modelType] + withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { + return eq(existing, entry) + }) + + updated := append([]SelectedModel{entry}, withoutCurrent...) + if len(updated) > maxRecentModelsPerType { + updated = updated[:maxRecentModelsPerType] + } + + if slices.EqualFunc(current, updated, eq) { + return nil + } + + c.RecentModels[modelType] = updated + + if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { + return fmt.Errorf("failed to persist recent models: %w", err) + } + + return nil +} + func allToolNames() []string { return []string{ "agent", diff --git a/internal/config/load.go b/internal/config/load.go index cc7e54393857084b232a93f70d765090a2b513a8..7a7d3ae4da2d8970954461d4a1dc9b52a544636a 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -329,6 +329,9 @@ func (c *Config) setDefaults(workingDir, dataDir string) { if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) } + if c.RecentModels == nil { + c.RecentModels = make(map[SelectedModelType][]SelectedModel) + } if c.MCP == nil { c.MCP = make(map[string]MCPConfig) } diff --git a/internal/config/recent_models_test.go b/internal/config/recent_models_test.go new file mode 100644 index 0000000000000000000000000000000000000000..739ddc0031a65cab261723772c3f38658dcd1561 --- /dev/null +++ b/internal/config/recent_models_test.go @@ -0,0 +1,253 @@ +package config + +import ( + "encoding/json" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// readConfigJSON reads and unmarshals the JSON config file at path. +func readConfigJSON(t *testing.T, path string) map[string]any { + t.Helper() + baseDir := filepath.Dir(path) + fileName := filepath.Base(path) + b, err := fs.ReadFile(os.DirFS(baseDir), fileName) + require.NoError(t, err) + var out map[string]any + require.NoError(t, json.Unmarshal(b, &out)) + return out +} + +// readRecentModels reads the recent_models section from the config file. +func readRecentModels(t *testing.T, path string) map[string]any { + t.Helper() + out := readConfigJSON(t, path) + rm, ok := out["recent_models"].(map[string]any) + require.True(t, ok) + return rm +} + +func TestRecordRecentModel_AddsAndPersists(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) + require.NoError(t, err) + + // in-memory state + require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) + require.Equal(t, "openai", cfg.RecentModels[SelectedModelTypeLarge][0].Provider) + require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model) + + // persisted state + rm := readRecentModels(t, cfg.dataConfigDir) + large, ok := rm[string(SelectedModelTypeLarge)].([]any) + require.True(t, ok) + require.Len(t, large, 1) + item, ok := large[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "openai", item["provider"]) + require.Equal(t, "gpt-4o", item["model"]) +} + +func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + // Add two entries + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) + // Re-add first; should move to front and not duplicate + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + + got := cfg.RecentModels[SelectedModelTypeLarge] + require.Len(t, got, 2) + require.Equal(t, SelectedModel{Provider: "openai", Model: "gpt-4o"}, got[0]) + require.Equal(t, SelectedModel{Provider: "anthropic", Model: "claude"}, got[1]) +} + +func TestRecordRecentModel_TrimsToMax(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + // Insert 6 unique models; max is 5 + entries := []SelectedModel{ + {Provider: "p1", Model: "m1"}, + {Provider: "p2", Model: "m2"}, + {Provider: "p3", Model: "m3"}, + {Provider: "p4", Model: "m4"}, + {Provider: "p5", Model: "m5"}, + {Provider: "p6", Model: "m6"}, + } + for _, e := range entries { + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e)) + } + + // in-memory state + got := cfg.RecentModels[SelectedModelTypeLarge] + require.Len(t, got, 5) + // Newest first, capped at 5: p6..p2 + require.Equal(t, SelectedModel{Provider: "p6", Model: "m6"}, got[0]) + require.Equal(t, SelectedModel{Provider: "p5", Model: "m5"}, got[1]) + require.Equal(t, SelectedModel{Provider: "p4", Model: "m4"}, got[2]) + require.Equal(t, SelectedModel{Provider: "p3", Model: "m3"}, got[3]) + require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4]) + + // persisted state: verify trimmed to 5 and newest-first order + rm := readRecentModels(t, cfg.dataConfigDir) + large, ok := rm[string(SelectedModelTypeLarge)].([]any) + require.True(t, ok) + require.Len(t, large, 5) + // Build provider:model IDs and verify order + var ids []string + for _, v := range large { + m := v.(map[string]any) + ids = append(ids, m["provider"].(string)+":"+m["model"].(string)) + } + require.Equal(t, []string{"p6:m6", "p5:m5", "p4:m4", "p3:m3", "p2:m2"}, ids) +} + +func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + // Missing provider + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) + // Missing model + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) + + _, ok := cfg.RecentModels[SelectedModelTypeLarge] + // Map may be initialized, but should have no entries + if ok { + require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0) + } + // No file should be written (stat via fs.FS) + baseDir := filepath.Dir(cfg.dataConfigDir) + fileName := filepath.Base(cfg.dataConfigDir) + _, err := fs.Stat(os.DirFS(baseDir), fileName) + require.True(t, os.IsNotExist(err)) +} + +func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + entry := SelectedModel{Provider: "openai", Model: "gpt-4o"} + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + + baseDir := filepath.Dir(cfg.dataConfigDir) + fileName := filepath.Base(cfg.dataConfigDir) + before, err := fs.ReadFile(os.DirFS(baseDir), fileName) + require.NoError(t, err) + + // Get file ModTime to verify no write occurs + stBefore, err := fs.Stat(os.DirFS(baseDir), fileName) + require.NoError(t, err) + beforeMod := stBefore.ModTime() + + // Re-record same entry should be a no-op (no write) + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + + after, err := fs.ReadFile(os.DirFS(baseDir), fileName) + require.NoError(t, err) + require.Equal(t, string(before), string(after)) + + // Verify ModTime unchanged to ensure truly no write occurred + stAfter, err := fs.Stat(os.DirFS(baseDir), fileName) + require.NoError(t, err) + require.True(t, stAfter.ModTime().Equal(beforeMod), "file ModTime should not change on noop") +} + +func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + sel := SelectedModel{Provider: "openai", Model: "gpt-4o"} + require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel)) + + // in-memory + require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall]) + require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) + + // persisted (read via fs.FS) + rm := readRecentModels(t, cfg.dataConfigDir) + small, ok := rm[string(SelectedModelTypeSmall)].([]any) + require.True(t, ok) + require.Len(t, small, 1) +} + +func TestRecordRecentModel_TypeIsolation(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := &Config{} + cfg.setDefaults(dir, "") + cfg.dataConfigDir = filepath.Join(dir, "config.json") + + // Add models to both large and small types + largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"} + smallModel := SelectedModel{Provider: "anthropic", Model: "claude"} + + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel)) + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel)) + + // in-memory: verify types maintain separate histories + require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) + require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) + require.Equal(t, largeModel, cfg.RecentModels[SelectedModelTypeLarge][0]) + require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) + + // Add another to large, verify small unchanged + anotherLarge := SelectedModel{Provider: "google", Model: "gemini"} + require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) + + require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2) + require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) + require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) + + // persisted state: verify both types exist with correct lengths and contents + rm := readRecentModels(t, cfg.dataConfigDir) + + large, ok := rm[string(SelectedModelTypeLarge)].([]any) + require.True(t, ok) + require.Len(t, large, 2) + // Verify newest first for large type + require.Equal(t, "google", large[0].(map[string]any)["provider"]) + require.Equal(t, "gemini", large[0].(map[string]any)["model"]) + require.Equal(t, "openai", large[1].(map[string]any)["provider"]) + require.Equal(t, "gpt-4o", large[1].(map[string]any)["model"]) + + small, ok := rm[string(SelectedModelTypeSmall)].([]any) + require.True(t, ok) + require.Len(t, small, 1) + require.Equal(t, "anthropic", small[0].(map[string]any)["provider"]) + require.Equal(t, "claude", small[0].(map[string]any)["model"]) +} diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 87d333c4bd7e349b77cd2eff7e753743acde4296..2383f749de277e7fe915b57aac17fa0e7928756e 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -22,6 +22,13 @@ type ModelListComponent struct { providers []catwalk.Provider } +func modelKey(providerID, modelID string) string { + if providerID == "" || modelID == "" { + return "" + } + return providerID + ":" + modelID +} + func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent { t := styles.CurrentTheme() inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1) @@ -104,14 +111,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { var groups []list.Group[list.CompletionItem[ModelOption]] // first none section selectedItemID := "" + itemsByKey := make(map[string]list.CompletionItem[ModelOption]) cfg := config.Get() var currentModel config.SelectedModel + selectedType := config.SelectedModelTypeLarge if m.modelType == LargeModelType { currentModel = cfg.Models[config.SelectedModelTypeLarge] + selectedType = config.SelectedModelTypeLarge } else { currentModel = cfg.Models[config.SelectedModelTypeSmall] + selectedType = config.SelectedModelTypeSmall } + recentItems := cfg.RecentModels[selectedType] configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon) configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured")) @@ -169,14 +181,17 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { Section: section, } for _, model := range configProvider.Models { - item := list.NewCompletionItem(model.Name, ModelOption{ + modelOption := ModelOption{ Provider: configProvider, Model: model, - }, - list.WithCompletionID( - fmt.Sprintf("%s:%s", providerConfig.ID, model.ID), - ), + } + key := modelKey(string(configProvider.ID), model.ID) + item := list.NewCompletionItem( + model.Name, + modelOption, + list.WithCompletionID(key), ) + itemsByKey[key] = item group.Items = append(group.Items, item) if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider { @@ -239,14 +254,17 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { Section: section, } for _, model := range displayProvider.Models { - item := list.NewCompletionItem(model.Name, ModelOption{ + modelOption := ModelOption{ Provider: displayProvider, Model: model, - }, - list.WithCompletionID( - fmt.Sprintf("%s:%s", displayProvider.ID, model.ID), - ), + } + key := modelKey(string(displayProvider.ID), model.ID) + item := list.NewCompletionItem( + model.Name, + modelOption, + list.WithCompletionID(key), ) + itemsByKey[key] = item group.Items = append(group.Items, item) if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider { selectedItemID = item.ID() @@ -255,6 +273,48 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { groups = append(groups, group) } + if len(recentItems) > 0 { + recentSection := list.NewItemSection("Recently used") + recentGroup := list.Group[list.CompletionItem[ModelOption]]{ + Section: recentSection, + } + var validRecentItems []config.SelectedModel + for _, recent := range recentItems { + key := modelKey(recent.Provider, recent.Model) + option, ok := itemsByKey[key] + if !ok { + continue + } + validRecentItems = append(validRecentItems, recent) + recentID := fmt.Sprintf("recent::%s", key) + modelOption := option.Value() + providerName := modelOption.Provider.Name + if providerName == "" { + providerName = string(modelOption.Provider.ID) + } + item := list.NewCompletionItem( + modelOption.Model.Name, + option.Value(), + list.WithCompletionID(recentID), + list.WithCompletionShortcut(providerName), + ) + recentGroup.Items = append(recentGroup.Items, item) + if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider { + selectedItemID = recentID + } + } + + if len(validRecentItems) != len(recentItems) { + if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + return util.ReportError(err) + } + } + + if len(recentGroup.Items) > 0 { + groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...) + } + } + var cmds []tea.Cmd cmd := m.list.SetGroups(groups) diff --git a/internal/tui/components/dialogs/models/list_recent_test.go b/internal/tui/components/dialogs/models/list_recent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..249c958ec4a34621eabac5190541f8449668f97f --- /dev/null +++ b/internal/tui/components/dialogs/models/list_recent_test.go @@ -0,0 +1,369 @@ +package models + +import ( + "encoding/json" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/tui/exp/list" + "github.com/stretchr/testify/require" +) + +// execCmdML runs a tea.Cmd through the ModelListComponent's Update loop. +func execCmdML(t *testing.T, m *ModelListComponent, cmd tea.Cmd) { + t.Helper() + for cmd != nil { + msg := cmd() + var next tea.Cmd + _, next = m.Update(msg) + cmd = next + } +} + +// readConfigJSON reads and unmarshals the JSON config file at path. +func readConfigJSON(t *testing.T, path string) map[string]any { + t.Helper() + baseDir := filepath.Dir(path) + fileName := filepath.Base(path) + b, err := fs.ReadFile(os.DirFS(baseDir), fileName) + require.NoError(t, err) + var out map[string]any + require.NoError(t, json.Unmarshal(b, &out)) + return out +} + +// readRecentModels reads the recent_models section from the config file. +func readRecentModels(t *testing.T, path string) map[string]any { + t.Helper() + out := readConfigJSON(t, path) + rm, ok := out["recent_models"].(map[string]any) + require.True(t, ok) + return rm +} + +func TestModelList_RecentlyUsedSectionAndPrunesInvalid(t *testing.T) { + // Pre-initialize logger to os.DevNull to prevent file lock on Windows. + log.Setup(os.DevNull, false) + + // Isolate config/data paths + cfgDir := t.TempDir() + dataDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", cfgDir) + t.Setenv("XDG_DATA_HOME", dataDir) + + // Pre-seed config so provider auto-update is disabled and we have recents + confPath := filepath.Join(cfgDir, "crush", "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755)) + initial := map[string]any{ + "options": map[string]any{ + "disable_provider_auto_update": true, + }, + "models": map[string]any{ + "large": map[string]any{ + "model": "m1", + "provider": "p1", + }, + }, + "recent_models": map[string]any{ + "large": []any{ + map[string]any{"model": "m2", "provider": "p1"}, // valid + map[string]any{"model": "x", "provider": "unknown-provider"}, // invalid -> pruned + }, + }, + } + bts, err := json.Marshal(initial) + require.NoError(t, err) + require.NoError(t, os.WriteFile(confPath, bts, 0o644)) + + // Also create empty providers.json to prevent loading real providers + dataConfDir := filepath.Join(dataDir, "crush") + require.NoError(t, os.MkdirAll(dataConfDir, 0o755)) + emptyProviders := []byte("[]") + require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644)) + + // Initialize global config instance (no network due to auto-update disabled) + _, err = config.Init(cfgDir, dataDir, false) + require.NoError(t, err) + + // Build a small provider set for the list component + provider := catwalk.Provider{ + ID: catwalk.InferenceProvider("p1"), + Name: "Provider One", + Models: []catwalk.Model{ + {ID: "m1", Name: "Model One", DefaultMaxTokens: 100}, + {ID: "m2", Name: "Model Two", DefaultMaxTokens: 100}, // recent + }, + } + + // Create and initialize the component with our provider set + listKeyMap := list.DefaultKeyMap() + cmp := NewModelListComponent(listKeyMap, "Find your fave", false) + cmp.providers = []catwalk.Provider{provider} + execCmdML(t, cmp, cmp.Init()) + + // Find all recent items (IDs prefixed with "recent::") and verify pruning + groups := cmp.list.Groups() + require.NotEmpty(t, groups) + var recentItems []list.CompletionItem[ModelOption] + for _, g := range groups { + for _, it := range g.Items { + if strings.HasPrefix(it.ID(), "recent::") { + recentItems = append(recentItems, it) + } + } + } + require.NotEmpty(t, recentItems, "no recent items found") + // Ensure the valid recent (p1:m2) is present and the invalid one is not + foundValid := false + for _, it := range recentItems { + if it.ID() == "recent::p1:m2" { + foundValid = true + } + require.NotEqual(t, "recent::unknown-provider:x", it.ID(), "invalid recent should be pruned") + } + require.True(t, foundValid, "expected valid recent not found") + + // Verify original config in cfgDir remains unchanged + origConfPath := filepath.Join(cfgDir, "crush", "crush.json") + afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath)) + require.NoError(t, err) + var origParsed map[string]any + require.NoError(t, json.Unmarshal(afterOrig, &origParsed)) + origRM := origParsed["recent_models"].(map[string]any) + origLarge := origRM["large"].([]any) + require.Len(t, origLarge, 2, "original config should be unchanged") + + // Config should be rewritten with pruned recents in dataDir + dataConf := filepath.Join(dataDir, "crush", "crush.json") + rm := readRecentModels(t, dataConf) + largeAny, ok := rm["large"].([]any) + require.True(t, ok) + // Ensure that only valid recent(s) remain and the invalid one is removed + found := false + for _, v := range largeAny { + m := v.(map[string]any) + require.NotEqual(t, "unknown-provider", m["provider"], "invalid provider should be pruned") + if m["provider"] == "p1" && m["model"] == "m2" { + found = true + } + } + require.True(t, found, "persisted recents should include p1:m2") +} + +func TestModelList_PrunesInvalidModelWithinValidProvider(t *testing.T) { + // Pre-initialize logger to os.DevNull to prevent file lock on Windows. + log.Setup(os.DevNull, false) + + // Isolate config/data paths + cfgDir := t.TempDir() + dataDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", cfgDir) + t.Setenv("XDG_DATA_HOME", dataDir) + + // Pre-seed config with valid provider but one invalid model + confPath := filepath.Join(cfgDir, "crush", "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755)) + initial := map[string]any{ + "options": map[string]any{ + "disable_provider_auto_update": true, + }, + "models": map[string]any{ + "large": map[string]any{ + "model": "m1", + "provider": "p1", + }, + }, + "recent_models": map[string]any{ + "large": []any{ + map[string]any{"model": "m1", "provider": "p1"}, // valid + map[string]any{"model": "missing", "provider": "p1"}, // invalid model + }, + }, + } + bts, err := json.Marshal(initial) + require.NoError(t, err) + require.NoError(t, os.WriteFile(confPath, bts, 0o644)) + + // Create empty providers.json + dataConfDir := filepath.Join(dataDir, "crush") + require.NoError(t, os.MkdirAll(dataConfDir, 0o755)) + emptyProviders := []byte("[]") + require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644)) + + // Initialize global config instance + _, err = config.Init(cfgDir, dataDir, false) + require.NoError(t, err) + + // Build provider set that only includes m1, not "missing" + provider := catwalk.Provider{ + ID: catwalk.InferenceProvider("p1"), + Name: "Provider One", + Models: []catwalk.Model{ + {ID: "m1", Name: "Model One", DefaultMaxTokens: 100}, + }, + } + + // Create and initialize component + listKeyMap := list.DefaultKeyMap() + cmp := NewModelListComponent(listKeyMap, "Find your fave", false) + cmp.providers = []catwalk.Provider{provider} + execCmdML(t, cmp, cmp.Init()) + + // Find all recent items + groups := cmp.list.Groups() + require.NotEmpty(t, groups) + var recentItems []list.CompletionItem[ModelOption] + for _, g := range groups { + for _, it := range g.Items { + if strings.HasPrefix(it.ID(), "recent::") { + recentItems = append(recentItems, it) + } + } + } + require.NotEmpty(t, recentItems, "valid recent should exist") + + // Verify the valid recent is present and invalid model is not + foundValid := false + for _, it := range recentItems { + if it.ID() == "recent::p1:m1" { + foundValid = true + } + require.NotEqual(t, "recent::p1:missing", it.ID(), "invalid model should be pruned") + } + require.True(t, foundValid, "valid recent p1:m1 should be present") + + // Verify original config in cfgDir remains unchanged + origConfPath := filepath.Join(cfgDir, "crush", "crush.json") + afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath)) + require.NoError(t, err) + var origParsed map[string]any + require.NoError(t, json.Unmarshal(afterOrig, &origParsed)) + origRM := origParsed["recent_models"].(map[string]any) + origLarge := origRM["large"].([]any) + require.Len(t, origLarge, 2, "original config should be unchanged") + + // Config should be rewritten with pruned recents in dataDir + dataConf := filepath.Join(dataDir, "crush", "crush.json") + rm := readRecentModels(t, dataConf) + largeAny, ok := rm["large"].([]any) + require.True(t, ok) + require.Len(t, largeAny, 1, "should only have one valid model") + // Verify only p1:m1 remains + m := largeAny[0].(map[string]any) + require.Equal(t, "p1", m["provider"]) + require.Equal(t, "m1", m["model"]) +} + +func TestModelKey_EmptyInputs(t *testing.T) { + // Empty provider + require.Equal(t, "", modelKey("", "model")) + // Empty model + require.Equal(t, "", modelKey("provider", "")) + // Both empty + require.Equal(t, "", modelKey("", "")) + // Valid inputs + require.Equal(t, "p:m", modelKey("p", "m")) +} + +func TestModelList_AllRecentsInvalid(t *testing.T) { + // Pre-initialize logger to os.DevNull to prevent file lock on Windows. + log.Setup(os.DevNull, false) + + // Isolate config/data paths + cfgDir := t.TempDir() + dataDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", cfgDir) + t.Setenv("XDG_DATA_HOME", dataDir) + + // Pre-seed config with only invalid recents + confPath := filepath.Join(cfgDir, "crush", "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755)) + initial := map[string]any{ + "options": map[string]any{ + "disable_provider_auto_update": true, + }, + "models": map[string]any{ + "large": map[string]any{ + "model": "m1", + "provider": "p1", + }, + }, + "recent_models": map[string]any{ + "large": []any{ + map[string]any{"model": "x", "provider": "unknown1"}, + map[string]any{"model": "y", "provider": "unknown2"}, + }, + }, + } + bts, err := json.Marshal(initial) + require.NoError(t, err) + require.NoError(t, os.WriteFile(confPath, bts, 0o644)) + + // Also create empty providers.json and data config + dataConfDir := filepath.Join(dataDir, "crush") + require.NoError(t, os.MkdirAll(dataConfDir, 0o755)) + emptyProviders := []byte("[]") + require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644)) + + // Initialize global config instance with isolated dataDir + _, err = config.Init(cfgDir, dataDir, false) + require.NoError(t, err) + + // Build provider set (doesn't include unknown1 or unknown2) + provider := catwalk.Provider{ + ID: catwalk.InferenceProvider("p1"), + Name: "Provider One", + Models: []catwalk.Model{ + {ID: "m1", Name: "Model One", DefaultMaxTokens: 100}, + }, + } + + // Create and initialize component + listKeyMap := list.DefaultKeyMap() + cmp := NewModelListComponent(listKeyMap, "Find your fave", false) + cmp.providers = []catwalk.Provider{provider} + execCmdML(t, cmp, cmp.Init()) + + // Verify no recent items exist in UI + groups := cmp.list.Groups() + require.NotEmpty(t, groups) + var recentItems []list.CompletionItem[ModelOption] + for _, g := range groups { + for _, it := range g.Items { + if strings.HasPrefix(it.ID(), "recent::") { + recentItems = append(recentItems, it) + } + } + } + require.Empty(t, recentItems, "all invalid recents should be pruned, resulting in no recent section") + + // Verify original config in cfgDir remains unchanged + origConfPath := filepath.Join(cfgDir, "crush", "crush.json") + afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath)) + require.NoError(t, err) + var origParsed map[string]any + require.NoError(t, json.Unmarshal(afterOrig, &origParsed)) + origRM := origParsed["recent_models"].(map[string]any) + origLarge := origRM["large"].([]any) + require.Len(t, origLarge, 2, "original config should be unchanged") + + // Config should be rewritten with empty recents in dataDir + dataConf := filepath.Join(dataDir, "crush", "crush.json") + rm := readRecentModels(t, dataConf) + // When all recents are pruned, the value may be nil or an empty array + largeVal := rm["large"] + if largeVal == nil { + // nil is acceptable - means empty + return + } + largeAny, ok := largeVal.([]any) + require.True(t, ok, "large key should be nil or array") + require.Empty(t, largeAny, "persisted recents should be empty after pruning all invalid entries") +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index a80af22ef3365021175164f512b741a8b20de3f1..793421f58307778307d95bda1c3aec13af7522eb 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -267,7 +267,10 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, util.ReportWarn("Agent is busy, please wait...") } - config.Get().UpdatePreferredModel(msg.ModelType, msg.Model) + cfg := config.Get() + if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + return a, util.ReportError(err) + } go a.app.UpdateAgentModel(context.TODO()) diff --git a/schema.json b/schema.json index 093012bcd5dcd811fb512203d10f2d8db1780255..7c11d60755bc1cf0523e458de571ea3618e0be57 100644 --- a/schema.json +++ b/schema.json @@ -53,6 +53,16 @@ "type": "object", "description": "Model configurations for different model types" }, + "recent_models": { + "additionalProperties": { + "items": { + "$ref": "#/$defs/SelectedModel" + }, + "type": "array" + }, + "type": "object", + "description": "Recently used models sorted by most recent first" + }, "providers": { "additionalProperties": { "$ref": "#/$defs/ProviderConfig"