splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"groq",
109			"openrouter",
110		}
111		for _, p := range providers {
112			if slices.Contains(simpleProviders, string(p.ID)) {
113				filteredProviders = append(filteredProviders, p)
114			}
115		}
116		s.modelList.SetProviders(filteredProviders)
117	}
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121	s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126	return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136	s.height = height
137	if width != s.width {
138		s.width = width
139		s.logoRendered = s.logoBlock()
140	}
141	// remove padding, logo height, gap, title space
142	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
143	listWidth := min(60, width)
144	return s.modelList.SetSize(listWidth, s.listHeight)
145}
146
147// Update implements SplashPage.
148func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
149	switch msg := msg.(type) {
150	case tea.WindowSizeMsg:
151		return s, s.SetSize(msg.Width, msg.Height)
152	case tea.KeyPressMsg:
153		switch {
154		case key.Matches(msg, s.keyMap.Back):
155			if s.needsAPIKey {
156				// Go back to model selection
157				s.needsAPIKey = false
158				s.selectedModel = nil
159				return s, nil
160			}
161		case key.Matches(msg, s.keyMap.Select):
162			if s.isOnboarding && !s.needsAPIKey {
163				modelInx := s.modelList.SelectedIndex()
164				items := s.modelList.Items()
165				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
166				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
167					cmd := s.setPreferredModel(selectedItem)
168					s.isOnboarding = false
169					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
170				} else {
171					// Provider not configured, show API key input
172					s.needsAPIKey = true
173					s.selectedModel = &selectedItem
174					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175					return s, nil
176				}
177			} else if s.needsAPIKey {
178				// Handle API key submission
179				apiKey := s.apiKeyInput.Value()
180				if apiKey != "" {
181					return s, s.saveAPIKeyAndContinue(apiKey)
182				}
183			} else if s.needsProjectInit {
184				return s, s.initializeProject()
185			}
186		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
187			if s.needsProjectInit {
188				s.selectedNo = !s.selectedNo
189				return s, nil
190			}
191		case key.Matches(msg, s.keyMap.Yes):
192			if s.needsProjectInit {
193				return s, s.initializeProject()
194			}
195		case key.Matches(msg, s.keyMap.No):
196			s.selectedNo = true
197			return s, s.initializeProject()
198		default:
199			if s.needsAPIKey {
200				u, cmd := s.apiKeyInput.Update(msg)
201				s.apiKeyInput = u.(*models.APIKeyInput)
202				return s, cmd
203			} else if s.isOnboarding {
204				u, cmd := s.modelList.Update(msg)
205				s.modelList = u
206				return s, cmd
207			}
208		}
209	case tea.PasteMsg:
210		if s.needsAPIKey {
211			u, cmd := s.apiKeyInput.Update(msg)
212			s.apiKeyInput = u.(*models.APIKeyInput)
213			return s, cmd
214		} else if s.isOnboarding {
215			var cmd tea.Cmd
216			s.modelList, cmd = s.modelList.Update(msg)
217			return s, cmd
218		}
219	}
220	return s, nil
221}
222
223func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
224	if s.selectedModel == nil {
225		return util.ReportError(fmt.Errorf("no model selected"))
226	}
227
228	cfg := config.Get()
229	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
230	if err != nil {
231		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
232	}
233
234	// Reset API key state and continue with model selection
235	s.needsAPIKey = false
236	cmd := s.setPreferredModel(*s.selectedModel)
237	s.isOnboarding = false
238	s.selectedModel = nil
239
240	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
241}
242
243func (s *splashCmp) initializeProject() tea.Cmd {
244	s.needsProjectInit = false
245
246	if err := config.MarkProjectInitialized(); err != nil {
247		return util.ReportError(err)
248	}
249	var cmds []tea.Cmd
250
251	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
252	if !s.selectedNo {
253		cmds = append(cmds,
254			util.CmdHandler(chat.SessionClearedMsg{}),
255			util.CmdHandler(chat.SendMsg{
256				Text: prompt.Initialize(),
257			}),
258		)
259	}
260	return tea.Sequence(cmds...)
261}
262
263func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
264	cfg := config.Get()
265	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
266	if model == nil {
267		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
268	}
269
270	selectedModel := config.SelectedModel{
271		Model:           selectedItem.Model.ID,
272		Provider:        string(selectedItem.Provider.ID),
273		ReasoningEffort: model.DefaultReasoningEffort,
274		MaxTokens:       model.DefaultMaxTokens,
275	}
276
277	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
278	if err != nil {
279		return util.ReportError(err)
280	}
281
282	// Now lets automatically setup the small model
283	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
284	if err != nil {
285		return util.ReportError(err)
286	}
287	if knownProvider == nil {
288		// for local provider we just use the same model
289		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
290		if err != nil {
291			return util.ReportError(err)
292		}
293	} else {
294		smallModel := knownProvider.DefaultSmallModelID
295		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
296		// should never happen
297		if model == nil {
298			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
299			if err != nil {
300				return util.ReportError(err)
301			}
302			return nil
303		}
304		smallSelectedModel := config.SelectedModel{
305			Model:           smallModel,
306			Provider:        string(selectedItem.Provider.ID),
307			ReasoningEffort: model.DefaultReasoningEffort,
308			MaxTokens:       model.DefaultMaxTokens,
309		}
310		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
311		if err != nil {
312			return util.ReportError(err)
313		}
314	}
315	cfg.SetupAgents()
316	return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320	providers, err := config.Providers()
321	if err != nil {
322		return nil, err
323	}
324	for _, p := range providers {
325		if p.ID == providerID {
326			return &p, nil
327		}
328	}
329	return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333	cfg := config.Get()
334	if _, ok := cfg.Providers[providerID]; ok {
335		return true
336	}
337	return false
338}
339
340func (s *splashCmp) View() string {
341	t := styles.CurrentTheme()
342	var content string
343	if s.needsAPIKey {
344		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
345		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
346		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
347			lipgloss.JoinVertical(
348				lipgloss.Left,
349				apiKeyView,
350			),
351		)
352		content = lipgloss.JoinVertical(
353			lipgloss.Left,
354			s.logoRendered,
355			apiKeySelector,
356		)
357	} else if s.isOnboarding {
358		modelListView := s.modelList.View()
359		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
361			lipgloss.JoinVertical(
362				lipgloss.Left,
363				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
364				"",
365				modelListView,
366			),
367		)
368		content = lipgloss.JoinVertical(
369			lipgloss.Left,
370			s.logoRendered,
371			modelSelector,
372		)
373	} else if s.needsProjectInit {
374		titleStyle := t.S().Base.Foreground(t.FgBase)
375		bodyStyle := t.S().Base.Foreground(t.FgMuted)
376		shortcutStyle := t.S().Base.Foreground(t.Success)
377
378		initText := lipgloss.JoinVertical(
379			lipgloss.Left,
380			titleStyle.Render("Would you like to initialize this project?"),
381			"",
382			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
383			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
384			"",
385			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
386			"",
387			bodyStyle.Render("Would you like to initialize now?"),
388		)
389
390		yesButton := core.SelectableButton(core.ButtonOpts{
391			Text:           "Yep!",
392			UnderlineIndex: 0,
393			Selected:       !s.selectedNo,
394		})
395
396		noButton := core.SelectableButton(core.ButtonOpts{
397			Text:           "Nope",
398			UnderlineIndex: 0,
399			Selected:       s.selectedNo,
400		})
401
402		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
403		infoSection := s.infoSection()
404
405		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
406
407		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
408			lipgloss.JoinVertical(
409				lipgloss.Left,
410				initText,
411				"",
412				buttons,
413			),
414		)
415
416		content = lipgloss.JoinVertical(
417			lipgloss.Left,
418			s.logoRendered,
419			infoSection,
420			initContent,
421		)
422	} else {
423		parts := []string{
424			s.logoRendered,
425			s.infoSection(),
426		}
427		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
428	}
429
430	return t.S().Base.
431		Width(s.width).
432		Height(s.height).
433		PaddingTop(SplashScreenPaddingY).
434		PaddingBottom(SplashScreenPaddingY).
435		Render(content)
436}
437
438func (s *splashCmp) Cursor() *tea.Cursor {
439	if s.needsAPIKey {
440		cursor := s.apiKeyInput.Cursor()
441		if cursor != nil {
442			return s.moveCursor(cursor)
443		}
444	} else if s.isOnboarding {
445		cursor := s.modelList.Cursor()
446		if cursor != nil {
447			return s.moveCursor(cursor)
448		}
449	} else {
450		return nil
451	}
452	return nil
453}
454
455func (s *splashCmp) infoSection() string {
456	t := styles.CurrentTheme()
457	return t.S().Base.PaddingLeft(2).Render(
458		lipgloss.JoinVertical(
459			lipgloss.Left,
460			s.cwd(),
461			"",
462			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
463			"",
464		),
465	)
466}
467
468func (s *splashCmp) logoBlock() string {
469	t := styles.CurrentTheme()
470	return t.S().Base.Padding(0, 2).Width(s.width).Render(
471		logo.Render(version.Version, false, logo.Opts{
472			FieldColor:   t.Primary,
473			TitleColorA:  t.Secondary,
474			TitleColorB:  t.Primary,
475			CharmColor:   t.Secondary,
476			VersionColor: t.Primary,
477			Width:        s.width - 4,
478		}),
479	)
480}
481
482func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
483	if cursor == nil {
484		return nil
485	}
486	// Calculate the correct Y offset based on current state
487	logoHeight := lipgloss.Height(s.logoRendered)
488	if s.needsAPIKey {
489		infoSectionHeight := lipgloss.Height(s.infoSection())
490		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
491		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
492		offset := baseOffset + remainingHeight
493		cursor.Y += offset
494		cursor.X = cursor.X + 1
495	} else if s.isOnboarding {
496		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
497		cursor.Y += offset
498		cursor.X = cursor.X + 1
499	}
500
501	return cursor
502}
503
504func (s *splashCmp) logoGap() int {
505	if s.height > 35 {
506		return LogoGap
507	}
508	return 0
509}
510
511// Bindings implements SplashPage.
512func (s *splashCmp) Bindings() []key.Binding {
513	if s.needsAPIKey {
514		return []key.Binding{
515			s.keyMap.Select,
516			s.keyMap.Back,
517		}
518	} else if s.isOnboarding {
519		return []key.Binding{
520			s.keyMap.Select,
521			s.keyMap.Next,
522			s.keyMap.Previous,
523		}
524	} else if s.needsProjectInit {
525		return []key.Binding{
526			s.keyMap.Select,
527			s.keyMap.Yes,
528			s.keyMap.No,
529			s.keyMap.Tab,
530			s.keyMap.LeftRight,
531		}
532	}
533	return []key.Binding{}
534}
535
536func (s *splashCmp) getMaxInfoWidth() int {
537	return min(s.width-2, 40) // 2 for left padding
538}
539
540func (s *splashCmp) cwd() string {
541	cwd := config.Get().WorkingDir()
542	t := styles.CurrentTheme()
543	homeDir, err := os.UserHomeDir()
544	if err == nil && cwd != homeDir {
545		cwd = strings.ReplaceAll(cwd, homeDir, "~")
546	}
547	maxWidth := s.getMaxInfoWidth()
548	return t.S().Muted.Width(maxWidth).Render(cwd)
549}
550
551func LSPList(maxWidth int) []string {
552	t := styles.CurrentTheme()
553	lspList := []string{}
554	lsp := config.Get().LSP.Sorted()
555	if len(lsp) == 0 {
556		return []string{t.S().Base.Foreground(t.Border).Render("None")}
557	}
558	for _, l := range lsp {
559		iconColor := t.Success
560		if l.LSP.Disabled {
561			iconColor = t.FgMuted
562		}
563		lspList = append(lspList,
564			core.Status(
565				core.StatusOpts{
566					IconColor:   iconColor,
567					Title:       l.Name,
568					Description: l.LSP.Command,
569				},
570				maxWidth,
571			),
572		)
573	}
574	return lspList
575}
576
577func (s *splashCmp) lspBlock() string {
578	t := styles.CurrentTheme()
579	maxWidth := s.getMaxInfoWidth() / 2
580	section := t.S().Subtle.Render("LSPs")
581	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
582	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
583		lipgloss.JoinVertical(
584			lipgloss.Left,
585			lspList...,
586		),
587	)
588}
589
590func MCPList(maxWidth int) []string {
591	t := styles.CurrentTheme()
592	mcpList := []string{}
593	mcps := config.Get().MCP.Sorted()
594	if len(mcps) == 0 {
595		return []string{t.S().Base.Foreground(t.Border).Render("None")}
596	}
597	for _, l := range mcps {
598		iconColor := t.Success
599		if l.MCP.Disabled {
600			iconColor = t.FgMuted
601		}
602		mcpList = append(mcpList,
603			core.Status(
604				core.StatusOpts{
605					IconColor:   iconColor,
606					Title:       l.Name,
607					Description: l.MCP.Command,
608				},
609				maxWidth,
610			),
611		)
612	}
613	return mcpList
614}
615
616func (s *splashCmp) mcpBlock() string {
617	t := styles.CurrentTheme()
618	maxWidth := s.getMaxInfoWidth() / 2
619	section := t.S().Subtle.Render("MCPs")
620	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
621	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
622		lipgloss.JoinVertical(
623			lipgloss.Left,
624			mcpList...,
625		),
626	)
627}
628
629func (s *splashCmp) IsShowingAPIKey() bool {
630	return s.needsAPIKey
631}