1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/tui/exp/list"
 12	"github.com/charmbracelet/crush/internal/tui/styles"
 13	"github.com/charmbracelet/crush/internal/tui/util"
 14)
 15
 16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 17
 18type ModelListComponent struct {
 19	list      listModel
 20	modelType int
 21	providers []catwalk.Provider
 22}
 23
 24func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 25	t := styles.CurrentTheme()
 26	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 27	options := []list.ListOption{
 28		list.WithKeyMap(keyMap),
 29		list.WithWrapNavigation(),
 30	}
 31	if shouldResize {
 32		options = append(options, list.WithResizeByList())
 33	}
 34	modelList := list.NewFilterableGroupedList(
 35		[]list.Group[list.CompletionItem[ModelOption]]{},
 36		list.WithFilterInputStyle(inputStyle),
 37		list.WithFilterPlaceholder(inputPlaceholder),
 38		list.WithFilterListOptions(
 39			options...,
 40		),
 41	)
 42
 43	return &ModelListComponent{
 44		list:      modelList,
 45		modelType: LargeModelType,
 46	}
 47}
 48
 49func (m *ModelListComponent) Init() tea.Cmd {
 50	var cmds []tea.Cmd
 51	if len(m.providers) == 0 {
 52		providers, err := config.Providers()
 53		filteredProviders := []catwalk.Provider{}
 54		for _, p := range providers {
 55			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 56			if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
 57				filteredProviders = append(filteredProviders, p)
 58			}
 59		}
 60
 61		m.providers = filteredProviders
 62		if err != nil {
 63			cmds = append(cmds, util.ReportError(err))
 64		}
 65	}
 66	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 67	return tea.Batch(cmds...)
 68}
 69
 70func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 71	u, cmd := m.list.Update(msg)
 72	m.list = u.(listModel)
 73	return m, cmd
 74}
 75
 76func (m *ModelListComponent) View() string {
 77	return m.list.View()
 78}
 79
 80func (m *ModelListComponent) Cursor() *tea.Cursor {
 81	return m.list.Cursor()
 82}
 83
 84func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 85	return m.list.SetSize(width, height)
 86}
 87
 88func (m *ModelListComponent) SelectedModel() *ModelOption {
 89	s := m.list.SelectedItem()
 90	if s == nil {
 91		return nil
 92	}
 93	sv := *s
 94	model := sv.Value()
 95	return &model
 96}
 97
 98func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 99	t := styles.CurrentTheme()
100	m.modelType = modelType
101
102	var groups []list.Group[list.CompletionItem[ModelOption]]
103	// first none section
104	selectedItemID := ""
105
106	cfg := config.Get()
107	var currentModel config.SelectedModel
108	if m.modelType == LargeModelType {
109		currentModel = cfg.Models[config.SelectedModelTypeLarge]
110	} else {
111		currentModel = cfg.Models[config.SelectedModelTypeSmall]
112	}
113
114	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
115	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
116
117	// Create a map to track which providers we've already added
118	addedProviders := make(map[string]bool)
119
120	// First, add any configured providers that are not in the known providers list
121	// These should appear at the top of the list
122	knownProviders, err := config.Providers()
123	if err != nil {
124		return util.ReportError(err)
125	}
126	for providerID, providerConfig := range cfg.Providers.Seq2() {
127		if providerConfig.Disable {
128			continue
129		}
130
131		// Check if this provider is not in the known providers list
132		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
133			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
134			// Convert config provider to provider.Provider format
135			configProvider := catwalk.Provider{
136				Name:   providerConfig.Name,
137				ID:     catwalk.InferenceProvider(providerID),
138				Models: make([]catwalk.Model, len(providerConfig.Models)),
139			}
140
141			// Convert models
142			for i, model := range providerConfig.Models {
143				configProvider.Models[i] = catwalk.Model{
144					ID:                     model.ID,
145					Name:                   model.Name,
146					CostPer1MIn:            model.CostPer1MIn,
147					CostPer1MOut:           model.CostPer1MOut,
148					CostPer1MInCached:      model.CostPer1MInCached,
149					CostPer1MOutCached:     model.CostPer1MOutCached,
150					ContextWindow:          model.ContextWindow,
151					DefaultMaxTokens:       model.DefaultMaxTokens,
152					CanReason:              model.CanReason,
153					HasReasoningEffort:     model.HasReasoningEffort,
154					DefaultReasoningEffort: model.DefaultReasoningEffort,
155					SupportsImages:         model.SupportsImages,
156				}
157			}
158
159			// Add this unknown provider to the list
160			name := configProvider.Name
161			if name == "" {
162				name = string(configProvider.ID)
163			}
164			section := list.NewItemSection(name)
165			section.SetInfo(configured)
166			group := list.Group[list.CompletionItem[ModelOption]]{
167				Section: section,
168			}
169			for _, model := range configProvider.Models {
170				item := list.NewCompletionItem(model.Name, ModelOption{
171					Provider: configProvider,
172					Model:    model,
173				},
174					list.WithCompletionID(
175						fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
176					),
177				)
178
179				group.Items = append(group.Items, item)
180				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
181					selectedItemID = item.ID()
182				}
183			}
184			groups = append(groups, group)
185
186			addedProviders[providerID] = true
187		}
188	}
189
190	// Then add the known providers from the predefined list
191	for _, provider := range m.providers {
192		// Skip if we already added this provider as an unknown provider
193		if addedProviders[string(provider.ID)] {
194			continue
195		}
196
197		// Check if this provider is configured and not disabled
198		if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
199			continue
200		}
201
202		name := provider.Name
203		if name == "" {
204			name = string(provider.ID)
205		}
206
207		section := list.NewItemSection(name)
208		if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
209			section.SetInfo(configured)
210		}
211		group := list.Group[list.CompletionItem[ModelOption]]{
212			Section: section,
213		}
214		for _, model := range provider.Models {
215			item := list.NewCompletionItem(model.Name, ModelOption{
216				Provider: provider,
217				Model:    model,
218			},
219				list.WithCompletionID(
220					fmt.Sprintf("%s:%s", provider.ID, model.ID),
221				),
222			)
223			group.Items = append(group.Items, item)
224			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
225				selectedItemID = item.ID()
226			}
227		}
228		groups = append(groups, group)
229	}
230
231	var cmds []tea.Cmd
232
233	cmd := m.list.SetGroups(groups)
234
235	if cmd != nil {
236		cmds = append(cmds, cmd)
237	}
238	cmd = m.list.SetSelected(selectedItemID)
239	if cmd != nil {
240		cmds = append(cmds, cmd)
241	}
242
243	return tea.Sequence(cmds...)
244}
245
246// GetModelType returns the current model type
247func (m *ModelListComponent) GetModelType() int {
248	return m.modelType
249}
250
251func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
252	m.list.SetInputPlaceholder(placeholder)
253}