splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat"
 14	"github.com/charmbracelet/crush/internal/tui/components/completions"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 19	"github.com/charmbracelet/crush/internal/tui/components/logo"
 20	"github.com/charmbracelet/crush/internal/tui/styles"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/charmbracelet/lipgloss/v2"
 24)
 25
 26type Splash interface {
 27	util.Model
 28	layout.Sizeable
 29	layout.Help
 30	Cursor() *tea.Cursor
 31	// SetOnboarding controls whether the splash shows model selection UI
 32	SetOnboarding(bool)
 33	// SetProjectInit controls whether the splash shows project initialization prompt
 34	SetProjectInit(bool)
 35
 36	// Showing API key input
 37	IsShowingAPIKey() bool
 38}
 39
 40const (
 41	SplashScreenPaddingX = 2 // Padding X for the splash screen
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43)
 44
 45// OnboardingCompleteMsg is sent when onboarding is complete
 46type OnboardingCompleteMsg struct{}
 47
 48type splashCmp struct {
 49	width, height int
 50	keyMap        KeyMap
 51	logoRendered  string
 52
 53	// State
 54	isOnboarding     bool
 55	needsProjectInit bool
 56	needsAPIKey      bool
 57	selectedNo       bool
 58
 59	modelList     *models.ModelListComponent
 60	apiKeyInput   *models.APIKeyInput
 61	selectedModel *models.ModelOption
 62}
 63
 64func New() Splash {
 65	keyMap := DefaultKeyMap()
 66	listKeyMap := list.DefaultKeyMap()
 67	listKeyMap.Down.SetEnabled(false)
 68	listKeyMap.Up.SetEnabled(false)
 69	listKeyMap.HalfPageDown.SetEnabled(false)
 70	listKeyMap.HalfPageUp.SetEnabled(false)
 71	listKeyMap.Home.SetEnabled(false)
 72	listKeyMap.End.SetEnabled(false)
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 79	apiKeyInput := models.NewAPIKeyInput()
 80
 81	return &splashCmp{
 82		width:        0,
 83		height:       0,
 84		keyMap:       keyMap,
 85		logoRendered: "",
 86		modelList:    modelList,
 87		apiKeyInput:  apiKeyInput,
 88		selectedNo:   false,
 89	}
 90}
 91
 92func (s *splashCmp) SetOnboarding(onboarding bool) {
 93	s.isOnboarding = onboarding
 94	if onboarding {
 95		providers, err := config.Providers()
 96		if err != nil {
 97			return
 98		}
 99		filteredProviders := []provider.Provider{}
100		simpleProviders := []string{
101			"anthropic",
102			"openai",
103			"gemini",
104			"xai",
105			"openrouter",
106		}
107		for _, p := range providers {
108			if slices.Contains(simpleProviders, string(p.ID)) {
109				filteredProviders = append(filteredProviders, p)
110			}
111		}
112		s.modelList.SetProviders(filteredProviders)
113	}
114}
115
116func (s *splashCmp) SetProjectInit(needsInit bool) {
117	s.needsProjectInit = needsInit
118}
119
120// GetSize implements SplashPage.
121func (s *splashCmp) GetSize() (int, int) {
122	return s.width, s.height
123}
124
125// Init implements SplashPage.
126func (s *splashCmp) Init() tea.Cmd {
127	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
128}
129
130// SetSize implements SplashPage.
131func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
132	s.height = height
133	if width != s.width {
134		s.width = width
135		s.logoRendered = s.logoBlock()
136	}
137	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
138	listWidth := min(60, width-(SplashScreenPaddingX*2))
139
140	return s.modelList.SetSize(listWidth, listHeigh)
141}
142
143// Update implements SplashPage.
144func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
145	switch msg := msg.(type) {
146	case tea.WindowSizeMsg:
147		return s, s.SetSize(msg.Width, msg.Height)
148	case tea.KeyPressMsg:
149		switch {
150		case key.Matches(msg, s.keyMap.Back):
151			if s.needsAPIKey {
152				// Go back to model selection
153				s.needsAPIKey = false
154				s.selectedModel = nil
155				return s, nil
156			}
157		case key.Matches(msg, s.keyMap.Select):
158			if s.isOnboarding && !s.needsAPIKey {
159				modelInx := s.modelList.SelectedIndex()
160				items := s.modelList.Items()
161				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
162				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
163					cmd := s.setPreferredModel(selectedItem)
164					s.isOnboarding = false
165					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
166				} else {
167					// Provider not configured, show API key input
168					s.needsAPIKey = true
169					s.selectedModel = &selectedItem
170					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
171					return s, nil
172				}
173			} else if s.needsAPIKey {
174				// Handle API key submission
175				apiKey := s.apiKeyInput.Value()
176				if apiKey != "" {
177					return s, s.saveAPIKeyAndContinue(apiKey)
178				}
179			} else if s.needsProjectInit {
180				return s, s.initializeProject()
181			}
182		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
183			if s.needsProjectInit {
184				s.selectedNo = !s.selectedNo
185				return s, nil
186			}
187		case key.Matches(msg, s.keyMap.Yes):
188			if s.needsProjectInit {
189				return s, s.initializeProject()
190			}
191		case key.Matches(msg, s.keyMap.No):
192			if s.needsProjectInit {
193				s.needsProjectInit = false
194				return s, util.CmdHandler(OnboardingCompleteMsg{})
195			}
196		default:
197			if s.needsAPIKey {
198				u, cmd := s.apiKeyInput.Update(msg)
199				s.apiKeyInput = u.(*models.APIKeyInput)
200				return s, cmd
201			} else if s.isOnboarding {
202				u, cmd := s.modelList.Update(msg)
203				s.modelList = u
204				return s, cmd
205			}
206		}
207	case tea.PasteMsg:
208		if s.needsAPIKey {
209			u, cmd := s.apiKeyInput.Update(msg)
210			s.apiKeyInput = u.(*models.APIKeyInput)
211			return s, cmd
212		} else if s.isOnboarding {
213			var cmd tea.Cmd
214			s.modelList, cmd = s.modelList.Update(msg)
215			return s, cmd
216		}
217	}
218	return s, nil
219}
220
221func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
222	if s.selectedModel == nil {
223		return util.ReportError(fmt.Errorf("no model selected"))
224	}
225
226	cfg := config.Get()
227	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
228	if err != nil {
229		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
230	}
231
232	// Reset API key state and continue with model selection
233	s.needsAPIKey = false
234	cmd := s.setPreferredModel(*s.selectedModel)
235	s.isOnboarding = false
236	s.selectedModel = nil
237
238	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
239}
240
241func (s *splashCmp) initializeProject() tea.Cmd {
242	s.needsProjectInit = false
243	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2441. Build/lint/test commands - especially for running a single test
2452. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
246
247The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
248If there's already a CRUSH.md, improve it.
249If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
250Add the .crush directory to the .gitignore file if it's not already there.`
251
252	if err := config.MarkProjectInitialized(); err != nil {
253		return util.ReportError(err)
254	}
255	var cmds []tea.Cmd
256
257	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
258	if !s.selectedNo {
259		cmds = append(cmds,
260			util.CmdHandler(chat.SessionClearedMsg{}),
261			util.CmdHandler(chat.SendMsg{
262				Text: prompt,
263			}),
264		)
265	}
266	return tea.Sequence(cmds...)
267}
268
269func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
270	cfg := config.Get()
271	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
272	if model == nil {
273		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
274	}
275
276	selectedModel := config.SelectedModel{
277		Model:           selectedItem.Model.ID,
278		Provider:        string(selectedItem.Provider.ID),
279		ReasoningEffort: model.DefaultReasoningEffort,
280		MaxTokens:       model.DefaultMaxTokens,
281	}
282
283	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
284	if err != nil {
285		return util.ReportError(err)
286	}
287
288	// Now lets automatically setup the small model
289	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
290	if err != nil {
291		return util.ReportError(err)
292	}
293	if knownProvider == nil {
294		// for local provider we just use the same model
295		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
296		if err != nil {
297			return util.ReportError(err)
298		}
299	} else {
300		smallModel := knownProvider.DefaultSmallModelID
301		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
302		// should never happen
303		if model == nil {
304			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
305			if err != nil {
306				return util.ReportError(err)
307			}
308			return nil
309		}
310		smallSelectedModel := config.SelectedModel{
311			Model:           smallModel,
312			Provider:        string(selectedItem.Provider.ID),
313			ReasoningEffort: model.DefaultReasoningEffort,
314			MaxTokens:       model.DefaultMaxTokens,
315		}
316		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
317		if err != nil {
318			return util.ReportError(err)
319		}
320	}
321	return nil
322}
323
324func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
325	providers, err := config.Providers()
326	if err != nil {
327		return nil, err
328	}
329	for _, p := range providers {
330		if p.ID == providerID {
331			return &p, nil
332		}
333	}
334	return nil, nil
335}
336
337func (s *splashCmp) isProviderConfigured(providerID string) bool {
338	cfg := config.Get()
339	if _, ok := cfg.Providers[providerID]; ok {
340		return true
341	}
342	return false
343}
344
345func (s *splashCmp) View() string {
346	t := styles.CurrentTheme()
347	var content string
348	if s.needsAPIKey {
349		infoSection := s.infoSection()
350		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
351		apiKeyView := s.apiKeyInput.View()
352		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
353			lipgloss.JoinVertical(
354				lipgloss.Left,
355				apiKeyView,
356			),
357		)
358		content = lipgloss.JoinVertical(
359			lipgloss.Left,
360			s.logoRendered,
361			infoSection,
362			apiKeySelector,
363		)
364	} else if s.isOnboarding {
365		infoSection := s.infoSection()
366		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
367		modelListView := s.modelList.View()
368		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
369			lipgloss.JoinVertical(
370				lipgloss.Left,
371				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
372				"",
373				modelListView,
374			),
375		)
376		content = lipgloss.JoinVertical(
377			lipgloss.Left,
378			s.logoRendered,
379			infoSection,
380			modelSelector,
381		)
382	} else if s.needsProjectInit {
383		titleStyle := t.S().Base.Foreground(t.FgBase)
384		bodyStyle := t.S().Base.Foreground(t.FgMuted)
385		shortcutStyle := t.S().Base.Foreground(t.Success)
386
387		initText := lipgloss.JoinVertical(
388			lipgloss.Left,
389			titleStyle.Render("Would you like to initialize this project?"),
390			"",
391			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
392			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
393			"",
394			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
395			"",
396			bodyStyle.Render("Would you like to initialize now?"),
397		)
398
399		yesButton := core.SelectableButton(core.ButtonOpts{
400			Text:           "Yep!",
401			UnderlineIndex: 0,
402			Selected:       !s.selectedNo,
403		})
404
405		noButton := core.SelectableButton(core.ButtonOpts{
406			Text:           "Nope",
407			UnderlineIndex: 0,
408			Selected:       s.selectedNo,
409		})
410
411		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
412		infoSection := s.infoSection()
413
414		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
415
416		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
417			lipgloss.JoinVertical(
418				lipgloss.Left,
419				initText,
420				"",
421				buttons,
422			),
423		)
424
425		content = lipgloss.JoinVertical(
426			lipgloss.Left,
427			s.logoRendered,
428			infoSection,
429			initContent,
430		)
431	} else {
432		parts := []string{
433			s.logoRendered,
434			s.infoSection(),
435		}
436		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
437	}
438
439	return t.S().Base.
440		Width(s.width).
441		Height(s.height).
442		PaddingTop(SplashScreenPaddingY).
443		PaddingLeft(SplashScreenPaddingX).
444		PaddingRight(SplashScreenPaddingX).
445		PaddingBottom(SplashScreenPaddingY).
446		Render(content)
447}
448
449func (s *splashCmp) Cursor() *tea.Cursor {
450	if s.needsAPIKey {
451		cursor := s.apiKeyInput.Cursor()
452		if cursor != nil {
453			return s.moveCursor(cursor)
454		}
455	} else if s.isOnboarding {
456		cursor := s.modelList.Cursor()
457		if cursor != nil {
458			return s.moveCursor(cursor)
459		}
460	} else {
461		return nil
462	}
463	return nil
464}
465
466func (s *splashCmp) infoSection() string {
467	return lipgloss.JoinVertical(
468		lipgloss.Left,
469		s.cwd(),
470		"",
471		lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
472		"",
473	)
474}
475
476func (s *splashCmp) logoBlock() string {
477	t := styles.CurrentTheme()
478	const padding = 2
479	return logo.Render(version.Version, false, logo.Opts{
480		FieldColor:   t.Primary,
481		TitleColorA:  t.Secondary,
482		TitleColorB:  t.Primary,
483		CharmColor:   t.Secondary,
484		VersionColor: t.Primary,
485		Width:        s.width - (SplashScreenPaddingX * 2),
486	})
487}
488
489func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
490	if cursor == nil {
491		return nil
492	}
493
494	// Calculate the correct Y offset based on current state
495	logoHeight := lipgloss.Height(m.logoRendered)
496	infoSectionHeight := lipgloss.Height(m.infoSection())
497	baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
498	if m.needsAPIKey {
499		// For API key input, position at the bottom of the remaining space
500		remainingHeight := m.height - logoHeight - (SplashScreenPaddingY * 2)
501		offset := baseOffset + remainingHeight - lipgloss.Height(m.apiKeyInput.View())
502		cursor.Y += offset
503		// API key input already includes prompt in its cursor positioning
504		cursor.X = cursor.X + SplashScreenPaddingX
505	} else if m.isOnboarding {
506		// For model list, use the original calculation
507		listHeight := min(40, m.height-(SplashScreenPaddingY*2)-logoHeight-1-infoSectionHeight)
508		offset := m.height - listHeight
509		cursor.Y += offset
510		// Model list doesn't have a prompt, so add padding + space for list styling
511		cursor.X = cursor.X + SplashScreenPaddingX + 1
512	}
513
514	return cursor
515}
516
517// Bindings implements SplashPage.
518func (s *splashCmp) Bindings() []key.Binding {
519	if s.needsAPIKey {
520		return []key.Binding{
521			s.keyMap.Select,
522			s.keyMap.Back,
523		}
524	} else if s.isOnboarding {
525		return []key.Binding{
526			s.keyMap.Select,
527			s.keyMap.Next,
528			s.keyMap.Previous,
529		}
530	} else if s.needsProjectInit {
531		return []key.Binding{
532			s.keyMap.Select,
533			s.keyMap.Yes,
534			s.keyMap.No,
535			s.keyMap.Tab,
536			s.keyMap.LeftRight,
537		}
538	}
539	return []key.Binding{}
540}
541
542func (s *splashCmp) getMaxInfoWidth() int {
543	return min(s.width-(SplashScreenPaddingX*2), 40)
544}
545
546func (s *splashCmp) cwd() string {
547	cwd := config.Get().WorkingDir()
548	t := styles.CurrentTheme()
549	homeDir, err := os.UserHomeDir()
550	if err == nil && cwd != homeDir {
551		cwd = strings.ReplaceAll(cwd, homeDir, "~")
552	}
553	maxWidth := s.getMaxInfoWidth()
554	return t.S().Muted.Width(maxWidth).Render(cwd)
555}
556
557func LSPList(maxWidth int) []string {
558	t := styles.CurrentTheme()
559	lspList := []string{}
560	lsp := config.Get().LSP.Sorted()
561	if len(lsp) == 0 {
562		return []string{t.S().Base.Foreground(t.Border).Render("None")}
563	}
564	for _, l := range lsp {
565		iconColor := t.Success
566		if l.LSP.Disabled {
567			iconColor = t.FgMuted
568		}
569		lspList = append(lspList,
570			core.Status(
571				core.StatusOpts{
572					IconColor:   iconColor,
573					Title:       l.Name,
574					Description: l.LSP.Command,
575				},
576				maxWidth,
577			),
578		)
579	}
580	return lspList
581}
582
583func (s *splashCmp) lspBlock() string {
584	t := styles.CurrentTheme()
585	maxWidth := s.getMaxInfoWidth() / 2
586	section := t.S().Subtle.Render("LSPs")
587	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
588	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
589		lipgloss.JoinVertical(
590			lipgloss.Left,
591			lspList...,
592		),
593	)
594}
595
596func MCPList(maxWidth int) []string {
597	t := styles.CurrentTheme()
598	mcpList := []string{}
599	mcps := config.Get().MCP.Sorted()
600	if len(mcps) == 0 {
601		return []string{t.S().Base.Foreground(t.Border).Render("None")}
602	}
603	for _, l := range mcps {
604		iconColor := t.Success
605		if l.MCP.Disabled {
606			iconColor = t.FgMuted
607		}
608		mcpList = append(mcpList,
609			core.Status(
610				core.StatusOpts{
611					IconColor:   iconColor,
612					Title:       l.Name,
613					Description: l.MCP.Command,
614				},
615				maxWidth,
616			),
617		)
618	}
619	return mcpList
620}
621
622func (s *splashCmp) mcpBlock() string {
623	t := styles.CurrentTheme()
624	maxWidth := s.getMaxInfoWidth() / 2
625	section := t.S().Subtle.Render("MCPs")
626	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
627	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
628		lipgloss.JoinVertical(
629			lipgloss.Left,
630			mcpList...,
631		),
632	)
633}
634
635func (s *splashCmp) IsShowingAPIKey() bool {
636	return s.needsAPIKey
637}