list.go

  1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6
  7	tea "github.com/charmbracelet/bubbletea/v2"
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/fur/provider"
 10	"github.com/charmbracelet/crush/internal/tui/components/completions"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15	"github.com/charmbracelet/lipgloss/v2"
 16)
 17
 18type ModelListComponent struct {
 19	list      list.ListModel
 20	modelType int
 21}
 22
 23func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
 24	modelList := list.New(
 25		list.WithFilterable(true),
 26		list.WithKeyMap(keyMap),
 27		list.WithInputStyle(inputStyle),
 28		list.WithFilterPlaceholder(inputPlaceholder),
 29		list.WithWrapNavigation(true),
 30	)
 31
 32	return &ModelListComponent{
 33		list:      modelList,
 34		modelType: LargeModelType,
 35	}
 36}
 37
 38func (m *ModelListComponent) Init() tea.Cmd {
 39	return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
 40}
 41
 42func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 43	u, cmd := m.list.Update(msg)
 44	m.list = u.(list.ListModel)
 45	return m, cmd
 46}
 47
 48func (m *ModelListComponent) View() tea.View {
 49	return m.list.View()
 50}
 51
 52func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 53	return m.list.SetSize(width, height)
 54}
 55
 56func (m *ModelListComponent) Items() []util.Model {
 57	return m.list.Items()
 58}
 59
 60func (m *ModelListComponent) SelectedIndex() int {
 61	return m.list.SelectedIndex()
 62}
 63
 64func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 65	t := styles.CurrentTheme()
 66	m.modelType = modelType
 67
 68	providers, err := config.Providers()
 69	if err != nil {
 70		return util.ReportError(err)
 71	}
 72
 73	modelItems := []util.Model{}
 74	selectIndex := 0
 75
 76	cfg := config.Get()
 77	var currentModel config.SelectedModel
 78	if m.modelType == LargeModelType {
 79		currentModel = cfg.Models[config.SelectedModelTypeLarge]
 80	} else {
 81		currentModel = cfg.Models[config.SelectedModelTypeSmall]
 82	}
 83
 84	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
 85	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
 86
 87	// Create a map to track which providers we've already added
 88	addedProviders := make(map[string]bool)
 89
 90	// First, add any configured providers that are not in the known providers list
 91	// These should appear at the top of the list
 92	knownProviders := provider.KnownProviders()
 93	for providerID, providerConfig := range cfg.Providers {
 94		if providerConfig.Disable {
 95			continue
 96		}
 97
 98		// Check if this provider is not in the known providers list
 99		if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
100			// Convert config provider to provider.Provider format
101			configProvider := provider.Provider{
102				Name:   string(providerID), // Use provider ID as name for unknown providers
103				ID:     provider.InferenceProvider(providerID),
104				Models: make([]provider.Model, len(providerConfig.Models)),
105			}
106
107			// Convert models
108			for i, model := range providerConfig.Models {
109				configProvider.Models[i] = provider.Model{
110					ID:                     model.ID,
111					Name:                   model.Name,
112					CostPer1MIn:            model.CostPer1MIn,
113					CostPer1MOut:           model.CostPer1MOut,
114					CostPer1MInCached:      model.CostPer1MInCached,
115					CostPer1MOutCached:     model.CostPer1MOutCached,
116					ContextWindow:          model.ContextWindow,
117					DefaultMaxTokens:       model.DefaultMaxTokens,
118					CanReason:              model.CanReason,
119					HasReasoningEffort:     model.HasReasoningEffort,
120					DefaultReasoningEffort: model.DefaultReasoningEffort,
121					SupportsImages:         model.SupportsImages,
122				}
123			}
124
125			// Add this unknown provider to the list
126			name := configProvider.Name
127			if name == "" {
128				name = string(configProvider.ID)
129			}
130			section := commands.NewItemSection(name)
131			section.SetInfo(configured)
132			modelItems = append(modelItems, section)
133			for _, model := range configProvider.Models {
134				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
135					Provider: configProvider,
136					Model:    model,
137				}))
138				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
139					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
140				}
141			}
142			addedProviders[providerID] = true
143		}
144	}
145
146	// Then add the known providers from the predefined list
147	for _, provider := range providers {
148		// Skip if we already added this provider as an unknown provider
149		if addedProviders[string(provider.ID)] {
150			continue
151		}
152
153		// Check if this provider is configured and not disabled
154		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
155			continue
156		}
157
158		name := provider.Name
159		if name == "" {
160			name = string(provider.ID)
161		}
162
163		section := commands.NewItemSection(name)
164		if _, ok := cfg.Providers[string(provider.ID)]; ok {
165			section.SetInfo(configured)
166		}
167		modelItems = append(modelItems, section)
168		for _, model := range provider.Models {
169			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
170				Provider: provider,
171				Model:    model,
172			}))
173			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
174				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
175			}
176		}
177	}
178
179	return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
180}
181
182// GetModelType returns the current model type
183func (m *ModelListComponent) GetModelType() int {
184	return m.modelType
185}
186
187func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
188	m.list.SetFilterPlaceholder(placeholder)
189}