list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6
  7	tea "github.com/charmbracelet/bubbletea/v2"
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/agent"
 11	"github.com/charmbracelet/crush/internal/tui/exp/list"
 12	"github.com/charmbracelet/crush/internal/tui/styles"
 13	"github.com/charmbracelet/crush/internal/tui/util"
 14)
 15
 16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 17
 18type ModelListComponent struct {
 19	list      listModel
 20	modelType int
 21	providers []catwalk.Provider
 22	config    *config.Config
 23}
 24
 25func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 26	t := styles.CurrentTheme()
 27	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 28	options := []list.ListOption{
 29		list.WithKeyMap(keyMap),
 30		list.WithWrapNavigation(),
 31	}
 32	if shouldResize {
 33		options = append(options, list.WithResizeByList())
 34	}
 35	modelList := list.NewFilterableGroupedList(
 36		[]list.Group[list.CompletionItem[ModelOption]]{},
 37		list.WithFilterInputStyle(inputStyle),
 38		list.WithFilterPlaceholder(inputPlaceholder),
 39		list.WithFilterListOptions(
 40			options...,
 41		),
 42	)
 43
 44	return &ModelListComponent{
 45		list:      modelList,
 46		modelType: LargeModelType,
 47		config:    cfg,
 48	}
 49}
 50
 51func (m *ModelListComponent) Init() tea.Cmd {
 52	var cmds []tea.Cmd
 53	if len(m.providers) == 0 {
 54		providers, err := config.Providers()
 55		m.providers = providers
 56		if err != nil {
 57			cmds = append(cmds, util.ReportError(err))
 58		}
 59	}
 60	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 61	return tea.Batch(cmds...)
 62}
 63
 64func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 65	u, cmd := m.list.Update(msg)
 66	m.list = u.(listModel)
 67	return m, cmd
 68}
 69
 70func (m *ModelListComponent) View() string {
 71	return m.list.View()
 72}
 73
 74func (m *ModelListComponent) Cursor() *tea.Cursor {
 75	return m.list.Cursor()
 76}
 77
 78func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 79	return m.list.SetSize(width, height)
 80}
 81
 82func (m *ModelListComponent) SelectedModel() *ModelOption {
 83	s := m.list.SelectedItem()
 84	if s == nil {
 85		return nil
 86	}
 87	sv := *s
 88	model := sv.Value()
 89	return &model
 90}
 91
 92func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 93	t := styles.CurrentTheme()
 94	m.modelType = modelType
 95
 96	var groups []list.Group[list.CompletionItem[ModelOption]]
 97	// first none section
 98	selectedItemID := ""
 99
100	var currentModel agent.Model
101	if m.modelType == LargeModelType {
102		currentModel = m.config.Models[config.SelectedModelTypeLarge]
103	} else {
104		currentModel = m.config.Models[config.SelectedModelTypeSmall]
105	}
106
107	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
108	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
109
110	// Create a map to track which providers we've already added
111	addedProviders := make(map[string]bool)
112
113	// First, add any configured providers that are not in the known providers list
114	// These should appear at the top of the list
115	knownProviders, err := config.Providers()
116	if err != nil {
117		return util.ReportError(err)
118	}
119	for providerID, providerConfig := range m.config.Providers.Seq2() {
120		if providerConfig.Disable {
121			continue
122		}
123
124		// Check if this provider is not in the known providers list
125		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
126			// Convert config provider to provider.Provider format
127			configProvider := catwalk.Provider{
128				Name:   providerConfig.Name,
129				ID:     catwalk.InferenceProvider(providerID),
130				Models: make([]catwalk.Model, len(providerConfig.Models)),
131			}
132
133			// Convert models
134			for i, model := range providerConfig.Models {
135				configProvider.Models[i] = catwalk.Model{
136					ID:                     model.ID,
137					Name:                   model.Name,
138					CostPer1MIn:            model.CostPer1MIn,
139					CostPer1MOut:           model.CostPer1MOut,
140					CostPer1MInCached:      model.CostPer1MInCached,
141					CostPer1MOutCached:     model.CostPer1MOutCached,
142					ContextWindow:          model.ContextWindow,
143					DefaultMaxTokens:       model.DefaultMaxTokens,
144					CanReason:              model.CanReason,
145					HasReasoningEffort:     model.HasReasoningEffort,
146					DefaultReasoningEffort: model.DefaultReasoningEffort,
147					SupportsImages:         model.SupportsImages,
148				}
149			}
150
151			// Add this unknown provider to the list
152			name := configProvider.Name
153			if name == "" {
154				name = string(configProvider.ID)
155			}
156			section := list.NewItemSection(name)
157			section.SetInfo(configured)
158			group := list.Group[list.CompletionItem[ModelOption]]{
159				Section: section,
160			}
161			for _, model := range configProvider.Models {
162				item := list.NewCompletionItem(model.Name, ModelOption{
163					Provider: configProvider,
164					Model:    model,
165				},
166					list.WithCompletionID(
167						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
168					),
169				)
170
171				group.Items = append(group.Items, item)
172				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
173					selectedItemID = item.ID()
174				}
175			}
176			groups = append(groups, group)
177
178			addedProviders[providerID] = true
179		}
180	}
181
182	// Then add the known providers from the predefined list
183	for _, provider := range m.providers {
184		// Skip if we already added this provider as an unknown provider
185		if addedProviders[string(provider.ID)] {
186			continue
187		}
188
189		// Check if this provider is configured and not disabled
190		if providerConfig, exists := m.config.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
191			continue
192		}
193
194		name := provider.Name
195		if name == "" {
196			name = string(provider.ID)
197		}
198
199		section := list.NewItemSection(name)
200		if _, ok := m.config.Providers.Get(string(provider.ID)); ok {
201			section.SetInfo(configured)
202		}
203		group := list.Group[list.CompletionItem[ModelOption]]{
204			Section: section,
205		}
206		for _, model := range provider.Models {
207			item := list.NewCompletionItem(model.Name, ModelOption{
208				Provider: provider,
209				Model:    model,
210			},
211				list.WithCompletionID(
212					fmt.Sprintf("%s:%s", provider.ID, model.ID),
213				),
214			)
215			group.Items = append(group.Items, item)
216			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
217				selectedItemID = item.ID()
218			}
219		}
220		groups = append(groups, group)
221	}
222
223	var cmds []tea.Cmd
224
225	cmd := m.list.SetGroups(groups)
226
227	if cmd != nil {
228		cmds = append(cmds, cmd)
229	}
230	cmd = m.list.SetSelected(selectedItemID)
231	if cmd != nil {
232		cmds = append(cmds, cmd)
233	}
234
235	return tea.Sequence(cmds...)
236}
237
238// GetModelType returns the current model type
239func (m *ModelListComponent) GetModelType() int {
240	return m.modelType
241}
242
243func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
244	m.list.SetInputPlaceholder(placeholder)
245}
246
247func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
248	m.providers = providers
249}