1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"openrouter",
109		}
110		for _, p := range providers {
111			if slices.Contains(simpleProviders, string(p.ID)) {
112				filteredProviders = append(filteredProviders, p)
113			}
114		}
115		s.modelList.SetProviders(filteredProviders)
116	}
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120	s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125	return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135	s.height = height
136	if width != s.width {
137		s.width = width
138		s.logoRendered = s.logoBlock()
139	}
140	// remove padding, logo height, gap, title space
141	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142	listWidth := min(60, width)
143	return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148	switch msg := msg.(type) {
149	case tea.WindowSizeMsg:
150		return s, s.SetSize(msg.Width, msg.Height)
151	case tea.KeyPressMsg:
152		switch {
153		case key.Matches(msg, s.keyMap.Back):
154			if s.needsAPIKey {
155				// Go back to model selection
156				s.needsAPIKey = false
157				s.selectedModel = nil
158				return s, nil
159			}
160		case key.Matches(msg, s.keyMap.Select):
161			if s.isOnboarding && !s.needsAPIKey {
162				modelInx := s.modelList.SelectedIndex()
163				items := s.modelList.Items()
164				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166					cmd := s.setPreferredModel(selectedItem)
167					s.isOnboarding = false
168					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169				} else {
170					// Provider not configured, show API key input
171					s.needsAPIKey = true
172					s.selectedModel = &selectedItem
173					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174					return s, nil
175				}
176			} else if s.needsAPIKey {
177				// Handle API key submission
178				apiKey := s.apiKeyInput.Value()
179				if apiKey != "" {
180					return s, s.saveAPIKeyAndContinue(apiKey)
181				}
182			} else if s.needsProjectInit {
183				return s, s.initializeProject()
184			}
185		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186			if s.needsProjectInit {
187				s.selectedNo = !s.selectedNo
188				return s, nil
189			}
190		case key.Matches(msg, s.keyMap.Yes):
191			if s.needsProjectInit {
192				return s, s.initializeProject()
193			}
194		case key.Matches(msg, s.keyMap.No):
195			if s.needsProjectInit {
196				s.needsProjectInit = false
197				return s, util.CmdHandler(OnboardingCompleteMsg{})
198			}
199		default:
200			if s.needsAPIKey {
201				u, cmd := s.apiKeyInput.Update(msg)
202				s.apiKeyInput = u.(*models.APIKeyInput)
203				return s, cmd
204			} else if s.isOnboarding {
205				u, cmd := s.modelList.Update(msg)
206				s.modelList = u
207				return s, cmd
208			}
209		}
210	case tea.PasteMsg:
211		if s.needsAPIKey {
212			u, cmd := s.apiKeyInput.Update(msg)
213			s.apiKeyInput = u.(*models.APIKeyInput)
214			return s, cmd
215		} else if s.isOnboarding {
216			var cmd tea.Cmd
217			s.modelList, cmd = s.modelList.Update(msg)
218			return s, cmd
219		}
220	}
221	return s, nil
222}
223
224func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
225	if s.selectedModel == nil {
226		return util.ReportError(fmt.Errorf("no model selected"))
227	}
228
229	cfg := config.Get()
230	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
231	if err != nil {
232		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
233	}
234
235	// Reset API key state and continue with model selection
236	s.needsAPIKey = false
237	cmd := s.setPreferredModel(*s.selectedModel)
238	s.isOnboarding = false
239	s.selectedModel = nil
240
241	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
242}
243
244func (s *splashCmp) initializeProject() tea.Cmd {
245	s.needsProjectInit = false
246
247	if err := config.MarkProjectInitialized(); err != nil {
248		return util.ReportError(err)
249	}
250	var cmds []tea.Cmd
251
252	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253	if !s.selectedNo {
254		cmds = append(cmds,
255			util.CmdHandler(chat.SessionClearedMsg{}),
256			util.CmdHandler(chat.SendMsg{
257				Text: prompt.Initialize(),
258			}),
259		)
260	}
261	return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265	cfg := config.Get()
266	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267	if model == nil {
268		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269	}
270
271	selectedModel := config.SelectedModel{
272		Model:           selectedItem.Model.ID,
273		Provider:        string(selectedItem.Provider.ID),
274		ReasoningEffort: model.DefaultReasoningEffort,
275		MaxTokens:       model.DefaultMaxTokens,
276	}
277
278	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279	if err != nil {
280		return util.ReportError(err)
281	}
282
283	// Now lets automatically setup the small model
284	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285	if err != nil {
286		return util.ReportError(err)
287	}
288	if knownProvider == nil {
289		// for local provider we just use the same model
290		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291		if err != nil {
292			return util.ReportError(err)
293		}
294	} else {
295		smallModel := knownProvider.DefaultSmallModelID
296		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297		// should never happen
298		if model == nil {
299			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300			if err != nil {
301				return util.ReportError(err)
302			}
303			return nil
304		}
305		smallSelectedModel := config.SelectedModel{
306			Model:           smallModel,
307			Provider:        string(selectedItem.Provider.ID),
308			ReasoningEffort: model.DefaultReasoningEffort,
309			MaxTokens:       model.DefaultMaxTokens,
310		}
311		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312		if err != nil {
313			return util.ReportError(err)
314		}
315	}
316	return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320	providers, err := config.Providers()
321	if err != nil {
322		return nil, err
323	}
324	for _, p := range providers {
325		if p.ID == providerID {
326			return &p, nil
327		}
328	}
329	return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333	cfg := config.Get()
334	if _, ok := cfg.Providers[providerID]; ok {
335		return true
336	}
337	return false
338}
339
340func (s *splashCmp) View() string {
341	t := styles.CurrentTheme()
342	var content string
343	if s.needsAPIKey {
344		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
345		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
346		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
347			lipgloss.JoinVertical(
348				lipgloss.Left,
349				apiKeyView,
350			),
351		)
352		content = lipgloss.JoinVertical(
353			lipgloss.Left,
354			s.logoRendered,
355			apiKeySelector,
356		)
357	} else if s.isOnboarding {
358		modelListView := s.modelList.View()
359		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
361			lipgloss.JoinVertical(
362				lipgloss.Left,
363				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
364				"",
365				modelListView,
366			),
367		)
368		content = lipgloss.JoinVertical(
369			lipgloss.Left,
370			s.logoRendered,
371			modelSelector,
372		)
373	} else if s.needsProjectInit {
374		titleStyle := t.S().Base.Foreground(t.FgBase)
375		bodyStyle := t.S().Base.Foreground(t.FgMuted)
376		shortcutStyle := t.S().Base.Foreground(t.Success)
377
378		initText := lipgloss.JoinVertical(
379			lipgloss.Left,
380			titleStyle.Render("Would you like to initialize this project?"),
381			"",
382			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
383			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
384			"",
385			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
386			"",
387			bodyStyle.Render("Would you like to initialize now?"),
388		)
389
390		yesButton := core.SelectableButton(core.ButtonOpts{
391			Text:           "Yep!",
392			UnderlineIndex: 0,
393			Selected:       !s.selectedNo,
394		})
395
396		noButton := core.SelectableButton(core.ButtonOpts{
397			Text:           "Nope",
398			UnderlineIndex: 0,
399			Selected:       s.selectedNo,
400		})
401
402		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
403		infoSection := s.infoSection()
404
405		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
406
407		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
408			lipgloss.JoinVertical(
409				lipgloss.Left,
410				initText,
411				"",
412				buttons,
413			),
414		)
415
416		content = lipgloss.JoinVertical(
417			lipgloss.Left,
418			s.logoRendered,
419			infoSection,
420			initContent,
421		)
422	} else {
423		parts := []string{
424			s.logoRendered,
425			s.infoSection(),
426		}
427		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
428	}
429
430	return t.S().Base.
431		Width(s.width).
432		Height(s.height).
433		PaddingTop(SplashScreenPaddingY).
434		PaddingBottom(SplashScreenPaddingY).
435		Render(content)
436}
437
438func (s *splashCmp) Cursor() *tea.Cursor {
439	if s.needsAPIKey {
440		cursor := s.apiKeyInput.Cursor()
441		if cursor != nil {
442			return s.moveCursor(cursor)
443		}
444	} else if s.isOnboarding {
445		cursor := s.modelList.Cursor()
446		if cursor != nil {
447			return s.moveCursor(cursor)
448		}
449	} else {
450		return nil
451	}
452	return nil
453}
454
455func (s *splashCmp) infoSection() string {
456	t := styles.CurrentTheme()
457	return t.S().Base.PaddingLeft(2).Render(
458		lipgloss.JoinVertical(
459			lipgloss.Left,
460			s.cwd(),
461			"",
462			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
463			"",
464		),
465	)
466}
467
468func (s *splashCmp) logoBlock() string {
469	t := styles.CurrentTheme()
470	return t.S().Base.Padding(0, 2).Width(s.width).Render(
471		logo.Render(version.Version, false, logo.Opts{
472			FieldColor:   t.Primary,
473			TitleColorA:  t.Secondary,
474			TitleColorB:  t.Primary,
475			CharmColor:   t.Secondary,
476			VersionColor: t.Primary,
477			Width:        s.width - 4,
478		}),
479	)
480}
481
482func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
483	if cursor == nil {
484		return nil
485	}
486	// Calculate the correct Y offset based on current state
487	logoHeight := lipgloss.Height(s.logoRendered)
488	if s.needsAPIKey {
489		infoSectionHeight := lipgloss.Height(s.infoSection())
490		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
491		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
492		offset := baseOffset + remainingHeight
493		cursor.Y += offset
494		cursor.X = cursor.X + 1
495	} else if s.isOnboarding {
496		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
497		cursor.Y += offset
498		cursor.X = cursor.X + 1
499	}
500
501	return cursor
502}
503
504func (s *splashCmp) logoGap() int {
505	if s.height > 35 {
506		return LogoGap
507	}
508	return 0
509}
510
511// Bindings implements SplashPage.
512func (s *splashCmp) Bindings() []key.Binding {
513	if s.needsAPIKey {
514		return []key.Binding{
515			s.keyMap.Select,
516			s.keyMap.Back,
517		}
518	} else if s.isOnboarding {
519		return []key.Binding{
520			s.keyMap.Select,
521			s.keyMap.Next,
522			s.keyMap.Previous,
523		}
524	} else if s.needsProjectInit {
525		return []key.Binding{
526			s.keyMap.Select,
527			s.keyMap.Yes,
528			s.keyMap.No,
529			s.keyMap.Tab,
530			s.keyMap.LeftRight,
531		}
532	}
533	return []key.Binding{}
534}
535
536func (s *splashCmp) getMaxInfoWidth() int {
537	return min(s.width-2, 40) // 2 for left padding
538}
539
540func (s *splashCmp) cwd() string {
541	cwd := config.Get().WorkingDir()
542	t := styles.CurrentTheme()
543	homeDir, err := os.UserHomeDir()
544	if err == nil && cwd != homeDir {
545		cwd = strings.ReplaceAll(cwd, homeDir, "~")
546	}
547	maxWidth := s.getMaxInfoWidth()
548	return t.S().Muted.Width(maxWidth).Render(cwd)
549}
550
551func LSPList(maxWidth int) []string {
552	t := styles.CurrentTheme()
553	lspList := []string{}
554	lsp := config.Get().LSP.Sorted()
555	if len(lsp) == 0 {
556		return []string{t.S().Base.Foreground(t.Border).Render("None")}
557	}
558	for _, l := range lsp {
559		iconColor := t.Success
560		if l.LSP.Disabled {
561			iconColor = t.FgMuted
562		}
563		lspList = append(lspList,
564			core.Status(
565				core.StatusOpts{
566					IconColor:   iconColor,
567					Title:       l.Name,
568					Description: l.LSP.Command,
569				},
570				maxWidth,
571			),
572		)
573	}
574	return lspList
575}
576
577func (s *splashCmp) lspBlock() string {
578	t := styles.CurrentTheme()
579	maxWidth := s.getMaxInfoWidth() / 2
580	section := t.S().Subtle.Render("LSPs")
581	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
582	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
583		lipgloss.JoinVertical(
584			lipgloss.Left,
585			lspList...,
586		),
587	)
588}
589
590func MCPList(maxWidth int) []string {
591	t := styles.CurrentTheme()
592	mcpList := []string{}
593	mcps := config.Get().MCP.Sorted()
594	if len(mcps) == 0 {
595		return []string{t.S().Base.Foreground(t.Border).Render("None")}
596	}
597	for _, l := range mcps {
598		iconColor := t.Success
599		if l.MCP.Disabled {
600			iconColor = t.FgMuted
601		}
602		mcpList = append(mcpList,
603			core.Status(
604				core.StatusOpts{
605					IconColor:   iconColor,
606					Title:       l.Name,
607					Description: l.MCP.Command,
608				},
609				maxWidth,
610			),
611		)
612	}
613	return mcpList
614}
615
616func (s *splashCmp) mcpBlock() string {
617	t := styles.CurrentTheme()
618	maxWidth := s.getMaxInfoWidth() / 2
619	section := t.S().Subtle.Render("MCPs")
620	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
621	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
622		lipgloss.JoinVertical(
623			lipgloss.Left,
624			mcpList...,
625		),
626	)
627}
628
629func (s *splashCmp) IsShowingAPIKey() bool {
630	return s.needsAPIKey
631}