Detailed changes
@@ -87,6 +87,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getUsageByModelStmt, err = db.PrepareContext(ctx, getUsageByModel); err != nil {
return nil, fmt.Errorf("error preparing query GetUsageByModel: %w", err)
}
+ if q.listAllUserMessagesStmt, err = db.PrepareContext(ctx, listAllUserMessages); err != nil {
+ return nil, fmt.Errorf("error preparing query ListAllUserMessages: %w", err)
+ }
if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil {
return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err)
}
@@ -105,6 +108,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
}
+ if q.listUserMessagesBySessionStmt, err = db.PrepareContext(ctx, listUserMessagesBySession); err != nil {
+ return nil, fmt.Errorf("error preparing query ListUserMessagesBySession: %w", err)
+ }
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
}
@@ -224,6 +230,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getUsageByModelStmt: %w", cerr)
}
}
+ if q.listAllUserMessagesStmt != nil {
+ if cerr := q.listAllUserMessagesStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing listAllUserMessagesStmt: %w", cerr)
+ }
+ }
if q.listFilesByPathStmt != nil {
if cerr := q.listFilesByPathStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr)
@@ -254,6 +265,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
}
}
+ if q.listUserMessagesBySessionStmt != nil {
+ if cerr := q.listUserMessagesBySessionStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing listUserMessagesBySessionStmt: %w", cerr)
+ }
+ }
if q.updateMessageStmt != nil {
if cerr := q.updateMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
@@ -329,12 +345,14 @@ type Queries struct {
getUsageByDayOfWeekStmt *sql.Stmt
getUsageByHourStmt *sql.Stmt
getUsageByModelStmt *sql.Stmt
+ listAllUserMessagesStmt *sql.Stmt
listFilesByPathStmt *sql.Stmt
listFilesBySessionStmt *sql.Stmt
listLatestSessionFilesStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
+ listUserMessagesBySessionStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
updateSessionTitleAndUsageStmt *sql.Stmt
@@ -365,12 +383,14 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
getUsageByDayOfWeekStmt: q.getUsageByDayOfWeekStmt,
getUsageByHourStmt: q.getUsageByHourStmt,
getUsageByModelStmt: q.getUsageByModelStmt,
+ listAllUserMessagesStmt: q.listAllUserMessagesStmt,
listFilesByPathStmt: q.listFilesByPathStmt,
listFilesBySessionStmt: q.listFilesBySessionStmt,
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
+ listUserMessagesBySessionStmt: q.listUserMessagesBySessionStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
@@ -107,6 +107,47 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
return i, err
}
+const listAllUserMessages = `-- name: ListAllUserMessages :many
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+FROM messages
+WHERE role = 'user'
+ORDER BY created_at DESC
+`
+
+func (q *Queries) ListAllUserMessages(ctx context.Context) ([]Message, error) {
+ rows, err := q.query(ctx, q.listAllUserMessagesStmt, listAllUserMessages)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Message{}
+ for rows.Next() {
+ var i Message
+ if err := rows.Scan(
+ &i.ID,
+ &i.SessionID,
+ &i.Role,
+ &i.Parts,
+ &i.Model,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ &i.FinishedAt,
+ &i.Provider,
+ &i.IsSummaryMessage,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
const listMessagesBySession = `-- name: ListMessagesBySession :many
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
FROM messages
@@ -148,6 +189,47 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
return items, nil
}
+const listUserMessagesBySession = `-- name: ListUserMessagesBySession :many
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+FROM messages
+WHERE session_id = ? AND role = 'user'
+ORDER BY created_at DESC
+`
+
+func (q *Queries) ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) {
+ rows, err := q.query(ctx, q.listUserMessagesBySessionStmt, listUserMessagesBySession, sessionID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Message{}
+ for rows.Next() {
+ var i Message
+ if err := rows.Scan(
+ &i.ID,
+ &i.SessionID,
+ &i.Role,
+ &i.Parts,
+ &i.Model,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ &i.FinishedAt,
+ &i.Provider,
+ &i.IsSummaryMessage,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
const updateMessage = `-- name: UpdateMessage :exec
UPDATE messages
SET
@@ -30,12 +30,14 @@ type Querier interface {
GetUsageByDayOfWeek(ctx context.Context) ([]GetUsageByDayOfWeekRow, error)
GetUsageByHour(ctx context.Context) ([]GetUsageByHourRow, error)
GetUsageByModel(ctx context.Context) ([]GetUsageByModelRow, error)
+ ListAllUserMessages(ctx context.Context) ([]Message, error)
ListFilesByPath(ctx context.Context, path string) ([]File, error)
ListFilesBySession(ctx context.Context, sessionID string) ([]File, error)
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
ListNewFiles(ctx context.Context) ([]File, error)
ListSessions(ctx context.Context) ([]Session, error)
+ ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
@@ -41,3 +41,15 @@ WHERE id = ?;
-- name: DeleteSessionMessages :exec
DELETE FROM messages
WHERE session_id = ?;
+
+-- name: ListUserMessagesBySession :many
+SELECT *
+FROM messages
+WHERE session_id = ? AND role = 'user'
+ORDER BY created_at DESC;
+
+-- name: ListAllUserMessages :many
+SELECT *
+FROM messages
+WHERE role = 'user'
+ORDER BY created_at DESC;
@@ -26,6 +26,8 @@ type Service interface {
Update(ctx context.Context, message Message) error
Get(ctx context.Context, id string) (Message, error)
List(ctx context.Context, sessionID string) ([]Message, error)
+ ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
+ ListAllUserMessages(ctx context.Context) ([]Message, error)
Delete(ctx context.Context, id string) error
DeleteSessionMessages(ctx context.Context, sessionID string) error
}
@@ -157,6 +159,36 @@ func (s *service) List(ctx context.Context, sessionID string) ([]Message, error)
return messages, nil
}
+func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
+ dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
+ if err != nil {
+ return nil, err
+ }
+ messages := make([]Message, len(dbMessages))
+ for i, dbMessage := range dbMessages {
+ messages[i], err = s.fromDBItem(dbMessage)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return messages, nil
+}
+
+func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
+ dbMessages, err := s.q.ListAllUserMessages(ctx)
+ if err != nil {
+ return nil, err
+ }
+ messages := make([]Message, len(dbMessages))
+ for i, dbMessage := range dbMessages {
+ messages[i], err = s.fromDBItem(dbMessage)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return messages, nil
+}
+
func (s *service) fromDBItem(item db.Message) (Message, error) {
parts, err := unmarshalParts([]byte(item.Parts))
if err != nil {
@@ -0,0 +1,184 @@
+package model
+
+import (
+ "context"
+ "log/slog"
+
+ tea "charm.land/bubbletea/v2"
+
+ "github.com/charmbracelet/crush/internal/message"
+)
+
+// promptHistoryLoadedMsg is sent when prompt history is loaded.
+type promptHistoryLoadedMsg struct {
+ messages []string
+}
+
+// loadPromptHistory loads user messages for history navigation.
+func (m *UI) loadPromptHistory() tea.Cmd {
+ return func() tea.Msg {
+ ctx := context.Background()
+ var messages []message.Message
+ var err error
+
+ if m.session != nil {
+ messages, err = m.com.App.Messages.ListUserMessages(ctx, m.session.ID)
+ } else {
+ messages, err = m.com.App.Messages.ListAllUserMessages(ctx)
+ }
+ if err != nil {
+ slog.Error("failed to load prompt history", "error", err)
+ return promptHistoryLoadedMsg{messages: nil}
+ }
+
+ texts := make([]string, 0, len(messages))
+ for _, msg := range messages {
+ if text := msg.Content().Text; text != "" {
+ texts = append(texts, text)
+ }
+ }
+ return promptHistoryLoadedMsg{messages: texts}
+ }
+}
+
+// handleHistoryUp handles up arrow for history navigation.
+func (m *UI) handleHistoryUp(msg tea.Msg) tea.Cmd {
+ // Navigate to older history entry from cursor position (0,0).
+ if m.textarea.Length() == 0 || m.isAtEditorStart() {
+ if m.historyPrev() {
+ // we send this so that the textarea moves the view to the correct position
+ // without this the cursor will show up in the wrong place.
+ ta, cmd := m.textarea.Update(nil)
+ m.textarea = ta
+ return cmd
+ }
+ }
+
+ // First move cursor to start before entering history.
+ if m.textarea.Line() == 0 {
+ m.textarea.CursorStart()
+ return nil
+ }
+
+ // Let textarea handle normal cursor movement.
+ ta, cmd := m.textarea.Update(msg)
+ m.textarea = ta
+ return cmd
+}
+
+// handleHistoryDown handles down arrow for history navigation.
+func (m *UI) handleHistoryDown(msg tea.Msg) tea.Cmd {
+ // Navigate to newer history entry from end of text.
+ if m.isAtEditorEnd() {
+ if m.historyNext() {
+ // we send this so that the textarea moves the view to the correct position
+ // without this the cursor will show up in the wrong place.
+ ta, cmd := m.textarea.Update(nil)
+ m.textarea = ta
+ return cmd
+ }
+ }
+
+ // First move cursor to end before navigating history.
+ if m.textarea.Line() == max(m.textarea.LineCount()-1, 0) {
+ m.textarea.MoveToEnd()
+ ta, cmd := m.textarea.Update(nil)
+ m.textarea = ta
+ return cmd
+ }
+
+ // Let textarea handle normal cursor movement.
+ ta, cmd := m.textarea.Update(msg)
+ m.textarea = ta
+ return cmd
+}
+
+// handleHistoryEscape handles escape for exiting history navigation.
+func (m *UI) handleHistoryEscape(msg tea.Msg) tea.Cmd {
+ // Return to current draft when browsing history.
+ if m.promptHistory.index >= 0 {
+ m.promptHistory.index = -1
+ m.textarea.Reset()
+ m.textarea.InsertString(m.promptHistory.draft)
+ ta, cmd := m.textarea.Update(nil)
+ m.textarea = ta
+ return cmd
+ }
+
+ // Let textarea handle escape normally.
+ ta, cmd := m.textarea.Update(msg)
+ m.textarea = ta
+ return cmd
+}
+
+// updateHistoryDraft updates history state when text is modified.
+func (m *UI) updateHistoryDraft(oldValue string) {
+ if m.textarea.Value() != oldValue {
+ m.promptHistory.draft = m.textarea.Value()
+ m.promptHistory.index = -1
+ }
+}
+
+// historyPrev changes the text area content to the previous message in the history
+// it returns false if it could not find the previous message.
+func (m *UI) historyPrev() bool {
+ if len(m.promptHistory.messages) == 0 {
+ return false
+ }
+ if m.promptHistory.index == -1 {
+ m.promptHistory.draft = m.textarea.Value()
+ }
+ nextIndex := m.promptHistory.index + 1
+ if nextIndex >= len(m.promptHistory.messages) {
+ return false
+ }
+ m.promptHistory.index = nextIndex
+ m.textarea.Reset()
+ m.textarea.InsertString(m.promptHistory.messages[nextIndex])
+ m.textarea.MoveToBegin()
+ return true
+}
+
+// historyNext changes the text area content to the next message in the history
+// it returns false if it could not find the next message.
+func (m *UI) historyNext() bool {
+ if m.promptHistory.index < 0 {
+ return false
+ }
+ nextIndex := m.promptHistory.index - 1
+ if nextIndex < 0 {
+ m.promptHistory.index = -1
+ m.textarea.Reset()
+ m.textarea.InsertString(m.promptHistory.draft)
+ return true
+ }
+ m.promptHistory.index = nextIndex
+ m.textarea.Reset()
+ m.textarea.InsertString(m.promptHistory.messages[nextIndex])
+ return true
+}
+
+// historyReset resets the history, but does not clear the message
+// it just sets the current draft to empty and the position in the history.
+func (m *UI) historyReset() {
+ m.promptHistory.index = -1
+ m.promptHistory.draft = ""
+}
+
+// isAtEditorStart returns true if we are at the 0 line and 0 col in the textarea.
+func (m *UI) isAtEditorStart() bool {
+ return m.textarea.Line() == 0 && m.textarea.LineInfo().ColumnOffset == 0
+}
+
+// isAtEditorEnd returns true if we are in the last line and the last column in the textarea.
+func (m *UI) isAtEditorEnd() bool {
+ lineCount := m.textarea.LineCount()
+ if lineCount == 0 {
+ return true
+ }
+ if m.textarea.Line() != lineCount-1 {
+ return false
+ }
+ info := m.textarea.LineInfo()
+ return info.CharOffset >= info.CharWidth-1 || info.CharWidth == 0
+}
@@ -15,6 +15,10 @@ type KeyMap struct {
AttachmentDeleteMode key.Binding
Escape key.Binding
DeleteAllAttachments key.Binding
+
+ // History navigation
+ HistoryPrev key.Binding
+ HistoryNext key.Binding
}
Chat struct {
@@ -131,6 +135,12 @@ func DefaultKeyMap() KeyMap {
key.WithKeys("r"),
key.WithHelp("ctrl+r+r", "delete all attachments"),
)
+ km.Editor.HistoryPrev = key.NewBinding(
+ key.WithKeys("up"),
+ )
+ km.Editor.HistoryNext = key.NewBinding(
+ key.WithKeys("down"),
+ )
km.Chat.NewSession = key.NewBinding(
key.WithKeys("ctrl+n"),
@@ -48,9 +48,11 @@ func (m *UI) updateInitializeView(msg tea.KeyPressMsg) (cmds []tea.Cmd) {
// initializeProject starts project initialization and transitions to the landing view.
func (m *UI) initializeProject() tea.Cmd {
// clear the session
- m.newSession()
- cfg := m.com.Config()
var cmds []tea.Cmd
+ if cmd := m.newSession(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
+ cfg := m.com.Config()
initialize := func() tea.Msg {
initPrompt, err := agent.InitializePrompt(*cfg)
@@ -211,6 +211,13 @@ type UI struct {
// mouse highlighting related state
lastClickTime time.Time
+
+ // Prompt history for up/down navigation through previous messages.
+ promptHistory struct {
+ messages []string
+ index int
+ draft string
+ }
}
// New creates a new instance of the [UI] model.
@@ -307,6 +314,8 @@ func (m *UI) Init() tea.Cmd {
}
// load the user commands async
cmds = append(cmds, m.loadCustomCommands())
+ // load prompt history async
+ cmds = append(cmds, m.loadPromptHistory())
return tea.Batch(cmds...)
}
@@ -390,6 +399,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
m.updateLayoutAndSize()
}
+ // Reload prompt history for the new session.
+ m.historyReset()
+ cmds = append(cmds, m.loadPromptHistory())
case sendMessageMsg:
cmds = append(cmds, m.sendMessage(msg.Content, msg.Attachments...))
@@ -417,13 +429,20 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
commands.SetMCPPrompts(m.mcpPrompts)
}
+ case promptHistoryLoadedMsg:
+ m.promptHistory.messages = msg.messages
+ m.promptHistory.index = -1
+ m.promptHistory.draft = ""
+
case closeDialogMsg:
m.dialog.CloseFrontDialog()
case pubsub.Event[session.Session]:
if msg.Type == pubsub.DeletedEvent {
if m.session != nil && m.session.ID == msg.Payload.ID {
- m.newSession()
+ if cmd := m.newSession(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
}
break
}
@@ -1095,7 +1114,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait before starting a new session..."))
break
}
- m.newSession()
+ if cmd := m.newSession(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
m.dialog.CloseDialog(dialog.CommandsID)
case dialog.ActionSummarize:
if m.isAgentBusy() {
@@ -1494,8 +1515,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
}
m.randomizePlaceholders()
+ m.historyReset()
- return m.sendMessage(value, attachments...)
+ return tea.Batch(m.sendMessage(value, attachments...), m.loadPromptHistory())
case key.Matches(msg, m.keyMap.Chat.NewSession):
if !m.hasSession() {
break
@@ -1504,7 +1526,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait before starting a new session..."))
break
}
- m.newSession()
+ if cmd := m.newSession(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
case key.Matches(msg, m.keyMap.Tab):
if m.state != uiLanding {
m.setState(m.state, uiFocusMain)
@@ -1524,6 +1548,21 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
ta, cmd := m.textarea.Update(msg)
m.textarea = ta
cmds = append(cmds, cmd)
+ case key.Matches(msg, m.keyMap.Editor.HistoryPrev):
+ cmd := m.handleHistoryUp(msg)
+ if cmd != nil {
+ cmds = append(cmds, cmd)
+ }
+ case key.Matches(msg, m.keyMap.Editor.HistoryNext):
+ cmd := m.handleHistoryDown(msg)
+ if cmd != nil {
+ cmds = append(cmds, cmd)
+ }
+ case key.Matches(msg, m.keyMap.Editor.Escape):
+ cmd := m.handleHistoryEscape(msg)
+ if cmd != nil {
+ cmds = append(cmds, cmd)
+ }
default:
if handleGlobalKeys(msg) {
// Handle global keys first before passing to textarea.
@@ -1557,6 +1596,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
m.textarea = ta
cmds = append(cmds, cmd)
+ // Any text modification becomes the current draft.
+ m.updateHistoryDraft(curValue)
+
// After updating textarea, check if we need to filter completions.
// Skip filtering on the initial @ keystroke since items are loading async.
if m.completionsOpen && msg.String() != "@" {
@@ -1596,7 +1638,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
break
}
m.focus = uiFocusEditor
- m.newSession()
+ if cmd := m.newSession(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
case key.Matches(msg, m.keyMap.Chat.Expand):
m.chat.ToggleExpandedSelectedItem()
case key.Matches(msg, m.keyMap.Chat.Up):
@@ -2772,9 +2816,10 @@ func (m *UI) handlePermissionNotification(notification permission.PermissionNoti
// newSession clears the current session state and prepares for a new session.
// The actual session creation happens when the user sends their first message.
-func (m *UI) newSession() {
+// Returns a command to reload prompt history.
+func (m *UI) newSession() tea.Cmd {
if !m.hasSession() {
- return
+ return nil
}
m.session = nil
@@ -2786,6 +2831,8 @@ func (m *UI) newSession() {
m.pillsExpanded = false
m.promptQueue = 0
m.pillsView = ""
+ m.historyReset()
+ return m.loadPromptHistory()
}
// handlePasteMsg handles a paste message.