list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6
  7	tea "github.com/charmbracelet/bubbletea/v2"
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/tui/components/completions"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15	"github.com/charmbracelet/lipgloss/v2"
 16)
 17
 18type ModelListComponent struct {
 19	list      list.ListModel
 20	modelType int
 21	providers []catwalk.Provider
 22}
 23
 24func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
 25	modelList := list.New(
 26		list.WithFilterable(true),
 27		list.WithKeyMap(keyMap),
 28		list.WithInputStyle(inputStyle),
 29		list.WithFilterPlaceholder(inputPlaceholder),
 30		list.WithWrapNavigation(true),
 31	)
 32
 33	return &ModelListComponent{
 34		list:      modelList,
 35		modelType: LargeModelType,
 36	}
 37}
 38
 39func (m *ModelListComponent) Init() tea.Cmd {
 40	var cmds []tea.Cmd
 41	if len(m.providers) == 0 {
 42		providers, err := config.Providers()
 43		m.providers = providers
 44		if err != nil {
 45			cmds = append(cmds, util.ReportError(err))
 46		}
 47	}
 48	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 49	return tea.Batch(cmds...)
 50}
 51
 52func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 53	u, cmd := m.list.Update(msg)
 54	m.list = u.(list.ListModel)
 55	return m, cmd
 56}
 57
 58func (m *ModelListComponent) View() string {
 59	return m.list.View()
 60}
 61
 62func (m *ModelListComponent) Cursor() *tea.Cursor {
 63	return m.list.Cursor()
 64}
 65
 66func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 67	return m.list.SetSize(width, height)
 68}
 69
 70func (m *ModelListComponent) Items() []util.Model {
 71	return m.list.Items()
 72}
 73
 74func (m *ModelListComponent) SelectedIndex() int {
 75	return m.list.SelectedIndex()
 76}
 77
 78func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 79	t := styles.CurrentTheme()
 80	m.modelType = modelType
 81
 82	modelItems := []util.Model{}
 83	// first none section
 84	selectIndex := 1
 85
 86	cfg := config.Get()
 87	var currentModel config.SelectedModel
 88	if m.modelType == LargeModelType {
 89		currentModel = cfg.Models[config.SelectedModelTypeLarge]
 90	} else {
 91		currentModel = cfg.Models[config.SelectedModelTypeSmall]
 92	}
 93
 94	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
 95	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
 96
 97	// Create a map to track which providers we've already added
 98	addedProviders := make(map[string]bool)
 99
100	// First, add any configured providers that are not in the known providers list
101	// These should appear at the top of the list
102	knownProviders, err := config.Providers()
103	if err != nil {
104		return util.ReportError(err)
105	}
106	for providerID, providerConfig := range cfg.Providers {
107		if providerConfig.Disable {
108			continue
109		}
110
111		// Check if this provider is not in the known providers list
112		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
113			// Convert config provider to provider.Provider format
114			configProvider := catwalk.Provider{
115				Name:   providerConfig.Name,
116				ID:     catwalk.InferenceProvider(providerID),
117				Models: make([]catwalk.Model, len(providerConfig.Models)),
118			}
119
120			// Convert models
121			for i, model := range providerConfig.Models {
122				configProvider.Models[i] = catwalk.Model{
123					ID:                     model.ID,
124					Name:                   model.Name,
125					CostPer1MIn:            model.CostPer1MIn,
126					CostPer1MOut:           model.CostPer1MOut,
127					CostPer1MInCached:      model.CostPer1MInCached,
128					CostPer1MOutCached:     model.CostPer1MOutCached,
129					ContextWindow:          model.ContextWindow,
130					DefaultMaxTokens:       model.DefaultMaxTokens,
131					CanReason:              model.CanReason,
132					HasReasoningEffort:     model.HasReasoningEffort,
133					DefaultReasoningEffort: model.DefaultReasoningEffort,
134					SupportsImages:         model.SupportsImages,
135				}
136			}
137
138			// Add this unknown provider to the list
139			name := configProvider.Name
140			if name == "" {
141				name = string(configProvider.ID)
142			}
143			section := commands.NewItemSection(name)
144			section.SetInfo(configured)
145			modelItems = append(modelItems, section)
146			for _, model := range configProvider.Models {
147				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
148					Provider: configProvider,
149					Model:    model,
150				}))
151				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
152					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
153				}
154			}
155			addedProviders[providerID] = true
156		}
157	}
158
159	// Then add the known providers from the predefined list
160	for _, provider := range m.providers {
161		// Skip if we already added this provider as an unknown provider
162		if addedProviders[string(provider.ID)] {
163			continue
164		}
165
166		// Check if this provider is configured and not disabled
167		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
168			continue
169		}
170
171		name := provider.Name
172		if name == "" {
173			name = string(provider.ID)
174		}
175
176		section := commands.NewItemSection(name)
177		if _, ok := cfg.Providers[string(provider.ID)]; ok {
178			section.SetInfo(configured)
179		}
180		modelItems = append(modelItems, section)
181		for _, model := range provider.Models {
182			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
183				Provider: provider,
184				Model:    model,
185			}))
186			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
187				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
188			}
189		}
190	}
191
192	return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
193}
194
195// GetModelType returns the current model type
196func (m *ModelListComponent) GetModelType() int {
197	return m.modelType
198}
199
200func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
201	m.list.SetFilterPlaceholder(placeholder)
202}
203
204func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
205	m.providers = providers
206}