splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"openrouter",
109		}
110		for _, p := range providers {
111			if slices.Contains(simpleProviders, string(p.ID)) {
112				filteredProviders = append(filteredProviders, p)
113			}
114		}
115		s.modelList.SetProviders(filteredProviders)
116	}
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120	s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125	return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135	s.height = height
136	if width != s.width {
137		s.width = width
138		s.logoRendered = s.logoBlock()
139	}
140	// remove padding, logo height, gap, title space
141	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142	listWidth := min(60, width)
143	return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148	switch msg := msg.(type) {
149	case tea.WindowSizeMsg:
150		return s, s.SetSize(msg.Width, msg.Height)
151	case tea.KeyPressMsg:
152		switch {
153		case key.Matches(msg, s.keyMap.Back):
154			if s.needsAPIKey {
155				// Go back to model selection
156				s.needsAPIKey = false
157				s.selectedModel = nil
158				return s, nil
159			}
160		case key.Matches(msg, s.keyMap.Select):
161			if s.isOnboarding && !s.needsAPIKey {
162				modelInx := s.modelList.SelectedIndex()
163				items := s.modelList.Items()
164				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166					cmd := s.setPreferredModel(selectedItem)
167					s.isOnboarding = false
168					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169				} else {
170					// Provider not configured, show API key input
171					s.needsAPIKey = true
172					s.selectedModel = &selectedItem
173					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174					return s, nil
175				}
176			} else if s.needsAPIKey {
177				// Handle API key submission
178				apiKey := s.apiKeyInput.Value()
179				if apiKey != "" {
180					return s, s.saveAPIKeyAndContinue(apiKey)
181				}
182			} else if s.needsProjectInit {
183				return s, s.initializeProject()
184			}
185		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186			if s.needsProjectInit {
187				s.selectedNo = !s.selectedNo
188				return s, nil
189			}
190		case key.Matches(msg, s.keyMap.Yes):
191			if s.needsProjectInit {
192				return s, s.initializeProject()
193			}
194		case key.Matches(msg, s.keyMap.No):
195			if s.needsProjectInit {
196				s.needsProjectInit = false
197				return s, util.CmdHandler(OnboardingCompleteMsg{})
198			}
199		default:
200			if s.needsAPIKey {
201				u, cmd := s.apiKeyInput.Update(msg)
202				s.apiKeyInput = u.(*models.APIKeyInput)
203				return s, cmd
204			} else if s.isOnboarding {
205				u, cmd := s.modelList.Update(msg)
206				s.modelList = u
207				return s, cmd
208			}
209		}
210	case tea.PasteMsg:
211		if s.needsAPIKey {
212			u, cmd := s.apiKeyInput.Update(msg)
213			s.apiKeyInput = u.(*models.APIKeyInput)
214			return s, cmd
215		} else if s.isOnboarding {
216			var cmd tea.Cmd
217			s.modelList, cmd = s.modelList.Update(msg)
218			return s, cmd
219		}
220	}
221	return s, nil
222}
223
224func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
225	if s.selectedModel == nil {
226		return util.ReportError(fmt.Errorf("no model selected"))
227	}
228
229	cfg := config.Get()
230	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
231	if err != nil {
232		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
233	}
234
235	// Reset API key state and continue with model selection
236	s.needsAPIKey = false
237	cmd := s.setPreferredModel(*s.selectedModel)
238	s.isOnboarding = false
239	s.selectedModel = nil
240
241	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
242}
243
244func (s *splashCmp) initializeProject() tea.Cmd {
245	s.needsProjectInit = false
246
247	if err := config.MarkProjectInitialized(); err != nil {
248		return util.ReportError(err)
249	}
250	var cmds []tea.Cmd
251
252	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253	if !s.selectedNo {
254		cmds = append(cmds,
255			util.CmdHandler(chat.SessionClearedMsg{}),
256			util.CmdHandler(chat.SendMsg{
257				Text: prompt.Initialize(),
258			}),
259		)
260	}
261	return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265	cfg := config.Get()
266	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267	if model == nil {
268		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269	}
270
271	selectedModel := config.SelectedModel{
272		Model:           selectedItem.Model.ID,
273		Provider:        string(selectedItem.Provider.ID),
274		ReasoningEffort: model.DefaultReasoningEffort,
275		MaxTokens:       model.DefaultMaxTokens,
276	}
277
278	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279	if err != nil {
280		return util.ReportError(err)
281	}
282
283	// Now lets automatically setup the small model
284	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285	if err != nil {
286		return util.ReportError(err)
287	}
288	if knownProvider == nil {
289		// for local provider we just use the same model
290		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291		if err != nil {
292			return util.ReportError(err)
293		}
294	} else {
295		smallModel := knownProvider.DefaultSmallModelID
296		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297		// should never happen
298		if model == nil {
299			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300			if err != nil {
301				return util.ReportError(err)
302			}
303			return nil
304		}
305		smallSelectedModel := config.SelectedModel{
306			Model:           smallModel,
307			Provider:        string(selectedItem.Provider.ID),
308			ReasoningEffort: model.DefaultReasoningEffort,
309			MaxTokens:       model.DefaultMaxTokens,
310		}
311		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312		if err != nil {
313			return util.ReportError(err)
314		}
315	}
316	cfg.SetupAgents()
317	return nil
318}
319
320func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
321	providers, err := config.Providers()
322	if err != nil {
323		return nil, err
324	}
325	for _, p := range providers {
326		if p.ID == providerID {
327			return &p, nil
328		}
329	}
330	return nil, nil
331}
332
333func (s *splashCmp) isProviderConfigured(providerID string) bool {
334	cfg := config.Get()
335	if _, ok := cfg.Providers[providerID]; ok {
336		return true
337	}
338	return false
339}
340
341func (s *splashCmp) View() string {
342	t := styles.CurrentTheme()
343	var content string
344	if s.needsAPIKey {
345		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
346		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
347		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
348			lipgloss.JoinVertical(
349				lipgloss.Left,
350				apiKeyView,
351			),
352		)
353		content = lipgloss.JoinVertical(
354			lipgloss.Left,
355			s.logoRendered,
356			apiKeySelector,
357		)
358	} else if s.isOnboarding {
359		modelListView := s.modelList.View()
360		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
361		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
362			lipgloss.JoinVertical(
363				lipgloss.Left,
364				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
365				"",
366				modelListView,
367			),
368		)
369		content = lipgloss.JoinVertical(
370			lipgloss.Left,
371			s.logoRendered,
372			modelSelector,
373		)
374	} else if s.needsProjectInit {
375		titleStyle := t.S().Base.Foreground(t.FgBase)
376		bodyStyle := t.S().Base.Foreground(t.FgMuted)
377		shortcutStyle := t.S().Base.Foreground(t.Success)
378
379		initText := lipgloss.JoinVertical(
380			lipgloss.Left,
381			titleStyle.Render("Would you like to initialize this project?"),
382			"",
383			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
384			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
385			"",
386			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
387			"",
388			bodyStyle.Render("Would you like to initialize now?"),
389		)
390
391		yesButton := core.SelectableButton(core.ButtonOpts{
392			Text:           "Yep!",
393			UnderlineIndex: 0,
394			Selected:       !s.selectedNo,
395		})
396
397		noButton := core.SelectableButton(core.ButtonOpts{
398			Text:           "Nope",
399			UnderlineIndex: 0,
400			Selected:       s.selectedNo,
401		})
402
403		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
404		infoSection := s.infoSection()
405
406		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
407
408		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
409			lipgloss.JoinVertical(
410				lipgloss.Left,
411				initText,
412				"",
413				buttons,
414			),
415		)
416
417		content = lipgloss.JoinVertical(
418			lipgloss.Left,
419			s.logoRendered,
420			infoSection,
421			initContent,
422		)
423	} else {
424		parts := []string{
425			s.logoRendered,
426			s.infoSection(),
427		}
428		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
429	}
430
431	return t.S().Base.
432		Width(s.width).
433		Height(s.height).
434		PaddingTop(SplashScreenPaddingY).
435		PaddingBottom(SplashScreenPaddingY).
436		Render(content)
437}
438
439func (s *splashCmp) Cursor() *tea.Cursor {
440	if s.needsAPIKey {
441		cursor := s.apiKeyInput.Cursor()
442		if cursor != nil {
443			return s.moveCursor(cursor)
444		}
445	} else if s.isOnboarding {
446		cursor := s.modelList.Cursor()
447		if cursor != nil {
448			return s.moveCursor(cursor)
449		}
450	} else {
451		return nil
452	}
453	return nil
454}
455
456func (s *splashCmp) infoSection() string {
457	t := styles.CurrentTheme()
458	return t.S().Base.PaddingLeft(2).Render(
459		lipgloss.JoinVertical(
460			lipgloss.Left,
461			s.cwd(),
462			"",
463			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
464			"",
465		),
466	)
467}
468
469func (s *splashCmp) logoBlock() string {
470	t := styles.CurrentTheme()
471	return t.S().Base.Padding(0, 2).Width(s.width).Render(
472		logo.Render(version.Version, false, logo.Opts{
473			FieldColor:   t.Primary,
474			TitleColorA:  t.Secondary,
475			TitleColorB:  t.Primary,
476			CharmColor:   t.Secondary,
477			VersionColor: t.Primary,
478			Width:        s.width - 4,
479		}),
480	)
481}
482
483func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
484	if cursor == nil {
485		return nil
486	}
487	// Calculate the correct Y offset based on current state
488	logoHeight := lipgloss.Height(s.logoRendered)
489	if s.needsAPIKey {
490		infoSectionHeight := lipgloss.Height(s.infoSection())
491		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
492		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
493		offset := baseOffset + remainingHeight
494		cursor.Y += offset
495		cursor.X = cursor.X + 1
496	} else if s.isOnboarding {
497		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
498		cursor.Y += offset
499		cursor.X = cursor.X + 1
500	}
501
502	return cursor
503}
504
505func (s *splashCmp) logoGap() int {
506	if s.height > 35 {
507		return LogoGap
508	}
509	return 0
510}
511
512// Bindings implements SplashPage.
513func (s *splashCmp) Bindings() []key.Binding {
514	if s.needsAPIKey {
515		return []key.Binding{
516			s.keyMap.Select,
517			s.keyMap.Back,
518		}
519	} else if s.isOnboarding {
520		return []key.Binding{
521			s.keyMap.Select,
522			s.keyMap.Next,
523			s.keyMap.Previous,
524		}
525	} else if s.needsProjectInit {
526		return []key.Binding{
527			s.keyMap.Select,
528			s.keyMap.Yes,
529			s.keyMap.No,
530			s.keyMap.Tab,
531			s.keyMap.LeftRight,
532		}
533	}
534	return []key.Binding{}
535}
536
537func (s *splashCmp) getMaxInfoWidth() int {
538	return min(s.width-2, 40) // 2 for left padding
539}
540
541func (s *splashCmp) cwd() string {
542	cwd := config.Get().WorkingDir()
543	t := styles.CurrentTheme()
544	homeDir, err := os.UserHomeDir()
545	if err == nil && cwd != homeDir {
546		cwd = strings.ReplaceAll(cwd, homeDir, "~")
547	}
548	maxWidth := s.getMaxInfoWidth()
549	return t.S().Muted.Width(maxWidth).Render(cwd)
550}
551
552func LSPList(maxWidth int) []string {
553	t := styles.CurrentTheme()
554	lspList := []string{}
555	lsp := config.Get().LSP.Sorted()
556	if len(lsp) == 0 {
557		return []string{t.S().Base.Foreground(t.Border).Render("None")}
558	}
559	for _, l := range lsp {
560		iconColor := t.Success
561		if l.LSP.Disabled {
562			iconColor = t.FgMuted
563		}
564		lspList = append(lspList,
565			core.Status(
566				core.StatusOpts{
567					IconColor:   iconColor,
568					Title:       l.Name,
569					Description: l.LSP.Command,
570				},
571				maxWidth,
572			),
573		)
574	}
575	return lspList
576}
577
578func (s *splashCmp) lspBlock() string {
579	t := styles.CurrentTheme()
580	maxWidth := s.getMaxInfoWidth() / 2
581	section := t.S().Subtle.Render("LSPs")
582	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
583	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
584		lipgloss.JoinVertical(
585			lipgloss.Left,
586			lspList...,
587		),
588	)
589}
590
591func MCPList(maxWidth int) []string {
592	t := styles.CurrentTheme()
593	mcpList := []string{}
594	mcps := config.Get().MCP.Sorted()
595	if len(mcps) == 0 {
596		return []string{t.S().Base.Foreground(t.Border).Render("None")}
597	}
598	for _, l := range mcps {
599		iconColor := t.Success
600		if l.MCP.Disabled {
601			iconColor = t.FgMuted
602		}
603		mcpList = append(mcpList,
604			core.Status(
605				core.StatusOpts{
606					IconColor:   iconColor,
607					Title:       l.Name,
608					Description: l.MCP.Command,
609				},
610				maxWidth,
611			),
612		)
613	}
614	return mcpList
615}
616
617func (s *splashCmp) mcpBlock() string {
618	t := styles.CurrentTheme()
619	maxWidth := s.getMaxInfoWidth() / 2
620	section := t.S().Subtle.Render("MCPs")
621	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
622	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
623		lipgloss.JoinVertical(
624			lipgloss.Left,
625			mcpList...,
626		),
627	)
628}
629
630func (s *splashCmp) IsShowingAPIKey() bool {
631	return s.needsAPIKey
632}