1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "github.com/charmbracelet/bubbletea/v2"
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 26	t := styles.CurrentTheme()
 27	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 28	options := []list.ListOption{
 29		list.WithKeyMap(keyMap),
 30		list.WithWrapNavigation(),
 31	}
 32	if shouldResize {
 33		options = append(options, list.WithResizeByList())
 34	}
 35	modelList := list.NewFilterableGroupedList(
 36		[]list.Group[list.CompletionItem[ModelOption]]{},
 37		list.WithFilterInputStyle(inputStyle),
 38		list.WithFilterPlaceholder(inputPlaceholder),
 39		list.WithFilterListOptions(
 40			options...,
 41		),
 42	)
 43
 44	return &ModelListComponent{
 45		list:      modelList,
 46		modelType: LargeModelType,
 47	}
 48}
 49
 50func (m *ModelListComponent) Init() tea.Cmd {
 51	var cmds []tea.Cmd
 52	if len(m.providers) == 0 {
 53		cfg := config.Get()
 54		providers, err := config.Providers(cfg)
 55		filteredProviders := []catwalk.Provider{}
 56		for _, p := range providers {
 57			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 58			if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
 59				filteredProviders = append(filteredProviders, p)
 60			}
 61		}
 62
 63		m.providers = filteredProviders
 64		if err != nil {
 65			cmds = append(cmds, util.ReportError(err))
 66		}
 67	}
 68	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 69	return tea.Batch(cmds...)
 70}
 71
 72func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 73	u, cmd := m.list.Update(msg)
 74	m.list = u.(listModel)
 75	return m, cmd
 76}
 77
 78func (m *ModelListComponent) View() string {
 79	return m.list.View()
 80}
 81
 82func (m *ModelListComponent) Cursor() *tea.Cursor {
 83	return m.list.Cursor()
 84}
 85
 86func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 87	return m.list.SetSize(width, height)
 88}
 89
 90func (m *ModelListComponent) SelectedModel() *ModelOption {
 91	s := m.list.SelectedItem()
 92	if s == nil {
 93		return nil
 94	}
 95	sv := *s
 96	model := sv.Value()
 97	return &model
 98}
 99
100func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
101	t := styles.CurrentTheme()
102	m.modelType = modelType
103
104	var groups []list.Group[list.CompletionItem[ModelOption]]
105	// first none section
106	selectedItemID := ""
107
108	cfg := config.Get()
109	var currentModel config.SelectedModel
110	if m.modelType == LargeModelType {
111		currentModel = cfg.Models[config.SelectedModelTypeLarge]
112	} else {
113		currentModel = cfg.Models[config.SelectedModelTypeSmall]
114	}
115
116	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
117	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
118
119	// Create a map to track which providers we've already added
120	addedProviders := make(map[string]bool)
121
122	// First, add any configured providers that are not in the known providers list
123	// These should appear at the top of the list
124	knownProviders, err := config.Providers(cfg)
125	if err != nil {
126		return util.ReportError(err)
127	}
128	for providerID, providerConfig := range cfg.Providers.Seq2() {
129		if providerConfig.Disable {
130			continue
131		}
132
133		// Check if this provider is not in the known providers list
134		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
135			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
136			// Convert config provider to provider.Provider format
137			configProvider := catwalk.Provider{
138				Name:   providerConfig.Name,
139				ID:     catwalk.InferenceProvider(providerID),
140				Models: make([]catwalk.Model, len(providerConfig.Models)),
141			}
142
143			// Convert models
144			for i, model := range providerConfig.Models {
145				configProvider.Models[i] = catwalk.Model{
146					ID:                     model.ID,
147					Name:                   model.Name,
148					CostPer1MIn:            model.CostPer1MIn,
149					CostPer1MOut:           model.CostPer1MOut,
150					CostPer1MInCached:      model.CostPer1MInCached,
151					CostPer1MOutCached:     model.CostPer1MOutCached,
152					ContextWindow:          model.ContextWindow,
153					DefaultMaxTokens:       model.DefaultMaxTokens,
154					CanReason:              model.CanReason,
155					ReasoningLevels:        model.ReasoningLevels,
156					DefaultReasoningEffort: model.DefaultReasoningEffort,
157					SupportsImages:         model.SupportsImages,
158				}
159			}
160
161			// Add this unknown provider to the list
162			name := configProvider.Name
163			if name == "" {
164				name = string(configProvider.ID)
165			}
166			section := list.NewItemSection(name)
167			section.SetInfo(configured)
168			group := list.Group[list.CompletionItem[ModelOption]]{
169				Section: section,
170			}
171			for _, model := range configProvider.Models {
172				item := list.NewCompletionItem(model.Name, ModelOption{
173					Provider: configProvider,
174					Model:    model,
175				},
176					list.WithCompletionID(
177						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
178					),
179				)
180
181				group.Items = append(group.Items, item)
182				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
183					selectedItemID = item.ID()
184				}
185			}
186			groups = append(groups, group)
187
188			addedProviders[providerID] = true
189		}
190	}
191
192	// Then add the known providers from the predefined list
193	for _, provider := range m.providers {
194		// Skip if we already added this provider as an unknown provider
195		if addedProviders[string(provider.ID)] {
196			continue
197		}
198
199		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
200		if providerConfigured && providerConfig.Disable {
201			continue
202		}
203
204		displayProvider := provider
205		if providerConfigured {
206			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
207			modelIndex := make(map[string]int, len(displayProvider.Models))
208			for i, model := range displayProvider.Models {
209				modelIndex[model.ID] = i
210			}
211			for _, model := range providerConfig.Models {
212				if model.ID == "" {
213					continue
214				}
215				if idx, ok := modelIndex[model.ID]; ok {
216					if model.Name != "" {
217						displayProvider.Models[idx].Name = model.Name
218					}
219					continue
220				}
221				if model.Name == "" {
222					model.Name = model.ID
223				}
224				displayProvider.Models = append(displayProvider.Models, model)
225				modelIndex[model.ID] = len(displayProvider.Models) - 1
226			}
227		}
228
229		name := displayProvider.Name
230		if name == "" {
231			name = string(displayProvider.ID)
232		}
233
234		section := list.NewItemSection(name)
235		if providerConfigured {
236			section.SetInfo(configured)
237		}
238		group := list.Group[list.CompletionItem[ModelOption]]{
239			Section: section,
240		}
241		for _, model := range displayProvider.Models {
242			item := list.NewCompletionItem(model.Name, ModelOption{
243				Provider: displayProvider,
244				Model:    model,
245			},
246				list.WithCompletionID(
247					fmt.Sprintf("%s:%s", displayProvider.ID, model.ID),
248				),
249			)
250			group.Items = append(group.Items, item)
251			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
252				selectedItemID = item.ID()
253			}
254		}
255		groups = append(groups, group)
256	}
257
258	var cmds []tea.Cmd
259
260	cmd := m.list.SetGroups(groups)
261
262	if cmd != nil {
263		cmds = append(cmds, cmd)
264	}
265	cmd = m.list.SetSelected(selectedItemID)
266	if cmd != nil {
267		cmds = append(cmds, cmd)
268	}
269
270	return tea.Sequence(cmds...)
271}
272
273// GetModelType returns the current model type
274func (m *ModelListComponent) GetModelType() int {
275	return m.modelType
276}
277
278func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
279	m.list.SetInputPlaceholder(placeholder)
280}