Merge pull request #176 from charmbracelet/escape-sequences

Kujtim Hoxha created

fix(tui): escape ANSI escape sequences and control characters in tool

Change summary

internal/llm/tools/view.go                             | 24 -----
internal/tui/components/chat/messages/renderer.go      | 45 ++++++++++
internal/tui/components/chat/messages/renderer_test.go | 52 ++++++++++++
3 files changed, 98 insertions(+), 23 deletions(-)

Detailed changes

internal/llm/tools/view.go ๐Ÿ”—

@@ -238,7 +238,7 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
 
 	lineCount := 0
 
-	scanner := NewLineScanner(file)
+	scanner := bufio.NewScanner(file)
 	if offset > 0 {
 		for lineCount < offset && scanner.Scan() {
 			lineCount++
@@ -298,25 +298,3 @@ func isImageFile(filePath string) (bool, string) {
 		return false, ""
 	}
 }
-
-type LineScanner struct {
-	scanner *bufio.Scanner
-}
-
-func NewLineScanner(r io.Reader) *LineScanner {
-	return &LineScanner{
-		scanner: bufio.NewScanner(r),
-	}
-}
-
-func (s *LineScanner) Scan() bool {
-	return s.scanner.Scan()
-}
-
-func (s *LineScanner) Text() string {
-	return s.scanner.Text()
-}
-
-func (s *LineScanner) Err() error {
-	return s.scanner.Err()
-}

internal/tui/components/chat/messages/renderer.go ๐Ÿ”—

@@ -3,6 +3,7 @@ package messages
 import (
 	"encoding/json"
 	"fmt"
+	"strconv"
 	"strings"
 	"time"
 
@@ -656,6 +657,7 @@ func joinHeaderBody(header, body string) string {
 func renderPlainContent(v *toolCallCmp, content string) string {
 	t := styles.CurrentTheme()
 	content = strings.TrimSpace(content)
+	content = escapeContent(t, content)
 	lines := strings.Split(content, "\n")
 
 	width := v.textWidth() - 2 // -2 for left padding
@@ -694,6 +696,7 @@ func pad(v any, width int) string {
 
 func renderCodeContent(v *toolCallCmp, path, content string, offset int) string {
 	t := styles.CurrentTheme()
+	content = escapeContent(t, content)
 	truncated := truncateHeight(content, responseContextHeight)
 
 	highlighted, _ := highlight.SyntaxHighlight(truncated, path, t.BgBase)
@@ -766,3 +769,45 @@ func prettifyToolName(name string) string {
 		return name
 	}
 }
+
+// escapeContent escapes ANSI escape sequences and control characters in the
+// content and styles it for display in the terminal.
+func escapeContent(t *styles.Theme, content string) string {
+	content = strings.ReplaceAll(content, "\r\n", "\n")
+	lines := strings.Split(content, "\n")
+	for i, line := range lines {
+		lines[i] = escapeLine(t, line)
+	}
+
+	content = strings.Join(lines, "\n")
+	return content
+}
+
+// escapeLine escapes ANSI escape sequences and control characters and styles
+// them for display in the terminal.
+func escapeLine(t *styles.Theme, text string) string {
+	var (
+		sb    strings.Builder
+		state byte
+		seq   string
+		n     int
+		w     int
+	)
+	var faint lipgloss.Style
+	if t != nil {
+		faint = t.S().Muted.Faint(true)
+	}
+	for len(text) > 0 {
+		seq, w, n, state = ansi.DecodeSequence(text, state, nil)
+		if w > 0 {
+			sb.WriteString(seq)
+		} else {
+			quote := strconv.Quote(seq)
+			quote = strings.TrimPrefix(quote, "\"")
+			quote = strings.TrimSuffix(quote, "\"")
+			sb.WriteString(faint.Render(quote))
+		}
+		text = text[n:]
+	}
+	return sb.String()
+}

internal/tui/components/chat/messages/renderer_test.go ๐Ÿ”—

@@ -0,0 +1,52 @@
+package messages
+
+import (
+	"testing"
+)
+
+func TestEscapeContent(t *testing.T) {
+	cases := []struct {
+		name     string
+		input    string
+		expected string
+	}{
+		{
+			name:     "nothing to escape",
+			input:    "Hello, World!",
+			expected: "Hello, World!",
+		},
+		{
+			name:     "escape csi sequences",
+			input:    "\x1b[31mRed Text\x1b[0m",
+			expected: "\\x1b[31mRed Text\\x1b[0m",
+		},
+		{
+			name:     "escape control characters",
+			input:    "Hello\x00World\x7f!",
+			expected: "Hello\\x00World\\x7f!",
+		},
+		{
+			name:     "escape csi sequences with control characters",
+			input:    "\x1b[31mHello\x00World\x7f!\x1b[0m",
+			expected: "\\x1b[31mHello\\x00World\\x7f!\\x1b[0m",
+		},
+		{
+			name:     "just unicode",
+			input:    "ใ“ใ‚“ใซใกใฏ", // "Hello" in Japanese
+			expected: "ใ“ใ‚“ใซใกใฏ",
+		},
+		{
+			name:     "unicode with csi sequences and control characters",
+			input:    "\x1b[31mใ“ใ‚“ใซใกใฏ\x00World\x7f!\x1b[0m",
+			expected: "\\x1b[31mใ“ใ‚“ใซใกใฏ\\x00World\\x7f!\\x1b[0m",
+		},
+	}
+	for i, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			result := escapeContent(nil, c.input)
+			if result != c.expected {
+				t.Errorf("case %d, expected %q, got %q", i+1, c.expected, result)
+			}
+		})
+	}
+}