Merge remote-tracking branch 'origin/list' into tool_improvements

Kujtim Hoxha created

Change summary

README.md                                          |  44 +++
internal/config/config.go                          |   3 
internal/config/load.go                            |   2 
internal/config/load_test.go                       |  29 ++
internal/config/resolve.go                         | 114 +++++++++-
internal/config/resolve_test.go                    | 177 +++++++++++++++
internal/llm/provider/anthropic.go                 |  55 ++++
internal/llm/provider/gemini.go                    |  12 
internal/llm/provider/openai.go                    |   6 
internal/llm/provider/provider.go                  |  44 ++-
internal/tui/components/chat/editor/editor.go      |  57 +++-
internal/tui/components/completions/completions.go | 100 ++++++--
internal/tui/exp/list/items.go                     |   5 
internal/tui/page/chat/chat.go                     |   4 
internal/tui/tui.go                                |  13 
15 files changed, 556 insertions(+), 109 deletions(-)

Detailed changes

README.md 🔗

@@ -135,7 +135,6 @@ Crush supports Model Context Protocol (MCP) servers through three transport type
 ### Logging
 
 Enable debug logging with the `-d` flag or in config. View logs with `crush logs`. Logs are stored in `.crush/logs/crush.log`.
-
 ```bash
 # Run with debug logging
 crush -d
@@ -186,16 +185,21 @@ The `allowed_tools` array accepts:
 
 You can also skip all permission prompts entirely by running Crush with the `--yolo` flag.
 
-### OpenAI-Compatible APIs
+### Custom Providers
+
+Crush supports custom provider configurations for both OpenAI-compatible and Anthropic-compatible APIs.
+
+#### OpenAI-Compatible APIs
 
-Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
+Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
 
 ```json
 {
   "providers": {
     "deepseek": {
-      "provider_type": "openai",
+      "type": "openai",
       "base_url": "https://api.deepseek.com/v1",
+      "api_key": "$DEEPSEEK_API_KEY",
       "models": [
         {
           "id": "deepseek-chat",
@@ -213,6 +217,38 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D
 }
 ```
 
+#### Anthropic-Compatible APIs
+
+You can also configure custom Anthropic-compatible providers:
+
+```json
+{
+  "providers": {
+    "custom-anthropic": {
+      "type": "anthropic",
+      "base_url": "https://api.anthropic.com/v1",
+      "api_key": "$ANTHROPIC_API_KEY",
+      "extra_headers": {
+        "anthropic-version": "2023-06-01"
+      },
+      "models": [
+        {
+          "id": "claude-3-sonnet",
+          "model": "Claude 3 Sonnet",
+          "cost_per_1m_in": 3000,
+          "cost_per_1m_out": 15000,
+          "cost_per_1m_in_cached": 300,
+          "cost_per_1m_out_cached": 15000,
+          "context_window": 200000,
+          "default_max_tokens": 4096,
+          "supports_attachments": true
+        }
+      ]
+    }
+  }
+}
+```
+
 ## Whatcha think?
 
 We’d love to hear your thoughts on this project. Feel free to drop us a note!

internal/config/config.go 🔗

@@ -77,6 +77,9 @@ type ProviderConfig struct {
 	// Marks the provider as disabled.
 	Disable bool `json:"disable,omitempty"`
 
+	// Custom system prompt prefix.
+	SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
+
 	// Extra headers to send with each request to the provider.
 	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 	// Extra body

internal/config/load.go 🔗

@@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
 			c.Providers.Del(id)
 			continue
 		}
-		if providerConfig.Type != catwalk.TypeOpenAI {
+		if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
 			slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
 			c.Providers.Del(id)
 			continue

internal/config/load_test.go 🔗

@@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 		assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 	})
 
+	t.Run("custom anthropic provider is supported", func(t *testing.T) {
+		cfg := &Config{
+			Providers: csync.NewMapFrom(map[string]ProviderConfig{
+				"custom-anthropic": {
+					APIKey:  "test-key",
+					BaseURL: "https://api.anthropic.com/v1",
+					Type:    catwalk.TypeAnthropic,
+					Models: []catwalk.Model{{
+						ID: "claude-3-sonnet",
+					}},
+				},
+			}),
+		}
+		cfg.setDefaults("/tmp")
+
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+		assert.NoError(t, err)
+
+		assert.Equal(t, cfg.Providers.Len(), 1)
+		customProvider, exists := cfg.Providers.Get("custom-anthropic")
+		assert.True(t, exists)
+		assert.Equal(t, "custom-anthropic", customProvider.ID)
+		assert.Equal(t, "test-key", customProvider.APIKey)
+		assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
+		assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
+	})
+
 	t.Run("disabled custom provider is removed", func(t *testing.T) {
 		cfg := &Config{
 			Providers: csync.NewMapFrom(map[string]ProviderConfig{

internal/config/resolve.go 🔗

@@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver {
 }
 
 // ResolveValue is a method for resolving values, such as environment variables.
-// it will expect strings that start with `$` to be resolved as environment variables or shell commands.
-// if the string does not start with `$`, it will return the string as is.
+// it will resolve shell-like variable substitution anywhere in the string, including:
+// - $(command) for command substitution
+// - $VAR or ${VAR} for environment variables
 func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
-	if !strings.HasPrefix(value, "$") {
+	// Special case: lone $ is an error (backward compatibility)
+	if value == "$" {
+		return "", fmt.Errorf("invalid value format: %s", value)
+	}
+
+	// If no $ found, return as-is
+	if !strings.Contains(value, "$") {
 		return value, nil
 	}
 
-	if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
-		command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
+	result := value
+
+	// Handle command substitution: $(command)
+	for {
+		start := strings.Index(result, "$(")
+		if start == -1 {
+			break
+		}
+
+		// Find matching closing parenthesis
+		depth := 0
+		end := -1
+		for i := start + 2; i < len(result); i++ {
+			if result[i] == '(' {
+				depth++
+			} else if result[i] == ')' {
+				if depth == 0 {
+					end = i
+					break
+				}
+				depth--
+			}
+		}
+
+		if end == -1 {
+			return "", fmt.Errorf("unmatched $( in value: %s", value)
+		}
+
+		command := result[start+2 : end]
 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
-		defer cancel()
 
 		stdout, _, err := r.shell.Exec(ctx, command)
+		cancel()
 		if err != nil {
-			return "", fmt.Errorf("command execution failed: %w", err)
+			return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
 		}
-		return strings.TrimSpace(stdout), nil
+
+		// Replace the $(command) with the output
+		replacement := strings.TrimSpace(stdout)
+		result = result[:start] + replacement + result[end+1:]
 	}
 
-	if after, ok := strings.CutPrefix(value, "$"); ok {
-		varName := after
-		value = r.env.Get(varName)
-		if value == "" {
+	// Handle environment variables: $VAR and ${VAR}
+	searchStart := 0
+	for {
+		start := strings.Index(result[searchStart:], "$")
+		if start == -1 {
+			break
+		}
+		start += searchStart // Adjust for the offset
+
+		// Skip if this is part of $( which we already handled
+		if start+1 < len(result) && result[start+1] == '(' {
+			// Skip past this $(...)
+			searchStart = start + 1
+			continue
+		}
+		var varName string
+		var end int
+
+		if start+1 < len(result) && result[start+1] == '{' {
+			// Handle ${VAR} format
+			closeIdx := strings.Index(result[start+2:], "}")
+			if closeIdx == -1 {
+				return "", fmt.Errorf("unmatched ${ in value: %s", value)
+			}
+			varName = result[start+2 : start+2+closeIdx]
+			end = start + 2 + closeIdx + 1
+		} else {
+			// Handle $VAR format - variable names must start with letter or underscore
+			if start+1 >= len(result) {
+				return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
+			}
+
+			if result[start+1] != '_' &&
+				(result[start+1] < 'a' || result[start+1] > 'z') &&
+				(result[start+1] < 'A' || result[start+1] > 'Z') {
+				return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
+			}
+
+			end = start + 1
+			for end < len(result) && (result[end] == '_' ||
+				(result[end] >= 'a' && result[end] <= 'z') ||
+				(result[end] >= 'A' && result[end] <= 'Z') ||
+				(result[end] >= '0' && result[end] <= '9')) {
+				end++
+			}
+			varName = result[start+1 : end]
+		}
+
+		envValue := r.env.Get(varName)
+		if envValue == "" {
 			return "", fmt.Errorf("environment variable %q not set", varName)
 		}
-		return value, nil
+
+		result = result[:start] + envValue + result[end:]
+		searchStart = start + len(envValue) // Continue searching after the replacement
 	}
-	return "", fmt.Errorf("invalid value format: %s", value)
+
+	return result, nil
 }
 
 type environmentVariableResolver struct {

internal/config/resolve_test.go 🔗

@@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
 			envVars:     map[string]string{},
 			expectError: true,
 		},
-		{
-			name:  "shell command execution",
-			value: "$(echo hello)",
-			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
-				if command == "echo hello" {
-					return "hello\n", "", nil
-				}
-				return "", "", errors.New("unexpected command")
-			},
-			expected: "hello",
-		},
+
 		{
 			name:  "shell command with whitespace trimming",
 			value: "$(echo '  spaced  ')",
@@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
 	}
 }
 
+func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
+	tests := []struct {
+		name        string
+		value       string
+		envVars     map[string]string
+		shellFunc   func(ctx context.Context, command string) (stdout, stderr string, err error)
+		expected    string
+		expectError bool
+	}{
+		{
+			name:  "command substitution within string",
+			value: "Bearer $(echo token123)",
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				if command == "echo token123" {
+					return "token123\n", "", nil
+				}
+				return "", "", errors.New("unexpected command")
+			},
+			expected: "Bearer token123",
+		},
+		{
+			name:     "environment variable within string",
+			value:    "Bearer $TOKEN",
+			envVars:  map[string]string{"TOKEN": "sk-ant-123"},
+			expected: "Bearer sk-ant-123",
+		},
+		{
+			name:     "environment variable with braces within string",
+			value:    "Bearer ${TOKEN}",
+			envVars:  map[string]string{"TOKEN": "sk-ant-456"},
+			expected: "Bearer sk-ant-456",
+		},
+		{
+			name:  "mixed command and environment substitution",
+			value: "$USER-$(date +%Y)-$HOST",
+			envVars: map[string]string{
+				"USER": "testuser",
+				"HOST": "localhost",
+			},
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				if command == "date +%Y" {
+					return "2024\n", "", nil
+				}
+				return "", "", errors.New("unexpected command")
+			},
+			expected: "testuser-2024-localhost",
+		},
+		{
+			name:  "multiple command substitutions",
+			value: "$(echo hello) $(echo world)",
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				switch command {
+				case "echo hello":
+					return "hello\n", "", nil
+				case "echo world":
+					return "world\n", "", nil
+				}
+				return "", "", errors.New("unexpected command")
+			},
+			expected: "hello world",
+		},
+		{
+			name:  "nested parentheses in command",
+			value: "$(echo $(echo inner))",
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				if command == "echo $(echo inner)" {
+					return "nested\n", "", nil
+				}
+				return "", "", errors.New("unexpected command")
+			},
+			expected: "nested",
+		},
+		{
+			name:        "lone dollar with non-variable chars",
+			value:       "prefix$123suffix", // Numbers can't start variable names
+			expectError: true,
+		},
+		{
+			name:        "dollar with special chars",
+			value:       "a$@b$#c", // Special chars aren't valid in variable names
+			expectError: true,
+		},
+		{
+			name:        "empty environment variable substitution",
+			value:       "Bearer $EMPTY_VAR",
+			envVars:     map[string]string{},
+			expectError: true,
+		},
+		{
+			name:        "unmatched command substitution opening",
+			value:       "Bearer $(echo test",
+			expectError: true,
+		},
+		{
+			name:        "unmatched environment variable braces",
+			value:       "Bearer ${TOKEN",
+			expectError: true,
+		},
+		{
+			name:  "command substitution with error",
+			value: "Bearer $(false)",
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				return "", "", errors.New("command failed")
+			},
+			expectError: true,
+		},
+		{
+			name:  "complex real-world example",
+			value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
+			shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+				if command == "cat /tmp/token.txt | base64 -w 0" {
+					return "c2stYW50LXRlc3Q=\n", "", nil
+				}
+				return "", "", errors.New("unexpected command")
+			},
+			expected: "Bearer c2stYW50LXRlc3Q=",
+		},
+		{
+			name:     "environment variable with underscores and numbers",
+			value:    "Bearer $API_KEY_V2",
+			envVars:  map[string]string{"API_KEY_V2": "sk-test-123"},
+			expected: "Bearer sk-test-123",
+		},
+		{
+			name:     "no substitution needed",
+			value:    "Bearer sk-ant-static-token",
+			expected: "Bearer sk-ant-static-token",
+		},
+		{
+			name:        "incomplete variable at end",
+			value:       "Bearer $",
+			expectError: true,
+		},
+		{
+			name:        "variable with invalid character",
+			value:       "Bearer $VAR-NAME", // Hyphen not allowed in variable names
+			expectError: true,
+		},
+		{
+			name:        "multiple invalid variables",
+			value:       "$1$2$3",
+			expectError: true,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			testEnv := env.NewFromMap(tt.envVars)
+			resolver := &shellVariableResolver{
+				shell: &mockShell{execFunc: tt.shellFunc},
+				env:   testEnv,
+			}
+
+			result, err := resolver.ResolveValue(tt.value)
+
+			if tt.expectError {
+				assert.Error(t, err)
+			} else {
+				assert.NoError(t, err)
+				assert.Equal(t, tt.expected, result)
+			}
+		})
+	}
+}
+
 func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
 	tests := []struct {
 		name        string

internal/llm/provider/anthropic.go 🔗

@@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl
 
 func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 	anthropicClientOptions := []option.RequestOption{}
-	if opts.apiKey != "" {
-		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
+
+	// Check if Authorization header is provided in extra headers
+	hasBearerAuth := false
+	if opts.extraHeaders != nil {
+		for key := range opts.extraHeaders {
+			if strings.ToLower(key) == "authorization" {
+				hasBearerAuth = true
+				break
+			}
+		}
+	}
+
+	isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
+
+	if opts.apiKey != "" && !hasBearerAuth {
+		if isBearerToken {
+			slog.Debug("API key starts with 'Bearer ', using as Authorization header")
+			anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
+		} else {
+			// Use standard X-Api-Key header
+			anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
+		}
+	} else if hasBearerAuth {
+		slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
 	}
 	if useBedrock {
 		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
@@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 		maxTokens = int64(a.adjustedMaxTokens)
 	}
 
+	systemBlocks := []anthropic.TextBlockParam{}
+
+	// Add custom system prompt prefix if configured
+	if a.providerOptions.systemPromptPrefix != "" {
+		systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
+			Text: a.providerOptions.systemPromptPrefix,
+			CacheControl: anthropic.CacheControlEphemeralParam{
+				Type: "ephemeral",
+			},
+		})
+	}
+
+	systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
+		Text: a.providerOptions.systemMessage,
+		CacheControl: anthropic.CacheControlEphemeralParam{
+			Type: "ephemeral",
+		},
+	})
+
 	return anthropic.MessageNewParams{
 		Model:       anthropic.Model(model.ID),
 		MaxTokens:   maxTokens,
@@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 		Messages:    messages,
 		Tools:       tools,
 		Thinking:    thinkingParam,
-		System: []anthropic.TextBlockParam{
-			{
-				Text: a.providerOptions.systemMessage,
-				CacheControl: anthropic.CacheControlEphemeralParam{
-					Type: "ephemeral",
-				},
-			},
-		},
+		System:      systemBlocks,
 	}
 }
 
@@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
 				close(eventChan)
 				return
 			}
+
 			// If there is an error we are going to see if we can retry the call
 			retry, after, retryErr := a.shouldRetry(attempts, err)
 			if retryErr != nil {

internal/llm/provider/gemini.go 🔗

@@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 	if modelConfig.MaxTokens > 0 {
 		maxTokens = modelConfig.MaxTokens
 	}
+	systemMessage := g.providerOptions.systemMessage
+	if g.providerOptions.systemPromptPrefix != "" {
+		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+	}
 	history := geminiMessages[:len(geminiMessages)-1] // All but last message
 	lastMsg := geminiMessages[len(geminiMessages)-1]
 	config := &genai.GenerateContentConfig{
 		MaxOutputTokens: int32(maxTokens),
 		SystemInstruction: &genai.Content{
-			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+			Parts: []*genai.Part{{Text: systemMessage}},
 		},
 	}
 	config.Tools = g.convertTools(tools)
@@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 	if g.providerOptions.maxTokens > 0 {
 		maxTokens = g.providerOptions.maxTokens
 	}
+	systemMessage := g.providerOptions.systemMessage
+	if g.providerOptions.systemPromptPrefix != "" {
+		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+	}
 	history := geminiMessages[:len(geminiMessages)-1] // All but last message
 	lastMsg := geminiMessages[len(geminiMessages)-1]
 	config := &genai.GenerateContentConfig{
 		MaxOutputTokens: int32(maxTokens),
 		SystemInstruction: &genai.Content{
-			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+			Parts: []*genai.Part{{Text: systemMessage}},
 		},
 	}
 	config.Tools = g.convertTools(tools)

internal/llm/provider/openai.go 🔗

@@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
 
 func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 	// Add system message first
-	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
+	systemMessage := o.providerOptions.systemMessage
+	if o.providerOptions.systemPromptPrefix != "" {
+		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
+	}
+	openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
 
 	for _, msg := range messages {
 		switch msg.Role {

internal/llm/provider/provider.go 🔗

@@ -61,17 +61,18 @@ type Provider interface {
 }
 
 type providerClientOptions struct {
-	baseURL       string
-	config        config.ProviderConfig
-	apiKey        string
-	modelType     config.SelectedModelType
-	model         func(config.SelectedModelType) catwalk.Model
-	disableCache  bool
-	systemMessage string
-	maxTokens     int64
-	extraHeaders  map[string]string
-	extraBody     map[string]any
-	extraParams   map[string]string
+	baseURL            string
+	config             config.ProviderConfig
+	apiKey             string
+	modelType          config.SelectedModelType
+	model              func(config.SelectedModelType) catwalk.Model
+	disableCache       bool
+	systemMessage      string
+	systemPromptPrefix string
+	maxTokens          int64
+	extraHeaders       map[string]string
+	extraBody          map[string]any
+	extraParams        map[string]string
 }
 
 type ProviderClientOption func(*providerClientOptions)
@@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
 		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
 	}
 
+	// Resolve extra headers
+	resolvedExtraHeaders := make(map[string]string)
+	for key, value := range cfg.ExtraHeaders {
+		resolvedValue, err := config.Get().Resolve(value)
+		if err != nil {
+			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
+		}
+		resolvedExtraHeaders[key] = resolvedValue
+	}
+
 	clientOptions := providerClientOptions{
-		baseURL:      cfg.BaseURL,
-		config:       cfg,
-		apiKey:       resolvedAPIKey,
-		extraHeaders: cfg.ExtraHeaders,
-		extraBody:    cfg.ExtraBody,
+		baseURL:            cfg.BaseURL,
+		config:             cfg,
+		apiKey:             resolvedAPIKey,
+		extraHeaders:       resolvedExtraHeaders,
+		extraBody:          cfg.ExtraBody,
+		systemPromptPrefix: cfg.SystemPromptPrefix,
 		model: func(tp config.SelectedModelType) catwalk.Model {
 			return *config.Get().GetModelByType(tp)
 		},

internal/tui/components/chat/editor/editor.go 🔗

@@ -161,10 +161,17 @@ func (m *editorCmp) send() tea.Cmd {
 	)
 }
 
+func (m *editorCmp) repositionCompletions() tea.Msg {
+	x, y := m.completionsPosition()
+	return completions.RepositionCompletionsMsg{X: x, Y: y}
+}
+
 func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	var cmd tea.Cmd
 	var cmds []tea.Cmd
 	switch msg := msg.(type) {
+	case tea.WindowSizeMsg:
+		return m, m.repositionCompletions
 	case filepicker.FilePickedMsg:
 		if len(m.attachments) >= maxAttachments {
 			return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments))
@@ -182,32 +189,37 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			return m, nil
 		}
 		if item, ok := msg.Value.(FileCompletionItem); ok {
+			word := m.textarea.Word()
 			// If the selected item is a file, insert its path into the textarea
 			value := m.textarea.Value()
-			value = value[:m.completionsStartIndex]
-			value += item.Path
+			value = value[:m.completionsStartIndex] + // Remove the current query
+				item.Path + // Insert the file path
+				value[m.completionsStartIndex+len(word):] // Append the rest of the value
+			// XXX: This will always move the cursor to the end of the textarea.
 			m.textarea.SetValue(value)
+			m.textarea.MoveToEnd()
 			if !msg.Insert {
 				m.isCompletionsOpen = false
 				m.currentQuery = ""
 				m.completionsStartIndex = 0
 			}
-			return m, nil
 		}
 	case openEditorMsg:
 		m.textarea.SetValue(msg.Text)
 		m.textarea.MoveToEnd()
 	case tea.KeyPressMsg:
+		cur := m.textarea.Cursor()
+		curIdx := m.textarea.Width()*cur.Y + cur.X
 		switch {
 		// Completions
 		case msg.String() == "/" && !m.isCompletionsOpen &&
-			// only show if beginning of prompt, or if previous char is a space:
-			(len(m.textarea.Value()) == 0 || m.textarea.Value()[len(m.textarea.Value())-1] == ' '):
+			// only show if beginning of prompt, or if previous char is a space or newline:
+			(len(m.textarea.Value()) == 0 || unicode.IsSpace(rune(m.textarea.Value()[len(m.textarea.Value())-1]))):
 			m.isCompletionsOpen = true
 			m.currentQuery = ""
-			m.completionsStartIndex = len(m.textarea.Value())
+			m.completionsStartIndex = curIdx
 			cmds = append(cmds, m.startCompletions)
-		case m.isCompletionsOpen && m.textarea.Cursor().X <= m.completionsStartIndex:
+		case m.isCompletionsOpen && curIdx <= m.completionsStartIndex:
 			cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
 		}
 		if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) {
@@ -244,6 +256,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		}
 		if key.Matches(msg, m.keyMap.Newline) {
 			m.textarea.InsertRune('\n')
+			cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
 		}
 		// Handle Enter key
 		if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) {
@@ -275,12 +288,18 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 					// XXX: wont' work if editing in the middle of the field.
 					m.completionsStartIndex = strings.LastIndex(m.textarea.Value(), word)
 					m.currentQuery = word[1:]
+					x, y := m.completionsPosition()
+					x -= len(m.currentQuery)
 					m.isCompletionsOpen = true
-					cmds = append(cmds, util.CmdHandler(completions.FilterCompletionsMsg{
-						Query:  m.currentQuery,
-						Reopen: m.isCompletionsOpen,
-					}))
-				} else {
+					cmds = append(cmds,
+						util.CmdHandler(completions.FilterCompletionsMsg{
+							Query:  m.currentQuery,
+							Reopen: m.isCompletionsOpen,
+							X:      x,
+							Y:      y,
+						}),
+					)
+				} else if m.isCompletionsOpen {
 					m.isCompletionsOpen = false
 					m.currentQuery = ""
 					m.completionsStartIndex = 0
@@ -293,6 +312,16 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	return m, tea.Batch(cmds...)
 }
 
+func (m *editorCmp) completionsPosition() (int, int) {
+	cur := m.textarea.Cursor()
+	if cur == nil {
+		return m.x, m.y + 1 // adjust for padding
+	}
+	x := cur.X + m.x
+	y := cur.Y + m.y + 1 // adjust for padding
+	return x, y
+}
+
 func (m *editorCmp) Cursor() *tea.Cursor {
 	cursor := m.textarea.Cursor()
 	if cursor != nil {
@@ -373,9 +402,7 @@ func (m *editorCmp) startCompletions() tea.Msg {
 		})
 	}
 
-	cur := m.textarea.Cursor()
-	x := cur.X + m.x // adjust for padding
-	y := cur.Y + m.y + 1
+	x, y := m.completionsPosition()
 	return completions.OpenCompletionsMsg{
 		Completions: completionItems,
 		X:           x,

internal/tui/components/completions/completions.go 🔗

@@ -27,6 +27,12 @@ type OpenCompletionsMsg struct {
 type FilterCompletionsMsg struct {
 	Query  string // The query to filter completions
 	Reopen bool
+	X      int // X position for the completions popup
+	Y      int // Y position for the completions popup
+}
+
+type RepositionCompletionsMsg struct {
+	X, Y int
 }
 
 type CompletionsClosedMsg struct{}
@@ -53,18 +59,24 @@ type Completions interface {
 type listModel = list.FilterableList[list.CompletionItem[any]]
 
 type completionsCmp struct {
-	width  int
-	height int  // Height of the completions component`
-	x      int  // X position for the completions popup
-	y      int  // Y position for the completions popup
-	open   bool // Indicates if the completions are open
-	keyMap KeyMap
+	wWidth    int // The window width
+	wHeight   int // The window height
+	width     int
+	lastWidth int
+	height    int  // Height of the completions component`
+	x, xorig  int  // X position for the completions popup
+	y         int  // Y position for the completions popup
+	open      bool // Indicates if the completions are open
+	keyMap    KeyMap
 
 	list  listModel
 	query string // The current filter query
 }
 
-const maxCompletionsWidth = 80 // Maximum width for the completions popup
+const (
+	maxCompletionsWidth = 80 // Maximum width for the completions popup
+	minCompletionsWidth = 20 // Minimum width for the completions popup
+)
 
 func New() Completions {
 	completionsKeyMap := DefaultKeyMap()
@@ -88,7 +100,7 @@ func New() Completions {
 	)
 	return &completionsCmp{
 		width:  0,
-		height: 0,
+		height: maxCompletionsHeight,
 		list:   l,
 		query:  "",
 		keyMap: completionsKeyMap,
@@ -107,8 +119,7 @@ func (c *completionsCmp) Init() tea.Cmd {
 func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	switch msg := msg.(type) {
 	case tea.WindowSizeMsg:
-		c.width = min(msg.Width-c.x, maxCompletionsWidth)
-		c.height = min(msg.Height-c.y, 15)
+		c.wWidth, c.wHeight = msg.Width, msg.Height
 		return c, nil
 	case tea.KeyPressMsg:
 		switch {
@@ -129,7 +140,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			selectedItem := *s
 			c.list.SetSelected(selectedItem.ID())
 			return c, util.CmdHandler(SelectCompletionMsg{
-				Value:  selectedItem,
+				Value:  selectedItem.Value(),
 				Insert: true,
 			})
 		case key.Matches(msg, c.keyMap.DownInsert):
@@ -140,7 +151,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			selectedItem := *s
 			c.list.SetSelected(selectedItem.ID())
 			return c, util.CmdHandler(SelectCompletionMsg{
-				Value:  selectedItem,
+				Value:  selectedItem.Value(),
 				Insert: true,
 			})
 		case key.Matches(msg, c.keyMap.Select):
@@ -151,18 +162,21 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			selectedItem := *s
 			c.open = false // Close completions after selection
 			return c, util.CmdHandler(SelectCompletionMsg{
-				Value: selectedItem,
+				Value: selectedItem.Value(),
 			})
 		case key.Matches(msg, c.keyMap.Cancel):
 			return c, util.CmdHandler(CloseCompletionsMsg{})
 		}
+	case RepositionCompletionsMsg:
+		c.x, c.y = msg.X, msg.Y
+		c.adjustPosition()
 	case CloseCompletionsMsg:
 		c.open = false
 		return c, util.CmdHandler(CompletionsClosedMsg{})
 	case OpenCompletionsMsg:
 		c.open = true
 		c.query = ""
-		c.x = msg.X
+		c.x, c.xorig = msg.X, msg.X
 		c.y = msg.Y
 		items := []list.CompletionItem[any]{}
 		t := styles.CurrentTheme()
@@ -174,10 +188,18 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			)
 			items = append(items, item)
 		}
-		c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height
+		width := listWidth(items)
+		if len(items) == 0 {
+			width = listWidth(c.list.Items())
+		}
+		if c.x+width >= c.wWidth {
+			c.x = c.wWidth - width - 1
+		}
+		c.width = width
+		c.height = max(min(maxCompletionsHeight, len(items)), 1) // Ensure at least 1 item height
 		return c, tea.Batch(
-			c.list.SetSize(c.width, c.height),
 			c.list.SetItems(items),
+			c.list.SetSize(c.width, c.height),
 			util.CmdHandler(CompletionsOpenedMsg{}),
 		)
 	case FilterCompletionsMsg:
@@ -201,8 +223,11 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		c.query = msg.Query
 		var cmds []tea.Cmd
 		cmds = append(cmds, c.list.Filter(msg.Query))
-		itemsLen := len(c.list.Items())
-		c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+		items := c.list.Items()
+		itemsLen := len(items)
+		c.xorig = msg.X
+		c.x, c.y = msg.X, msg.Y
+		c.adjustPosition()
 		cmds = append(cmds, c.list.SetSize(c.width, c.height))
 		if itemsLen == 0 {
 			cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{}))
@@ -215,21 +240,50 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	return c, nil
 }
 
+func (c *completionsCmp) adjustPosition() {
+	items := c.list.Items()
+	itemsLen := len(items)
+	width := listWidth(items)
+	c.lastWidth = c.width
+	if c.x < 0 || width < c.lastWidth {
+		c.x = c.xorig
+	} else if c.x+width >= c.wWidth {
+		c.x = c.wWidth - width - 1
+	}
+	c.width = width
+	c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+}
+
 // View implements Completions.
 func (c *completionsCmp) View() string {
 	if !c.open || len(c.list.Items()) == 0 {
 		return ""
 	}
 
-	return c.style().Render(c.list.View())
-}
-
-func (c *completionsCmp) style() lipgloss.Style {
 	t := styles.CurrentTheme()
-	return t.S().Base.
+	style := t.S().Base.
 		Width(c.width).
 		Height(c.height).
 		Background(t.BgSubtle)
+
+	return style.Render(c.list.View())
+}
+
+// listWidth returns the width of the last 10 items in the list, which is used
+// to determine the width of the completions popup.
+// Note this only works for [completionItemCmp] items.
+func listWidth(items []list.CompletionItem[any]) int {
+	var width int
+	if len(items) == 0 {
+		return width
+	}
+
+	for i := len(items) - 1; i >= 0 && i >= len(items)-10; i-- {
+		itemWidth := lipgloss.Width(items[i].Text()) + 2 // +2 for padding
+		width = max(width, itemWidth)
+	}
+
+	return width
 }
 
 func (c *completionsCmp) Open() bool {

internal/tui/exp/list/items.go 🔗

@@ -23,6 +23,7 @@ type CompletionItem[T any] interface {
 	layout.Sizeable
 	HasMatchIndexes
 	Value() T
+	Text() string
 }
 
 type completionItemCmp[T any] struct {
@@ -312,6 +313,10 @@ func (c *completionItemCmp[T]) ID() string {
 	return c.id
 }
 
+func (c *completionItemCmp[T]) Text() string {
+	return c.text
+}
+
 type ItemSection interface {
 	Item
 	layout.Sizeable

internal/tui/page/chat/chat.go 🔗

@@ -173,7 +173,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		}
 		return p, nil
 	case tea.WindowSizeMsg:
-		return p, p.SetSize(msg.Width, msg.Height)
+		u, cmd := p.editor.Update(msg)
+		p.editor = u.(editor.Editor)
+		return p, tea.Batch(p.SetSize(msg.Width, msg.Height), cmd)
 	case CancelTimerExpiredMsg:
 		p.isCanceling = false
 		return p, nil

internal/tui/tui.go 🔗

@@ -111,19 +111,10 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		return a, a.handleWindowResize(msg.Width, msg.Height)
 
 	// Completions messages
-	case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg:
+	case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg,
+		completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg:
 		u, completionCmd := a.completions.Update(msg)
 		a.completions = u.(completions.Completions)
-		switch msg := msg.(type) {
-		case completions.OpenCompletionsMsg:
-			x, _ := a.completions.Position()
-			if a.completions.Width()+x >= a.wWidth {
-				// Adjust X position to fit in the window.
-				msg.X = a.wWidth - a.completions.Width() - 1
-				u, completionCmd = a.completions.Update(msg)
-				a.completions = u.(completions.Completions)
-			}
-		}
 		return a, completionCmd
 
 	// Dialog messages