feat: implement prompt history (#2005)

Kujtim Hoxha created

Change summary

internal/db/db.go               |  20 +++
internal/db/messages.sql.go     |  82 +++++++++++++++
internal/db/querier.go          |   2 
internal/db/sql/messages.sql    |  12 ++
internal/message/message.go     |  32 ++++++
internal/ui/model/history.go    | 184 +++++++++++++++++++++++++++++++++++
internal/ui/model/keys.go       |  10 +
internal/ui/model/onboarding.go |   6 
internal/ui/model/ui.go         |  61 ++++++++++-
9 files changed, 400 insertions(+), 9 deletions(-)

Detailed changes

internal/db/db.go 🔗

@@ -87,6 +87,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
 	if q.getUsageByModelStmt, err = db.PrepareContext(ctx, getUsageByModel); err != nil {
 		return nil, fmt.Errorf("error preparing query GetUsageByModel: %w", err)
 	}
+	if q.listAllUserMessagesStmt, err = db.PrepareContext(ctx, listAllUserMessages); err != nil {
+		return nil, fmt.Errorf("error preparing query ListAllUserMessages: %w", err)
+	}
 	if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil {
 		return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err)
 	}
@@ -105,6 +108,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
 	if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
 		return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
 	}
+	if q.listUserMessagesBySessionStmt, err = db.PrepareContext(ctx, listUserMessagesBySession); err != nil {
+		return nil, fmt.Errorf("error preparing query ListUserMessagesBySession: %w", err)
+	}
 	if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
 		return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
 	}
@@ -224,6 +230,11 @@ func (q *Queries) Close() error {
 			err = fmt.Errorf("error closing getUsageByModelStmt: %w", cerr)
 		}
 	}
+	if q.listAllUserMessagesStmt != nil {
+		if cerr := q.listAllUserMessagesStmt.Close(); cerr != nil {
+			err = fmt.Errorf("error closing listAllUserMessagesStmt: %w", cerr)
+		}
+	}
 	if q.listFilesByPathStmt != nil {
 		if cerr := q.listFilesByPathStmt.Close(); cerr != nil {
 			err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr)
@@ -254,6 +265,11 @@ func (q *Queries) Close() error {
 			err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
 		}
 	}
+	if q.listUserMessagesBySessionStmt != nil {
+		if cerr := q.listUserMessagesBySessionStmt.Close(); cerr != nil {
+			err = fmt.Errorf("error closing listUserMessagesBySessionStmt: %w", cerr)
+		}
+	}
 	if q.updateMessageStmt != nil {
 		if cerr := q.updateMessageStmt.Close(); cerr != nil {
 			err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
@@ -329,12 +345,14 @@ type Queries struct {
 	getUsageByDayOfWeekStmt        *sql.Stmt
 	getUsageByHourStmt             *sql.Stmt
 	getUsageByModelStmt            *sql.Stmt
+	listAllUserMessagesStmt        *sql.Stmt
 	listFilesByPathStmt            *sql.Stmt
 	listFilesBySessionStmt         *sql.Stmt
 	listLatestSessionFilesStmt     *sql.Stmt
 	listMessagesBySessionStmt      *sql.Stmt
 	listNewFilesStmt               *sql.Stmt
 	listSessionsStmt               *sql.Stmt
+	listUserMessagesBySessionStmt  *sql.Stmt
 	updateMessageStmt              *sql.Stmt
 	updateSessionStmt              *sql.Stmt
 	updateSessionTitleAndUsageStmt *sql.Stmt
@@ -365,12 +383,14 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
 		getUsageByDayOfWeekStmt:        q.getUsageByDayOfWeekStmt,
 		getUsageByHourStmt:             q.getUsageByHourStmt,
 		getUsageByModelStmt:            q.getUsageByModelStmt,
+		listAllUserMessagesStmt:        q.listAllUserMessagesStmt,
 		listFilesByPathStmt:            q.listFilesByPathStmt,
 		listFilesBySessionStmt:         q.listFilesBySessionStmt,
 		listLatestSessionFilesStmt:     q.listLatestSessionFilesStmt,
 		listMessagesBySessionStmt:      q.listMessagesBySessionStmt,
 		listNewFilesStmt:               q.listNewFilesStmt,
 		listSessionsStmt:               q.listSessionsStmt,
+		listUserMessagesBySessionStmt:  q.listUserMessagesBySessionStmt,
 		updateMessageStmt:              q.updateMessageStmt,
 		updateSessionStmt:              q.updateSessionStmt,
 		updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,

internal/db/messages.sql.go 🔗

@@ -107,6 +107,47 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
 	return i, err
 }
 
+const listAllUserMessages = `-- name: ListAllUserMessages :many
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+FROM messages
+WHERE role = 'user'
+ORDER BY created_at DESC
+`
+
+func (q *Queries) ListAllUserMessages(ctx context.Context) ([]Message, error) {
+	rows, err := q.query(ctx, q.listAllUserMessagesStmt, listAllUserMessages)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+	items := []Message{}
+	for rows.Next() {
+		var i Message
+		if err := rows.Scan(
+			&i.ID,
+			&i.SessionID,
+			&i.Role,
+			&i.Parts,
+			&i.Model,
+			&i.CreatedAt,
+			&i.UpdatedAt,
+			&i.FinishedAt,
+			&i.Provider,
+			&i.IsSummaryMessage,
+		); err != nil {
+			return nil, err
+		}
+		items = append(items, i)
+	}
+	if err := rows.Close(); err != nil {
+		return nil, err
+	}
+	if err := rows.Err(); err != nil {
+		return nil, err
+	}
+	return items, nil
+}
+
 const listMessagesBySession = `-- name: ListMessagesBySession :many
 SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
 FROM messages
@@ -148,6 +189,47 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
 	return items, nil
 }
 
+const listUserMessagesBySession = `-- name: ListUserMessagesBySession :many
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+FROM messages
+WHERE session_id = ? AND role = 'user'
+ORDER BY created_at DESC
+`
+
+func (q *Queries) ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) {
+	rows, err := q.query(ctx, q.listUserMessagesBySessionStmt, listUserMessagesBySession, sessionID)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+	items := []Message{}
+	for rows.Next() {
+		var i Message
+		if err := rows.Scan(
+			&i.ID,
+			&i.SessionID,
+			&i.Role,
+			&i.Parts,
+			&i.Model,
+			&i.CreatedAt,
+			&i.UpdatedAt,
+			&i.FinishedAt,
+			&i.Provider,
+			&i.IsSummaryMessage,
+		); err != nil {
+			return nil, err
+		}
+		items = append(items, i)
+	}
+	if err := rows.Close(); err != nil {
+		return nil, err
+	}
+	if err := rows.Err(); err != nil {
+		return nil, err
+	}
+	return items, nil
+}
+
 const updateMessage = `-- name: UpdateMessage :exec
 UPDATE messages
 SET

internal/db/querier.go 🔗

@@ -30,12 +30,14 @@ type Querier interface {
 	GetUsageByDayOfWeek(ctx context.Context) ([]GetUsageByDayOfWeekRow, error)
 	GetUsageByHour(ctx context.Context) ([]GetUsageByHourRow, error)
 	GetUsageByModel(ctx context.Context) ([]GetUsageByModelRow, error)
+	ListAllUserMessages(ctx context.Context) ([]Message, error)
 	ListFilesByPath(ctx context.Context, path string) ([]File, error)
 	ListFilesBySession(ctx context.Context, sessionID string) ([]File, error)
 	ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
 	ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
 	ListNewFiles(ctx context.Context) ([]File, error)
 	ListSessions(ctx context.Context) ([]Session, error)
+	ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
 	UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
 	UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
 	UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error

internal/db/sql/messages.sql 🔗

@@ -41,3 +41,15 @@ WHERE id = ?;
 -- name: DeleteSessionMessages :exec
 DELETE FROM messages
 WHERE session_id = ?;
+
+-- name: ListUserMessagesBySession :many
+SELECT *
+FROM messages
+WHERE session_id = ? AND role = 'user'
+ORDER BY created_at DESC;
+
+-- name: ListAllUserMessages :many
+SELECT *
+FROM messages
+WHERE role = 'user'
+ORDER BY created_at DESC;

internal/message/message.go 🔗

@@ -26,6 +26,8 @@ type Service interface {
 	Update(ctx context.Context, message Message) error
 	Get(ctx context.Context, id string) (Message, error)
 	List(ctx context.Context, sessionID string) ([]Message, error)
+	ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
+	ListAllUserMessages(ctx context.Context) ([]Message, error)
 	Delete(ctx context.Context, id string) error
 	DeleteSessionMessages(ctx context.Context, sessionID string) error
 }
@@ -157,6 +159,36 @@ func (s *service) List(ctx context.Context, sessionID string) ([]Message, error)
 	return messages, nil
 }
 
+func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
+	dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
+	if err != nil {
+		return nil, err
+	}
+	messages := make([]Message, len(dbMessages))
+	for i, dbMessage := range dbMessages {
+		messages[i], err = s.fromDBItem(dbMessage)
+		if err != nil {
+			return nil, err
+		}
+	}
+	return messages, nil
+}
+
+func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
+	dbMessages, err := s.q.ListAllUserMessages(ctx)
+	if err != nil {
+		return nil, err
+	}
+	messages := make([]Message, len(dbMessages))
+	for i, dbMessage := range dbMessages {
+		messages[i], err = s.fromDBItem(dbMessage)
+		if err != nil {
+			return nil, err
+		}
+	}
+	return messages, nil
+}
+
 func (s *service) fromDBItem(item db.Message) (Message, error) {
 	parts, err := unmarshalParts([]byte(item.Parts))
 	if err != nil {

internal/ui/model/history.go 🔗

@@ -0,0 +1,184 @@
+package model
+
+import (
+	"context"
+	"log/slog"
+
+	tea "charm.land/bubbletea/v2"
+
+	"github.com/charmbracelet/crush/internal/message"
+)
+
+// promptHistoryLoadedMsg is sent when prompt history is loaded.
+type promptHistoryLoadedMsg struct {
+	messages []string
+}
+
+// loadPromptHistory loads user messages for history navigation.
+func (m *UI) loadPromptHistory() tea.Cmd {
+	return func() tea.Msg {
+		ctx := context.Background()
+		var messages []message.Message
+		var err error
+
+		if m.session != nil {
+			messages, err = m.com.App.Messages.ListUserMessages(ctx, m.session.ID)
+		} else {
+			messages, err = m.com.App.Messages.ListAllUserMessages(ctx)
+		}
+		if err != nil {
+			slog.Error("failed to load prompt history", "error", err)
+			return promptHistoryLoadedMsg{messages: nil}
+		}
+
+		texts := make([]string, 0, len(messages))
+		for _, msg := range messages {
+			if text := msg.Content().Text; text != "" {
+				texts = append(texts, text)
+			}
+		}
+		return promptHistoryLoadedMsg{messages: texts}
+	}
+}
+
+// handleHistoryUp handles up arrow for history navigation.
+func (m *UI) handleHistoryUp(msg tea.Msg) tea.Cmd {
+	// Navigate to older history entry from cursor position (0,0).
+	if m.textarea.Length() == 0 || m.isAtEditorStart() {
+		if m.historyPrev() {
+			// we send this so that the textarea moves the view to the correct position
+			// without this the cursor will show up in the wrong place.
+			ta, cmd := m.textarea.Update(nil)
+			m.textarea = ta
+			return cmd
+		}
+	}
+
+	// First move cursor to start before entering history.
+	if m.textarea.Line() == 0 {
+		m.textarea.CursorStart()
+		return nil
+	}
+
+	// Let textarea handle normal cursor movement.
+	ta, cmd := m.textarea.Update(msg)
+	m.textarea = ta
+	return cmd
+}
+
+// handleHistoryDown handles down arrow for history navigation.
+func (m *UI) handleHistoryDown(msg tea.Msg) tea.Cmd {
+	// Navigate to newer history entry from end of text.
+	if m.isAtEditorEnd() {
+		if m.historyNext() {
+			// we send this so that the textarea moves the view to the correct position
+			// without this the cursor will show up in the wrong place.
+			ta, cmd := m.textarea.Update(nil)
+			m.textarea = ta
+			return cmd
+		}
+	}
+
+	// First move cursor to end before navigating history.
+	if m.textarea.Line() == max(m.textarea.LineCount()-1, 0) {
+		m.textarea.MoveToEnd()
+		ta, cmd := m.textarea.Update(nil)
+		m.textarea = ta
+		return cmd
+	}
+
+	// Let textarea handle normal cursor movement.
+	ta, cmd := m.textarea.Update(msg)
+	m.textarea = ta
+	return cmd
+}
+
+// handleHistoryEscape handles escape for exiting history navigation.
+func (m *UI) handleHistoryEscape(msg tea.Msg) tea.Cmd {
+	// Return to current draft when browsing history.
+	if m.promptHistory.index >= 0 {
+		m.promptHistory.index = -1
+		m.textarea.Reset()
+		m.textarea.InsertString(m.promptHistory.draft)
+		ta, cmd := m.textarea.Update(nil)
+		m.textarea = ta
+		return cmd
+	}
+
+	// Let textarea handle escape normally.
+	ta, cmd := m.textarea.Update(msg)
+	m.textarea = ta
+	return cmd
+}
+
+// updateHistoryDraft updates history state when text is modified.
+func (m *UI) updateHistoryDraft(oldValue string) {
+	if m.textarea.Value() != oldValue {
+		m.promptHistory.draft = m.textarea.Value()
+		m.promptHistory.index = -1
+	}
+}
+
+// historyPrev changes the text area content to the previous message in the history
+// it returns false if it could not find the previous message.
+func (m *UI) historyPrev() bool {
+	if len(m.promptHistory.messages) == 0 {
+		return false
+	}
+	if m.promptHistory.index == -1 {
+		m.promptHistory.draft = m.textarea.Value()
+	}
+	nextIndex := m.promptHistory.index + 1
+	if nextIndex >= len(m.promptHistory.messages) {
+		return false
+	}
+	m.promptHistory.index = nextIndex
+	m.textarea.Reset()
+	m.textarea.InsertString(m.promptHistory.messages[nextIndex])
+	m.textarea.MoveToBegin()
+	return true
+}
+
+// historyNext changes the text area content to the next message in the history
+// it returns false if it could not find the next message.
+func (m *UI) historyNext() bool {
+	if m.promptHistory.index < 0 {
+		return false
+	}
+	nextIndex := m.promptHistory.index - 1
+	if nextIndex < 0 {
+		m.promptHistory.index = -1
+		m.textarea.Reset()
+		m.textarea.InsertString(m.promptHistory.draft)
+		return true
+	}
+	m.promptHistory.index = nextIndex
+	m.textarea.Reset()
+	m.textarea.InsertString(m.promptHistory.messages[nextIndex])
+	return true
+}
+
+// historyReset resets the history, but does not clear the message
+// it just sets the current draft to empty and the position in the history.
+func (m *UI) historyReset() {
+	m.promptHistory.index = -1
+	m.promptHistory.draft = ""
+}
+
+// isAtEditorStart returns true if we are at the 0 line and 0 col in the textarea.
+func (m *UI) isAtEditorStart() bool {
+	return m.textarea.Line() == 0 && m.textarea.LineInfo().ColumnOffset == 0
+}
+
+// isAtEditorEnd returns true if we are in the last line and the last column in the textarea.
+func (m *UI) isAtEditorEnd() bool {
+	lineCount := m.textarea.LineCount()
+	if lineCount == 0 {
+		return true
+	}
+	if m.textarea.Line() != lineCount-1 {
+		return false
+	}
+	info := m.textarea.LineInfo()
+	return info.CharOffset >= info.CharWidth-1 || info.CharWidth == 0
+}

internal/ui/model/keys.go 🔗

@@ -15,6 +15,10 @@ type KeyMap struct {
 		AttachmentDeleteMode key.Binding
 		Escape               key.Binding
 		DeleteAllAttachments key.Binding
+
+		// History navigation
+		HistoryPrev key.Binding
+		HistoryNext key.Binding
 	}
 
 	Chat struct {
@@ -131,6 +135,12 @@ func DefaultKeyMap() KeyMap {
 		key.WithKeys("r"),
 		key.WithHelp("ctrl+r+r", "delete all attachments"),
 	)
+	km.Editor.HistoryPrev = key.NewBinding(
+		key.WithKeys("up"),
+	)
+	km.Editor.HistoryNext = key.NewBinding(
+		key.WithKeys("down"),
+	)
 
 	km.Chat.NewSession = key.NewBinding(
 		key.WithKeys("ctrl+n"),

internal/ui/model/onboarding.go 🔗

@@ -48,9 +48,11 @@ func (m *UI) updateInitializeView(msg tea.KeyPressMsg) (cmds []tea.Cmd) {
 // initializeProject starts project initialization and transitions to the landing view.
 func (m *UI) initializeProject() tea.Cmd {
 	// clear the session
-	m.newSession()
-	cfg := m.com.Config()
 	var cmds []tea.Cmd
+	if cmd := m.newSession(); cmd != nil {
+		cmds = append(cmds, cmd)
+	}
+	cfg := m.com.Config()
 
 	initialize := func() tea.Msg {
 		initPrompt, err := agent.InitializePrompt(*cfg)

internal/ui/model/ui.go 🔗

@@ -211,6 +211,13 @@ type UI struct {
 
 	// mouse highlighting related state
 	lastClickTime time.Time
+
+	// Prompt history for up/down navigation through previous messages.
+	promptHistory struct {
+		messages []string
+		index    int
+		draft    string
+	}
 }
 
 // New creates a new instance of the [UI] model.
@@ -307,6 +314,8 @@ func (m *UI) Init() tea.Cmd {
 	}
 	// load the user commands async
 	cmds = append(cmds, m.loadCustomCommands())
+	// load prompt history async
+	cmds = append(cmds, m.loadPromptHistory())
 	return tea.Batch(cmds...)
 }
 
@@ -390,6 +399,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			}
 			m.updateLayoutAndSize()
 		}
+		// Reload prompt history for the new session.
+		m.historyReset()
+		cmds = append(cmds, m.loadPromptHistory())
 
 	case sendMessageMsg:
 		cmds = append(cmds, m.sendMessage(msg.Content, msg.Attachments...))
@@ -417,13 +429,20 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			commands.SetMCPPrompts(m.mcpPrompts)
 		}
 
+	case promptHistoryLoadedMsg:
+		m.promptHistory.messages = msg.messages
+		m.promptHistory.index = -1
+		m.promptHistory.draft = ""
+
 	case closeDialogMsg:
 		m.dialog.CloseFrontDialog()
 
 	case pubsub.Event[session.Session]:
 		if msg.Type == pubsub.DeletedEvent {
 			if m.session != nil && m.session.ID == msg.Payload.ID {
-				m.newSession()
+				if cmd := m.newSession(); cmd != nil {
+					cmds = append(cmds, cmd)
+				}
 			}
 			break
 		}
@@ -1095,7 +1114,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 			cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait before starting a new session..."))
 			break
 		}
-		m.newSession()
+		if cmd := m.newSession(); cmd != nil {
+			cmds = append(cmds, cmd)
+		}
 		m.dialog.CloseDialog(dialog.CommandsID)
 	case dialog.ActionSummarize:
 		if m.isAgentBusy() {
@@ -1494,8 +1515,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 				}
 
 				m.randomizePlaceholders()
+				m.historyReset()
 
-				return m.sendMessage(value, attachments...)
+				return tea.Batch(m.sendMessage(value, attachments...), m.loadPromptHistory())
 			case key.Matches(msg, m.keyMap.Chat.NewSession):
 				if !m.hasSession() {
 					break
@@ -1504,7 +1526,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 					cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait before starting a new session..."))
 					break
 				}
-				m.newSession()
+				if cmd := m.newSession(); cmd != nil {
+					cmds = append(cmds, cmd)
+				}
 			case key.Matches(msg, m.keyMap.Tab):
 				if m.state != uiLanding {
 					m.setState(m.state, uiFocusMain)
@@ -1524,6 +1548,21 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 				ta, cmd := m.textarea.Update(msg)
 				m.textarea = ta
 				cmds = append(cmds, cmd)
+			case key.Matches(msg, m.keyMap.Editor.HistoryPrev):
+				cmd := m.handleHistoryUp(msg)
+				if cmd != nil {
+					cmds = append(cmds, cmd)
+				}
+			case key.Matches(msg, m.keyMap.Editor.HistoryNext):
+				cmd := m.handleHistoryDown(msg)
+				if cmd != nil {
+					cmds = append(cmds, cmd)
+				}
+			case key.Matches(msg, m.keyMap.Editor.Escape):
+				cmd := m.handleHistoryEscape(msg)
+				if cmd != nil {
+					cmds = append(cmds, cmd)
+				}
 			default:
 				if handleGlobalKeys(msg) {
 					// Handle global keys first before passing to textarea.
@@ -1557,6 +1596,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 				m.textarea = ta
 				cmds = append(cmds, cmd)
 
+				// Any text modification becomes the current draft.
+				m.updateHistoryDraft(curValue)
+
 				// After updating textarea, check if we need to filter completions.
 				// Skip filtering on the initial @ keystroke since items are loading async.
 				if m.completionsOpen && msg.String() != "@" {
@@ -1596,7 +1638,9 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 					break
 				}
 				m.focus = uiFocusEditor
-				m.newSession()
+				if cmd := m.newSession(); cmd != nil {
+					cmds = append(cmds, cmd)
+				}
 			case key.Matches(msg, m.keyMap.Chat.Expand):
 				m.chat.ToggleExpandedSelectedItem()
 			case key.Matches(msg, m.keyMap.Chat.Up):
@@ -2772,9 +2816,10 @@ func (m *UI) handlePermissionNotification(notification permission.PermissionNoti
 
 // newSession clears the current session state and prepares for a new session.
 // The actual session creation happens when the user sends their first message.
-func (m *UI) newSession() {
+// Returns a command to reload prompt history.
+func (m *UI) newSession() tea.Cmd {
 	if !m.hasSession() {
-		return
+		return nil
 	}
 
 	m.session = nil
@@ -2786,6 +2831,8 @@ func (m *UI) newSession() {
 	m.pillsExpanded = false
 	m.promptQueue = 0
 	m.pillsView = ""
+	m.historyReset()
+	return m.loadPromptHistory()
 }
 
 // handlePasteMsg handles a paste message.