list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/tui/exp/list"
 12	"github.com/charmbracelet/crush/internal/tui/styles"
 13	"github.com/charmbracelet/crush/internal/tui/util"
 14)
 15
 16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 17
 18type ModelListComponent struct {
 19	list      listModel
 20	modelType int
 21	providers []catwalk.Provider
 22	cfg       *config.Config
 23}
 24
 25func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 26	t := styles.CurrentTheme()
 27	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 28	options := []list.ListOption{
 29		list.WithKeyMap(keyMap),
 30		list.WithWrapNavigation(),
 31	}
 32	if shouldResize {
 33		options = append(options, list.WithResizeByList())
 34	}
 35	modelList := list.NewFilterableGroupedList(
 36		[]list.Group[list.CompletionItem[ModelOption]]{},
 37		list.WithFilterInputStyle(inputStyle),
 38		list.WithFilterPlaceholder(inputPlaceholder),
 39		list.WithFilterListOptions(
 40			options...,
 41		),
 42	)
 43
 44	return &ModelListComponent{
 45		list:      modelList,
 46		modelType: LargeModelType,
 47		cfg:       cfg,
 48	}
 49}
 50
 51func (m *ModelListComponent) Init() tea.Cmd {
 52	var cmds []tea.Cmd
 53	if len(m.providers) == 0 {
 54		providers, err := config.Providers(m.cfg)
 55		filteredProviders := []catwalk.Provider{}
 56		for _, p := range providers {
 57			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 58			if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
 59				filteredProviders = append(filteredProviders, p)
 60			}
 61		}
 62
 63		m.providers = filteredProviders
 64		if err != nil {
 65			cmds = append(cmds, util.ReportError(err))
 66		}
 67	}
 68	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 69	return tea.Batch(cmds...)
 70}
 71
 72func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 73	u, cmd := m.list.Update(msg)
 74	m.list = u.(listModel)
 75	return m, cmd
 76}
 77
 78func (m *ModelListComponent) View() string {
 79	return m.list.View()
 80}
 81
 82func (m *ModelListComponent) Cursor() *tea.Cursor {
 83	return m.list.Cursor()
 84}
 85
 86func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 87	return m.list.SetSize(width, height)
 88}
 89
 90func (m *ModelListComponent) SelectedModel() *ModelOption {
 91	s := m.list.SelectedItem()
 92	if s == nil {
 93		return nil
 94	}
 95	sv := *s
 96	model := sv.Value()
 97	return &model
 98}
 99
100func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
101	t := styles.CurrentTheme()
102	m.modelType = modelType
103
104	var groups []list.Group[list.CompletionItem[ModelOption]]
105	// first none section
106	selectedItemID := ""
107
108	var currentModel config.SelectedModel
109	if m.modelType == LargeModelType {
110		currentModel = m.cfg.Models[config.SelectedModelTypeLarge]
111	} else {
112		currentModel = m.cfg.Models[config.SelectedModelTypeSmall]
113	}
114
115	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
116	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
117
118	// Create a map to track which providers we've already added
119	addedProviders := make(map[string]bool)
120
121	// First, add any configured providers that are not in the known providers list
122	// These should appear at the top of the list
123	knownProviders, err := config.Providers(m.cfg)
124	if err != nil {
125		return util.ReportError(err)
126	}
127	for providerID, providerConfig := range m.cfg.Providers.Seq2() {
128		if providerConfig.Disable {
129			continue
130		}
131
132		// Check if this provider is not in the known providers list
133		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
134			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
135			// Convert config provider to provider.Provider format
136			configProvider := catwalk.Provider{
137				Name:   providerConfig.Name,
138				ID:     catwalk.InferenceProvider(providerID),
139				Models: make([]catwalk.Model, len(providerConfig.Models)),
140			}
141
142			// Convert models
143			for i, model := range providerConfig.Models {
144				configProvider.Models[i] = catwalk.Model{
145					ID:                     model.ID,
146					Name:                   model.Name,
147					CostPer1MIn:            model.CostPer1MIn,
148					CostPer1MOut:           model.CostPer1MOut,
149					CostPer1MInCached:      model.CostPer1MInCached,
150					CostPer1MOutCached:     model.CostPer1MOutCached,
151					ContextWindow:          model.ContextWindow,
152					DefaultMaxTokens:       model.DefaultMaxTokens,
153					CanReason:              model.CanReason,
154					HasReasoningEffort:     model.HasReasoningEffort,
155					DefaultReasoningEffort: model.DefaultReasoningEffort,
156					SupportsImages:         model.SupportsImages,
157				}
158			}
159
160			// Add this unknown provider to the list
161			name := configProvider.Name
162			if name == "" {
163				name = string(configProvider.ID)
164			}
165			section := list.NewItemSection(name)
166			section.SetInfo(configured)
167			group := list.Group[list.CompletionItem[ModelOption]]{
168				Section: section,
169			}
170			for _, model := range configProvider.Models {
171				item := list.NewCompletionItem(model.Name, ModelOption{
172					Provider: configProvider,
173					Model:    model,
174				},
175					list.WithCompletionID(
176						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
177					),
178				)
179
180				group.Items = append(group.Items, item)
181				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
182					selectedItemID = item.ID()
183				}
184			}
185			groups = append(groups, group)
186
187			addedProviders[providerID] = true
188		}
189	}
190
191	// Then add the known providers from the predefined list
192	for _, provider := range m.providers {
193		// Skip if we already added this provider as an unknown provider
194		if addedProviders[string(provider.ID)] {
195			continue
196		}
197
198		// Check if this provider is configured and not disabled
199		if providerConfig, exists := m.cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
200			continue
201		}
202
203		name := provider.Name
204		if name == "" {
205			name = string(provider.ID)
206		}
207
208		section := list.NewItemSection(name)
209		if _, ok := m.cfg.Providers.Get(string(provider.ID)); ok {
210			section.SetInfo(configured)
211		}
212		group := list.Group[list.CompletionItem[ModelOption]]{
213			Section: section,
214		}
215		for _, model := range provider.Models {
216			item := list.NewCompletionItem(model.Name, ModelOption{
217				Provider: provider,
218				Model:    model,
219			},
220				list.WithCompletionID(
221					fmt.Sprintf("%s:%s", provider.ID, model.ID),
222				),
223			)
224			group.Items = append(group.Items, item)
225			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
226				selectedItemID = item.ID()
227			}
228		}
229		groups = append(groups, group)
230	}
231
232	var cmds []tea.Cmd
233
234	cmd := m.list.SetGroups(groups)
235
236	if cmd != nil {
237		cmds = append(cmds, cmd)
238	}
239	cmd = m.list.SetSelected(selectedItemID)
240	if cmd != nil {
241		cmds = append(cmds, cmd)
242	}
243
244	return tea.Sequence(cmds...)
245}
246
247// GetModelType returns the current model type
248func (m *ModelListComponent) GetModelType() int {
249	return m.modelType
250}
251
252func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
253	m.list.SetInputPlaceholder(placeholder)
254}