splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat"
 14	"github.com/charmbracelet/crush/internal/tui/components/completions"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 19	"github.com/charmbracelet/crush/internal/tui/components/logo"
 20	"github.com/charmbracelet/crush/internal/tui/styles"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/charmbracelet/lipgloss/v2"
 24)
 25
 26type Splash interface {
 27	util.Model
 28	layout.Sizeable
 29	layout.Help
 30	Cursor() *tea.Cursor
 31	// SetOnboarding controls whether the splash shows model selection UI
 32	SetOnboarding(bool)
 33	// SetProjectInit controls whether the splash shows project initialization prompt
 34	SetProjectInit(bool)
 35
 36	// Showing API key input
 37	IsShowingAPIKey() bool
 38}
 39
 40const (
 41	SplashScreenPaddingX = 2 // Padding X for the splash screen
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43)
 44
 45// OnboardingCompleteMsg is sent when onboarding is complete
 46type OnboardingCompleteMsg struct{}
 47
 48type splashCmp struct {
 49	width, height int
 50	keyMap        KeyMap
 51	logoRendered  string
 52
 53	// State
 54	isOnboarding     bool
 55	needsProjectInit bool
 56	needsAPIKey      bool
 57	selectedNo       bool
 58
 59	modelList     *models.ModelListComponent
 60	apiKeyInput   *models.APIKeyInput
 61	selectedModel *models.ModelOption
 62}
 63
 64func New() Splash {
 65	keyMap := DefaultKeyMap()
 66	listKeyMap := list.DefaultKeyMap()
 67	listKeyMap.Down.SetEnabled(false)
 68	listKeyMap.Up.SetEnabled(false)
 69	listKeyMap.HalfPageDown.SetEnabled(false)
 70	listKeyMap.HalfPageUp.SetEnabled(false)
 71	listKeyMap.Home.SetEnabled(false)
 72	listKeyMap.End.SetEnabled(false)
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 79	apiKeyInput := models.NewAPIKeyInput()
 80
 81	return &splashCmp{
 82		width:        0,
 83		height:       0,
 84		keyMap:       keyMap,
 85		logoRendered: "",
 86		modelList:    modelList,
 87		apiKeyInput:  apiKeyInput,
 88		selectedNo:   false,
 89	}
 90}
 91
 92func (s *splashCmp) SetOnboarding(onboarding bool) {
 93	s.isOnboarding = onboarding
 94	if onboarding {
 95		providers, err := config.Providers()
 96		if err != nil {
 97			return
 98		}
 99		filteredProviders := []provider.Provider{}
100		simpleProviders := []string{
101			"anthropic",
102			"openai",
103			"gemini",
104			"xai",
105			"openrouter",
106		}
107		for _, p := range providers {
108			if slices.Contains(simpleProviders, string(p.ID)) {
109				filteredProviders = append(filteredProviders, p)
110			}
111		}
112		s.modelList.SetProviders(filteredProviders)
113	}
114}
115
116func (s *splashCmp) SetProjectInit(needsInit bool) {
117	s.needsProjectInit = needsInit
118}
119
120// GetSize implements SplashPage.
121func (s *splashCmp) GetSize() (int, int) {
122	return s.width, s.height
123}
124
125// Init implements SplashPage.
126func (s *splashCmp) Init() tea.Cmd {
127	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
128}
129
130// SetSize implements SplashPage.
131func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
132	s.height = height
133	if width != s.width {
134		s.width = width
135		s.logoRendered = s.logoBlock()
136	}
137	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
138	listWidth := min(60, width-(SplashScreenPaddingX*2))
139
140	return s.modelList.SetSize(listWidth, listHeigh)
141}
142
143// Update implements SplashPage.
144func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
145	switch msg := msg.(type) {
146	case tea.WindowSizeMsg:
147		return s, s.SetSize(msg.Width, msg.Height)
148	case tea.KeyPressMsg:
149		switch {
150		case key.Matches(msg, s.keyMap.Back):
151			if s.needsAPIKey {
152				// Go back to model selection
153				s.needsAPIKey = false
154				s.selectedModel = nil
155				return s, nil
156			}
157		case key.Matches(msg, s.keyMap.Select):
158			if s.isOnboarding && !s.needsAPIKey {
159				modelInx := s.modelList.SelectedIndex()
160				items := s.modelList.Items()
161				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
162				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
163					cmd := s.setPreferredModel(selectedItem)
164					s.isOnboarding = false
165					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
166				} else {
167					// Provider not configured, show API key input
168					s.needsAPIKey = true
169					s.selectedModel = &selectedItem
170					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
171					return s, nil
172				}
173			} else if s.needsAPIKey {
174				// Handle API key submission
175				apiKey := s.apiKeyInput.Value()
176				if apiKey != "" {
177					return s, s.saveAPIKeyAndContinue(apiKey)
178				}
179			} else if s.needsProjectInit {
180				return s, s.initializeProject()
181			}
182		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
183			if s.needsProjectInit {
184				s.selectedNo = !s.selectedNo
185				return s, nil
186			}
187		case key.Matches(msg, s.keyMap.Yes):
188			if s.needsProjectInit {
189				return s, s.initializeProject()
190			}
191		case key.Matches(msg, s.keyMap.No):
192			if s.needsProjectInit {
193				s.needsProjectInit = false
194				return s, util.CmdHandler(OnboardingCompleteMsg{})
195			}
196		default:
197			if s.needsAPIKey {
198				u, cmd := s.apiKeyInput.Update(msg)
199				s.apiKeyInput = u.(*models.APIKeyInput)
200				return s, cmd
201			} else if s.isOnboarding {
202				u, cmd := s.modelList.Update(msg)
203				s.modelList = u
204				return s, cmd
205			}
206		}
207	case tea.PasteMsg:
208		if s.needsAPIKey {
209			u, cmd := s.apiKeyInput.Update(msg)
210			s.apiKeyInput = u.(*models.APIKeyInput)
211			return s, cmd
212		} else if s.isOnboarding {
213			var cmd tea.Cmd
214			s.modelList, cmd = s.modelList.Update(msg)
215			return s, cmd
216		}
217	}
218	return s, nil
219}
220
221func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
222	if s.selectedModel == nil {
223		return util.ReportError(fmt.Errorf("no model selected"))
224	}
225
226	cfg := config.Get()
227	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
228	if err != nil {
229		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
230	}
231
232	// Reset API key state and continue with model selection
233	s.needsAPIKey = false
234	cmd := s.setPreferredModel(*s.selectedModel)
235	s.isOnboarding = false
236	s.selectedModel = nil
237
238	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
239}
240
241func (s *splashCmp) initializeProject() tea.Cmd {
242	s.needsProjectInit = false
243	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2441. Build/lint/test commands - especially for running a single test
2452. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
246
247The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
248If there's already a CRUSH.md, improve it.
249If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
250Add the .crush directory to the .gitignore file if it's not already there.`
251
252	if err := config.MarkProjectInitialized(); err != nil {
253		return util.ReportError(err)
254	}
255	var cmds []tea.Cmd
256
257	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
258	if !s.selectedNo {
259		cmds = append(cmds,
260			util.CmdHandler(chat.SessionClearedMsg{}),
261			util.CmdHandler(chat.SendMsg{
262				Text: prompt,
263			}),
264		)
265	}
266	return tea.Sequence(cmds...)
267}
268
269func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
270	cfg := config.Get()
271	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
272	if model == nil {
273		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
274	}
275
276	selectedModel := config.SelectedModel{
277		Model:           selectedItem.Model.ID,
278		Provider:        string(selectedItem.Provider.ID),
279		ReasoningEffort: model.DefaultReasoningEffort,
280		MaxTokens:       model.DefaultMaxTokens,
281	}
282
283	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
284	if err != nil {
285		return util.ReportError(err)
286	}
287
288	// Now lets automatically setup the small model
289	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
290	if err != nil {
291		return util.ReportError(err)
292	}
293	if knownProvider == nil {
294		// for local provider we just use the same model
295		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
296		if err != nil {
297			return util.ReportError(err)
298		}
299	} else {
300		smallModel := knownProvider.DefaultSmallModelID
301		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
302		// should never happen
303		if model == nil {
304			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
305			if err != nil {
306				return util.ReportError(err)
307			}
308			return nil
309		}
310		smallSelectedModel := config.SelectedModel{
311			Model:           smallModel,
312			Provider:        string(selectedItem.Provider.ID),
313			ReasoningEffort: model.DefaultReasoningEffort,
314			MaxTokens:       model.DefaultMaxTokens,
315		}
316		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
317		if err != nil {
318			return util.ReportError(err)
319		}
320	}
321	return nil
322}
323
324func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
325	providers, err := config.Providers()
326	if err != nil {
327		return nil, err
328	}
329	for _, p := range providers {
330		if p.ID == providerID {
331			return &p, nil
332		}
333	}
334	return nil, nil
335}
336
337func (s *splashCmp) isProviderConfigured(providerID string) bool {
338	cfg := config.Get()
339	if _, ok := cfg.Providers[providerID]; ok {
340		return true
341	}
342	return false
343}
344
345func (s *splashCmp) View() string {
346	t := styles.CurrentTheme()
347	var content string
348	if s.needsAPIKey {
349		infoSection := s.infoSection()
350		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
351		apiKeyView := s.apiKeyInput.View()
352		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
353			lipgloss.JoinVertical(
354				lipgloss.Left,
355				apiKeyView,
356			),
357		)
358		content = lipgloss.JoinVertical(
359			lipgloss.Left,
360			s.logoRendered,
361			infoSection,
362			apiKeySelector,
363		)
364	} else if s.isOnboarding {
365		infoSection := s.infoSection()
366		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
367		modelListView := s.modelList.View()
368		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
369			lipgloss.JoinVertical(
370				lipgloss.Left,
371				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
372				"",
373				modelListView,
374			),
375		)
376		content = lipgloss.JoinVertical(
377			lipgloss.Left,
378			s.logoRendered,
379			infoSection,
380			modelSelector,
381		)
382	} else if s.needsProjectInit {
383		titleStyle := t.S().Base.Foreground(t.FgBase)
384		bodyStyle := t.S().Base.Foreground(t.FgMuted)
385		shortcutStyle := t.S().Base.Foreground(t.Success)
386
387		initText := lipgloss.JoinVertical(
388			lipgloss.Left,
389			titleStyle.Render("Would you like to initialize this project?"),
390			"",
391			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
392			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
393			"",
394			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
395			"",
396			bodyStyle.Render("Would you like to initialize now?"),
397		)
398
399		yesButton := core.SelectableButton(core.ButtonOpts{
400			Text:           "Yep!",
401			UnderlineIndex: 0,
402			Selected:       !s.selectedNo,
403		})
404
405		noButton := core.SelectableButton(core.ButtonOpts{
406			Text:           "Nope",
407			UnderlineIndex: 0,
408			Selected:       s.selectedNo,
409		})
410
411		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
412		infoSection := s.infoSection()
413
414		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
415
416		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
417			lipgloss.JoinVertical(
418				lipgloss.Left,
419				initText,
420				"",
421				buttons,
422			),
423		)
424
425		content = lipgloss.JoinVertical(
426			lipgloss.Left,
427			s.logoRendered,
428			infoSection,
429			initContent,
430		)
431	} else {
432		parts := []string{
433			s.logoRendered,
434			s.infoSection(),
435		}
436		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
437	}
438
439	return t.S().Base.
440		Width(s.width).
441		Height(s.height).
442		PaddingTop(SplashScreenPaddingY).
443		PaddingLeft(SplashScreenPaddingX).
444		PaddingRight(SplashScreenPaddingX).
445		PaddingBottom(SplashScreenPaddingY).
446		Render(content)
447}
448
449func (s *splashCmp) Cursor() *tea.Cursor {
450	if s.needsAPIKey {
451		cursor := s.apiKeyInput.Cursor()
452		if cursor != nil {
453			return s.moveCursor(cursor)
454		}
455	} else if s.isOnboarding {
456		cursor := s.modelList.Cursor()
457		if cursor != nil {
458			return s.moveCursor(cursor)
459		}
460	} else {
461		return nil
462	}
463	return nil
464}
465
466func (s *splashCmp) infoSection() string {
467	return lipgloss.JoinVertical(
468		lipgloss.Left,
469		s.cwd(),
470		"",
471		lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
472		"",
473	)
474}
475
476func (s *splashCmp) logoBlock() string {
477	t := styles.CurrentTheme()
478	const padding = 2
479	return logo.Render(version.Version, false, logo.Opts{
480		FieldColor:   t.Primary,
481		TitleColorA:  t.Secondary,
482		TitleColorB:  t.Primary,
483		CharmColor:   t.Secondary,
484		VersionColor: t.Primary,
485		Width:        s.width - (SplashScreenPaddingX * 2),
486	})
487}
488
489func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
490	if cursor == nil {
491		return nil
492	}
493
494	// Calculate the correct Y offset based on current state
495	logoHeight := lipgloss.Height(m.logoRendered)
496	infoSectionHeight := lipgloss.Height(m.infoSection())
497	baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
498	if m.needsAPIKey {
499		remainingHeight := m.height - baseOffset - lipgloss.Height(m.apiKeyInput.View()) - SplashScreenPaddingY
500		offset := baseOffset + remainingHeight
501		cursor.Y += offset
502		cursor.X = cursor.X + SplashScreenPaddingX
503	} else if m.isOnboarding {
504		listHeight := min(40, m.height-(SplashScreenPaddingY)-baseOffset)
505		offset := m.height - listHeight + 1 // +1 for the title
506		cursor.Y += offset
507		cursor.X = cursor.X + SplashScreenPaddingX + 1
508	}
509
510	return cursor
511}
512
513// Bindings implements SplashPage.
514func (s *splashCmp) Bindings() []key.Binding {
515	if s.needsAPIKey {
516		return []key.Binding{
517			s.keyMap.Select,
518			s.keyMap.Back,
519		}
520	} else if s.isOnboarding {
521		return []key.Binding{
522			s.keyMap.Select,
523			s.keyMap.Next,
524			s.keyMap.Previous,
525		}
526	} else if s.needsProjectInit {
527		return []key.Binding{
528			s.keyMap.Select,
529			s.keyMap.Yes,
530			s.keyMap.No,
531			s.keyMap.Tab,
532			s.keyMap.LeftRight,
533		}
534	}
535	return []key.Binding{}
536}
537
538func (s *splashCmp) getMaxInfoWidth() int {
539	return min(s.width-(SplashScreenPaddingX*2), 40)
540}
541
542func (s *splashCmp) cwd() string {
543	cwd := config.Get().WorkingDir()
544	t := styles.CurrentTheme()
545	homeDir, err := os.UserHomeDir()
546	if err == nil && cwd != homeDir {
547		cwd = strings.ReplaceAll(cwd, homeDir, "~")
548	}
549	maxWidth := s.getMaxInfoWidth()
550	return t.S().Muted.Width(maxWidth).Render(cwd)
551}
552
553func LSPList(maxWidth int) []string {
554	t := styles.CurrentTheme()
555	lspList := []string{}
556	lsp := config.Get().LSP.Sorted()
557	if len(lsp) == 0 {
558		return []string{t.S().Base.Foreground(t.Border).Render("None")}
559	}
560	for _, l := range lsp {
561		iconColor := t.Success
562		if l.LSP.Disabled {
563			iconColor = t.FgMuted
564		}
565		lspList = append(lspList,
566			core.Status(
567				core.StatusOpts{
568					IconColor:   iconColor,
569					Title:       l.Name,
570					Description: l.LSP.Command,
571				},
572				maxWidth,
573			),
574		)
575	}
576	return lspList
577}
578
579func (s *splashCmp) lspBlock() string {
580	t := styles.CurrentTheme()
581	maxWidth := s.getMaxInfoWidth() / 2
582	section := t.S().Subtle.Render("LSPs")
583	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
584	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
585		lipgloss.JoinVertical(
586			lipgloss.Left,
587			lspList...,
588		),
589	)
590}
591
592func MCPList(maxWidth int) []string {
593	t := styles.CurrentTheme()
594	mcpList := []string{}
595	mcps := config.Get().MCP.Sorted()
596	if len(mcps) == 0 {
597		return []string{t.S().Base.Foreground(t.Border).Render("None")}
598	}
599	for _, l := range mcps {
600		iconColor := t.Success
601		if l.MCP.Disabled {
602			iconColor = t.FgMuted
603		}
604		mcpList = append(mcpList,
605			core.Status(
606				core.StatusOpts{
607					IconColor:   iconColor,
608					Title:       l.Name,
609					Description: l.MCP.Command,
610				},
611				maxWidth,
612			),
613		)
614	}
615	return mcpList
616}
617
618func (s *splashCmp) mcpBlock() string {
619	t := styles.CurrentTheme()
620	maxWidth := s.getMaxInfoWidth() / 2
621	section := t.S().Subtle.Render("MCPs")
622	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
623	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
624		lipgloss.JoinVertical(
625			lipgloss.Left,
626			mcpList...,
627		),
628	)
629}
630
631func (s *splashCmp) IsShowingAPIKey() bool {
632	return s.needsAPIKey
633}