Detailed changes
@@ -135,7 +135,6 @@ Crush supports Model Context Protocol (MCP) servers through three transport type
### Logging
Enable debug logging with the `-d` flag or in config. View logs with `crush logs`. Logs are stored in `.crush/logs/crush.log`.
-
```bash
# Run with debug logging
crush -d
@@ -186,16 +185,21 @@ The `allowed_tools` array accepts:
You can also skip all permission prompts entirely by running Crush with the `--yolo` flag.
-### OpenAI-Compatible APIs
+### Custom Providers
+
+Crush supports custom provider configurations for both OpenAI-compatible and Anthropic-compatible APIs.
+
+#### OpenAI-Compatible APIs
-Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
+Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
```json
{
"providers": {
"deepseek": {
- "provider_type": "openai",
+ "type": "openai",
"base_url": "https://api.deepseek.com/v1",
+ "api_key": "$DEEPSEEK_API_KEY",
"models": [
{
"id": "deepseek-chat",
@@ -213,6 +217,38 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D
}
```
+#### Anthropic-Compatible APIs
+
+You can also configure custom Anthropic-compatible providers:
+
+```json
+{
+ "providers": {
+ "custom-anthropic": {
+ "type": "anthropic",
+ "base_url": "https://api.anthropic.com/v1",
+ "api_key": "$ANTHROPIC_API_KEY",
+ "extra_headers": {
+ "anthropic-version": "2023-06-01"
+ },
+ "models": [
+ {
+ "id": "claude-3-sonnet",
+ "model": "Claude 3 Sonnet",
+ "cost_per_1m_in": 3000,
+ "cost_per_1m_out": 15000,
+ "cost_per_1m_in_cached": 300,
+ "cost_per_1m_out_cached": 15000,
+ "context_window": 200000,
+ "default_max_tokens": 4096,
+ "supports_attachments": true
+ }
+ ]
+ }
+ }
+}
+```
+
## Whatcha think?
We’d love to hear your thoughts on this project. Feel free to drop us a note!
@@ -77,6 +77,9 @@ type ProviderConfig struct {
// Marks the provider as disabled.
Disable bool `json:"disable,omitempty"`
+ // Custom system prompt prefix.
+ SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
+
// Extra headers to send with each request to the provider.
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// Extra body
@@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
c.Providers.Del(id)
continue
}
- if providerConfig.Type != catwalk.TypeOpenAI {
+ if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
c.Providers.Del(id)
continue
@@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
})
+ t.Run("custom anthropic provider is supported", func(t *testing.T) {
+ cfg := &Config{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "custom-anthropic": {
+ APIKey: "test-key",
+ BaseURL: "https://api.anthropic.com/v1",
+ Type: catwalk.TypeAnthropic,
+ Models: []catwalk.Model{{
+ ID: "claude-3-sonnet",
+ }},
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ assert.NoError(t, err)
+
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ customProvider, exists := cfg.Providers.Get("custom-anthropic")
+ assert.True(t, exists)
+ assert.Equal(t, "custom-anthropic", customProvider.ID)
+ assert.Equal(t, "test-key", customProvider.APIKey)
+ assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
+ assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
+ })
+
t.Run("disabled custom provider is removed", func(t *testing.T) {
cfg := &Config{
Providers: csync.NewMapFrom(map[string]ProviderConfig{
@@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver {
}
// ResolveValue is a method for resolving values, such as environment variables.
-// it will expect strings that start with `$` to be resolved as environment variables or shell commands.
-// if the string does not start with `$`, it will return the string as is.
+// it will resolve shell-like variable substitution anywhere in the string, including:
+// - $(command) for command substitution
+// - $VAR or ${VAR} for environment variables
func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
- if !strings.HasPrefix(value, "$") {
+ // Special case: lone $ is an error (backward compatibility)
+ if value == "$" {
+ return "", fmt.Errorf("invalid value format: %s", value)
+ }
+
+ // If no $ found, return as-is
+ if !strings.Contains(value, "$") {
return value, nil
}
- if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
- command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
+ result := value
+
+ // Handle command substitution: $(command)
+ for {
+ start := strings.Index(result, "$(")
+ if start == -1 {
+ break
+ }
+
+ // Find matching closing parenthesis
+ depth := 0
+ end := -1
+ for i := start + 2; i < len(result); i++ {
+ if result[i] == '(' {
+ depth++
+ } else if result[i] == ')' {
+ if depth == 0 {
+ end = i
+ break
+ }
+ depth--
+ }
+ }
+
+ if end == -1 {
+ return "", fmt.Errorf("unmatched $( in value: %s", value)
+ }
+
+ command := result[start+2 : end]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
- defer cancel()
stdout, _, err := r.shell.Exec(ctx, command)
+ cancel()
if err != nil {
- return "", fmt.Errorf("command execution failed: %w", err)
+ return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
}
- return strings.TrimSpace(stdout), nil
+
+ // Replace the $(command) with the output
+ replacement := strings.TrimSpace(stdout)
+ result = result[:start] + replacement + result[end+1:]
}
- if after, ok := strings.CutPrefix(value, "$"); ok {
- varName := after
- value = r.env.Get(varName)
- if value == "" {
+ // Handle environment variables: $VAR and ${VAR}
+ searchStart := 0
+ for {
+ start := strings.Index(result[searchStart:], "$")
+ if start == -1 {
+ break
+ }
+ start += searchStart // Adjust for the offset
+
+ // Skip if this is part of $( which we already handled
+ if start+1 < len(result) && result[start+1] == '(' {
+ // Skip past this $(...)
+ searchStart = start + 1
+ continue
+ }
+ var varName string
+ var end int
+
+ if start+1 < len(result) && result[start+1] == '{' {
+ // Handle ${VAR} format
+ closeIdx := strings.Index(result[start+2:], "}")
+ if closeIdx == -1 {
+ return "", fmt.Errorf("unmatched ${ in value: %s", value)
+ }
+ varName = result[start+2 : start+2+closeIdx]
+ end = start + 2 + closeIdx + 1
+ } else {
+ // Handle $VAR format - variable names must start with letter or underscore
+ if start+1 >= len(result) {
+ return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
+ }
+
+ if result[start+1] != '_' &&
+ (result[start+1] < 'a' || result[start+1] > 'z') &&
+ (result[start+1] < 'A' || result[start+1] > 'Z') {
+ return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
+ }
+
+ end = start + 1
+ for end < len(result) && (result[end] == '_' ||
+ (result[end] >= 'a' && result[end] <= 'z') ||
+ (result[end] >= 'A' && result[end] <= 'Z') ||
+ (result[end] >= '0' && result[end] <= '9')) {
+ end++
+ }
+ varName = result[start+1 : end]
+ }
+
+ envValue := r.env.Get(varName)
+ if envValue == "" {
return "", fmt.Errorf("environment variable %q not set", varName)
}
- return value, nil
+
+ result = result[:start] + envValue + result[end:]
+ searchStart = start + len(envValue) // Continue searching after the replacement
}
- return "", fmt.Errorf("invalid value format: %s", value)
+
+ return result, nil
}
type environmentVariableResolver struct {
@@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
envVars: map[string]string{},
expectError: true,
},
- {
- name: "shell command execution",
- value: "$(echo hello)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "echo hello" {
- return "hello\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "hello",
- },
+
{
name: "shell command with whitespace trimming",
value: "$(echo ' spaced ')",
@@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
}
}
+func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+ expected string
+ expectError bool
+ }{
+ {
+ name: "command substitution within string",
+ value: "Bearer $(echo token123)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo token123" {
+ return "token123\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "Bearer token123",
+ },
+ {
+ name: "environment variable within string",
+ value: "Bearer $TOKEN",
+ envVars: map[string]string{"TOKEN": "sk-ant-123"},
+ expected: "Bearer sk-ant-123",
+ },
+ {
+ name: "environment variable with braces within string",
+ value: "Bearer ${TOKEN}",
+ envVars: map[string]string{"TOKEN": "sk-ant-456"},
+ expected: "Bearer sk-ant-456",
+ },
+ {
+ name: "mixed command and environment substitution",
+ value: "$USER-$(date +%Y)-$HOST",
+ envVars: map[string]string{
+ "USER": "testuser",
+ "HOST": "localhost",
+ },
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "date +%Y" {
+ return "2024\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "testuser-2024-localhost",
+ },
+ {
+ name: "multiple command substitutions",
+ value: "$(echo hello) $(echo world)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ switch command {
+ case "echo hello":
+ return "hello\n", "", nil
+ case "echo world":
+ return "world\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "hello world",
+ },
+ {
+ name: "nested parentheses in command",
+ value: "$(echo $(echo inner))",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo $(echo inner)" {
+ return "nested\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "nested",
+ },
+ {
+ name: "lone dollar with non-variable chars",
+ value: "prefix$123suffix", // Numbers can't start variable names
+ expectError: true,
+ },
+ {
+ name: "dollar with special chars",
+ value: "a$@b$#c", // Special chars aren't valid in variable names
+ expectError: true,
+ },
+ {
+ name: "empty environment variable substitution",
+ value: "Bearer $EMPTY_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+ {
+ name: "unmatched command substitution opening",
+ value: "Bearer $(echo test",
+ expectError: true,
+ },
+ {
+ name: "unmatched environment variable braces",
+ value: "Bearer ${TOKEN",
+ expectError: true,
+ },
+ {
+ name: "command substitution with error",
+ value: "Bearer $(false)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ return "", "", errors.New("command failed")
+ },
+ expectError: true,
+ },
+ {
+ name: "complex real-world example",
+ value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "cat /tmp/token.txt | base64 -w 0" {
+ return "c2stYW50LXRlc3Q=\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "Bearer c2stYW50LXRlc3Q=",
+ },
+ {
+ name: "environment variable with underscores and numbers",
+ value: "Bearer $API_KEY_V2",
+ envVars: map[string]string{"API_KEY_V2": "sk-test-123"},
+ expected: "Bearer sk-test-123",
+ },
+ {
+ name: "no substitution needed",
+ value: "Bearer sk-ant-static-token",
+ expected: "Bearer sk-ant-static-token",
+ },
+ {
+ name: "incomplete variable at end",
+ value: "Bearer $",
+ expectError: true,
+ },
+ {
+ name: "variable with invalid character",
+ value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names
+ expectError: true,
+ },
+ {
+ name: "multiple invalid variables",
+ value: "$1$2$3",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := &shellVariableResolver{
+ shell: &mockShell{execFunc: tt.shellFunc},
+ env: testEnv,
+ }
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
tests := []struct {
name string
@@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl
func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
- if opts.apiKey != "" {
- anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
+
+ // Check if Authorization header is provided in extra headers
+ hasBearerAuth := false
+ if opts.extraHeaders != nil {
+ for key := range opts.extraHeaders {
+ if strings.ToLower(key) == "authorization" {
+ hasBearerAuth = true
+ break
+ }
+ }
+ }
+
+ isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
+
+ if opts.apiKey != "" && !hasBearerAuth {
+ if isBearerToken {
+ slog.Debug("API key starts with 'Bearer ', using as Authorization header")
+ anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
+ } else {
+ // Use standard X-Api-Key header
+ anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
+ }
+ } else if hasBearerAuth {
+ slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
}
if useBedrock {
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
@@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
maxTokens = int64(a.adjustedMaxTokens)
}
+ systemBlocks := []anthropic.TextBlockParam{}
+
+ // Add custom system prompt prefix if configured
+ if a.providerOptions.systemPromptPrefix != "" {
+ systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
+ Text: a.providerOptions.systemPromptPrefix,
+ CacheControl: anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ },
+ })
+ }
+
+ systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
+ Text: a.providerOptions.systemMessage,
+ CacheControl: anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ },
+ })
+
return anthropic.MessageNewParams{
Model: anthropic.Model(model.ID),
MaxTokens: maxTokens,
@@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
Messages: messages,
Tools: tools,
Thinking: thinkingParam,
- System: []anthropic.TextBlockParam{
- {
- Text: a.providerOptions.systemMessage,
- CacheControl: anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- },
- },
- },
+ System: systemBlocks,
}
}
@@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
close(eventChan)
return
}
+
// If there is an error we are going to see if we can retry the call
retry, after, retryErr := a.shouldRetry(attempts, err)
if retryErr != nil {
@@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if modelConfig.MaxTokens > 0 {
maxTokens = modelConfig.MaxTokens
}
+ systemMessage := g.providerOptions.systemMessage
+ if g.providerOptions.systemPromptPrefix != "" {
+ systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
- Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ Parts: []*genai.Part{{Text: systemMessage}},
},
}
config.Tools = g.convertTools(tools)
@@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if g.providerOptions.maxTokens > 0 {
maxTokens = g.providerOptions.maxTokens
}
+ systemMessage := g.providerOptions.systemMessage
+ if g.providerOptions.systemPromptPrefix != "" {
+ systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
- Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ Parts: []*genai.Part{{Text: systemMessage}},
},
}
config.Tools = g.convertTools(tools)
@@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
// Add system message first
- openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
+ systemMessage := o.providerOptions.systemMessage
+ if o.providerOptions.systemPromptPrefix != "" {
+ systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ }
+ openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
for _, msg := range messages {
switch msg.Role {
@@ -61,17 +61,18 @@ type Provider interface {
}
type providerClientOptions struct {
- baseURL string
- config config.ProviderConfig
- apiKey string
- modelType config.SelectedModelType
- model func(config.SelectedModelType) catwalk.Model
- disableCache bool
- systemMessage string
- maxTokens int64
- extraHeaders map[string]string
- extraBody map[string]any
- extraParams map[string]string
+ baseURL string
+ config config.ProviderConfig
+ apiKey string
+ modelType config.SelectedModelType
+ model func(config.SelectedModelType) catwalk.Model
+ disableCache bool
+ systemMessage string
+ systemPromptPrefix string
+ maxTokens int64
+ extraHeaders map[string]string
+ extraBody map[string]any
+ extraParams map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
+ // Resolve extra headers
+ resolvedExtraHeaders := make(map[string]string)
+ for key, value := range cfg.ExtraHeaders {
+ resolvedValue, err := config.Get().Resolve(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
+ }
+ resolvedExtraHeaders[key] = resolvedValue
+ }
+
clientOptions := providerClientOptions{
- baseURL: cfg.BaseURL,
- config: cfg,
- apiKey: resolvedAPIKey,
- extraHeaders: cfg.ExtraHeaders,
- extraBody: cfg.ExtraBody,
+ baseURL: cfg.BaseURL,
+ config: cfg,
+ apiKey: resolvedAPIKey,
+ extraHeaders: resolvedExtraHeaders,
+ extraBody: cfg.ExtraBody,
+ systemPromptPrefix: cfg.SystemPromptPrefix,
model: func(tp config.SelectedModelType) catwalk.Model {
return *config.Get().GetModelByType(tp)
},
@@ -161,10 +161,17 @@ func (m *editorCmp) send() tea.Cmd {
)
}
+func (m *editorCmp) repositionCompletions() tea.Msg {
+ x, y := m.completionsPosition()
+ return completions.RepositionCompletionsMsg{X: x, Y: y}
+}
+
func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
var cmds []tea.Cmd
switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ return m, m.repositionCompletions
case filepicker.FilePickedMsg:
if len(m.attachments) >= maxAttachments {
return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments))
@@ -182,32 +189,37 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
if item, ok := msg.Value.(FileCompletionItem); ok {
+ word := m.textarea.Word()
// If the selected item is a file, insert its path into the textarea
value := m.textarea.Value()
- value = value[:m.completionsStartIndex]
- value += item.Path
+ value = value[:m.completionsStartIndex] + // Remove the current query
+ item.Path + // Insert the file path
+ value[m.completionsStartIndex+len(word):] // Append the rest of the value
+ // XXX: This will always move the cursor to the end of the textarea.
m.textarea.SetValue(value)
+ m.textarea.MoveToEnd()
if !msg.Insert {
m.isCompletionsOpen = false
m.currentQuery = ""
m.completionsStartIndex = 0
}
- return m, nil
}
case openEditorMsg:
m.textarea.SetValue(msg.Text)
m.textarea.MoveToEnd()
case tea.KeyPressMsg:
+ cur := m.textarea.Cursor()
+ curIdx := m.textarea.Width()*cur.Y + cur.X
switch {
// Completions
case msg.String() == "/" && !m.isCompletionsOpen &&
- // only show if beginning of prompt, or if previous char is a space:
- (len(m.textarea.Value()) == 0 || m.textarea.Value()[len(m.textarea.Value())-1] == ' '):
+ // only show if beginning of prompt, or if previous char is a space or newline:
+ (len(m.textarea.Value()) == 0 || unicode.IsSpace(rune(m.textarea.Value()[len(m.textarea.Value())-1]))):
m.isCompletionsOpen = true
m.currentQuery = ""
- m.completionsStartIndex = len(m.textarea.Value())
+ m.completionsStartIndex = curIdx
cmds = append(cmds, m.startCompletions)
- case m.isCompletionsOpen && m.textarea.Cursor().X <= m.completionsStartIndex:
+ case m.isCompletionsOpen && curIdx <= m.completionsStartIndex:
cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
}
if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) {
@@ -244,6 +256,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if key.Matches(msg, m.keyMap.Newline) {
m.textarea.InsertRune('\n')
+ cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
}
// Handle Enter key
if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) {
@@ -275,12 +288,18 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// XXX: wont' work if editing in the middle of the field.
m.completionsStartIndex = strings.LastIndex(m.textarea.Value(), word)
m.currentQuery = word[1:]
+ x, y := m.completionsPosition()
+ x -= len(m.currentQuery)
m.isCompletionsOpen = true
- cmds = append(cmds, util.CmdHandler(completions.FilterCompletionsMsg{
- Query: m.currentQuery,
- Reopen: m.isCompletionsOpen,
- }))
- } else {
+ cmds = append(cmds,
+ util.CmdHandler(completions.FilterCompletionsMsg{
+ Query: m.currentQuery,
+ Reopen: m.isCompletionsOpen,
+ X: x,
+ Y: y,
+ }),
+ )
+ } else if m.isCompletionsOpen {
m.isCompletionsOpen = false
m.currentQuery = ""
m.completionsStartIndex = 0
@@ -293,6 +312,16 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(cmds...)
}
+func (m *editorCmp) completionsPosition() (int, int) {
+ cur := m.textarea.Cursor()
+ if cur == nil {
+ return m.x, m.y + 1 // adjust for padding
+ }
+ x := cur.X + m.x
+ y := cur.Y + m.y + 1 // adjust for padding
+ return x, y
+}
+
func (m *editorCmp) Cursor() *tea.Cursor {
cursor := m.textarea.Cursor()
if cursor != nil {
@@ -373,9 +402,7 @@ func (m *editorCmp) startCompletions() tea.Msg {
})
}
- cur := m.textarea.Cursor()
- x := cur.X + m.x // adjust for padding
- y := cur.Y + m.y + 1
+ x, y := m.completionsPosition()
return completions.OpenCompletionsMsg{
Completions: completionItems,
X: x,
@@ -27,6 +27,12 @@ type OpenCompletionsMsg struct {
type FilterCompletionsMsg struct {
Query string // The query to filter completions
Reopen bool
+ X int // X position for the completions popup
+ Y int // Y position for the completions popup
+}
+
+type RepositionCompletionsMsg struct {
+ X, Y int
}
type CompletionsClosedMsg struct{}
@@ -53,18 +59,24 @@ type Completions interface {
type listModel = list.FilterableList[list.CompletionItem[any]]
type completionsCmp struct {
- width int
- height int // Height of the completions component`
- x int // X position for the completions popup
- y int // Y position for the completions popup
- open bool // Indicates if the completions are open
- keyMap KeyMap
+ wWidth int // The window width
+ wHeight int // The window height
+ width int
+ lastWidth int
+ height int // Height of the completions component`
+ x, xorig int // X position for the completions popup
+ y int // Y position for the completions popup
+ open bool // Indicates if the completions are open
+ keyMap KeyMap
list listModel
query string // The current filter query
}
-const maxCompletionsWidth = 80 // Maximum width for the completions popup
+const (
+ maxCompletionsWidth = 80 // Maximum width for the completions popup
+ minCompletionsWidth = 20 // Minimum width for the completions popup
+)
func New() Completions {
completionsKeyMap := DefaultKeyMap()
@@ -88,7 +100,7 @@ func New() Completions {
)
return &completionsCmp{
width: 0,
- height: 0,
+ height: maxCompletionsHeight,
list: l,
query: "",
keyMap: completionsKeyMap,
@@ -107,8 +119,7 @@ func (c *completionsCmp) Init() tea.Cmd {
func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
- c.width = min(msg.Width-c.x, maxCompletionsWidth)
- c.height = min(msg.Height-c.y, 15)
+ c.wWidth, c.wHeight = msg.Width, msg.Height
return c, nil
case tea.KeyPressMsg:
switch {
@@ -129,7 +140,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
selectedItem := *s
c.list.SetSelected(selectedItem.ID())
return c, util.CmdHandler(SelectCompletionMsg{
- Value: selectedItem,
+ Value: selectedItem.Value(),
Insert: true,
})
case key.Matches(msg, c.keyMap.DownInsert):
@@ -140,7 +151,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
selectedItem := *s
c.list.SetSelected(selectedItem.ID())
return c, util.CmdHandler(SelectCompletionMsg{
- Value: selectedItem,
+ Value: selectedItem.Value(),
Insert: true,
})
case key.Matches(msg, c.keyMap.Select):
@@ -151,18 +162,21 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
selectedItem := *s
c.open = false // Close completions after selection
return c, util.CmdHandler(SelectCompletionMsg{
- Value: selectedItem,
+ Value: selectedItem.Value(),
})
case key.Matches(msg, c.keyMap.Cancel):
return c, util.CmdHandler(CloseCompletionsMsg{})
}
+ case RepositionCompletionsMsg:
+ c.x, c.y = msg.X, msg.Y
+ c.adjustPosition()
case CloseCompletionsMsg:
c.open = false
return c, util.CmdHandler(CompletionsClosedMsg{})
case OpenCompletionsMsg:
c.open = true
c.query = ""
- c.x = msg.X
+ c.x, c.xorig = msg.X, msg.X
c.y = msg.Y
items := []list.CompletionItem[any]{}
t := styles.CurrentTheme()
@@ -174,10 +188,18 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
)
items = append(items, item)
}
- c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height
+ width := listWidth(items)
+ if len(items) == 0 {
+ width = listWidth(c.list.Items())
+ }
+ if c.x+width >= c.wWidth {
+ c.x = c.wWidth - width - 1
+ }
+ c.width = width
+ c.height = max(min(maxCompletionsHeight, len(items)), 1) // Ensure at least 1 item height
return c, tea.Batch(
- c.list.SetSize(c.width, c.height),
c.list.SetItems(items),
+ c.list.SetSize(c.width, c.height),
util.CmdHandler(CompletionsOpenedMsg{}),
)
case FilterCompletionsMsg:
@@ -201,8 +223,11 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
c.query = msg.Query
var cmds []tea.Cmd
cmds = append(cmds, c.list.Filter(msg.Query))
- itemsLen := len(c.list.Items())
- c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+ items := c.list.Items()
+ itemsLen := len(items)
+ c.xorig = msg.X
+ c.x, c.y = msg.X, msg.Y
+ c.adjustPosition()
cmds = append(cmds, c.list.SetSize(c.width, c.height))
if itemsLen == 0 {
cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{}))
@@ -215,21 +240,50 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return c, nil
}
+func (c *completionsCmp) adjustPosition() {
+ items := c.list.Items()
+ itemsLen := len(items)
+ width := listWidth(items)
+ c.lastWidth = c.width
+ if c.x < 0 || width < c.lastWidth {
+ c.x = c.xorig
+ } else if c.x+width >= c.wWidth {
+ c.x = c.wWidth - width - 1
+ }
+ c.width = width
+ c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+}
+
// View implements Completions.
func (c *completionsCmp) View() string {
if !c.open || len(c.list.Items()) == 0 {
return ""
}
- return c.style().Render(c.list.View())
-}
-
-func (c *completionsCmp) style() lipgloss.Style {
t := styles.CurrentTheme()
- return t.S().Base.
+ style := t.S().Base.
Width(c.width).
Height(c.height).
Background(t.BgSubtle)
+
+ return style.Render(c.list.View())
+}
+
+// listWidth returns the width of the last 10 items in the list, which is used
+// to determine the width of the completions popup.
+// Note this only works for [completionItemCmp] items.
+func listWidth(items []list.CompletionItem[any]) int {
+ var width int
+ if len(items) == 0 {
+ return width
+ }
+
+ for i := len(items) - 1; i >= 0 && i >= len(items)-10; i-- {
+ itemWidth := lipgloss.Width(items[i].Text()) + 2 // +2 for padding
+ width = max(width, itemWidth)
+ }
+
+ return width
}
func (c *completionsCmp) Open() bool {
@@ -23,6 +23,7 @@ type CompletionItem[T any] interface {
layout.Sizeable
HasMatchIndexes
Value() T
+ Text() string
}
type completionItemCmp[T any] struct {
@@ -312,6 +313,10 @@ func (c *completionItemCmp[T]) ID() string {
return c.id
}
+func (c *completionItemCmp[T]) Text() string {
+ return c.text
+}
+
type ItemSection interface {
Item
layout.Sizeable
@@ -173,7 +173,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
return p, nil
case tea.WindowSizeMsg:
- return p, p.SetSize(msg.Width, msg.Height)
+ u, cmd := p.editor.Update(msg)
+ p.editor = u.(editor.Editor)
+ return p, tea.Batch(p.SetSize(msg.Width, msg.Height), cmd)
case CancelTimerExpiredMsg:
p.isCanceling = false
return p, nil
@@ -111,19 +111,10 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return a, a.handleWindowResize(msg.Width, msg.Height)
// Completions messages
- case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg:
+ case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg,
+ completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg:
u, completionCmd := a.completions.Update(msg)
a.completions = u.(completions.Completions)
- switch msg := msg.(type) {
- case completions.OpenCompletionsMsg:
- x, _ := a.completions.Position()
- if a.completions.Width()+x >= a.wWidth {
- // Adjust X position to fit in the window.
- msg.X = a.wWidth - a.completions.Width() - 1
- u, completionCmd = a.completions.Update(msg)
- a.completions = u.(completions.Completions)
- }
- }
return a, completionCmd
// Dialog messages