feat: steering (#605)

Kujtim Hoxha created

* feat: enable steering the model while in the loop
* feat: show queued messages in the TUI

Change summary

internal/llm/agent/agent.go                   | 62 ++++++++++++++++++++
internal/tui/components/chat/chat.go          | 19 ++++++
internal/tui/components/chat/editor/editor.go |  7 --
internal/tui/components/chat/queue.go         | 28 +++++++++
internal/tui/page/chat/chat.go                | 10 +++
internal/tui/styles/theme.go                  | 54 +++++++++--------
6 files changed, 146 insertions(+), 34 deletions(-)

Detailed changes

internal/llm/agent/agent.go πŸ”—

@@ -60,6 +60,8 @@ type Service interface {
 	IsBusy() bool
 	Summarize(ctx context.Context, sessionID string) error
 	UpdateModel() error
+	QueuedPrompts(sessionID string) int
+	ClearQueue(sessionID string)
 }
 
 type agent struct {
@@ -79,6 +81,8 @@ type agent struct {
 	summarizeProviderID string
 
 	activeRequests *csync.Map[string, context.CancelFunc]
+
+	promptQueue *csync.Map[string, []string]
 }
 
 var agentPromptMap = map[string]prompt.PromptID{
@@ -228,6 +232,7 @@ func NewAgent(
 		summarizeProviderID: string(providerCfg.ID),
 		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 		tools:               csync.NewLazySlice(toolFn),
+		promptQueue:         csync.NewMap[string, []string](),
 	}, nil
 }
 
@@ -247,6 +252,11 @@ func (a *agent) Cancel(sessionID string) {
 		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 		cancel()
 	}
+
+	if a.QueuedPrompts(sessionID) > 0 {
+		slog.Info("Clearing queued prompts", "session_id", sessionID)
+		a.promptQueue.Del(sessionID)
+	}
 }
 
 func (a *agent) IsBusy() bool {
@@ -265,6 +275,14 @@ func (a *agent) IsSessionBusy(sessionID string) bool {
 	return busy
 }
 
+func (a *agent) QueuedPrompts(sessionID string) int {
+	l, ok := a.promptQueue.Get(sessionID)
+	if !ok {
+		return 0
+	}
+	return len(l)
+}
+
 func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 	if content == "" {
 		return nil
@@ -327,7 +345,13 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
 	}
 	events := make(chan AgentEvent)
 	if a.IsSessionBusy(sessionID) {
-		return nil, ErrSessionBusy
+		existing, ok := a.promptQueue.Get(sessionID)
+		if !ok {
+			existing = []string{}
+		}
+		existing = append(existing, content)
+		a.promptQueue.Set(sessionID, existing)
+		return nil, nil
 	}
 
 	genCtx, cancel := context.WithCancel(ctx)
@@ -422,7 +446,36 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 			// We are not done, we need to respond with the tool response
 			msgHistory = append(msgHistory, agentMessage, *toolResults)
+			// If there are queued prompts, process the next one
+			nextPrompt, ok := a.promptQueue.Take(sessionID)
+			if ok {
+				for _, prompt := range nextPrompt {
+					// Create a new user message for the queued prompt
+					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
+					if err != nil {
+						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
+					}
+					// Append the new user message to the conversation history
+					msgHistory = append(msgHistory, userMsg)
+				}
+			}
+
 			continue
+		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
+			queuePrompts, ok := a.promptQueue.Take(sessionID)
+			if ok {
+				for _, prompt := range queuePrompts {
+					if prompt == "" {
+						continue
+					}
+					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
+					if err != nil {
+						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
+					}
+					msgHistory = append(msgHistory, userMsg)
+				}
+				continue
+			}
 		}
 		if agentMessage.FinishReason() == "" {
 			// Kujtim: could not track down where this is happening but this means its cancelled
@@ -852,6 +905,13 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 	return nil
 }
 
+func (a *agent) ClearQueue(sessionID string) {
+	if a.QueuedPrompts(sessionID) > 0 {
+		slog.Info("Clearing queued prompts", "session_id", sessionID)
+		a.promptQueue.Del(sessionID)
+	}
+}
+
 func (a *agent) CancelAll() {
 	if !a.IsBusy() {
 		return

internal/tui/components/chat/chat.go πŸ”—

@@ -18,6 +18,7 @@ import (
 	"github.com/charmbracelet/crush/internal/tui/exp/list"
 	"github.com/charmbracelet/crush/internal/tui/styles"
 	"github.com/charmbracelet/crush/internal/tui/util"
+	"github.com/charmbracelet/lipgloss/v2"
 )
 
 type SendMsg struct {
@@ -198,13 +199,29 @@ func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 // View renders the message list or an initial screen if empty.
 func (m *messageListCmp) View() string {
 	t := styles.CurrentTheme()
-	return t.S().Base.
+	listView := t.S().Base.
 		Padding(1, 1, 0, 1).
 		Width(m.width).
 		Height(m.height).
 		Render(
 			m.listCmp.View(),
 		)
+
+	if m.app.CoderAgent != nil && m.app.CoderAgent.QueuedPrompts(m.session.ID) > 0 {
+		queue := m.app.CoderAgent.QueuedPrompts(m.session.ID)
+		queuePill := queuePill(queue, t)
+		layers := []*lipgloss.Layer{
+			lipgloss.NewLayer(listView),
+			lipgloss.NewLayer(
+				queuePill,
+			).X(4).Y(m.height - 3),
+		}
+		canvas := lipgloss.NewCanvas(
+			layers...,
+		)
+		return canvas.Render()
+	}
+	return listView
 }
 
 func (m *messageListCmp) handlePermissionRequest(permission permission.PermissionNotification) tea.Cmd {

internal/tui/components/chat/editor/editor.go πŸ”—

@@ -138,13 +138,6 @@ func (m *editorCmp) Init() tea.Cmd {
 }
 
 func (m *editorCmp) send() tea.Cmd {
-	if m.app.CoderAgent == nil {
-		return util.ReportError(fmt.Errorf("coder agent is not initialized"))
-	}
-	if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
-		return util.ReportWarn("Agent is working, please wait...")
-	}
-
 	value := m.textarea.Value()
 	value = strings.TrimSpace(value)
 

internal/tui/components/chat/queue.go πŸ”—

@@ -0,0 +1,28 @@
+package chat
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/tui/styles"
+	"github.com/charmbracelet/lipgloss/v2"
+)
+
+func queuePill(queue int, t *styles.Theme) string {
+	if queue <= 0 {
+		return ""
+	}
+	triangles := styles.ForegroundGrad("β–Άβ–Άβ–Άβ–Άβ–Άβ–Άβ–Άβ–Άβ–Ά", false, t.RedDark, t.Accent)
+	if queue < 10 {
+		triangles = triangles[:queue]
+	}
+
+	allTriangles := strings.Join(triangles, "")
+
+	return t.S().Base.
+		BorderStyle(lipgloss.RoundedBorder()).
+		BorderForeground(t.BgOverlay).
+		PaddingLeft(1).
+		PaddingRight(1).
+		Render(fmt.Sprintf("%s %d Queued", allTriangles, queue))
+}

internal/tui/page/chat/chat.go πŸ”—

@@ -653,6 +653,10 @@ func (p *chatPage) cancel() tea.Cmd {
 		return nil
 	}
 
+	if p.app.CoderAgent != nil && p.app.CoderAgent.QueuedPrompts(p.session.ID) > 0 {
+		p.app.CoderAgent.ClearQueue(p.session.ID)
+		return nil
+	}
 	p.isCanceling = true
 	return cancelTimerCmd()
 }
@@ -828,6 +832,12 @@ func (p *chatPage) Help() help.KeyMap {
 					key.WithHelp("esc", "press again to cancel"),
 				)
 			}
+			if p.app.CoderAgent.QueuedPrompts(p.session.ID) > 0 {
+				cancelBinding = key.NewBinding(
+					key.WithKeys("esc"),
+					key.WithHelp("esc", "clear queue"),
+				)
+			}
 			shortList = append(shortList, cancelBinding)
 			fullList = append(fullList,
 				[]key.Binding{

internal/tui/styles/theme.go πŸ”—

@@ -591,18 +591,18 @@ func Lighten(c color.Color, percent float64) color.Color {
 	}
 }
 
-// ApplyForegroundGrad renders a given string with a horizontal gradient
-// foreground.
-func ApplyForegroundGrad(input string, color1, color2 color.Color) string {
+func ForegroundGrad(input string, bold bool, color1, color2 color.Color) []string {
 	if input == "" {
-		return ""
+		return []string{""}
 	}
-
-	var o strings.Builder
+	t := CurrentTheme()
 	if len(input) == 1 {
-		return lipgloss.NewStyle().Foreground(color1).Render(input)
+		style := t.S().Base.Foreground(color1)
+		if bold {
+			style.Bold(true)
+		}
+		return []string{style.Render(input)}
 	}
-
 	var clusters []string
 	gr := uniseg.NewGraphemes(input)
 	for gr.Next() {
@@ -611,9 +611,26 @@ func ApplyForegroundGrad(input string, color1, color2 color.Color) string {
 
 	ramp := blendColors(len(clusters), color1, color2)
 	for i, c := range ramp {
-		fmt.Fprint(&o, CurrentTheme().S().Base.Foreground(c).Render(clusters[i]))
+		style := t.S().Base.Foreground(c)
+		if bold {
+			style.Bold(true)
+		}
+		clusters[i] = style.Render(clusters[i])
 	}
+	return clusters
+}
 
+// ApplyForegroundGrad renders a given string with a horizontal gradient
+// foreground.
+func ApplyForegroundGrad(input string, color1, color2 color.Color) string {
+	if input == "" {
+		return ""
+	}
+	var o strings.Builder
+	clusters := ForegroundGrad(input, false, color1, color2)
+	for _, c := range clusters {
+		fmt.Fprint(&o, c)
+	}
 	return o.String()
 }
 
@@ -623,24 +640,11 @@ func ApplyBoldForegroundGrad(input string, color1, color2 color.Color) string {
 	if input == "" {
 		return ""
 	}
-	t := CurrentTheme()
-
 	var o strings.Builder
-	if len(input) == 1 {
-		return t.S().Base.Bold(true).Foreground(color1).Render(input)
-	}
-
-	var clusters []string
-	gr := uniseg.NewGraphemes(input)
-	for gr.Next() {
-		clusters = append(clusters, string(gr.Runes()))
-	}
-
-	ramp := blendColors(len(clusters), color1, color2)
-	for i, c := range ramp {
-		fmt.Fprint(&o, t.S().Base.Bold(true).Foreground(c).Render(clusters[i]))
+	clusters := ForegroundGrad(input, true, color1, color2)
+	for _, c := range clusters {
+		fmt.Fprint(&o, c)
 	}
-
 	return o.String()
 }