list.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	tea "github.com/charmbracelet/bubbletea/v2"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/tui/components/completions"
 10	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 11	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 12	"github.com/charmbracelet/crush/internal/tui/util"
 13	"github.com/charmbracelet/lipgloss/v2"
 14)
 15
 16type ModelListComponent struct {
 17	list      list.ListModel
 18	modelType int
 19}
 20
 21func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style) *ModelListComponent {
 22	modelList := list.New(
 23		list.WithFilterable(true),
 24		list.WithKeyMap(keyMap),
 25		list.WithInputStyle(inputStyle),
 26		list.WithWrapNavigation(true),
 27	)
 28
 29	return &ModelListComponent{
 30		list:      modelList,
 31		modelType: LargeModelType,
 32	}
 33}
 34
 35func (m *ModelListComponent) Init() tea.Cmd {
 36	return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
 37}
 38
 39func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 40	u, cmd := m.list.Update(msg)
 41	m.list = u.(list.ListModel)
 42	return m, cmd
 43}
 44
 45func (m *ModelListComponent) View() tea.View {
 46	return m.list.View()
 47}
 48
 49func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 50	return m.list.SetSize(width, height)
 51}
 52
 53func (m *ModelListComponent) Items() []util.Model {
 54	return m.list.Items()
 55}
 56
 57func (m *ModelListComponent) SelectedIndex() int {
 58	return m.list.SelectedIndex()
 59}
 60
 61func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 62	m.modelType = modelType
 63
 64	providers, err := config.Providers()
 65	if err != nil {
 66		return util.ReportError(err)
 67	}
 68
 69	modelItems := []util.Model{}
 70	selectIndex := 0
 71
 72	cfg := config.Get()
 73	var currentModel config.SelectedModel
 74	if m.modelType == LargeModelType {
 75		currentModel = cfg.Models[config.SelectedModelTypeLarge]
 76	} else {
 77		currentModel = cfg.Models[config.SelectedModelTypeSmall]
 78	}
 79
 80	// Create a map to track which providers we've already added
 81	addedProviders := make(map[string]bool)
 82
 83	// First, add any configured providers that are not in the known providers list
 84	// These should appear at the top of the list
 85	knownProviders := provider.KnownProviders()
 86	for providerID, providerConfig := range cfg.Providers {
 87		if providerConfig.Disable {
 88			continue
 89		}
 90
 91		// Check if this provider is not in the known providers list
 92		if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
 93			// Convert config provider to provider.Provider format
 94			configProvider := provider.Provider{
 95				Name:   string(providerID), // Use provider ID as name for unknown providers
 96				ID:     provider.InferenceProvider(providerID),
 97				Models: make([]provider.Model, len(providerConfig.Models)),
 98			}
 99
100			// Convert models
101			for i, model := range providerConfig.Models {
102				configProvider.Models[i] = provider.Model{
103					ID:                     model.ID,
104					Name:                   model.Name,
105					CostPer1MIn:            model.CostPer1MIn,
106					CostPer1MOut:           model.CostPer1MOut,
107					CostPer1MInCached:      model.CostPer1MInCached,
108					CostPer1MOutCached:     model.CostPer1MOutCached,
109					ContextWindow:          model.ContextWindow,
110					DefaultMaxTokens:       model.DefaultMaxTokens,
111					CanReason:              model.CanReason,
112					HasReasoningEffort:     model.HasReasoningEffort,
113					DefaultReasoningEffort: model.DefaultReasoningEffort,
114					SupportsImages:         model.SupportsImages,
115				}
116			}
117
118			// Add this unknown provider to the list
119			name := configProvider.Name
120			if name == "" {
121				name = string(configProvider.ID)
122			}
123			modelItems = append(modelItems, commands.NewItemSection(name))
124			for _, model := range configProvider.Models {
125				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
126					Provider: configProvider,
127					Model:    model,
128				}))
129				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
130					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
131				}
132			}
133			addedProviders[providerID] = true
134		}
135	}
136
137	// Then add the known providers from the predefined list
138	for _, provider := range providers {
139		// Skip if we already added this provider as an unknown provider
140		if addedProviders[string(provider.ID)] {
141			continue
142		}
143
144		// Check if this provider is configured and not disabled
145		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
146			continue
147		}
148
149		name := provider.Name
150		if name == "" {
151			name = string(provider.ID)
152		}
153		modelItems = append(modelItems, commands.NewItemSection(name))
154		for _, model := range provider.Models {
155			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
156				Provider: provider,
157				Model:    model,
158			}))
159			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
160				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
161			}
162		}
163	}
164
165	return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
166}
167
168// GetModelType returns the current model type
169func (m *ModelListComponent) GetModelType() int {
170	return m.modelType
171}