splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat"
 14	"github.com/charmbracelet/crush/internal/tui/components/completions"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 19	"github.com/charmbracelet/crush/internal/tui/components/logo"
 20	"github.com/charmbracelet/crush/internal/tui/styles"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/charmbracelet/lipgloss/v2"
 24)
 25
 26type Splash interface {
 27	util.Model
 28	layout.Sizeable
 29	layout.Help
 30	Cursor() *tea.Cursor
 31	// SetOnboarding controls whether the splash shows model selection UI
 32	SetOnboarding(bool)
 33	// SetProjectInit controls whether the splash shows project initialization prompt
 34	SetProjectInit(bool)
 35}
 36
 37const (
 38	SplashScreenPaddingX = 2 // Padding X for the splash screen
 39	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 40)
 41
 42// OnboardingCompleteMsg is sent when onboarding is complete
 43type OnboardingCompleteMsg struct{}
 44
 45type splashCmp struct {
 46	width, height int
 47	keyMap        KeyMap
 48	logoRendered  string
 49
 50	// State
 51	isOnboarding     bool
 52	needsProjectInit bool
 53	needsAPIKey      bool
 54	selectedNo       bool
 55
 56	modelList     *models.ModelListComponent
 57	apiKeyInput   *models.APIKeyInput
 58	selectedModel *models.ModelOption
 59}
 60
 61func New() Splash {
 62	keyMap := DefaultKeyMap()
 63	listKeyMap := list.DefaultKeyMap()
 64	listKeyMap.Down.SetEnabled(false)
 65	listKeyMap.Up.SetEnabled(false)
 66	listKeyMap.HalfPageDown.SetEnabled(false)
 67	listKeyMap.HalfPageUp.SetEnabled(false)
 68	listKeyMap.Home.SetEnabled(false)
 69	listKeyMap.End.SetEnabled(false)
 70	listKeyMap.DownOneItem = keyMap.Next
 71	listKeyMap.UpOneItem = keyMap.Previous
 72
 73	t := styles.CurrentTheme()
 74	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 75	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 76	apiKeyInput := models.NewAPIKeyInput()
 77
 78	return &splashCmp{
 79		width:        0,
 80		height:       0,
 81		keyMap:       keyMap,
 82		logoRendered: "",
 83		modelList:    modelList,
 84		apiKeyInput:  apiKeyInput,
 85		selectedNo:   false,
 86	}
 87}
 88
 89func (s *splashCmp) SetOnboarding(onboarding bool) {
 90	s.isOnboarding = onboarding
 91	if onboarding {
 92		providers, err := config.Providers()
 93		if err != nil {
 94			return
 95		}
 96		filteredProviders := []provider.Provider{}
 97		simpleProviders := []string{
 98			"anthropic",
 99			"openai",
100			"gemini",
101			"xai",
102			"openrouter",
103		}
104		for _, p := range providers {
105			if slices.Contains(simpleProviders, string(p.ID)) {
106				filteredProviders = append(filteredProviders, p)
107			}
108		}
109		s.modelList.SetProviders(filteredProviders)
110	}
111}
112
113func (s *splashCmp) SetProjectInit(needsInit bool) {
114	s.needsProjectInit = needsInit
115}
116
117// GetSize implements SplashPage.
118func (s *splashCmp) GetSize() (int, int) {
119	return s.width, s.height
120}
121
122// Init implements SplashPage.
123func (s *splashCmp) Init() tea.Cmd {
124	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
125}
126
127// SetSize implements SplashPage.
128func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
129	s.width = width
130	s.height = height
131	s.logoRendered = s.logoBlock()
132	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
133	listWidth := min(60, width-(SplashScreenPaddingX*2))
134
135	return s.modelList.SetSize(listWidth, listHeigh)
136}
137
138// Update implements SplashPage.
139func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
140	switch msg := msg.(type) {
141	case tea.WindowSizeMsg:
142		return s, s.SetSize(msg.Width, msg.Height)
143	case tea.KeyPressMsg:
144		switch {
145		case key.Matches(msg, s.keyMap.Back):
146			if s.needsAPIKey {
147				// Go back to model selection
148				s.needsAPIKey = false
149				s.selectedModel = nil
150				return s, nil
151			}
152		case key.Matches(msg, s.keyMap.Select):
153			if s.isOnboarding && !s.needsAPIKey {
154				modelInx := s.modelList.SelectedIndex()
155				items := s.modelList.Items()
156				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
157				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
158					cmd := s.setPreferredModel(selectedItem)
159					s.isOnboarding = false
160					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
161				} else {
162					// Provider not configured, show API key input
163					s.needsAPIKey = true
164					s.selectedModel = &selectedItem
165					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
166					return s, nil
167				}
168			} else if s.needsAPIKey {
169				// Handle API key submission
170				apiKey := s.apiKeyInput.Value()
171				if apiKey != "" {
172					return s, s.saveAPIKeyAndContinue(apiKey)
173				}
174			} else if s.needsProjectInit {
175				return s, s.initializeProject()
176			}
177		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
178			if s.needsProjectInit {
179				s.selectedNo = !s.selectedNo
180				return s, nil
181			}
182		case key.Matches(msg, s.keyMap.Yes):
183			if s.needsProjectInit {
184				return s, s.initializeProject()
185			}
186		case key.Matches(msg, s.keyMap.No):
187			if s.needsProjectInit {
188				s.needsProjectInit = false
189				return s, util.CmdHandler(OnboardingCompleteMsg{})
190			}
191		default:
192			if s.needsAPIKey {
193				u, cmd := s.apiKeyInput.Update(msg)
194				s.apiKeyInput = u.(*models.APIKeyInput)
195				return s, cmd
196			} else if s.isOnboarding {
197				u, cmd := s.modelList.Update(msg)
198				s.modelList = u
199				return s, cmd
200			}
201		}
202	case tea.PasteMsg:
203		if s.needsAPIKey {
204			u, cmd := s.apiKeyInput.Update(msg)
205			s.apiKeyInput = u.(*models.APIKeyInput)
206			return s, cmd
207		} else if s.isOnboarding {
208			var cmd tea.Cmd
209			s.modelList, cmd = s.modelList.Update(msg)
210			return s, cmd
211		}
212	}
213	return s, nil
214}
215
216func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
217	if s.selectedModel == nil {
218		return util.ReportError(fmt.Errorf("no model selected"))
219	}
220
221	cfg := config.Get()
222	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
223	if err != nil {
224		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
225	}
226
227	// Reset API key state and continue with model selection
228	s.needsAPIKey = false
229	cmd := s.setPreferredModel(*s.selectedModel)
230	s.isOnboarding = false
231	s.selectedModel = nil
232
233	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
234}
235
236func (s *splashCmp) initializeProject() tea.Cmd {
237	s.needsProjectInit = false
238	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2391. Build/lint/test commands - especially for running a single test
2402. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
241
242The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
243If there's already a CRUSH.md, improve it.
244If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
245Add the .crush directory to the .gitignore file if it's not already there.`
246
247	if err := config.MarkProjectInitialized(); err != nil {
248		return util.ReportError(err)
249	}
250	var cmds []tea.Cmd
251
252	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253	if !s.selectedNo {
254		cmds = append(cmds,
255			util.CmdHandler(chat.SessionClearedMsg{}),
256			util.CmdHandler(chat.SendMsg{
257				Text: prompt,
258			}),
259		)
260	}
261	return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265	cfg := config.Get()
266	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267	if model == nil {
268		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269	}
270
271	selectedModel := config.SelectedModel{
272		Model:           selectedItem.Model.ID,
273		Provider:        string(selectedItem.Provider.ID),
274		ReasoningEffort: model.DefaultReasoningEffort,
275		MaxTokens:       model.DefaultMaxTokens,
276	}
277
278	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279	if err != nil {
280		return util.ReportError(err)
281	}
282
283	// Now lets automatically setup the small model
284	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285	if err != nil {
286		return util.ReportError(err)
287	}
288	if knownProvider == nil {
289		// for local provider we just use the same model
290		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291		if err != nil {
292			return util.ReportError(err)
293		}
294	} else {
295		smallModel := knownProvider.DefaultSmallModelID
296		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297		// should never happen
298		if model == nil {
299			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300			if err != nil {
301				return util.ReportError(err)
302			}
303			return nil
304		}
305		smallSelectedModel := config.SelectedModel{
306			Model:           smallModel,
307			Provider:        string(selectedItem.Provider.ID),
308			ReasoningEffort: model.DefaultReasoningEffort,
309			MaxTokens:       model.DefaultMaxTokens,
310		}
311		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312		if err != nil {
313			return util.ReportError(err)
314		}
315	}
316	return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320	providers, err := config.Providers()
321	if err != nil {
322		return nil, err
323	}
324	for _, p := range providers {
325		if p.ID == providerID {
326			return &p, nil
327		}
328	}
329	return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333	cfg := config.Get()
334	if _, ok := cfg.Providers[providerID]; ok {
335		return true
336	}
337	return false
338}
339
340func (s *splashCmp) View() string {
341	t := styles.CurrentTheme()
342	var content string
343	if s.needsAPIKey {
344		infoSection := s.infoSection()
345		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
346		apiKeyView := s.apiKeyInput.View()
347		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
348			lipgloss.JoinVertical(
349				lipgloss.Left,
350				apiKeyView,
351			),
352		)
353		content = lipgloss.JoinVertical(
354			lipgloss.Left,
355			s.logoRendered,
356			infoSection,
357			apiKeySelector,
358		)
359	} else if s.isOnboarding {
360		infoSection := s.infoSection()
361		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
362		modelListView := s.modelList.View()
363		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
364			lipgloss.JoinVertical(
365				lipgloss.Left,
366				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
367				"",
368				modelListView,
369			),
370		)
371		content = lipgloss.JoinVertical(
372			lipgloss.Left,
373			s.logoRendered,
374			infoSection,
375			modelSelector,
376		)
377	} else if s.needsProjectInit {
378		titleStyle := t.S().Base.Foreground(t.FgBase)
379		bodyStyle := t.S().Base.Foreground(t.FgMuted)
380		shortcutStyle := t.S().Base.Foreground(t.Success)
381
382		initText := lipgloss.JoinVertical(
383			lipgloss.Left,
384			titleStyle.Render("Would you like to initialize this project?"),
385			"",
386			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
387			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
388			"",
389			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
390			"",
391			bodyStyle.Render("Would you like to initialize now?"),
392		)
393
394		yesButton := core.SelectableButton(core.ButtonOpts{
395			Text:           "Yep!",
396			UnderlineIndex: 0,
397			Selected:       !s.selectedNo,
398		})
399
400		noButton := core.SelectableButton(core.ButtonOpts{
401			Text:           "Nope",
402			UnderlineIndex: 0,
403			Selected:       s.selectedNo,
404		})
405
406		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
407		infoSection := s.infoSection()
408
409		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
410
411		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
412			lipgloss.JoinVertical(
413				lipgloss.Left,
414				initText,
415				"",
416				buttons,
417			),
418		)
419
420		content = lipgloss.JoinVertical(
421			lipgloss.Left,
422			s.logoRendered,
423			infoSection,
424			initContent,
425		)
426	} else {
427		parts := []string{
428			s.logoRendered,
429			s.infoSection(),
430		}
431		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
432	}
433
434	return t.S().Base.
435		Width(s.width).
436		Height(s.height).
437		PaddingTop(SplashScreenPaddingY).
438		PaddingLeft(SplashScreenPaddingX).
439		PaddingRight(SplashScreenPaddingX).
440		PaddingBottom(SplashScreenPaddingY).
441		Render(content)
442}
443
444func (s *splashCmp) Cursor() *tea.Cursor {
445	if s.needsAPIKey {
446		cursor := s.apiKeyInput.Cursor()
447		if cursor != nil {
448			return s.moveCursor(cursor)
449		}
450	} else if s.isOnboarding {
451		cursor := s.modelList.Cursor()
452		if cursor != nil {
453			return s.moveCursor(cursor)
454		}
455	} else {
456		return nil
457	}
458	return nil
459}
460
461func (s *splashCmp) infoSection() string {
462	return lipgloss.JoinVertical(
463		lipgloss.Left,
464		s.cwd(),
465		"",
466		lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
467		"",
468	)
469}
470
471func (s *splashCmp) logoBlock() string {
472	t := styles.CurrentTheme()
473	const padding = 2
474	return logo.Render(version.Version, false, logo.Opts{
475		FieldColor:   t.Primary,
476		TitleColorA:  t.Secondary,
477		TitleColorB:  t.Primary,
478		CharmColor:   t.Secondary,
479		VersionColor: t.Primary,
480		Width:        s.width - (SplashScreenPaddingX * 2),
481	})
482}
483
484func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
485	if cursor == nil {
486		return nil
487	}
488
489	// Calculate the correct Y offset based on current state
490	logoHeight := lipgloss.Height(m.logoRendered)
491	infoSectionHeight := lipgloss.Height(m.infoSection())
492	baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
493	if m.needsAPIKey {
494		// For API key input, position at the bottom of the remaining space
495		remainingHeight := m.height - logoHeight - (SplashScreenPaddingY * 2)
496		offset := baseOffset + remainingHeight - lipgloss.Height(m.apiKeyInput.View())
497		cursor.Y += offset
498		// API key input already includes prompt in its cursor positioning
499		cursor.X = cursor.X + SplashScreenPaddingX
500	} else if m.isOnboarding {
501		// For model list, use the original calculation
502		listHeight := min(40, m.height-(SplashScreenPaddingY*2)-logoHeight-1-infoSectionHeight)
503		offset := m.height - listHeight
504		cursor.Y += offset
505		// Model list doesn't have a prompt, so add padding + space for list styling
506		cursor.X = cursor.X + SplashScreenPaddingX + 1
507	}
508
509	return cursor
510}
511
512// Bindings implements SplashPage.
513func (s *splashCmp) Bindings() []key.Binding {
514	if s.needsAPIKey {
515		return []key.Binding{
516			s.keyMap.Select,
517			s.keyMap.Back,
518		}
519	} else if s.isOnboarding {
520		return []key.Binding{
521			s.keyMap.Select,
522			s.keyMap.Next,
523			s.keyMap.Previous,
524		}
525	} else if s.needsProjectInit {
526		return []key.Binding{
527			s.keyMap.Select,
528			s.keyMap.Yes,
529			s.keyMap.No,
530			s.keyMap.Tab,
531			s.keyMap.LeftRight,
532		}
533	}
534	return []key.Binding{}
535}
536
537func (s *splashCmp) getMaxInfoWidth() int {
538	return min(s.width-(SplashScreenPaddingX*2), 40)
539}
540
541func (s *splashCmp) cwd() string {
542	cwd := config.Get().WorkingDir()
543	t := styles.CurrentTheme()
544	homeDir, err := os.UserHomeDir()
545	if err == nil && cwd != homeDir {
546		cwd = strings.ReplaceAll(cwd, homeDir, "~")
547	}
548	maxWidth := s.getMaxInfoWidth()
549	return t.S().Muted.Width(maxWidth).Render(cwd)
550}
551
552func LSPList(maxWidth int) []string {
553	t := styles.CurrentTheme()
554	lspList := []string{}
555	lsp := config.Get().LSP.Sorted()
556	if len(lsp) == 0 {
557		return []string{t.S().Base.Foreground(t.Border).Render("None")}
558	}
559	for _, l := range lsp {
560		iconColor := t.Success
561		if l.LSP.Disabled {
562			iconColor = t.FgMuted
563		}
564		lspList = append(lspList,
565			core.Status(
566				core.StatusOpts{
567					IconColor:   iconColor,
568					Title:       l.Name,
569					Description: l.LSP.Command,
570				},
571				maxWidth,
572			),
573		)
574	}
575	return lspList
576}
577
578func (s *splashCmp) lspBlock() string {
579	t := styles.CurrentTheme()
580	maxWidth := s.getMaxInfoWidth() / 2
581	section := t.S().Subtle.Render("LSPs")
582	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
583	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
584		lipgloss.JoinVertical(
585			lipgloss.Left,
586			lspList...,
587		),
588	)
589}
590
591func MCPList(maxWidth int) []string {
592	t := styles.CurrentTheme()
593	mcpList := []string{}
594	mcps := config.Get().MCP.Sorted()
595	if len(mcps) == 0 {
596		return []string{t.S().Base.Foreground(t.Border).Render("None")}
597	}
598	for _, l := range mcps {
599		iconColor := t.Success
600		if l.MCP.Disabled {
601			iconColor = t.FgMuted
602		}
603		mcpList = append(mcpList,
604			core.Status(
605				core.StatusOpts{
606					IconColor:   iconColor,
607					Title:       l.Name,
608					Description: l.MCP.Command,
609				},
610				maxWidth,
611			),
612		)
613	}
614	return mcpList
615}
616
617func (s *splashCmp) mcpBlock() string {
618	t := styles.CurrentTheme()
619	maxWidth := s.getMaxInfoWidth() / 2
620	section := t.S().Subtle.Render("MCPs")
621	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
622	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
623		lipgloss.JoinVertical(
624			lipgloss.Left,
625			mcpList...,
626		),
627	)
628}