1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30)
 31
 32// ModelSelectedMsg is sent when a model is selected
 33type ModelSelectedMsg struct {
 34	Model     config.PreferredModel
 35	ModelType config.ModelType
 36}
 37
 38// CloseModelDialogMsg is sent when a model is selected
 39type CloseModelDialogMsg struct{}
 40
 41// ModelDialog interface for the model selection dialog
 42type ModelDialog interface {
 43	dialogs.DialogModel
 44}
 45
 46type ModelOption struct {
 47	Provider provider.Provider
 48	Model    provider.Model
 49}
 50
 51type modelDialogCmp struct {
 52	width   int
 53	wWidth  int
 54	wHeight int
 55
 56	modelList list.ListModel
 57	keyMap    KeyMap
 58	help      help.Model
 59	modelType int
 60}
 61
 62func NewModelDialogCmp() ModelDialog {
 63	listKeyMap := list.DefaultKeyMap()
 64	keyMap := DefaultKeyMap()
 65
 66	listKeyMap.Down.SetEnabled(false)
 67	listKeyMap.Up.SetEnabled(false)
 68	listKeyMap.HalfPageDown.SetEnabled(false)
 69	listKeyMap.HalfPageUp.SetEnabled(false)
 70	listKeyMap.Home.SetEnabled(false)
 71	listKeyMap.End.SetEnabled(false)
 72
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := list.New(
 79		list.WithFilterable(true),
 80		list.WithKeyMap(listKeyMap),
 81		list.WithInputStyle(inputStyle),
 82		list.WithWrapNavigation(true),
 83	)
 84	help := help.New()
 85	help.Styles = t.S().Help
 86
 87	return &modelDialogCmp{
 88		modelList: modelList,
 89		width:     defaultWidth,
 90		keyMap:    DefaultKeyMap(),
 91		help:      help,
 92		modelType: LargeModelType,
 93	}
 94}
 95
 96func (m *modelDialogCmp) Init() tea.Cmd {
 97	m.SetModelType(m.modelType)
 98	return m.modelList.Init()
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.SetModelType(m.modelType)
107		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108	case tea.KeyPressMsg:
109		switch {
110		case key.Matches(msg, m.keyMap.Select):
111			selectedItemInx := m.modelList.SelectedIndex()
112			if selectedItemInx == list.NoSelection {
113				return m, nil
114			}
115			items := m.modelList.Items()
116			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118			var modelType config.ModelType
119			if m.modelType == LargeModelType {
120				modelType = config.LargeModel
121			} else {
122				modelType = config.SmallModel
123			}
124
125			return m, tea.Sequence(
126				util.CmdHandler(dialogs.CloseDialogMsg{}),
127				util.CmdHandler(ModelSelectedMsg{
128					Model: config.PreferredModel{
129						ModelID:  selectedItem.Model.ID,
130						Provider: selectedItem.Provider.ID,
131					},
132					ModelType: modelType,
133				}),
134			)
135		case key.Matches(msg, m.keyMap.Tab):
136			if m.modelType == LargeModelType {
137				return m, m.SetModelType(SmallModelType)
138			} else {
139				return m, m.SetModelType(LargeModelType)
140			}
141		case key.Matches(msg, m.keyMap.Close):
142			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143		default:
144			u, cmd := m.modelList.Update(msg)
145			m.modelList = u.(list.ListModel)
146			return m, cmd
147		}
148	}
149	return m, nil
150}
151
152func (m *modelDialogCmp) View() tea.View {
153	t := styles.CurrentTheme()
154	listView := m.modelList.View()
155	radio := m.modelTypeRadio()
156	content := lipgloss.JoinVertical(
157		lipgloss.Left,
158		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159		listView.String(),
160		"",
161		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162	)
163	v := tea.NewView(m.style().Render(content))
164	if listView.Cursor() != nil {
165		c := m.moveCursor(listView.Cursor())
166		v.SetCursor(c)
167	}
168	return v
169}
170
171func (m *modelDialogCmp) style() lipgloss.Style {
172	t := styles.CurrentTheme()
173	return t.S().Base.
174		Width(m.width).
175		Border(lipgloss.RoundedBorder()).
176		BorderForeground(t.BorderFocus)
177}
178
179func (m *modelDialogCmp) listWidth() int {
180	return defaultWidth - 2 // 4 for padding
181}
182
183func (m *modelDialogCmp) listHeight() int {
184	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
185	return min(listHeigh, m.wHeight/2)
186}
187
188func (m *modelDialogCmp) Position() (int, int) {
189	row := m.wHeight/4 - 2 // just a bit above the center
190	col := m.wWidth / 2
191	col -= m.width / 2
192	return row, col
193}
194
195func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
196	row, col := m.Position()
197	offset := row + 3 // Border + title
198	cursor.Y += offset
199	cursor.X = cursor.X + col + 2
200	return cursor
201}
202
203func (m *modelDialogCmp) ID() dialogs.DialogID {
204	return ModelsDialogID
205}
206
207func (m *modelDialogCmp) modelTypeRadio() string {
208	t := styles.CurrentTheme()
209	choices := []string{"Large Task", "Small Task"}
210	iconSelected := "◉"
211	iconUnselected := "○"
212	if m.modelType == LargeModelType {
213		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
214	}
215	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
216}
217
218func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
219	m.modelType = modelType
220
221	providers := config.Providers()
222	modelItems := []util.Model{}
223	selectIndex := 0
224
225	cfg := config.Get()
226	var currentModel config.PreferredModel
227	if m.modelType == LargeModelType {
228		currentModel = cfg.Models.Large
229	} else {
230		currentModel = cfg.Models.Small
231	}
232
233	// Create a map to track which providers we've already added
234	addedProviders := make(map[provider.InferenceProvider]bool)
235
236	// First, add any configured providers that are not in the known providers list
237	// These should appear at the top of the list
238	knownProviders := provider.KnownProviders()
239	for providerID, providerConfig := range cfg.Providers {
240		if providerConfig.Disabled {
241			continue
242		}
243
244		// Check if this provider is not in the known providers list
245		if !slices.Contains(knownProviders, providerID) {
246			// Convert config provider to provider.Provider format
247			configProvider := provider.Provider{
248				Name:   string(providerID), // Use provider ID as name for unknown providers
249				ID:     providerID,
250				Models: make([]provider.Model, len(providerConfig.Models)),
251			}
252
253			// Convert models
254			for i, model := range providerConfig.Models {
255				configProvider.Models[i] = provider.Model{
256					ID:                     model.ID,
257					Name:                   model.Name,
258					CostPer1MIn:            model.CostPer1MIn,
259					CostPer1MOut:           model.CostPer1MOut,
260					CostPer1MInCached:      model.CostPer1MInCached,
261					CostPer1MOutCached:     model.CostPer1MOutCached,
262					ContextWindow:          model.ContextWindow,
263					DefaultMaxTokens:       model.DefaultMaxTokens,
264					CanReason:              model.CanReason,
265					HasReasoningEffort:     model.HasReasoningEffort,
266					DefaultReasoningEffort: model.ReasoningEffort,
267					SupportsImages:         model.SupportsImages,
268				}
269			}
270
271			// Add this unknown provider to the list
272			name := configProvider.Name
273			if name == "" {
274				name = string(configProvider.ID)
275			}
276			modelItems = append(modelItems, commands.NewItemSection(name))
277			for _, model := range configProvider.Models {
278				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
279					Provider: configProvider,
280					Model:    model,
281				}))
282				if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
283					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
284				}
285			}
286			addedProviders[providerID] = true
287		}
288	}
289
290	// Then add the known providers from the predefined list
291	for _, provider := range providers {
292		// Skip if we already added this provider as an unknown provider
293		if addedProviders[provider.ID] {
294			continue
295		}
296
297		// Check if this provider is configured and not disabled
298		if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
299			continue
300		}
301
302		name := provider.Name
303		if name == "" {
304			name = string(provider.ID)
305		}
306		modelItems = append(modelItems, commands.NewItemSection(name))
307		for _, model := range provider.Models {
308			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
309				Provider: provider,
310				Model:    model,
311			}))
312			if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
313				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
314			}
315		}
316	}
317
318	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
319}