diff --git a/README.md b/README.md index 26cb7308bb7a614603c61c3f4f4f5d1cee3fe40f..5fed716c8c6bf437e75ca65401c15e5be64441d5 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,6 @@ Crush supports Model Context Protocol (MCP) servers through three transport type ### Logging Enable debug logging with the `-d` flag or in config. View logs with `crush logs`. Logs are stored in `.crush/logs/crush.log`. - ```bash # Run with debug logging crush -d @@ -186,16 +185,21 @@ The `allowed_tools` array accepts: You can also skip all permission prompts entirely by running Crush with the `--yolo` flag. -### OpenAI-Compatible APIs +### Custom Providers + +Crush supports custom provider configurations for both OpenAI-compatible and Anthropic-compatible APIs. + +#### OpenAI-Compatible APIs -Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment. +Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment. ```json { "providers": { "deepseek": { - "provider_type": "openai", + "type": "openai", "base_url": "https://api.deepseek.com/v1", + "api_key": "$DEEPSEEK_API_KEY", "models": [ { "id": "deepseek-chat", @@ -213,6 +217,38 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D } ``` +#### Anthropic-Compatible APIs + +You can also configure custom Anthropic-compatible providers: + +```json +{ + "providers": { + "custom-anthropic": { + "type": "anthropic", + "base_url": "https://api.anthropic.com/v1", + "api_key": "$ANTHROPIC_API_KEY", + "extra_headers": { + "anthropic-version": "2023-06-01" + }, + "models": [ + { + "id": "claude-3-sonnet", + "model": "Claude 3 Sonnet", + "cost_per_1m_in": 3000, + "cost_per_1m_out": 15000, + "cost_per_1m_in_cached": 300, + "cost_per_1m_out_cached": 15000, + "context_window": 200000, + "default_max_tokens": 4096, + "supports_attachments": true + } + ] + } + } +} +``` + ## Whatcha think? We’d love to hear your thoughts on this project. Feel free to drop us a note! diff --git a/internal/config/config.go b/internal/config/config.go index 9709c11a0636d91cb492b7735b63e46e5e843c74..9a0da2a376abc88c5e584d7d39744da6f1890ce3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -77,6 +77,9 @@ type ProviderConfig struct { // Marks the provider as disabled. Disable bool `json:"disable,omitempty"` + // Custom system prompt prefix. + SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"` + // Extra headers to send with each request to the provider. ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Extra body diff --git a/internal/config/load.go b/internal/config/load.go index 98569d41be810dd0b9382c4df56cfb3e9c1c5842..44bcf8e3ce87953b9c3589cacaf2fe8a248e97aa 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know c.Providers.Del(id) continue } - if providerConfig.Type != catwalk.TypeOpenAI { + if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) c.Providers.Del(id) continue diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 5a52426f51ace9ee9e26bb42208511a72009dc3b..8c2735bd15fb3b52fe0c87401f57534e9b007e5b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL) }) + t.Run("custom anthropic provider is supported", func(t *testing.T) { + cfg := &Config{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + "custom-anthropic": { + APIKey: "test-key", + BaseURL: "https://api.anthropic.com/v1", + Type: catwalk.TypeAnthropic, + Models: []catwalk.Model{{ + ID: "claude-3-sonnet", + }}, + }, + }), + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + assert.NoError(t, err) + + assert.Equal(t, cfg.Providers.Len(), 1) + customProvider, exists := cfg.Providers.Get("custom-anthropic") + assert.True(t, exists) + assert.Equal(t, "custom-anthropic", customProvider.ID) + assert.Equal(t, "test-key", customProvider.APIKey) + assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL) + assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type) + }) + t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 3c97a6456cf7fe5968311746d62b2772b21d6aaa..3ef3522b09e504d3c57105e8bbe393b0f7c38b2b 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver { } // ResolveValue is a method for resolving values, such as environment variables. -// it will expect strings that start with `$` to be resolved as environment variables or shell commands. -// if the string does not start with `$`, it will return the string as is. +// it will resolve shell-like variable substitution anywhere in the string, including: +// - $(command) for command substitution +// - $VAR or ${VAR} for environment variables func (r *shellVariableResolver) ResolveValue(value string) (string, error) { - if !strings.HasPrefix(value, "$") { + // Special case: lone $ is an error (backward compatibility) + if value == "$" { + return "", fmt.Errorf("invalid value format: %s", value) + } + + // If no $ found, return as-is + if !strings.Contains(value, "$") { return value, nil } - if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") { - command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")") + result := value + + // Handle command substitution: $(command) + for { + start := strings.Index(result, "$(") + if start == -1 { + break + } + + // Find matching closing parenthesis + depth := 0 + end := -1 + for i := start + 2; i < len(result); i++ { + if result[i] == '(' { + depth++ + } else if result[i] == ')' { + if depth == 0 { + end = i + break + } + depth-- + } + } + + if end == -1 { + return "", fmt.Errorf("unmatched $( in value: %s", value) + } + + command := result[start+2 : end] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() stdout, _, err := r.shell.Exec(ctx, command) + cancel() if err != nil { - return "", fmt.Errorf("command execution failed: %w", err) + return "", fmt.Errorf("command execution failed for '%s': %w", command, err) } - return strings.TrimSpace(stdout), nil + + // Replace the $(command) with the output + replacement := strings.TrimSpace(stdout) + result = result[:start] + replacement + result[end+1:] } - if after, ok := strings.CutPrefix(value, "$"); ok { - varName := after - value = r.env.Get(varName) - if value == "" { + // Handle environment variables: $VAR and ${VAR} + searchStart := 0 + for { + start := strings.Index(result[searchStart:], "$") + if start == -1 { + break + } + start += searchStart // Adjust for the offset + + // Skip if this is part of $( which we already handled + if start+1 < len(result) && result[start+1] == '(' { + // Skip past this $(...) + searchStart = start + 1 + continue + } + var varName string + var end int + + if start+1 < len(result) && result[start+1] == '{' { + // Handle ${VAR} format + closeIdx := strings.Index(result[start+2:], "}") + if closeIdx == -1 { + return "", fmt.Errorf("unmatched ${ in value: %s", value) + } + varName = result[start+2 : start+2+closeIdx] + end = start + 2 + closeIdx + 1 + } else { + // Handle $VAR format - variable names must start with letter or underscore + if start+1 >= len(result) { + return "", fmt.Errorf("incomplete variable reference at end of string: %s", value) + } + + if result[start+1] != '_' && + (result[start+1] < 'a' || result[start+1] > 'z') && + (result[start+1] < 'A' || result[start+1] > 'Z') { + return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value) + } + + end = start + 1 + for end < len(result) && (result[end] == '_' || + (result[end] >= 'a' && result[end] <= 'z') || + (result[end] >= 'A' && result[end] <= 'Z') || + (result[end] >= '0' && result[end] <= '9')) { + end++ + } + varName = result[start+1 : end] + } + + envValue := r.env.Get(varName) + if envValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) } - return value, nil + + result = result[:start] + envValue + result[end:] + searchStart = start + len(envValue) // Continue searching after the replacement } - return "", fmt.Errorf("invalid value format: %s", value) + + return result, nil } type environmentVariableResolver struct { diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index 7cdcd2a7913cb581e5312f787791e8e89e699281..26ab184b26f82e70bf95320492b900a080f3e015 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { envVars: map[string]string{}, expectError: true, }, - { - name: "shell command execution", - value: "$(echo hello)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo hello" { - return "hello\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "hello", - }, + { name: "shell command with whitespace trimming", value: "$(echo ' spaced ')", @@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { } } +func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) + expected string + expectError bool + }{ + { + name: "command substitution within string", + value: "Bearer $(echo token123)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo token123" { + return "token123\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer token123", + }, + { + name: "environment variable within string", + value: "Bearer $TOKEN", + envVars: map[string]string{"TOKEN": "sk-ant-123"}, + expected: "Bearer sk-ant-123", + }, + { + name: "environment variable with braces within string", + value: "Bearer ${TOKEN}", + envVars: map[string]string{"TOKEN": "sk-ant-456"}, + expected: "Bearer sk-ant-456", + }, + { + name: "mixed command and environment substitution", + value: "$USER-$(date +%Y)-$HOST", + envVars: map[string]string{ + "USER": "testuser", + "HOST": "localhost", + }, + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "date +%Y" { + return "2024\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "testuser-2024-localhost", + }, + { + name: "multiple command substitutions", + value: "$(echo hello) $(echo world)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + switch command { + case "echo hello": + return "hello\n", "", nil + case "echo world": + return "world\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "hello world", + }, + { + name: "nested parentheses in command", + value: "$(echo $(echo inner))", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo $(echo inner)" { + return "nested\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "nested", + }, + { + name: "lone dollar with non-variable chars", + value: "prefix$123suffix", // Numbers can't start variable names + expectError: true, + }, + { + name: "dollar with special chars", + value: "a$@b$#c", // Special chars aren't valid in variable names + expectError: true, + }, + { + name: "empty environment variable substitution", + value: "Bearer $EMPTY_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "unmatched command substitution opening", + value: "Bearer $(echo test", + expectError: true, + }, + { + name: "unmatched environment variable braces", + value: "Bearer ${TOKEN", + expectError: true, + }, + { + name: "command substitution with error", + value: "Bearer $(false)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + return "", "", errors.New("command failed") + }, + expectError: true, + }, + { + name: "complex real-world example", + value: "Bearer $(cat /tmp/token.txt | base64 -w 0)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "cat /tmp/token.txt | base64 -w 0" { + return "c2stYW50LXRlc3Q=\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer c2stYW50LXRlc3Q=", + }, + { + name: "environment variable with underscores and numbers", + value: "Bearer $API_KEY_V2", + envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, + expected: "Bearer sk-test-123", + }, + { + name: "no substitution needed", + value: "Bearer sk-ant-static-token", + expected: "Bearer sk-ant-static-token", + }, + { + name: "incomplete variable at end", + value: "Bearer $", + expectError: true, + }, + { + name: "variable with invalid character", + value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names + expectError: true, + }, + { + name: "multiple invalid variables", + value: "$1$2$3", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := &shellVariableResolver{ + shell: &mockShell{execFunc: tt.shellFunc}, + env: testEnv, + } + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { tests := []struct { name string diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 0765389a05ecaf33c6c521770e1880a24210d35f..3de8c805b3f0cfa08b1b2bb6b60577742ce8cc1d 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { anthropicClientOptions := []option.RequestOption{} - if opts.apiKey != "" { - anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) + + // Check if Authorization header is provided in extra headers + hasBearerAuth := false + if opts.extraHeaders != nil { + for key := range opts.extraHeaders { + if strings.ToLower(key) == "authorization" { + hasBearerAuth = true + break + } + } + } + + isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ") + + if opts.apiKey != "" && !hasBearerAuth { + if isBearerToken { + slog.Debug("API key starts with 'Bearer ', using as Authorization header") + anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey)) + } else { + // Use standard X-Api-Key header + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) + } + } else if hasBearerAuth { + slog.Debug("Skipping X-Api-Key header because Authorization header is provided") } if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) @@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to maxTokens = int64(a.adjustedMaxTokens) } + systemBlocks := []anthropic.TextBlockParam{} + + // Add custom system prompt prefix if configured + if a.providerOptions.systemPromptPrefix != "" { + systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ + Text: a.providerOptions.systemPromptPrefix, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }) + } + + systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }) + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), MaxTokens: maxTokens, @@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to Messages: messages, Tools: tools, Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.providerOptions.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, + System: systemBlocks, } } @@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message close(eventChan) return } + // If there is an error we are going to see if we can retry the call retry, after, retryErr := a.shouldRetry(attempts, err) if retryErr != nil { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index b2d1da11148e74362e7b529b9ec78dc1810d0f0d..0070d246012547a691f8c6a8cbd8de2234cd93ec 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + systemMessage := g.providerOptions.systemMessage + if g.providerOptions.systemPromptPrefix != "" { + systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + Parts: []*genai.Part{{Text: systemMessage}}, }, } config.Tools = g.convertTools(tools) @@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if g.providerOptions.maxTokens > 0 { maxTokens = g.providerOptions.maxTokens } + systemMessage := g.providerOptions.systemMessage + if g.providerOptions.systemPromptPrefix != "" { + systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + Parts: []*genai.Part{{Text: systemMessage}}, }, } config.Tools = g.convertTools(tools) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 397d6954d0a5c8f3dbe25f4a34115ade4c242012..23e247830a48ba1860ba7bde5059da69fab6d3ac 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { // Add system message first - openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) + systemMessage := o.providerOptions.systemMessage + if o.providerOptions.systemPromptPrefix != "" { + systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage + } + openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage)) for _, msg := range messages { switch msg.Role { diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 062c2aa977c6ff101d1d8ab6f32809845bd48ff3..c236c10f0b0e9bf9b4db50544ca664291ef13b65 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -61,17 +61,18 @@ type Provider interface { } type providerClientOptions struct { - baseURL string - config config.ProviderConfig - apiKey string - modelType config.SelectedModelType - model func(config.SelectedModelType) catwalk.Model - disableCache bool - systemMessage string - maxTokens int64 - extraHeaders map[string]string - extraBody map[string]any - extraParams map[string]string + baseURL string + config config.ProviderConfig + apiKey string + modelType config.SelectedModelType + model func(config.SelectedModelType) catwalk.Model + disableCache bool + systemMessage string + systemPromptPrefix string + maxTokens int64 + extraHeaders map[string]string + extraBody map[string]any + extraParams map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) } + // Resolve extra headers + resolvedExtraHeaders := make(map[string]string) + for key, value := range cfg.ExtraHeaders { + resolvedValue, err := config.Get().Resolve(value) + if err != nil { + return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err) + } + resolvedExtraHeaders[key] = resolvedValue + } + clientOptions := providerClientOptions{ - baseURL: cfg.BaseURL, - config: cfg, - apiKey: resolvedAPIKey, - extraHeaders: cfg.ExtraHeaders, - extraBody: cfg.ExtraBody, + baseURL: cfg.BaseURL, + config: cfg, + apiKey: resolvedAPIKey, + extraHeaders: resolvedExtraHeaders, + extraBody: cfg.ExtraBody, + systemPromptPrefix: cfg.SystemPromptPrefix, model: func(tp config.SelectedModelType) catwalk.Model { return *config.Get().GetModelByType(tp) }, diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 55a5e7525a430039b314cd810cb94856185cf5af..4e5f0bc431eb466cea5c6c7d436234c7a5e8531b 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -161,10 +161,17 @@ func (m *editorCmp) send() tea.Cmd { ) } +func (m *editorCmp) repositionCompletions() tea.Msg { + x, y := m.completionsPosition() + return completions.RepositionCompletionsMsg{X: x, Y: y} +} + func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd var cmds []tea.Cmd switch msg := msg.(type) { + case tea.WindowSizeMsg: + return m, m.repositionCompletions case filepicker.FilePickedMsg: if len(m.attachments) >= maxAttachments { return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments)) @@ -182,32 +189,37 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } if item, ok := msg.Value.(FileCompletionItem); ok { + word := m.textarea.Word() // If the selected item is a file, insert its path into the textarea value := m.textarea.Value() - value = value[:m.completionsStartIndex] - value += item.Path + value = value[:m.completionsStartIndex] + // Remove the current query + item.Path + // Insert the file path + value[m.completionsStartIndex+len(word):] // Append the rest of the value + // XXX: This will always move the cursor to the end of the textarea. m.textarea.SetValue(value) + m.textarea.MoveToEnd() if !msg.Insert { m.isCompletionsOpen = false m.currentQuery = "" m.completionsStartIndex = 0 } - return m, nil } case openEditorMsg: m.textarea.SetValue(msg.Text) m.textarea.MoveToEnd() case tea.KeyPressMsg: + cur := m.textarea.Cursor() + curIdx := m.textarea.Width()*cur.Y + cur.X switch { // Completions case msg.String() == "/" && !m.isCompletionsOpen && - // only show if beginning of prompt, or if previous char is a space: - (len(m.textarea.Value()) == 0 || m.textarea.Value()[len(m.textarea.Value())-1] == ' '): + // only show if beginning of prompt, or if previous char is a space or newline: + (len(m.textarea.Value()) == 0 || unicode.IsSpace(rune(m.textarea.Value()[len(m.textarea.Value())-1]))): m.isCompletionsOpen = true m.currentQuery = "" - m.completionsStartIndex = len(m.textarea.Value()) + m.completionsStartIndex = curIdx cmds = append(cmds, m.startCompletions) - case m.isCompletionsOpen && m.textarea.Cursor().X <= m.completionsStartIndex: + case m.isCompletionsOpen && curIdx <= m.completionsStartIndex: cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) { @@ -244,6 +256,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if key.Matches(msg, m.keyMap.Newline) { m.textarea.InsertRune('\n') + cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } // Handle Enter key if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) { @@ -275,12 +288,18 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // XXX: wont' work if editing in the middle of the field. m.completionsStartIndex = strings.LastIndex(m.textarea.Value(), word) m.currentQuery = word[1:] + x, y := m.completionsPosition() + x -= len(m.currentQuery) m.isCompletionsOpen = true - cmds = append(cmds, util.CmdHandler(completions.FilterCompletionsMsg{ - Query: m.currentQuery, - Reopen: m.isCompletionsOpen, - })) - } else { + cmds = append(cmds, + util.CmdHandler(completions.FilterCompletionsMsg{ + Query: m.currentQuery, + Reopen: m.isCompletionsOpen, + X: x, + Y: y, + }), + ) + } else if m.isCompletionsOpen { m.isCompletionsOpen = false m.currentQuery = "" m.completionsStartIndex = 0 @@ -293,6 +312,16 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } +func (m *editorCmp) completionsPosition() (int, int) { + cur := m.textarea.Cursor() + if cur == nil { + return m.x, m.y + 1 // adjust for padding + } + x := cur.X + m.x + y := cur.Y + m.y + 1 // adjust for padding + return x, y +} + func (m *editorCmp) Cursor() *tea.Cursor { cursor := m.textarea.Cursor() if cursor != nil { @@ -373,9 +402,7 @@ func (m *editorCmp) startCompletions() tea.Msg { }) } - cur := m.textarea.Cursor() - x := cur.X + m.x // adjust for padding - y := cur.Y + m.y + 1 + x, y := m.completionsPosition() return completions.OpenCompletionsMsg{ Completions: completionItems, X: x, diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index 0d5b814952dcdb8b6fdabc2f9e6aa8873936babc..42c05bd9acae8fb099d5ae09c754114541b1f280 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -27,6 +27,12 @@ type OpenCompletionsMsg struct { type FilterCompletionsMsg struct { Query string // The query to filter completions Reopen bool + X int // X position for the completions popup + Y int // Y position for the completions popup +} + +type RepositionCompletionsMsg struct { + X, Y int } type CompletionsClosedMsg struct{} @@ -53,18 +59,24 @@ type Completions interface { type listModel = list.FilterableList[list.CompletionItem[any]] type completionsCmp struct { - width int - height int // Height of the completions component` - x int // X position for the completions popup - y int // Y position for the completions popup - open bool // Indicates if the completions are open - keyMap KeyMap + wWidth int // The window width + wHeight int // The window height + width int + lastWidth int + height int // Height of the completions component` + x, xorig int // X position for the completions popup + y int // Y position for the completions popup + open bool // Indicates if the completions are open + keyMap KeyMap list listModel query string // The current filter query } -const maxCompletionsWidth = 80 // Maximum width for the completions popup +const ( + maxCompletionsWidth = 80 // Maximum width for the completions popup + minCompletionsWidth = 20 // Minimum width for the completions popup +) func New() Completions { completionsKeyMap := DefaultKeyMap() @@ -88,7 +100,7 @@ func New() Completions { ) return &completionsCmp{ width: 0, - height: 0, + height: maxCompletionsHeight, list: l, query: "", keyMap: completionsKeyMap, @@ -107,8 +119,7 @@ func (c *completionsCmp) Init() tea.Cmd { func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: - c.width = min(msg.Width-c.x, maxCompletionsWidth) - c.height = min(msg.Height-c.y, 15) + c.wWidth, c.wHeight = msg.Width, msg.Height return c, nil case tea.KeyPressMsg: switch { @@ -156,13 +167,16 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case key.Matches(msg, c.keyMap.Cancel): return c, util.CmdHandler(CloseCompletionsMsg{}) } + case RepositionCompletionsMsg: + c.x, c.y = msg.X, msg.Y + c.adjustPosition() case CloseCompletionsMsg: c.open = false return c, util.CmdHandler(CompletionsClosedMsg{}) case OpenCompletionsMsg: c.open = true c.query = "" - c.x = msg.X + c.x, c.xorig = msg.X, msg.X c.y = msg.Y items := []list.CompletionItem[any]{} t := styles.CurrentTheme() @@ -174,10 +188,18 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { ) items = append(items, item) } - c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height + width := listWidth(items) + if len(items) == 0 { + width = listWidth(c.list.Items()) + } + if c.x+width >= c.wWidth { + c.x = c.wWidth - width - 1 + } + c.width = width + c.height = max(min(maxCompletionsHeight, len(items)), 1) // Ensure at least 1 item height return c, tea.Batch( - c.list.SetSize(c.width, c.height), c.list.SetItems(items), + c.list.SetSize(c.width, c.height), util.CmdHandler(CompletionsOpenedMsg{}), ) case FilterCompletionsMsg: @@ -201,8 +223,11 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.query = msg.Query var cmds []tea.Cmd cmds = append(cmds, c.list.Filter(msg.Query)) - itemsLen := len(c.list.Items()) - c.height = max(min(maxCompletionsHeight, itemsLen), 1) + items := c.list.Items() + itemsLen := len(items) + c.xorig = msg.X + c.x, c.y = msg.X, msg.Y + c.adjustPosition() cmds = append(cmds, c.list.SetSize(c.width, c.height)) if itemsLen == 0 { cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{})) @@ -215,21 +240,54 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return c, nil } +func (c *completionsCmp) adjustPosition() { + items := c.list.Items() + itemsLen := len(items) + width := listWidth(items) + c.lastWidth = c.width + if c.x < 0 || width < c.lastWidth { + c.x = c.xorig + } else if c.x+width >= c.wWidth { + c.x = c.wWidth - width - 1 + } + c.width = width + c.height = max(min(maxCompletionsHeight, itemsLen), 1) +} + // View implements Completions. func (c *completionsCmp) View() string { if !c.open || len(c.list.Items()) == 0 { return "" } - return c.style().Render(c.list.View()) -} - -func (c *completionsCmp) style() lipgloss.Style { t := styles.CurrentTheme() - return t.S().Base. + style := t.S().Base. Width(c.width). Height(c.height). Background(t.BgSubtle) + + return style.Render(c.list.View()) +} + +// listWidth returns the width of the last 10 items in the list, which is used +// to determine the width of the completions popup. +// Note this only works for [completionItemCmp] items. +func listWidth[T any](items []T) int { + var width int + if len(items) == 0 { + return width + } + + for i := len(items) - 1; i >= 0 && i >= len(items)-10; i-- { + item, ok := any(items[i]).(*completionItemCmp) + if !ok { + continue + } + itemWidth := lipgloss.Width(item.text) + 2 // +2 for padding + width = max(width, itemWidth) + } + + return width } func (c *completionsCmp) Open() bool { diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 073ac869bb5f3916e5eccbb37da135c0b012f251..711789f782e9a7bc9206a504d382b890a75a6cec 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -172,7 +172,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return p, nil case tea.WindowSizeMsg: - return p, p.SetSize(msg.Width, msg.Height) + u, cmd := p.editor.Update(msg) + p.editor = u.(editor.Editor) + return p, tea.Batch(p.SetSize(msg.Width, msg.Height), cmd) case CancelTimerExpiredMsg: p.isCanceling = false return p, nil diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c4c88199de49fd9145dcf21fc78d452b8de14e9a..7aa6b9c4e1b599457ea99fcc23c88ed281c962aa 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -111,19 +111,10 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, a.handleWindowResize(msg.Width, msg.Height) // Completions messages - case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg: + case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, + completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg: u, completionCmd := a.completions.Update(msg) a.completions = u.(completions.Completions) - switch msg := msg.(type) { - case completions.OpenCompletionsMsg: - x, _ := a.completions.Position() - if a.completions.Width()+x >= a.wWidth { - // Adjust X position to fit in the window. - msg.X = a.wWidth - a.completions.Width() - 1 - u, completionCmd = a.completions.Update(msg) - a.completions = u.(completions.Completions) - } - } return a, completionCmd // Dialog messages