list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6
  7	tea "github.com/charmbracelet/bubbletea/v2"
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/tui/exp/list"
 11	"github.com/charmbracelet/crush/internal/tui/styles"
 12	"github.com/charmbracelet/crush/internal/tui/util"
 13)
 14
 15type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 16
 17type ModelListComponent struct {
 18	list      listModel
 19	modelType int
 20	providers []catwalk.Provider
 21}
 22
 23func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 24	t := styles.CurrentTheme()
 25	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 26	options := []list.ListOption{
 27		list.WithKeyMap(keyMap),
 28		list.WithWrapNavigation(),
 29	}
 30	if shouldResize {
 31		options = append(options, list.WithResizeByList())
 32	}
 33	modelList := list.NewFilterableGroupedList(
 34		[]list.Group[list.CompletionItem[ModelOption]]{},
 35		list.WithFilterInputStyle(inputStyle),
 36		list.WithFilterPlaceholder(inputPlaceholder),
 37		list.WithFilterListOptions(
 38			options...,
 39		),
 40	)
 41
 42	return &ModelListComponent{
 43		list:      modelList,
 44		modelType: LargeModelType,
 45	}
 46}
 47
 48func (m *ModelListComponent) Init() tea.Cmd {
 49	var cmds []tea.Cmd
 50	if len(m.providers) == 0 {
 51		providers, err := config.Providers()
 52		m.providers = providers
 53		if err != nil {
 54			cmds = append(cmds, util.ReportError(err))
 55		}
 56	}
 57	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 58	return tea.Batch(cmds...)
 59}
 60
 61func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 62	u, cmd := m.list.Update(msg)
 63	m.list = u.(listModel)
 64	return m, cmd
 65}
 66
 67func (m *ModelListComponent) View() string {
 68	return m.list.View()
 69}
 70
 71func (m *ModelListComponent) Cursor() *tea.Cursor {
 72	return m.list.Cursor()
 73}
 74
 75func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 76	return m.list.SetSize(width, height)
 77}
 78
 79func (m *ModelListComponent) SelectedModel() *ModelOption {
 80	s := m.list.SelectedItem()
 81	if s == nil {
 82		return nil
 83	}
 84	sv := *s
 85	model := sv.Value()
 86	return &model
 87}
 88
 89func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 90	t := styles.CurrentTheme()
 91	m.modelType = modelType
 92
 93	var groups []list.Group[list.CompletionItem[ModelOption]]
 94	// first none section
 95	selectedItemID := ""
 96
 97	cfg := config.Get()
 98	var currentModel config.SelectedModel
 99	if m.modelType == LargeModelType {
100		currentModel = cfg.Models[config.SelectedModelTypeLarge]
101	} else {
102		currentModel = cfg.Models[config.SelectedModelTypeSmall]
103	}
104
105	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
106	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
107
108	// Create a map to track which providers we've already added
109	addedProviders := make(map[string]bool)
110
111	// First, add any configured providers that are not in the known providers list
112	// These should appear at the top of the list
113	knownProviders, err := config.Providers()
114	if err != nil {
115		return util.ReportError(err)
116	}
117	for providerID, providerConfig := range cfg.Providers.Seq2() {
118		if providerConfig.Disable {
119			continue
120		}
121
122		// Check if this provider is not in the known providers list
123		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
124			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
125			// Convert config provider to provider.Provider format
126			configProvider := catwalk.Provider{
127				Name:   providerConfig.Name,
128				ID:     catwalk.InferenceProvider(providerID),
129				Models: make([]catwalk.Model, len(providerConfig.Models)),
130			}
131
132			// Convert models
133			for i, model := range providerConfig.Models {
134				configProvider.Models[i] = catwalk.Model{
135					ID:                     model.ID,
136					Name:                   model.Name,
137					CostPer1MIn:            model.CostPer1MIn,
138					CostPer1MOut:           model.CostPer1MOut,
139					CostPer1MInCached:      model.CostPer1MInCached,
140					CostPer1MOutCached:     model.CostPer1MOutCached,
141					ContextWindow:          model.ContextWindow,
142					DefaultMaxTokens:       model.DefaultMaxTokens,
143					CanReason:              model.CanReason,
144					HasReasoningEffort:     model.HasReasoningEffort,
145					DefaultReasoningEffort: model.DefaultReasoningEffort,
146					SupportsImages:         model.SupportsImages,
147				}
148			}
149
150			// Add this unknown provider to the list
151			name := configProvider.Name
152			if name == "" {
153				name = string(configProvider.ID)
154			}
155			section := list.NewItemSection(name)
156			section.SetInfo(configured)
157			group := list.Group[list.CompletionItem[ModelOption]]{
158				Section: section,
159			}
160			for _, model := range configProvider.Models {
161				item := list.NewCompletionItem(model.Name, ModelOption{
162					Provider: configProvider,
163					Model:    model,
164				},
165					list.WithCompletionID(
166						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
167					),
168				)
169
170				group.Items = append(group.Items, item)
171				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
172					selectedItemID = item.ID()
173				}
174			}
175			groups = append(groups, group)
176
177			addedProviders[providerID] = true
178		}
179	}
180
181	// Then add the known providers from the predefined list
182	for _, provider := range m.providers {
183		// Skip if we already added this provider as an unknown provider
184		if addedProviders[string(provider.ID)] {
185			continue
186		}
187
188		// Check if this provider is configured and not disabled
189		if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
190			continue
191		}
192
193		name := provider.Name
194		if name == "" {
195			name = string(provider.ID)
196		}
197
198		section := list.NewItemSection(name)
199		if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
200			section.SetInfo(configured)
201		}
202		group := list.Group[list.CompletionItem[ModelOption]]{
203			Section: section,
204		}
205		for _, model := range provider.Models {
206			item := list.NewCompletionItem(model.Name, ModelOption{
207				Provider: provider,
208				Model:    model,
209			},
210				list.WithCompletionID(
211					fmt.Sprintf("%s:%s", provider.ID, model.ID),
212				),
213			)
214			group.Items = append(group.Items, item)
215			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
216				selectedItemID = item.ID()
217			}
218		}
219		groups = append(groups, group)
220	}
221
222	var cmds []tea.Cmd
223
224	cmd := m.list.SetGroups(groups)
225
226	if cmd != nil {
227		cmds = append(cmds, cmd)
228	}
229	cmd = m.list.SetSelected(selectedItemID)
230	if cmd != nil {
231		cmds = append(cmds, cmd)
232	}
233
234	return tea.Sequence(cmds...)
235}
236
237// GetModelType returns the current model type
238func (m *ModelListComponent) GetModelType() int {
239	return m.modelType
240}
241
242func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
243	m.list.SetInputPlaceholder(placeholder)
244}
245
246func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
247	m.providers = providers
248}