list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/env"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 26	t := styles.CurrentTheme()
 27	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 28	options := []list.ListOption{
 29		list.WithKeyMap(keyMap),
 30		list.WithWrapNavigation(),
 31	}
 32	if shouldResize {
 33		options = append(options, list.WithResizeByList())
 34	}
 35	modelList := list.NewFilterableGroupedList(
 36		[]list.Group[list.CompletionItem[ModelOption]]{},
 37		list.WithFilterInputStyle(inputStyle),
 38		list.WithFilterPlaceholder(inputPlaceholder),
 39		list.WithFilterListOptions(
 40			options...,
 41		),
 42	)
 43
 44	return &ModelListComponent{
 45		list:      modelList,
 46		modelType: LargeModelType,
 47	}
 48}
 49
 50func (m *ModelListComponent) Init() tea.Cmd {
 51	var cmds []tea.Cmd
 52	if len(m.providers) == 0 {
 53		providers, err := config.Providers()
 54		filteredProviders := []catwalk.Provider{}
 55		for _, p := range providers {
 56			hasApiKeyEnv := strings.HasPrefix(p.APIKey, "$")
 57			resolver := config.NewEnvironmentVariableResolver(env.New())
 58			endpoint, _ := resolver.ResolveValue(p.APIEndpoint)
 59			if endpoint != "" && hasApiKeyEnv {
 60				filteredProviders = append(filteredProviders, p)
 61			}
 62		}
 63
 64		m.providers = filteredProviders
 65		if err != nil {
 66			cmds = append(cmds, util.ReportError(err))
 67		}
 68	}
 69	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 70	return tea.Batch(cmds...)
 71}
 72
 73func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 74	u, cmd := m.list.Update(msg)
 75	m.list = u.(listModel)
 76	return m, cmd
 77}
 78
 79func (m *ModelListComponent) View() string {
 80	return m.list.View()
 81}
 82
 83func (m *ModelListComponent) Cursor() *tea.Cursor {
 84	return m.list.Cursor()
 85}
 86
 87func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 88	return m.list.SetSize(width, height)
 89}
 90
 91func (m *ModelListComponent) SelectedModel() *ModelOption {
 92	s := m.list.SelectedItem()
 93	if s == nil {
 94		return nil
 95	}
 96	sv := *s
 97	model := sv.Value()
 98	return &model
 99}
100
101func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
102	t := styles.CurrentTheme()
103	m.modelType = modelType
104
105	var groups []list.Group[list.CompletionItem[ModelOption]]
106	// first none section
107	selectedItemID := ""
108
109	cfg := config.Get()
110	var currentModel config.SelectedModel
111	if m.modelType == LargeModelType {
112		currentModel = cfg.Models[config.SelectedModelTypeLarge]
113	} else {
114		currentModel = cfg.Models[config.SelectedModelTypeSmall]
115	}
116
117	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
118	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
119
120	// Create a map to track which providers we've already added
121	addedProviders := make(map[string]bool)
122
123	// First, add any configured providers that are not in the known providers list
124	// These should appear at the top of the list
125	knownProviders, err := config.Providers()
126	if err != nil {
127		return util.ReportError(err)
128	}
129	for providerID, providerConfig := range cfg.Providers.Seq2() {
130		if providerConfig.Disable {
131			continue
132		}
133
134		// Check if this provider is not in the known providers list
135		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
136			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
137			// Convert config provider to provider.Provider format
138			configProvider := catwalk.Provider{
139				Name:   providerConfig.Name,
140				ID:     catwalk.InferenceProvider(providerID),
141				Models: make([]catwalk.Model, len(providerConfig.Models)),
142			}
143
144			// Convert models
145			for i, model := range providerConfig.Models {
146				configProvider.Models[i] = catwalk.Model{
147					ID:                     model.ID,
148					Name:                   model.Name,
149					CostPer1MIn:            model.CostPer1MIn,
150					CostPer1MOut:           model.CostPer1MOut,
151					CostPer1MInCached:      model.CostPer1MInCached,
152					CostPer1MOutCached:     model.CostPer1MOutCached,
153					ContextWindow:          model.ContextWindow,
154					DefaultMaxTokens:       model.DefaultMaxTokens,
155					CanReason:              model.CanReason,
156					HasReasoningEffort:     model.HasReasoningEffort,
157					DefaultReasoningEffort: model.DefaultReasoningEffort,
158					SupportsImages:         model.SupportsImages,
159				}
160			}
161
162			// Add this unknown provider to the list
163			name := configProvider.Name
164			if name == "" {
165				name = string(configProvider.ID)
166			}
167			section := list.NewItemSection(name)
168			section.SetInfo(configured)
169			group := list.Group[list.CompletionItem[ModelOption]]{
170				Section: section,
171			}
172			for _, model := range configProvider.Models {
173				item := list.NewCompletionItem(model.Name, ModelOption{
174					Provider: configProvider,
175					Model:    model,
176				},
177					list.WithCompletionID(
178						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
179					),
180				)
181
182				group.Items = append(group.Items, item)
183				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
184					selectedItemID = item.ID()
185				}
186			}
187			groups = append(groups, group)
188
189			addedProviders[providerID] = true
190		}
191	}
192
193	// Then add the known providers from the predefined list
194	for _, provider := range m.providers {
195		// Skip if we already added this provider as an unknown provider
196		if addedProviders[string(provider.ID)] {
197			continue
198		}
199
200		// Check if this provider is configured and not disabled
201		if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
202			continue
203		}
204
205		name := provider.Name
206		if name == "" {
207			name = string(provider.ID)
208		}
209
210		section := list.NewItemSection(name)
211		if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
212			section.SetInfo(configured)
213		}
214		group := list.Group[list.CompletionItem[ModelOption]]{
215			Section: section,
216		}
217		for _, model := range provider.Models {
218			item := list.NewCompletionItem(model.Name, ModelOption{
219				Provider: provider,
220				Model:    model,
221			},
222				list.WithCompletionID(
223					fmt.Sprintf("%s:%s", provider.ID, model.ID),
224				),
225			)
226			group.Items = append(group.Items, item)
227			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
228				selectedItemID = item.ID()
229			}
230		}
231		groups = append(groups, group)
232	}
233
234	var cmds []tea.Cmd
235
236	cmd := m.list.SetGroups(groups)
237
238	if cmd != nil {
239		cmds = append(cmds, cmd)
240	}
241	cmd = m.list.SetSelected(selectedItemID)
242	if cmd != nil {
243		cmds = append(cmds, cmd)
244	}
245
246	return tea.Sequence(cmds...)
247}
248
249// GetModelType returns the current model type
250func (m *ModelListComponent) GetModelType() int {
251	return m.modelType
252}
253
254func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
255	m.list.SetInputPlaceholder(placeholder)
256}