splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat"
 14	"github.com/charmbracelet/crush/internal/tui/components/completions"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 19	"github.com/charmbracelet/crush/internal/tui/components/logo"
 20	"github.com/charmbracelet/crush/internal/tui/styles"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/charmbracelet/lipgloss/v2"
 24)
 25
 26type Splash interface {
 27	util.Model
 28	layout.Sizeable
 29	layout.Help
 30	Cursor() *tea.Cursor
 31	// SetOnboarding controls whether the splash shows model selection UI
 32	SetOnboarding(bool)
 33	// SetProjectInit controls whether the splash shows project initialization prompt
 34	SetProjectInit(bool)
 35
 36	// Showing API key input
 37	IsShowingAPIKey() bool
 38}
 39
 40const (
 41	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 42
 43	LogoGap = 6
 44)
 45
 46// OnboardingCompleteMsg is sent when onboarding is complete
 47type OnboardingCompleteMsg struct{}
 48
 49type splashCmp struct {
 50	width, height int
 51	keyMap        KeyMap
 52	logoRendered  string
 53
 54	// State
 55	isOnboarding     bool
 56	needsProjectInit bool
 57	needsAPIKey      bool
 58	selectedNo       bool
 59
 60	listHeight    int
 61	modelList     *models.ModelListComponent
 62	apiKeyInput   *models.APIKeyInput
 63	selectedModel *models.ModelOption
 64}
 65
 66func New() Splash {
 67	keyMap := DefaultKeyMap()
 68	listKeyMap := list.DefaultKeyMap()
 69	listKeyMap.Down.SetEnabled(false)
 70	listKeyMap.Up.SetEnabled(false)
 71	listKeyMap.HalfPageDown.SetEnabled(false)
 72	listKeyMap.HalfPageUp.SetEnabled(false)
 73	listKeyMap.Home.SetEnabled(false)
 74	listKeyMap.End.SetEnabled(false)
 75	listKeyMap.DownOneItem = keyMap.Next
 76	listKeyMap.UpOneItem = keyMap.Previous
 77
 78	t := styles.CurrentTheme()
 79	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 80	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 81	apiKeyInput := models.NewAPIKeyInput()
 82
 83	return &splashCmp{
 84		width:        0,
 85		height:       0,
 86		keyMap:       keyMap,
 87		logoRendered: "",
 88		modelList:    modelList,
 89		apiKeyInput:  apiKeyInput,
 90		selectedNo:   false,
 91	}
 92}
 93
 94func (s *splashCmp) SetOnboarding(onboarding bool) {
 95	s.isOnboarding = onboarding
 96	if onboarding {
 97		providers, err := config.Providers()
 98		if err != nil {
 99			return
100		}
101		filteredProviders := []provider.Provider{}
102		simpleProviders := []string{
103			"anthropic",
104			"openai",
105			"gemini",
106			"xai",
107			"openrouter",
108		}
109		for _, p := range providers {
110			if slices.Contains(simpleProviders, string(p.ID)) {
111				filteredProviders = append(filteredProviders, p)
112			}
113		}
114		s.modelList.SetProviders(filteredProviders)
115	}
116}
117
118func (s *splashCmp) SetProjectInit(needsInit bool) {
119	s.needsProjectInit = needsInit
120}
121
122// GetSize implements SplashPage.
123func (s *splashCmp) GetSize() (int, int) {
124	return s.width, s.height
125}
126
127// Init implements SplashPage.
128func (s *splashCmp) Init() tea.Cmd {
129	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
130}
131
132// SetSize implements SplashPage.
133func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
134	s.height = height
135	if width != s.width {
136		s.width = width
137		s.logoRendered = s.logoBlock()
138	}
139	// remove padding, logo height, gap, title space
140	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
141	listWidth := min(60, width)
142	return s.modelList.SetSize(listWidth, s.listHeight)
143}
144
145// Update implements SplashPage.
146func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
147	switch msg := msg.(type) {
148	case tea.WindowSizeMsg:
149		return s, s.SetSize(msg.Width, msg.Height)
150	case tea.KeyPressMsg:
151		switch {
152		case key.Matches(msg, s.keyMap.Back):
153			if s.needsAPIKey {
154				// Go back to model selection
155				s.needsAPIKey = false
156				s.selectedModel = nil
157				return s, nil
158			}
159		case key.Matches(msg, s.keyMap.Select):
160			if s.isOnboarding && !s.needsAPIKey {
161				modelInx := s.modelList.SelectedIndex()
162				items := s.modelList.Items()
163				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
164				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
165					cmd := s.setPreferredModel(selectedItem)
166					s.isOnboarding = false
167					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
168				} else {
169					// Provider not configured, show API key input
170					s.needsAPIKey = true
171					s.selectedModel = &selectedItem
172					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
173					return s, nil
174				}
175			} else if s.needsAPIKey {
176				// Handle API key submission
177				apiKey := s.apiKeyInput.Value()
178				if apiKey != "" {
179					return s, s.saveAPIKeyAndContinue(apiKey)
180				}
181			} else if s.needsProjectInit {
182				return s, s.initializeProject()
183			}
184		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
185			if s.needsProjectInit {
186				s.selectedNo = !s.selectedNo
187				return s, nil
188			}
189		case key.Matches(msg, s.keyMap.Yes):
190			if s.needsProjectInit {
191				return s, s.initializeProject()
192			}
193		case key.Matches(msg, s.keyMap.No):
194			if s.needsProjectInit {
195				s.needsProjectInit = false
196				return s, util.CmdHandler(OnboardingCompleteMsg{})
197			}
198		default:
199			if s.needsAPIKey {
200				u, cmd := s.apiKeyInput.Update(msg)
201				s.apiKeyInput = u.(*models.APIKeyInput)
202				return s, cmd
203			} else if s.isOnboarding {
204				u, cmd := s.modelList.Update(msg)
205				s.modelList = u
206				return s, cmd
207			}
208		}
209	case tea.PasteMsg:
210		if s.needsAPIKey {
211			u, cmd := s.apiKeyInput.Update(msg)
212			s.apiKeyInput = u.(*models.APIKeyInput)
213			return s, cmd
214		} else if s.isOnboarding {
215			var cmd tea.Cmd
216			s.modelList, cmd = s.modelList.Update(msg)
217			return s, cmd
218		}
219	}
220	return s, nil
221}
222
223func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
224	if s.selectedModel == nil {
225		return util.ReportError(fmt.Errorf("no model selected"))
226	}
227
228	cfg := config.Get()
229	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
230	if err != nil {
231		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
232	}
233
234	// Reset API key state and continue with model selection
235	s.needsAPIKey = false
236	cmd := s.setPreferredModel(*s.selectedModel)
237	s.isOnboarding = false
238	s.selectedModel = nil
239
240	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
241}
242
243func (s *splashCmp) initializeProject() tea.Cmd {
244	s.needsProjectInit = false
245	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2461. Build/lint/test commands - especially for running a single test
2472. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
248
249The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
250If there's already a CRUSH.md, improve it.
251If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
252Add the .crush directory to the .gitignore file if it's not already there.`
253
254	if err := config.MarkProjectInitialized(); err != nil {
255		return util.ReportError(err)
256	}
257	var cmds []tea.Cmd
258
259	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
260	if !s.selectedNo {
261		cmds = append(cmds,
262			util.CmdHandler(chat.SessionClearedMsg{}),
263			util.CmdHandler(chat.SendMsg{
264				Text: prompt,
265			}),
266		)
267	}
268	return tea.Sequence(cmds...)
269}
270
271func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
272	cfg := config.Get()
273	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
274	if model == nil {
275		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
276	}
277
278	selectedModel := config.SelectedModel{
279		Model:           selectedItem.Model.ID,
280		Provider:        string(selectedItem.Provider.ID),
281		ReasoningEffort: model.DefaultReasoningEffort,
282		MaxTokens:       model.DefaultMaxTokens,
283	}
284
285	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
286	if err != nil {
287		return util.ReportError(err)
288	}
289
290	// Now lets automatically setup the small model
291	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
292	if err != nil {
293		return util.ReportError(err)
294	}
295	if knownProvider == nil {
296		// for local provider we just use the same model
297		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
298		if err != nil {
299			return util.ReportError(err)
300		}
301	} else {
302		smallModel := knownProvider.DefaultSmallModelID
303		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
304		// should never happen
305		if model == nil {
306			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
307			if err != nil {
308				return util.ReportError(err)
309			}
310			return nil
311		}
312		smallSelectedModel := config.SelectedModel{
313			Model:           smallModel,
314			Provider:        string(selectedItem.Provider.ID),
315			ReasoningEffort: model.DefaultReasoningEffort,
316			MaxTokens:       model.DefaultMaxTokens,
317		}
318		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
319		if err != nil {
320			return util.ReportError(err)
321		}
322	}
323	return nil
324}
325
326func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
327	providers, err := config.Providers()
328	if err != nil {
329		return nil, err
330	}
331	for _, p := range providers {
332		if p.ID == providerID {
333			return &p, nil
334		}
335	}
336	return nil, nil
337}
338
339func (s *splashCmp) isProviderConfigured(providerID string) bool {
340	cfg := config.Get()
341	if _, ok := cfg.Providers[providerID]; ok {
342		return true
343	}
344	return false
345}
346
347func (s *splashCmp) View() string {
348	t := styles.CurrentTheme()
349	var content string
350	if s.needsAPIKey {
351		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
352		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
353		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
354			lipgloss.JoinVertical(
355				lipgloss.Left,
356				apiKeyView,
357			),
358		)
359		content = lipgloss.JoinVertical(
360			lipgloss.Left,
361			s.logoRendered,
362			apiKeySelector,
363		)
364	} else if s.isOnboarding {
365		modelListView := s.modelList.View()
366		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
367		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
368			lipgloss.JoinVertical(
369				lipgloss.Left,
370				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
371				"",
372				modelListView,
373			),
374		)
375		content = lipgloss.JoinVertical(
376			lipgloss.Left,
377			s.logoRendered,
378			modelSelector,
379		)
380	} else if s.needsProjectInit {
381		titleStyle := t.S().Base.Foreground(t.FgBase)
382		bodyStyle := t.S().Base.Foreground(t.FgMuted)
383		shortcutStyle := t.S().Base.Foreground(t.Success)
384
385		initText := lipgloss.JoinVertical(
386			lipgloss.Left,
387			titleStyle.Render("Would you like to initialize this project?"),
388			"",
389			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
390			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
391			"",
392			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
393			"",
394			bodyStyle.Render("Would you like to initialize now?"),
395		)
396
397		yesButton := core.SelectableButton(core.ButtonOpts{
398			Text:           "Yep!",
399			UnderlineIndex: 0,
400			Selected:       !s.selectedNo,
401		})
402
403		noButton := core.SelectableButton(core.ButtonOpts{
404			Text:           "Nope",
405			UnderlineIndex: 0,
406			Selected:       s.selectedNo,
407		})
408
409		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
410		infoSection := s.infoSection()
411
412		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
413
414		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
415			lipgloss.JoinVertical(
416				lipgloss.Left,
417				initText,
418				"",
419				buttons,
420			),
421		)
422
423		content = lipgloss.JoinVertical(
424			lipgloss.Left,
425			s.logoRendered,
426			infoSection,
427			initContent,
428		)
429	} else {
430		parts := []string{
431			s.logoRendered,
432			s.infoSection(),
433		}
434		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
435	}
436
437	return t.S().Base.
438		Width(s.width).
439		Height(s.height).
440		PaddingTop(SplashScreenPaddingY).
441		PaddingBottom(SplashScreenPaddingY).
442		Render(content)
443}
444
445func (s *splashCmp) Cursor() *tea.Cursor {
446	if s.needsAPIKey {
447		cursor := s.apiKeyInput.Cursor()
448		if cursor != nil {
449			return s.moveCursor(cursor)
450		}
451	} else if s.isOnboarding {
452		cursor := s.modelList.Cursor()
453		if cursor != nil {
454			return s.moveCursor(cursor)
455		}
456	} else {
457		return nil
458	}
459	return nil
460}
461
462func (s *splashCmp) infoSection() string {
463	return lipgloss.JoinVertical(
464		lipgloss.Left,
465		s.cwd(),
466		"",
467		lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
468		"",
469	)
470}
471
472func (s *splashCmp) logoBlock() string {
473	t := styles.CurrentTheme()
474	return t.S().Base.Padding(0, 2).Width(s.width).Render(
475		logo.Render(version.Version, false, logo.Opts{
476			FieldColor:   t.Primary,
477			TitleColorA:  t.Secondary,
478			TitleColorB:  t.Primary,
479			CharmColor:   t.Secondary,
480			VersionColor: t.Primary,
481			Width:        s.width - 4,
482		}),
483	)
484}
485
486func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
487	if cursor == nil {
488		return nil
489	}
490	// Calculate the correct Y offset based on current state
491	logoHeight := lipgloss.Height(s.logoRendered)
492	if s.needsAPIKey {
493		infoSectionHeight := lipgloss.Height(s.infoSection())
494		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
495		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
496		offset := baseOffset + remainingHeight
497		cursor.Y += offset
498		cursor.X = cursor.X + 1
499	} else if s.isOnboarding {
500		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
501		cursor.Y += offset
502		cursor.X = cursor.X + 1
503	}
504
505	return cursor
506}
507
508func (s *splashCmp) logoGap() int {
509	if s.height > 35 {
510		return LogoGap
511	}
512	return 0
513}
514
515// Bindings implements SplashPage.
516func (s *splashCmp) Bindings() []key.Binding {
517	if s.needsAPIKey {
518		return []key.Binding{
519			s.keyMap.Select,
520			s.keyMap.Back,
521		}
522	} else if s.isOnboarding {
523		return []key.Binding{
524			s.keyMap.Select,
525			s.keyMap.Next,
526			s.keyMap.Previous,
527		}
528	} else if s.needsProjectInit {
529		return []key.Binding{
530			s.keyMap.Select,
531			s.keyMap.Yes,
532			s.keyMap.No,
533			s.keyMap.Tab,
534			s.keyMap.LeftRight,
535		}
536	}
537	return []key.Binding{}
538}
539
540func (s *splashCmp) getMaxInfoWidth() int {
541	return min(s.width, 40)
542}
543
544func (s *splashCmp) cwd() string {
545	cwd := config.Get().WorkingDir()
546	t := styles.CurrentTheme()
547	homeDir, err := os.UserHomeDir()
548	if err == nil && cwd != homeDir {
549		cwd = strings.ReplaceAll(cwd, homeDir, "~")
550	}
551	maxWidth := s.getMaxInfoWidth()
552	return t.S().Muted.Width(maxWidth).Render(cwd)
553}
554
555func LSPList(maxWidth int) []string {
556	t := styles.CurrentTheme()
557	lspList := []string{}
558	lsp := config.Get().LSP.Sorted()
559	if len(lsp) == 0 {
560		return []string{t.S().Base.Foreground(t.Border).Render("None")}
561	}
562	for _, l := range lsp {
563		iconColor := t.Success
564		if l.LSP.Disabled {
565			iconColor = t.FgMuted
566		}
567		lspList = append(lspList,
568			core.Status(
569				core.StatusOpts{
570					IconColor:   iconColor,
571					Title:       l.Name,
572					Description: l.LSP.Command,
573				},
574				maxWidth,
575			),
576		)
577	}
578	return lspList
579}
580
581func (s *splashCmp) lspBlock() string {
582	t := styles.CurrentTheme()
583	maxWidth := s.getMaxInfoWidth() / 2
584	section := t.S().Subtle.Render("LSPs")
585	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
586	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
587		lipgloss.JoinVertical(
588			lipgloss.Left,
589			lspList...,
590		),
591	)
592}
593
594func MCPList(maxWidth int) []string {
595	t := styles.CurrentTheme()
596	mcpList := []string{}
597	mcps := config.Get().MCP.Sorted()
598	if len(mcps) == 0 {
599		return []string{t.S().Base.Foreground(t.Border).Render("None")}
600	}
601	for _, l := range mcps {
602		iconColor := t.Success
603		if l.MCP.Disabled {
604			iconColor = t.FgMuted
605		}
606		mcpList = append(mcpList,
607			core.Status(
608				core.StatusOpts{
609					IconColor:   iconColor,
610					Title:       l.Name,
611					Description: l.MCP.Command,
612				},
613				maxWidth,
614			),
615		)
616	}
617	return mcpList
618}
619
620func (s *splashCmp) mcpBlock() string {
621	t := styles.CurrentTheme()
622	maxWidth := s.getMaxInfoWidth() / 2
623	section := t.S().Subtle.Render("MCPs")
624	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
625	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
626		lipgloss.JoinVertical(
627			lipgloss.Left,
628			mcpList...,
629		),
630	)
631}
632
633func (s *splashCmp) IsShowingAPIKey() bool {
634	return s.needsAPIKey
635}