splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat"
 14	"github.com/charmbracelet/crush/internal/tui/components/completions"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 19	"github.com/charmbracelet/crush/internal/tui/components/logo"
 20	"github.com/charmbracelet/crush/internal/tui/styles"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/charmbracelet/lipgloss/v2"
 24)
 25
 26type Splash interface {
 27	util.Model
 28	layout.Sizeable
 29	layout.Help
 30	Cursor() *tea.Cursor
 31	// SetOnboarding controls whether the splash shows model selection UI
 32	SetOnboarding(bool)
 33	// SetProjectInit controls whether the splash shows project initialization prompt
 34	SetProjectInit(bool)
 35
 36	// Showing API key input
 37	IsShowingAPIKey() bool
 38}
 39
 40const (
 41	SplashScreenPaddingX = 2 // Padding X for the splash screen
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43)
 44
 45// OnboardingCompleteMsg is sent when onboarding is complete
 46type OnboardingCompleteMsg struct{}
 47
 48type splashCmp struct {
 49	width, height int
 50	keyMap        KeyMap
 51	logoRendered  string
 52
 53	// State
 54	isOnboarding     bool
 55	needsProjectInit bool
 56	needsAPIKey      bool
 57	selectedNo       bool
 58
 59	modelList     *models.ModelListComponent
 60	apiKeyInput   *models.APIKeyInput
 61	selectedModel *models.ModelOption
 62}
 63
 64func New() Splash {
 65	keyMap := DefaultKeyMap()
 66	listKeyMap := list.DefaultKeyMap()
 67	listKeyMap.Down.SetEnabled(false)
 68	listKeyMap.Up.SetEnabled(false)
 69	listKeyMap.HalfPageDown.SetEnabled(false)
 70	listKeyMap.HalfPageUp.SetEnabled(false)
 71	listKeyMap.Home.SetEnabled(false)
 72	listKeyMap.End.SetEnabled(false)
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 79	apiKeyInput := models.NewAPIKeyInput()
 80
 81	return &splashCmp{
 82		width:        0,
 83		height:       0,
 84		keyMap:       keyMap,
 85		logoRendered: "",
 86		modelList:    modelList,
 87		apiKeyInput:  apiKeyInput,
 88		selectedNo:   false,
 89	}
 90}
 91
 92func (s *splashCmp) SetOnboarding(onboarding bool) {
 93	s.isOnboarding = onboarding
 94	if onboarding {
 95		providers, err := config.Providers()
 96		if err != nil {
 97			return
 98		}
 99		filteredProviders := []provider.Provider{}
100		simpleProviders := []string{
101			"anthropic",
102			"openai",
103			"gemini",
104			"xai",
105			"openrouter",
106		}
107		for _, p := range providers {
108			if slices.Contains(simpleProviders, string(p.ID)) {
109				filteredProviders = append(filteredProviders, p)
110			}
111		}
112		s.modelList.SetProviders(filteredProviders)
113	}
114}
115
116func (s *splashCmp) SetProjectInit(needsInit bool) {
117	s.needsProjectInit = needsInit
118}
119
120// GetSize implements SplashPage.
121func (s *splashCmp) GetSize() (int, int) {
122	return s.width, s.height
123}
124
125// Init implements SplashPage.
126func (s *splashCmp) Init() tea.Cmd {
127	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
128}
129
130// SetSize implements SplashPage.
131func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
132	s.height = height
133	if width != s.width {
134		s.width = width
135		s.logoRendered = s.logoBlock()
136	}
137	infoSectionHeight := lipgloss.Height(s.infoSection())
138	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2-infoSectionHeight)
139	listWidth := min(60, width-(SplashScreenPaddingX*2))
140
141	return s.modelList.SetSize(listWidth, listHeigh)
142}
143
144// Update implements SplashPage.
145func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
146	switch msg := msg.(type) {
147	case tea.WindowSizeMsg:
148		return s, s.SetSize(msg.Width, msg.Height)
149	case tea.KeyPressMsg:
150		switch {
151		case key.Matches(msg, s.keyMap.Back):
152			if s.needsAPIKey {
153				// Go back to model selection
154				s.needsAPIKey = false
155				s.selectedModel = nil
156				return s, nil
157			}
158		case key.Matches(msg, s.keyMap.Select):
159			if s.isOnboarding && !s.needsAPIKey {
160				modelInx := s.modelList.SelectedIndex()
161				items := s.modelList.Items()
162				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
163				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
164					cmd := s.setPreferredModel(selectedItem)
165					s.isOnboarding = false
166					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
167				} else {
168					// Provider not configured, show API key input
169					s.needsAPIKey = true
170					s.selectedModel = &selectedItem
171					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
172					return s, nil
173				}
174			} else if s.needsAPIKey {
175				// Handle API key submission
176				apiKey := s.apiKeyInput.Value()
177				if apiKey != "" {
178					return s, s.saveAPIKeyAndContinue(apiKey)
179				}
180			} else if s.needsProjectInit {
181				return s, s.initializeProject()
182			}
183		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
184			if s.needsProjectInit {
185				s.selectedNo = !s.selectedNo
186				return s, nil
187			}
188		case key.Matches(msg, s.keyMap.Yes):
189			if s.needsProjectInit {
190				return s, s.initializeProject()
191			}
192		case key.Matches(msg, s.keyMap.No):
193			if s.needsProjectInit {
194				s.needsProjectInit = false
195				return s, util.CmdHandler(OnboardingCompleteMsg{})
196			}
197		default:
198			if s.needsAPIKey {
199				u, cmd := s.apiKeyInput.Update(msg)
200				s.apiKeyInput = u.(*models.APIKeyInput)
201				return s, cmd
202			} else if s.isOnboarding {
203				u, cmd := s.modelList.Update(msg)
204				s.modelList = u
205				return s, cmd
206			}
207		}
208	case tea.PasteMsg:
209		if s.needsAPIKey {
210			u, cmd := s.apiKeyInput.Update(msg)
211			s.apiKeyInput = u.(*models.APIKeyInput)
212			return s, cmd
213		} else if s.isOnboarding {
214			var cmd tea.Cmd
215			s.modelList, cmd = s.modelList.Update(msg)
216			return s, cmd
217		}
218	}
219	return s, nil
220}
221
222func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
223	if s.selectedModel == nil {
224		return util.ReportError(fmt.Errorf("no model selected"))
225	}
226
227	cfg := config.Get()
228	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
229	if err != nil {
230		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
231	}
232
233	// Reset API key state and continue with model selection
234	s.needsAPIKey = false
235	cmd := s.setPreferredModel(*s.selectedModel)
236	s.isOnboarding = false
237	s.selectedModel = nil
238
239	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
240}
241
242func (s *splashCmp) initializeProject() tea.Cmd {
243	s.needsProjectInit = false
244	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2451. Build/lint/test commands - especially for running a single test
2462. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
247
248The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
249If there's already a CRUSH.md, improve it.
250If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
251Add the .crush directory to the .gitignore file if it's not already there.`
252
253	if err := config.MarkProjectInitialized(); err != nil {
254		return util.ReportError(err)
255	}
256	var cmds []tea.Cmd
257
258	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
259	if !s.selectedNo {
260		cmds = append(cmds,
261			util.CmdHandler(chat.SessionClearedMsg{}),
262			util.CmdHandler(chat.SendMsg{
263				Text: prompt,
264			}),
265		)
266	}
267	return tea.Sequence(cmds...)
268}
269
270func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
271	cfg := config.Get()
272	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
273	if model == nil {
274		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
275	}
276
277	selectedModel := config.SelectedModel{
278		Model:           selectedItem.Model.ID,
279		Provider:        string(selectedItem.Provider.ID),
280		ReasoningEffort: model.DefaultReasoningEffort,
281		MaxTokens:       model.DefaultMaxTokens,
282	}
283
284	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
285	if err != nil {
286		return util.ReportError(err)
287	}
288
289	// Now lets automatically setup the small model
290	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
291	if err != nil {
292		return util.ReportError(err)
293	}
294	if knownProvider == nil {
295		// for local provider we just use the same model
296		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
297		if err != nil {
298			return util.ReportError(err)
299		}
300	} else {
301		smallModel := knownProvider.DefaultSmallModelID
302		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
303		// should never happen
304		if model == nil {
305			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
306			if err != nil {
307				return util.ReportError(err)
308			}
309			return nil
310		}
311		smallSelectedModel := config.SelectedModel{
312			Model:           smallModel,
313			Provider:        string(selectedItem.Provider.ID),
314			ReasoningEffort: model.DefaultReasoningEffort,
315			MaxTokens:       model.DefaultMaxTokens,
316		}
317		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
318		if err != nil {
319			return util.ReportError(err)
320		}
321	}
322	return nil
323}
324
325func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
326	providers, err := config.Providers()
327	if err != nil {
328		return nil, err
329	}
330	for _, p := range providers {
331		if p.ID == providerID {
332			return &p, nil
333		}
334	}
335	return nil, nil
336}
337
338func (s *splashCmp) isProviderConfigured(providerID string) bool {
339	cfg := config.Get()
340	if _, ok := cfg.Providers[providerID]; ok {
341		return true
342	}
343	return false
344}
345
346func (s *splashCmp) View() string {
347	t := styles.CurrentTheme()
348	var content string
349	if s.needsAPIKey {
350		infoSection := s.infoSection()
351		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
352		apiKeyView := s.apiKeyInput.View()
353		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
354			lipgloss.JoinVertical(
355				lipgloss.Left,
356				apiKeyView,
357			),
358		)
359		content = lipgloss.JoinVertical(
360			lipgloss.Left,
361			s.logoRendered,
362			infoSection,
363			apiKeySelector,
364		)
365	} else if s.isOnboarding {
366		infoSection := s.infoSection()
367		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
368		modelListView := s.modelList.View()
369		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
370			lipgloss.JoinVertical(
371				lipgloss.Left,
372				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
373				"",
374				modelListView,
375			),
376		)
377		content = lipgloss.JoinVertical(
378			lipgloss.Left,
379			s.logoRendered,
380			infoSection,
381			modelSelector,
382		)
383	} else if s.needsProjectInit {
384		titleStyle := t.S().Base.Foreground(t.FgBase)
385		bodyStyle := t.S().Base.Foreground(t.FgMuted)
386		shortcutStyle := t.S().Base.Foreground(t.Success)
387
388		initText := lipgloss.JoinVertical(
389			lipgloss.Left,
390			titleStyle.Render("Would you like to initialize this project?"),
391			"",
392			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
393			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
394			"",
395			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
396			"",
397			bodyStyle.Render("Would you like to initialize now?"),
398		)
399
400		yesButton := core.SelectableButton(core.ButtonOpts{
401			Text:           "Yep!",
402			UnderlineIndex: 0,
403			Selected:       !s.selectedNo,
404		})
405
406		noButton := core.SelectableButton(core.ButtonOpts{
407			Text:           "Nope",
408			UnderlineIndex: 0,
409			Selected:       s.selectedNo,
410		})
411
412		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
413		infoSection := s.infoSection()
414
415		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
416
417		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
418			lipgloss.JoinVertical(
419				lipgloss.Left,
420				initText,
421				"",
422				buttons,
423			),
424		)
425
426		content = lipgloss.JoinVertical(
427			lipgloss.Left,
428			s.logoRendered,
429			infoSection,
430			initContent,
431		)
432	} else {
433		parts := []string{
434			s.logoRendered,
435			s.infoSection(),
436		}
437		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
438	}
439
440	return t.S().Base.
441		Width(s.width).
442		Height(s.height).
443		PaddingTop(SplashScreenPaddingY).
444		PaddingLeft(SplashScreenPaddingX).
445		PaddingRight(SplashScreenPaddingX).
446		PaddingBottom(SplashScreenPaddingY).
447		Render(content)
448}
449
450func (s *splashCmp) Cursor() *tea.Cursor {
451	if s.needsAPIKey {
452		cursor := s.apiKeyInput.Cursor()
453		if cursor != nil {
454			return s.moveCursor(cursor)
455		}
456	} else if s.isOnboarding {
457		cursor := s.modelList.Cursor()
458		if cursor != nil {
459			return s.moveCursor(cursor)
460		}
461	} else {
462		return nil
463	}
464	return nil
465}
466
467func (s *splashCmp) infoSection() string {
468	return lipgloss.JoinVertical(
469		lipgloss.Left,
470		s.cwd(),
471		"",
472		lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
473		"",
474	)
475}
476
477func (s *splashCmp) logoBlock() string {
478	t := styles.CurrentTheme()
479	const padding = 2
480	return logo.Render(version.Version, false, logo.Opts{
481		FieldColor:   t.Primary,
482		TitleColorA:  t.Secondary,
483		TitleColorB:  t.Primary,
484		CharmColor:   t.Secondary,
485		VersionColor: t.Primary,
486		Width:        s.width - (SplashScreenPaddingX * 2),
487	})
488}
489
490func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
491	if cursor == nil {
492		return nil
493	}
494	// Calculate the correct Y offset based on current state
495	logoHeight := lipgloss.Height(m.logoRendered)
496	infoSectionHeight := lipgloss.Height(m.infoSection())
497	baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
498	if m.needsAPIKey {
499		remainingHeight := m.height - baseOffset - lipgloss.Height(m.apiKeyInput.View()) - SplashScreenPaddingY
500		offset := baseOffset + remainingHeight
501		cursor.Y += offset
502		cursor.X = cursor.X + SplashScreenPaddingX
503	} else if m.isOnboarding {
504		listHeight := min(40, m.height-(SplashScreenPaddingY)-baseOffset)
505		offset := m.height - listHeight + 2
506		cursor.Y += offset
507		cursor.X = cursor.X + SplashScreenPaddingX + 1
508	}
509
510	return cursor
511}
512
513// Bindings implements SplashPage.
514func (s *splashCmp) Bindings() []key.Binding {
515	if s.needsAPIKey {
516		return []key.Binding{
517			s.keyMap.Select,
518			s.keyMap.Back,
519		}
520	} else if s.isOnboarding {
521		return []key.Binding{
522			s.keyMap.Select,
523			s.keyMap.Next,
524			s.keyMap.Previous,
525		}
526	} else if s.needsProjectInit {
527		return []key.Binding{
528			s.keyMap.Select,
529			s.keyMap.Yes,
530			s.keyMap.No,
531			s.keyMap.Tab,
532			s.keyMap.LeftRight,
533		}
534	}
535	return []key.Binding{}
536}
537
538func (s *splashCmp) getMaxInfoWidth() int {
539	return min(s.width-(SplashScreenPaddingX*2), 40)
540}
541
542func (s *splashCmp) cwd() string {
543	cwd := config.Get().WorkingDir()
544	t := styles.CurrentTheme()
545	homeDir, err := os.UserHomeDir()
546	if err == nil && cwd != homeDir {
547		cwd = strings.ReplaceAll(cwd, homeDir, "~")
548	}
549	maxWidth := s.getMaxInfoWidth()
550	return t.S().Muted.Width(maxWidth).Render(cwd)
551}
552
553func LSPList(maxWidth int) []string {
554	t := styles.CurrentTheme()
555	lspList := []string{}
556	lsp := config.Get().LSP.Sorted()
557	if len(lsp) == 0 {
558		return []string{t.S().Base.Foreground(t.Border).Render("None")}
559	}
560	for _, l := range lsp {
561		iconColor := t.Success
562		if l.LSP.Disabled {
563			iconColor = t.FgMuted
564		}
565		lspList = append(lspList,
566			core.Status(
567				core.StatusOpts{
568					IconColor:   iconColor,
569					Title:       l.Name,
570					Description: l.LSP.Command,
571				},
572				maxWidth,
573			),
574		)
575	}
576	return lspList
577}
578
579func (s *splashCmp) lspBlock() string {
580	t := styles.CurrentTheme()
581	maxWidth := s.getMaxInfoWidth() / 2
582	section := t.S().Subtle.Render("LSPs")
583	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
584	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
585		lipgloss.JoinVertical(
586			lipgloss.Left,
587			lspList...,
588		),
589	)
590}
591
592func MCPList(maxWidth int) []string {
593	t := styles.CurrentTheme()
594	mcpList := []string{}
595	mcps := config.Get().MCP.Sorted()
596	if len(mcps) == 0 {
597		return []string{t.S().Base.Foreground(t.Border).Render("None")}
598	}
599	for _, l := range mcps {
600		iconColor := t.Success
601		if l.MCP.Disabled {
602			iconColor = t.FgMuted
603		}
604		mcpList = append(mcpList,
605			core.Status(
606				core.StatusOpts{
607					IconColor:   iconColor,
608					Title:       l.Name,
609					Description: l.MCP.Command,
610				},
611				maxWidth,
612			),
613		)
614	}
615	return mcpList
616}
617
618func (s *splashCmp) mcpBlock() string {
619	t := styles.CurrentTheme()
620	maxWidth := s.getMaxInfoWidth() / 2
621	section := t.S().Subtle.Render("MCPs")
622	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
623	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
624		lipgloss.JoinVertical(
625			lipgloss.Left,
626			mcpList...,
627		),
628	)
629}
630
631func (s *splashCmp) IsShowingAPIKey() bool {
632	return s.needsAPIKey
633}