perf: shell commands using exponential backoff

Raphael Amorim created

Change summary

internal/llm/provider/openai.go             |  2 
internal/llm/tools/shell/comparison_test.go | 83 +++++++++++++++++++++++
internal/llm/tools/shell/shell.go           | 18 ++++
internal/llm/tools/shell/shell_test.go      | 54 ++++++++++++++
4 files changed, 155 insertions(+), 2 deletions(-)

Detailed changes

internal/llm/provider/openai.go 🔗

@@ -166,7 +166,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
 		Tools:    tools,
 	}
 
-	if o.providerOptions.model.CanReason == true {
+	if o.providerOptions.model.CanReason {
 		params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
 		switch o.options.reasoningEffort {
 		case "low":

internal/llm/tools/shell/comparison_test.go 🔗

@@ -0,0 +1,83 @@
+package shell
+
+import (
+	"context"
+	"os"
+	"runtime"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestShellPerformanceComparison(t *testing.T) {
+	tmpDir, err := os.MkdirTemp("", "shell-test")
+	require.NoError(t, err)
+	defer os.RemoveAll(tmpDir)
+
+	shell := GetPersistentShell(tmpDir)
+	defer shell.Close()
+
+	// Test quick command
+	start := time.Now()
+	stdout, stderr, exitCode, _, err := shell.Exec(context.Background(), "echo 'hello'", 0)
+	duration := time.Since(start)
+
+	require.NoError(t, err)
+	assert.Equal(t, 0, exitCode)
+	assert.Contains(t, stdout, "hello")
+	assert.Empty(t, stderr)
+	
+	t.Logf("Quick command took: %v", duration)
+}
+
+func TestShellCPUUsageComparison(t *testing.T) {
+	tmpDir, err := os.MkdirTemp("", "shell-test")
+	require.NoError(t, err)
+	defer os.RemoveAll(tmpDir)
+
+	shell := GetPersistentShell(tmpDir)
+	defer shell.Close()
+
+	// Measure CPU and memory usage during a longer command
+	var m1, m2 runtime.MemStats
+	runtime.GC()
+	runtime.ReadMemStats(&m1)
+
+	start := time.Now()
+	_, stderr, exitCode, _, err := shell.Exec(context.Background(), "sleep 0.1", 1000)
+	duration := time.Since(start)
+	
+	runtime.ReadMemStats(&m2)
+
+	require.NoError(t, err)
+	assert.Equal(t, 0, exitCode)
+	assert.Empty(t, stderr)
+	
+	memGrowth := m2.Alloc - m1.Alloc
+	t.Logf("Sleep 0.1s command took: %v", duration)
+	t.Logf("Memory growth during polling: %d bytes", memGrowth)
+	t.Logf("GC cycles during test: %d", m2.NumGC-m1.NumGC)
+}
+
+// Benchmark CPU usage during polling
+func BenchmarkShellPolling(b *testing.B) {
+	tmpDir, err := os.MkdirTemp("", "shell-bench")
+	require.NoError(b, err)
+	defer os.RemoveAll(tmpDir)
+
+	shell := GetPersistentShell(tmpDir)
+	defer shell.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	
+	for i := 0; i < b.N; i++ {
+		// Use a short sleep to measure polling overhead
+		_, _, exitCode, _, err := shell.Exec(context.Background(), "sleep 0.02", 500)
+		if err != nil || exitCode != 0 {
+			b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
+		}
+	}
+}

internal/llm/tools/shell/shell.go 🔗

@@ -189,6 +189,13 @@ echo $EXEC_EXIT_CODE > %s
 
 	done := make(chan bool)
 	go func() {
+		// Use exponential backoff polling
+		pollInterval := 1 * time.Millisecond
+		maxPollInterval := 100 * time.Millisecond
+		
+		ticker := time.NewTicker(pollInterval)
+		defer ticker.Stop()
+		
 		for {
 			select {
 			case <-ctx.Done():
@@ -197,7 +204,7 @@ echo $EXEC_EXIT_CODE > %s
 				done <- true
 				return
 
-			case <-time.After(10 * time.Millisecond):
+			case <-ticker.C:
 				if fileExists(statusFile) && fileSize(statusFile) > 0 {
 					done <- true
 					return
@@ -212,6 +219,15 @@ echo $EXEC_EXIT_CODE > %s
 						return
 					}
 				}
+				
+				// Exponential backoff to reduce CPU usage for longer-running commands
+				if pollInterval < maxPollInterval {
+					pollInterval = time.Duration(float64(pollInterval) * 1.5)
+					if pollInterval > maxPollInterval {
+						pollInterval = maxPollInterval
+					}
+					ticker.Reset(pollInterval)
+				}
 			}
 		}
 	}()

internal/llm/tools/shell/shell_test.go 🔗

@@ -0,0 +1,54 @@
+package shell
+
+import (
+	"context"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestShellPerformanceImprovement(t *testing.T) {
+	// Create a temporary directory for the shell
+	tmpDir, err := os.MkdirTemp("", "shell-test")
+	require.NoError(t, err)
+	defer os.RemoveAll(tmpDir)
+
+	shell := GetPersistentShell(tmpDir)
+	defer shell.Close()
+
+	// Test that quick commands complete fast
+	start := time.Now()
+	stdout, stderr, exitCode, _, err := shell.Exec(context.Background(), "echo 'hello world'", 0)
+	duration := time.Since(start)
+
+	require.NoError(t, err)
+	assert.Equal(t, 0, exitCode)
+	assert.Contains(t, stdout, "hello world")
+	assert.Empty(t, stderr)
+	
+	// Quick commands should complete very fast with our exponential backoff
+	assert.Less(t, duration, 50*time.Millisecond, "Quick command should complete fast with exponential backoff")
+}
+
+// Benchmark to measure CPU efficiency
+func BenchmarkShellQuickCommands(b *testing.B) {
+	tmpDir, err := os.MkdirTemp("", "shell-bench")
+	require.NoError(b, err)
+	defer os.RemoveAll(tmpDir)
+
+	shell := GetPersistentShell(tmpDir)
+	defer shell.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	
+	for i := 0; i < b.N; i++ {
+		_, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
+		if err != nil || exitCode != 0 {
+			b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
+		}
+	}
+}