diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 42fe1ce67f87406b3e9cb7b5c05ea7f8659a2197..9f3a73db7fe4ac7581dce6d38cc72c7ec7e55ec1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,7 +2,7 @@ name: build on: [push, pull_request] jobs: - build-go: + build: uses: charmbracelet/meta/.github/workflows/build.yml@main with: go-version: "" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 05632188b0c45704cf73f307182aeacfb2b857fa..38cc2ce2603318c339d50a5e5b412110653ab748 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,3 +20,4 @@ jobs: fury_token: ${{ secrets.FURY_TOKEN }} nfpm_gpg_key: ${{ secrets.NFPM_GPG_KEY }} nfpm_passphrase: ${{ secrets.NFPM_PASSPHRASE }} + npm_token: ${{ secrets.NPM_TOKEN }} diff --git a/.goreleaser.yml b/.goreleaser.yml index 1b215ab33318126e67bb799622e41757b124ad73..827574aab39b51ff3bb32a9faf98da7773bad605 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -45,7 +45,7 @@ builds: goos: - linux - darwin - # - windows + - windows goarch: - amd64 - arm64 @@ -111,8 +111,15 @@ homebrew_casks: owner: charmbracelet name: homebrew-tap +npms: + - name: "@charmland/crush" + repository: "git+https://github.com/charmbracelet/crush.git" + bugs: https://github.com/charmbracelet/crush/issues + access: public + nfpms: - formats: + - apk - deb - rpm - archlinux diff --git a/README.md b/README.md index 534b8c36311e1b0559796ce4df0fba8ed0773611..c693a94cf4349fdf455975b0abfd1aa7499476e7 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,12 @@ Crush is a tool for building software with AI. ## Installation +Crush has first class support for macOS, Linux, and Windows. + Nightly builds are available while Crush is in development. -- [Packages](https://github.com/charmbracelet/crush/releases/tag/nightly) are available in Debian and RPM formats -- [Binaries](https://github.com/charmbracelet/crush/releases/tag/nightly) are available for Linux and macOS +- [Packages](https://github.com/charmbracelet/crush/releases/tag/nightly) are available in Debian, RPM, APK, and PKG formats +- [Binaries](https://github.com/charmbracelet/crush/releases/tag/nightly) are available for Linux, macOS and Windows You can also just install it with go: @@ -28,7 +30,7 @@ go install
Not a developer? Here’s a quick how-to. -Download the latest [nightly release](https://github.com/charmbracelet/crush/releases) for your system. The [macOS ARM64](https://github.com/charmbracelet/crush/releases/download/nightly/crush_0.1.0-nightly_Darwin_arm64.tar.gz) is most likely what you want. +Download the latest [nightly release](https://github.com/charmbracelet/crush/releases) for your system. The [macOS ARM64 one](https://github.com/charmbracelet/crush/releases/download/nightly/crush_0.1.0-nightly_Darwin_arm64.tar.gz) is most likely what you want. Next, open a terminal and run the following commands: @@ -36,17 +38,15 @@ Next, open a terminal and run the following commands: cd ~/Downloads tar -xvzf crush_0.1.0-nightly_Darwin_arm64.tar.gz -C crush sudo mv ./crush/crush /usr/local/bin/crush -rm -rf crush +rm -rf ./crush ``` Then, run Crush by typing `crush`. -*** +---
-Note that Crush doesn't support Windows yet, however Windows support is planned and in progress. - ## Getting Started The quickest way to get started to grab an API key for your preferred @@ -108,7 +108,7 @@ Crush supports Model Context Protocol (MCP) servers through three transport type "mcp": { "filesystem": { "type": "stdio", - "command": "node", + "command": "node", "args": ["/path/to/mcp-server.js"], "env": { "NODE_ENV": "production" @@ -143,7 +143,7 @@ crush -d # View last 1000 lines crush logs -# Follow logs in real-time +# Follow logs in real-time crush logs -f # Show last 500 lines @@ -161,6 +161,31 @@ Add to your `crush.json` config file: } ``` +### Configurable Default Permissions + +Crush includes a permission system to control which tools can be executed without prompting. You can configure allowed tools in your `crush.json` config file: + +```json +{ + "permissions": { + "allowed_tools": [ + "view", + "ls", + "grep", + "edit:write", + "mcp_context7_get-library-doc" + ] + } +} +``` + +The `allowed_tools` array accepts: + +- Tool names (e.g., `"view"`) - allows all actions for that tool +- Tool:action combinations (e.g., `"edit:write"`) - allows only specific actions + +You can also skip all permission prompts entirely by running Crush with the `--yolo` flag. + ### OpenAI-Compatible APIs Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment. @@ -174,7 +199,7 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D "models": [ { "id": "deepseek-chat", - "model": "Deepseek V3", + "name": "Deepseek V3", "cost_per_1m_in": 0.27, "cost_per_1m_out": 1.1, "cost_per_1m_in_cached": 0.07, diff --git a/go.mod b/go.mod index c2b6fa54ed62365230814e41aa3295a096514c17..e17354c051a21b593a385b1e3995cc543aafd0dd 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/charlievieth/fastwalk v1.0.11 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac + github.com/charmbracelet/catwalk v0.3.1 github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112 @@ -39,6 +40,7 @@ require ( github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef github.com/stretchr/testify v1.10.0 github.com/tidwall/sjson v1.2.5 + github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc github.com/zeebo/xxh3 v1.0.2 golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -70,7 +72,7 @@ require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.3.1 // indirect - github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42 // indirect + github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1 // indirect github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef github.com/charmbracelet/x/term v0.2.1 @@ -108,7 +110,7 @@ require ( github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect github.com/spf13/cast v1.7.1 // indirect - github.com/spf13/pflag v1.0.6 // indirect + github.com/spf13/pflag v1.0.7 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 0d05967cdf1a80070dab737b1c473d0c39c20611..755edeb81ead60da60196e2834c9e6354af168b7 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5/go.mod h1:6HamsBKWqEC/FVHuQMHgQL+knPyvHH55HwJDHl/adMw= github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac h1:murtkvFYxZ/73vk4Z/tpE4biB+WDZcFmmBp8je/yV6M= github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac/go.mod h1:m240IQxo1/eDQ7klblSzOCAUyc3LddHcV3Rc/YEGAgw= +github.com/charmbracelet/catwalk v0.3.1 h1:MkGWspcMyE659zDkqS+9wsaCMTKRFEDBFY2A2sap6+U= +github.com/charmbracelet/catwalk v0.3.1/go.mod h1:gUUCqqZ8bk4D7ZzGTu3I77k7cC2x4exRuJBN1H2u2pc= github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40= github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0= github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0= @@ -82,8 +84,8 @@ github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112 github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112/go.mod h1:BXY7j7rZgAprFwzNcO698++5KTd6GKI6lU83Pr4o0r0= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM= -github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42 h1:Zqw2oP9Wo8VzMijVJbtIJcAaZviYyU07stvmCFCfn0Y= -github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc= +github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1 h1:tsw1mOuIEIKlmm614bXctvJ3aavaFhyPG+y+wrKtuKQ= +github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc= github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa h1:lphz0Z3rsiOtMYiz8axkT24i9yFiueDhJbzyNUADmME= @@ -234,8 +236,9 @@ github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= -github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.7 h1:vN6T9TfwStFPFM5XzjsvmzZkLuaLX+HS+0SeFLRgU6M= +github.com/spf13/pflag v1.0.7/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= @@ -258,6 +261,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc h1:HjI/UCF4dRyzizePQrhGUSQvuU7z4tOqMqz6GRGlFCM= +github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc/go.mod h1:/0Qr7qJeDwWxoKku2xKQ4Szc+SwBE3g9VE8jNiamsmc= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/internal/ansiext/ansi.go b/internal/ansiext/ansi.go new file mode 100644 index 0000000000000000000000000000000000000000..4ec76a70ebf6f4edd963e1542ec83eaa09bd8ebf --- /dev/null +++ b/internal/ansiext/ansi.go @@ -0,0 +1,25 @@ +package ansiext + +import ( + "strings" + + "github.com/charmbracelet/x/ansi" +) + +// Escape replaces control characters with their Unicode Control Picture +// representations to ensure they are displayed correctly in the UI. +func Escape(content string) string { + var sb strings.Builder + sb.Grow(len(content)) + for _, r := range content { + switch { + case r >= 0 && r <= 0x1f: // Control characters 0x00-0x1F + sb.WriteRune('\u2400' + r) + case r == ansi.DEL: + sb.WriteRune('\u2421') + default: + sb.WriteRune(r) + } + } + return sb.String() +} diff --git a/internal/app/app.go b/internal/app/app.go index 170debad340e2cb33fcb7a1c9fe814c184573c9b..50e117ea1ae272156dbd11baa1a5f157a74333f1 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -59,13 +59,17 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { sessions := session.NewService(q) messages := message.NewService(q) files := history.NewService(q, conn) - skipPermissionsRequests := cfg.Options != nil && cfg.Options.SkipPermissionsRequests + skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests + allowedTools := []string{} + if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil { + allowedTools = cfg.Permissions.AllowedTools + } app := &App{ Sessions: sessions, Messages: messages, History: files, - Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests), + Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), LSPClients: make(map[string]*lsp.Client), globalCtx: ctx, @@ -157,16 +161,20 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool if result.Error != nil { if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { - slog.Info("Agent processing cancelled", "session_id", sess.ID) + slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID) return nil } return fmt.Errorf("agent processing failed: %w", result.Error) } - part := result.Message.Content().String()[readBts:] - fmt.Println(part) + msgContent := result.Message.Content().String() + if len(msgContent) < readBts { + slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts) + return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts) + } + fmt.Println(msgContent[readBts:]) - slog.Info("Non-interactive run completed", "session_id", sess.ID) + slog.Info("Non-interactive: run completed", "session_id", sess.ID) return nil case event := <-messageEvents: diff --git a/internal/cmd/root.go b/internal/cmd/root.go index d63160992141da26b6a26610b06f1b601213e00d..c6c24d5963c57981b1e91911146c1893728ffe37 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -73,7 +73,10 @@ to assist developers in writing, debugging, and understanding code directly from if err != nil { return err } - cfg.Options.SkipPermissionsRequests = yolo + if cfg.Permissions == nil { + cfg.Permissions = &config.Permissions{} + } + cfg.Permissions.SkipRequests = yolo ctx := cmd.Context() @@ -85,14 +88,14 @@ to assist developers in writing, debugging, and understanding code directly from app, err := app.New(ctx, conn, cfg) if err != nil { - slog.Error(fmt.Sprintf("Failed to create app instance: %v", err)) + slog.Error("Failed to create app instance", "error", err) return err } defer app.Shutdown() prompt, err = maybePrependStdin(prompt) if err != nil { - slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err)) + slog.Error("Failed to read from stdin", "error", err) return err } @@ -114,7 +117,7 @@ to assist developers in writing, debugging, and understanding code directly from go app.Subscribe(program) if _, err := program.Run(); err != nil { - slog.Error(fmt.Sprintf("TUI run error: %v", err)) + slog.Error("TUI run error", "error", err) return fmt.Errorf("TUI error: %v", err) } return nil diff --git a/internal/config/config.go b/internal/config/config.go index 18eca04912189415606599c5849e8a7beb592cb4..b9d44bc87448d3244d27c426bf0f70dc98ce064a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,9 +9,9 @@ import ( "strings" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/tidwall/sjson" "golang.org/x/exp/slog" ) @@ -71,7 +71,7 @@ type ProviderConfig struct { // The provider's API endpoint. BaseURL string `json:"base_url,omitempty"` // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. - Type provider.Type `json:"type,omitempty"` + Type catwalk.Type `json:"type,omitempty"` // The provider's API key. APIKey string `json:"api_key,omitempty"` // Marks the provider as disabled. @@ -86,7 +86,7 @@ type ProviderConfig struct { ExtraParams map[string]string `json:"-"` // The provider models - Models []provider.Model `json:"models,omitempty"` + Models []catwalk.Model `json:"models,omitempty"` } type MCPType string @@ -121,14 +121,18 @@ type TUIOptions struct { // Here we can add themes later or any TUI related options } +type Permissions struct { + AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts + SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode) +} + type Options struct { - ContextPaths []string `json:"context_paths,omitempty"` - TUI *TUIOptions `json:"tui,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debug_lsp,omitempty"` - DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"` - DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd - SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode) + ContextPaths []string `json:"context_paths,omitempty"` + TUI *TUIOptions `json:"tui,omitempty"` + Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debug_lsp,omitempty"` + DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"` + DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd } type MCPs map[string]MCPConfig @@ -245,14 +249,16 @@ type Config struct { Options *Options `json:"options,omitempty"` + Permissions *Permissions `json:"permissions,omitempty"` + // Internal workingDir string `json:"-"` // TODO: most likely remove this concept when I come back to it Agents map[string]Agent `json:"-"` // TODO: find a better way to do this this should probably not be part of the config resolver VariableResolver - dataConfigDir string `json:"-"` - knownProviders []provider.Provider `json:"-"` + dataConfigDir string `json:"-"` + knownProviders []catwalk.Provider `json:"-"` } func (c *Config) WorkingDir() string { @@ -274,7 +280,7 @@ func (c *Config) IsConfigured() bool { return len(c.EnabledProviders()) > 0 } -func (c *Config) GetModel(provider, model string) *provider.Model { +func (c *Config) GetModel(provider, model string) *catwalk.Model { if providerConfig, ok := c.Providers.Get(provider); ok { for _, m := range providerConfig.Models { if m.ID == model { @@ -296,7 +302,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi return nil } -func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model { +func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model { model, ok := c.Models[modelType] if !ok { return nil @@ -304,7 +310,7 @@ func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) LargeModel() *provider.Model { +func (c *Config) LargeModel() *catwalk.Model { model, ok := c.Models[SelectedModelTypeLarge] if !ok { return nil @@ -312,7 +318,7 @@ func (c *Config) LargeModel() *provider.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SmallModel() *provider.Model { +func (c *Config) SmallModel() *catwalk.Model { model, ok := c.Models[SelectedModelTypeSmall] if !ok { return nil @@ -378,7 +384,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return nil } - var foundProvider *provider.Provider + var foundProvider *catwalk.Provider for _, p := range c.knownProviders { if string(p.ID) == providerID { foundProvider = &p @@ -447,14 +453,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { headers := make(map[string]string) apiKey, _ := resolver.ResolveValue(c.APIKey) switch c.Type { - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: baseURL, _ := resolver.ResolveValue(c.BaseURL) if baseURL == "" { baseURL = "https://api.openai.com/v1" } testURL = baseURL + "/models" headers["Authorization"] = "Bearer " + apiKey - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: baseURL, _ := resolver.ResolveValue(c.BaseURL) if baseURL == "" { baseURL = "https://api.anthropic.com/v1" diff --git a/internal/config/load.go b/internal/config/load.go index 48ef9b1caf1e5d9ec1877f7fc9c3a53ab996d129..5d11901dd4d169041d315eb139d27e6dbec736de 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -12,12 +12,14 @@ import ( "strings" "sync" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/log" ) +const catwalkURL = "https://catwalk.charm.sh" + // LoadReader config via io.Reader. func LoadReader(fd io.Reader) (*Config, error) { data, err := io.ReadAll(fd) @@ -61,7 +63,7 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg.Options.Debug, ) - // Load known providers, this loads the config from fur + // Load known providers, this loads the config from catwalk providers, err := Providers() if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) @@ -97,7 +99,7 @@ func (c *Config) removeUnresponsiveProviders() { slog.Info("Testing provider connections") defer slog.Info("Provider connection tests completed") for _, p := range c.Providers.Seq2() { - if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic { wg.Add(1) go func(provider ProviderConfig) { defer wg.Done() @@ -122,7 +124,7 @@ func (c *Config) removeUnresponsiveProviders() { }) } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { +func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true @@ -141,7 +143,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know p.APIKey = config.APIKey } if len(config.Models) > 0 { - models := []provider.Model{} + models := []catwalk.Model{} seen := make(map[string]bool) for _, model := range config.Models { @@ -149,8 +151,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know continue } seen[model.ID] = true - if model.Model == "" { - model.Model = model.ID + if model.Name == "" { + model.Name = model.ID } models = append(models, model) } @@ -159,8 +161,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know continue } seen[model.ID] = true - if model.Model == "" { - model.Model = model.ID + if model.Name == "" { + model.Name = model.ID } models = append(models, model) } @@ -183,7 +185,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch p.ID { // Handle specific providers that require additional configuration - case provider.InferenceProviderVertexAI: + case catwalk.InferenceProviderVertexAI: if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") @@ -193,7 +195,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT") prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION") - case provider.InferenceProviderAzure: + case catwalk.InferenceProviderAzure: endpoint, err := resolver.ResolveValue(p.APIEndpoint) if err != nil || endpoint == "" { if configExists { @@ -204,7 +206,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } prepared.BaseURL = endpoint prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION") - case provider.InferenceProviderBedrock: + case catwalk.InferenceProviderBedrock: if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") @@ -244,7 +246,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } // default to OpenAI if not set if providerConfig.Type == "" { - providerConfig.Type = provider.TypeOpenAI + providerConfig.Type = catwalk.TypeOpenAI } if providerConfig.Disable { @@ -265,7 +267,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know c.Providers.Del(id) continue } - if providerConfig.Type != provider.TypeOpenAI { + if providerConfig.Type != catwalk.TypeOpenAI { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) c.Providers.Del(id) continue @@ -320,7 +322,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths) } -func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { +func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return @@ -389,7 +391,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg return } -func (c *Config) configureSelectedModels(knownProviders []provider.Provider) error { +func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 86a2356da2021dc22de88de05a80717e95aa492a..5a52426f51ace9ee9e26bb42208511a72009dc3b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -8,9 +8,9 @@ import ( "strings" "testing" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" ) @@ -56,12 +56,12 @@ func TestConfig_setDefaults(t *testing.T) { } func TestConfig_configureProviders(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -83,12 +83,12 @@ func TestConfig_configureProviders(t *testing.T) { } func TestConfig_configureProvidersWithOverride(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -100,10 +100,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { cfg.Providers.Set("openai", ProviderConfig{ APIKey: "xyz", BaseURL: "https://api.openai.com/v2", - Models: []provider.Model{ + Models: []catwalk.Model{ { - ID: "test-model", - Model: "Updated", + ID: "test-model", + Name: "Updated", }, { ID: "another-model", @@ -125,16 +125,16 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { assert.Equal(t, "xyz", pc.APIKey) assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) assert.Len(t, pc.Models, 2) - assert.Equal(t, "Updated", pc.Models[0].Model) + assert.Equal(t, "Updated", pc.Models[0].Name) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -145,7 +145,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "test-model", }, @@ -176,12 +176,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -205,12 +205,12 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { } func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -227,12 +227,12 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "some-random-model", }}, }, @@ -250,12 +250,12 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { } func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -282,12 +282,12 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { } func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -308,12 +308,12 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -333,12 +333,12 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { } func TestConfig_configureProvidersSetProviderID(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -455,12 +455,12 @@ func TestConfig_IsConfigured(t *testing.T) { } func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -494,7 +494,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -507,7 +507,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 1) @@ -520,7 +520,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -530,7 +530,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 0) @@ -544,7 +544,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{}, + Models: []catwalk.Model{}, }, }), } @@ -552,7 +552,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 0) @@ -567,7 +567,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Type: "unsupported", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -577,7 +577,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 0) @@ -591,8 +591,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Type: provider.TypeOpenAI, - Models: []provider.Model{{ + Type: catwalk.TypeOpenAI, + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -602,7 +602,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 1) @@ -619,9 +619,9 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Type: provider.TypeOpenAI, + Type: catwalk.TypeOpenAI, Disable: true, - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -631,7 +631,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 0) @@ -642,12 +642,12 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -675,12 +675,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -706,12 +706,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("provider removed when API key missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -737,12 +737,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "$MISSING_ENDPOINT", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -772,13 +772,13 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { func TestConfig_defaultModelSelection(t *testing.T) { t.Run("default behavior uses the default models for given provider", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -808,13 +808,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Equal(t, int64(500), small.MaxTokens) }) t.Run("should error if no providers configured", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_KEY", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -838,13 +838,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Error(t, err) }) t.Run("should error if model is missing", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -868,13 +868,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { }) t.Run("should configure the default models with a custom provider", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -892,7 +892,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "model", DefaultMaxTokens: 600, @@ -917,13 +917,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { }) t.Run("should fail if no model configured", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -941,7 +941,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{}, + Models: []catwalk.Model{}, }, }), } @@ -954,13 +954,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Error(t, err) }) t.Run("should use the default provider first", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "set", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -978,7 +978,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -1005,13 +1005,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { func TestConfig_configureSelectedModels(t *testing.T) { t.Run("should override defaults", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "larger-model", DefaultMaxTokens: 2000, @@ -1053,13 +1053,13 @@ func TestConfig_configureSelectedModels(t *testing.T) { assert.Equal(t, int64(500), small.MaxTokens) }) t.Run("should be possible to use multiple providers", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -1075,7 +1075,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { APIKey: "abc", DefaultLargeModelID: "a-large-model", DefaultSmallModelID: "a-small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "a-large-model", DefaultMaxTokens: 1000, @@ -1116,13 +1116,13 @@ func TestConfig_configureSelectedModels(t *testing.T) { }) t.Run("should override the max tokens only", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, diff --git a/internal/config/provider.go b/internal/config/provider.go index caeba48707be933d222313729934cc69c819f68e..98235cd84794812128082533f3a501bfce952cb8 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -10,17 +10,16 @@ import ( "sync" "time" - "github.com/charmbracelet/crush/internal/fur/client" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" ) type ProviderClient interface { - GetProviders() ([]provider.Provider, error) + GetProviders() ([]catwalk.Provider, error) } var ( providerOnce sync.Once - providerList []provider.Provider + providerList []catwalk.Provider ) // file to cache provider data @@ -44,7 +43,7 @@ func providerCacheFileData() string { return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json") } -func saveProvidersInCache(path string, providers []provider.Provider) error { +func saveProvidersInCache(path string, providers []catwalk.Provider) error { slog.Info("Caching provider data") if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("failed to create directory for provider cache: %w", err) @@ -61,26 +60,26 @@ func saveProvidersInCache(path string, providers []provider.Provider) error { return nil } -func loadProvidersFromCache(path string) ([]provider.Provider, error) { +func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read provider cache file: %w", err) } - var providers []provider.Provider + var providers []catwalk.Provider if err := json.Unmarshal(data, &providers); err != nil { return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } return providers, nil } -func Providers() ([]provider.Provider, error) { - client := client.New() +func Providers() ([]catwalk.Provider, error) { + client := catwalk.NewWithURL(catwalkURL) path := providerCacheFileData() return loadProvidersOnce(client, path) } -func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) { +func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) { var err error providerOnce.Do(func() { providerList, err = loadProviders(client, path) @@ -91,7 +90,7 @@ func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, return providerList, nil } -func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) { +func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) { // if cache is not stale, load from it stale, exists := isCacheStale(path) if !stale { diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index 480869d98e4d69087aefc5759de0776f7910ebec..cb71cabfa5a01cb16b6ef2b6708d1780e31951a9 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -5,14 +5,14 @@ import ( "os" "testing" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/stretchr/testify/require" ) type emptyProviderClient struct{} -func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) { - return []provider.Provider{}, nil +func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) { + return []catwalk.Provider{}, nil } func TestProvider_loadProvidersEmptyResult(t *testing.T) { @@ -33,7 +33,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) { tmpPath := t.TempDir() + "/providers.json" // Create an empty cache file - emptyProviders := []provider.Provider{} + emptyProviders := []catwalk.Provider{} data, err := json.Marshal(emptyProviders) require.NoError(t, err) require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index abfb6592bcd5e46a7cbf40dba54a10722ee69980..e6a1f331716d88285ef4c9929a23a474ed3597a0 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/stretchr/testify/assert" ) @@ -14,11 +14,11 @@ type mockProviderClient struct { shouldFail bool } -func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { +func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { if m.shouldFail { return nil, errors.New("failed to load providers") } - return []provider.Provider{ + return []catwalk.Provider{ { Name: "Mock", }, @@ -43,7 +43,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" // store providers to a temporary file - oldProviders := []provider.Provider{ + oldProviders := []catwalk.Provider{ { Name: "OldProvider", }, diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go deleted file mode 100644 index d007c9aee18f77c8b03fe804726b4196e474d0b4..0000000000000000000000000000000000000000 --- a/internal/fur/client/client.go +++ /dev/null @@ -1,63 +0,0 @@ -// Package client provides a client for interacting with the fur service. -package client - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/charmbracelet/crush/internal/fur/provider" -) - -const defaultURL = "https://fur.charm.sh" - -// Client represents a client for the fur service. -type Client struct { - baseURL string - httpClient *http.Client -} - -// New creates a new client instance -// Uses FUR_URL environment variable or falls back to localhost:8080. -func New() *Client { - baseURL := os.Getenv("FUR_URL") - if baseURL == "" { - baseURL = defaultURL - } - - return &Client{ - baseURL: baseURL, - httpClient: &http.Client{}, - } -} - -// NewWithURL creates a new client with a specific URL. -func NewWithURL(url string) *Client { - return &Client{ - baseURL: url, - httpClient: &http.Client{}, - } -} - -// GetProviders retrieves all available providers from the service. -func (c *Client) GetProviders() ([]provider.Provider, error) { - url := fmt.Sprintf("%s/providers", c.baseURL) - - resp, err := c.httpClient.Get(url) //nolint:noctx - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - defer resp.Body.Close() //nolint:errcheck - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var providers []provider.Provider - if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - return providers, nil -} diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go deleted file mode 100644 index 2bfe95a5bc3db4f1e52feebcaf7d484f4d5de948..0000000000000000000000000000000000000000 --- a/internal/fur/provider/provider.go +++ /dev/null @@ -1,75 +0,0 @@ -// Package provider provides types and constants for AI providers. -package provider - -// Type represents the type of AI provider. -type Type string - -// All the supported AI provider types. -const ( - TypeOpenAI Type = "openai" - TypeAnthropic Type = "anthropic" - TypeGemini Type = "gemini" - TypeAzure Type = "azure" - TypeBedrock Type = "bedrock" - TypeVertexAI Type = "vertexai" - TypeXAI Type = "xai" -) - -// InferenceProvider represents the inference provider identifier. -type InferenceProvider string - -// All the inference providers supported by the system. -const ( - InferenceProviderOpenAI InferenceProvider = "openai" - InferenceProviderAnthropic InferenceProvider = "anthropic" - InferenceProviderGemini InferenceProvider = "gemini" - InferenceProviderAzure InferenceProvider = "azure" - InferenceProviderBedrock InferenceProvider = "bedrock" - InferenceProviderVertexAI InferenceProvider = "vertexai" - InferenceProviderXAI InferenceProvider = "xai" - InferenceProviderGROQ InferenceProvider = "groq" - InferenceProviderOpenRouter InferenceProvider = "openrouter" -) - -// Provider represents an AI provider configuration. -type Provider struct { - Name string `json:"name"` - ID InferenceProvider `json:"id"` - APIKey string `json:"api_key,omitempty"` - APIEndpoint string `json:"api_endpoint,omitempty"` - Type Type `json:"type,omitempty"` - DefaultLargeModelID string `json:"default_large_model_id,omitempty"` - DefaultSmallModelID string `json:"default_small_model_id,omitempty"` - Models []Model `json:"models,omitempty"` -} - -// Model represents an AI model configuration. -type Model struct { - ID string `json:"id"` - Model string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - HasReasoningEffort bool `json:"has_reasoning_efforts"` - DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"` - SupportsImages bool `json:"supports_attachments"` -} - -// KnownProviders returns all the known inference providers. -func KnownProviders() []InferenceProvider { - return []InferenceProvider{ - InferenceProviderOpenAI, - InferenceProviderAnthropic, - InferenceProviderGemini, - InferenceProviderAzure, - InferenceProviderBedrock, - InferenceProviderVertexAI, - InferenceProviderXAI, - InferenceProviderGROQ, - InferenceProviderOpenRouter, - } -} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 72697cb0ac801f013a094dc5c44a3152f1443af1..75f1f545929cc2422461ed0e775775689f8567d2 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -10,9 +10,9 @@ import ( "sync" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" - fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -23,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/shell" ) // Common errors @@ -52,7 +53,7 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() fur.Model + Model() catwalk.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() @@ -226,7 +227,7 @@ func NewAgent( }, nil } -func (a *agent) Model() fur.Model { +func (a *agent) Model() catwalk.Model { return *config.Get().GetModelByType(a.agentCfg.Model) } @@ -234,7 +235,7 @@ func (a *agent) Cancel(sessionID string) { // Cancel regular requests if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) + slog.Info("Request cancellation initiated", "session_id", sessionID) cancel() } } @@ -242,7 +243,7 @@ func (a *agent) Cancel(sessionID string) { // Also check for summarize requests if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID)) + slog.Info("Summarize cancellation initiated", "session_id", sessionID) cancel() } } @@ -372,7 +373,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string }) titleErr := a.generateTitle(context.Background(), sessionID, content) if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) { - slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr)) + slog.Error("failed to generate title", "error", titleErr) } }() } @@ -645,7 +646,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -770,6 +771,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) return } + shell := shell.GetPersistentShell(config.Get().WorkingDir()) + summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir() event = AgentEvent{ Type: AgentEventTypeSummarize, Progress: "Creating new session...", diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index f4284faccee052e82e8ed82a820b16af58ccc64c..2ffbf2111931ad111751af1bfcd492422da205ee 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,17 +9,17 @@ import ( "runtime" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" ) func CoderPrompt(p string, contextFiles ...string) string { var basePrompt string switch p { - case string(provider.InferenceProviderOpenAI): + case string(catwalk.InferenceProviderOpenAI): basePrompt = baseOpenAICoderPrompt - case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI): + case string(catwalk.InferenceProviderGemini), string(catwalk.InferenceProviderVertexAI): basePrompt = baseGeminiCoderPrompt default: basePrompt = baseAnthropicCoderPrompt diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 1e8364b08cb76ec7210d9937302cd1c647857b2d..0765389a05ecaf33c6c521770e1880a24210d35f 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -15,8 +15,8 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -71,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic var contentBlocks []anthropic.ContentBlockParamUnion contentBlocks = append(contentBlocks, content) for _, binaryContent := range msg.BinaryContent() { - base64Image := binaryContent.String(provider.InferenceProviderAnthropic) + base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic) imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image) contentBlocks = append(contentBlocks, imageBlock) } @@ -248,7 +248,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, return nil, retryErr } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): return nil, ctx.Err() @@ -401,7 +401,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message return } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): // context cancelled @@ -529,6 +529,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } -func (a *anthropicClient) Model() provider.Model { +func (a *anthropicClient) Model() catwalk.Model { return a.providerOptions.model(a.providerOptions.modelType) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 0c0ccdbab2d642f139a2b1ab2f19f6298f1ac73d..8b5b21c36a390e80843504c7c9f6c257156f6379 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -6,8 +6,8 @@ import ( "fmt" "strings" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -32,7 +32,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } } - opts.model = func(modelType config.SelectedModelType) provider.Model { + opts.model = func(modelType config.SelectedModelType) catwalk.Model { model := config.Get().GetModelByType(modelType) // Prefix the model name with region @@ -88,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, return b.childProvider.stream(ctx, messages, tools) } -func (b *bedrockClient) Model() provider.Model { +func (b *bedrockClient) Model() catwalk.Model { return b.providerOptions.model(b.providerOptions.modelType) } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index d2aee5090029e207ef1bdf5e0dad8e011e763267..4fa0cff4d17c28da16528d33ff54e2a905521387 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -10,8 +10,8 @@ import ( "strings" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/google/uuid" @@ -210,7 +210,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too return nil, retryErr } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): return nil, ctx.Err() @@ -323,7 +323,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t return } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): if ctx.Err() != nil { @@ -463,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func (g *geminiClient) Model() provider.Model { +func (g *geminiClient) Model() catwalk.Model { return g.providerOptions.model(g.providerOptions.modelType) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index f55914520774e2fcf5e6283e22365f4ce3621dc1..397d6954d0a5c8f3dbe25f4a34115ade4c242012 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -9,8 +9,8 @@ import ( "log/slog" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" @@ -66,7 +66,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) for _, binaryContent := range msg.BinaryContent() { - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)} + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)} imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) @@ -222,7 +222,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too return nil, retryErr } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): return nil, ctx.Err() @@ -395,7 +395,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t return } if retry { - slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries)) + slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries) select { case <-ctx.Done(): // context cancelled @@ -486,6 +486,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { } } -func (o *openaiClient) Model() provider.Model { +func (o *openaiClient) Model() catwalk.Model { return o.providerOptions.model(o.providerOptions.modelType) } diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index c11e8ff14d7995859cccd3c95eeae4008fb20ac9..26c4d85ae35bbf4681719a12b568befccd8012af 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" @@ -55,10 +55,10 @@ func TestOpenAIClientStreamChoices(t *testing.T) { modelType: config.SelectedModelTypeLarge, apiKey: "test-key", systemMessage: "test", - model: func(config.SelectedModelType) provider.Model { - return provider.Model{ - ID: "test-model", - Model: "test-model", + model: func(config.SelectedModelType) catwalk.Model { + return catwalk.Model{ + ID: "test-model", + Name: "test-model", } }, }, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 412093334169b4c0d59fdd4f3f72b1e427651307..062c2aa977c6ff101d1d8ab6f32809845bd48ff3 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,8 +4,8 @@ import ( "context" "fmt" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -57,7 +57,7 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() provider.Model + Model() catwalk.Model } type providerClientOptions struct { @@ -65,7 +65,7 @@ type providerClientOptions struct { config config.ProviderConfig apiKey string modelType config.SelectedModelType - model func(config.SelectedModelType) provider.Model + model func(config.SelectedModelType) catwalk.Model disableCache bool systemMessage string maxTokens int64 @@ -80,7 +80,7 @@ type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() provider.Model + Model() catwalk.Model } type baseProvider[C ProviderClient] struct { @@ -109,7 +109,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func (p *baseProvider[C]) Model() provider.Model { +func (p *baseProvider[C]) Model() catwalk.Model { return p.client.Model() } @@ -149,7 +149,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, extraBody: cfg.ExtraBody, - model: func(tp config.SelectedModelType) provider.Model { + model: func(tp config.SelectedModelType) catwalk.Model { return *config.Get().GetModelByType(tp) }, } @@ -157,37 +157,37 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi o(&clientOptions) } switch cfg.Type { - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, client: newAnthropicClient(clientOptions, false), }, nil - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil - case provider.TypeGemini: + case catwalk.TypeGemini: return &baseProvider[GeminiClient]{ options: clientOptions, client: newGeminiClient(clientOptions), }, nil - case provider.TypeBedrock: + case catwalk.TypeBedrock: return &baseProvider[BedrockClient]{ options: clientOptions, client: newBedrockClient(clientOptions), }, nil - case provider.TypeAzure: + case catwalk.TypeAzure: return &baseProvider[AzureClient]{ options: clientOptions, client: newAzureClient(clientOptions), }, nil - case provider.TypeVertexAI: + case catwalk.TypeVertexAI: return &baseProvider[VertexAIClient]{ options: clientOptions, client: newVertexAIClient(clientOptions), }, nil - case provider.TypeXAI: + case catwalk.TypeXAI: clientOptions.baseURL = "https://api.x.ai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 10051a24bce881b9bdc4a8990364d30dec92bc85..99ab86068a5effa1e631037f3340ba814055d709 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log/slog" "strings" "time" @@ -23,8 +22,10 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - StartTime int64 `json:"start_time"` - EndTime int64 `json:"end_time"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` + Output string `json:"output"` + WorkingDirectory string `json:"working_directory"` } type bashTool struct { permissions permission.Service @@ -146,6 +147,7 @@ Before executing the command, please follow these steps: 5. Return Result: - Provide the processed output of the command. - If any errors occurred during execution, include those in the output. + - The result will also have metadata like the cwd (current working directory) at the end, included with tags. Usage notes: - The command argument is required. @@ -389,9 +391,12 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond) defer cancel() } - stdout, stderr, err := shell. - GetPersistentShell(b.workingDir). - Exec(ctx, params.Command) + + persistentShell := shell.GetPersistentShell(b.workingDir) + stdout, stderr, err := persistentShell.Exec(ctx, params.Command) + + // Get the current working directory after command execution + currentWorkingDir := persistentShell.GetWorkingDir() interrupted := shell.IsInterrupt(err) exitCode := shell.ExitCode(err) if exitCode == 0 && !interrupted && err != nil { @@ -401,15 +406,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) - slog.Info("Bash command executed", - "command", params.Command, - "stdout", stdout, - "stderr", stderr, - "exit_code", exitCode, - "interrupted", interrupted, - "err", err, - ) - errorMessage := stderr if errorMessage == "" && err != nil { errorMessage = err.Error() @@ -438,9 +434,12 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - StartTime: startTime.UnixMilli(), - EndTime: time.Now().UnixMilli(), + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), + Output: stdout, + WorkingDirectory: currentWorkingDir, } + stdout += fmt.Sprintf("\n\n%s", currentWorkingDir) if stdout == "" { return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil } diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index c70c76b7d2dbd798118a54859e5672dacc6e1304..5af6e055574e3f85b68d8616b44d361790c0a3fb 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -146,7 +146,7 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { if err == nil { return matches, len(matches) >= limit && limit > 0, nil } - slog.Warn(fmt.Sprintf("Ripgrep execution failed: %v. Falling back to doublestar.", err)) + slog.Warn("Ripgrep execution failed, falling back to doublestar", "error", err) } return fsext.GlobWithDoubleStar(pattern, searchPath, limit) diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 976bbb291b08b84af578b7e05e2d568cd2ad5d04..080870937bee98be852979748dab456fa6a53b66 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -90,7 +90,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc slog.Debug("BaseURI", "baseURI", u) } default: - slog.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v)) + slog.Debug("GlobPattern unknown type", "type", fmt.Sprintf("%T", v)) } // Log WatchKind diff --git a/internal/message/content.go b/internal/message/content.go index bdaf1577e34a4667bdb5c8cd2683865ec5cd08ac..b3f212187c86fb57667d95943fd15b8c6e3cccdb 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" ) type MessageRole string @@ -74,9 +74,9 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(p provider.InferenceProvider) string { +func (bc BinaryContent) String(p catwalk.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) - if p == provider.InferenceProviderOpenAI { + if p == catwalk.InferenceProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded } return base64Encoded diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 72dca2db9ccdb5b09ee4ff4794bbe5b51e893b40..cd149a49890b54086bd52e562eed0d44f00c407e 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -50,6 +50,7 @@ type permissionService struct { autoApproveSessions []string autoApproveSessionsMu sync.RWMutex skip bool + allowedTools []string } func (s *permissionService) GrantPersistent(permission PermissionRequest) { @@ -82,6 +83,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { return true } + // Check if the tool/action combination is in the allowlist + commandKey := opts.ToolName + ":" + opts.Action + if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) { + return true + } + s.autoApproveSessionsMu.RLock() autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID) s.autoApproveSessionsMu.RUnlock() @@ -130,11 +137,12 @@ func (s *permissionService) AutoApproveSession(sessionID string) { s.autoApproveSessionsMu.Unlock() } -func NewPermissionService(workingDir string, skip bool) Service { +func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service { return &permissionService{ Broker: pubsub.NewBroker[PermissionRequest](), workingDir: workingDir, sessionPermissions: make([]PermissionRequest, 0), skip: skip, + allowedTools: allowedTools, } } diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d10fbd240da6a171e345938cb3382a7f7fcf19b --- /dev/null +++ b/internal/permission/permission_test.go @@ -0,0 +1,92 @@ +package permission + +import ( + "testing" +) + +func TestPermissionService_AllowedCommands(t *testing.T) { + tests := []struct { + name string + allowedTools []string + toolName string + action string + expected bool + }{ + { + name: "tool in allowlist", + allowedTools: []string{"bash", "view"}, + toolName: "bash", + action: "execute", + expected: true, + }, + { + name: "tool:action in allowlist", + allowedTools: []string{"bash:execute", "edit:create"}, + toolName: "bash", + action: "execute", + expected: true, + }, + { + name: "tool not in allowlist", + allowedTools: []string{"view", "ls"}, + toolName: "bash", + action: "execute", + expected: false, + }, + { + name: "tool:action not in allowlist", + allowedTools: []string{"bash:read", "edit:create"}, + toolName: "bash", + action: "execute", + expected: false, + }, + { + name: "empty allowlist", + allowedTools: []string{}, + toolName: "bash", + action: "execute", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := NewPermissionService("/tmp", false, tt.allowedTools) + + // Create a channel to capture the permission request + // Since we're testing the allowlist logic, we need to simulate the request + ps := service.(*permissionService) + + // Test the allowlist logic directly + commandKey := tt.toolName + ":" + tt.action + allowed := false + for _, cmd := range ps.allowedTools { + if cmd == commandKey || cmd == tt.toolName { + allowed = true + break + } + } + + if allowed != tt.expected { + t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v", + tt.expected, allowed, tt.toolName, tt.action, tt.allowedTools) + } + }) + } +} + +func TestPermissionService_SkipMode(t *testing.T) { + service := NewPermissionService("/tmp", true, []string{}) + + result := service.Request(CreatePermissionRequest{ + SessionID: "test-session", + ToolName: "bash", + Action: "execute", + Description: "test command", + Path: "/tmp", + }) + + if !result { + t.Error("expected permission to be granted in skip mode") + } +} diff --git a/internal/shell/coreutils.go b/internal/shell/coreutils.go new file mode 100644 index 0000000000000000000000000000000000000000..5669d578987ba5a8792430c96e6fc869d8b5cf55 --- /dev/null +++ b/internal/shell/coreutils.go @@ -0,0 +1,59 @@ +package shell + +import ( + "context" + + "github.com/u-root/u-root/pkg/core" + "github.com/u-root/u-root/pkg/core/cat" + "github.com/u-root/u-root/pkg/core/chmod" + "github.com/u-root/u-root/pkg/core/cp" + "github.com/u-root/u-root/pkg/core/find" + "github.com/u-root/u-root/pkg/core/ls" + "github.com/u-root/u-root/pkg/core/mkdir" + "github.com/u-root/u-root/pkg/core/mv" + "github.com/u-root/u-root/pkg/core/rm" + "github.com/u-root/u-root/pkg/core/touch" + "github.com/u-root/u-root/pkg/core/xargs" + "mvdan.cc/sh/v3/interp" +) + +var coreUtils = map[string]func() core.Command{ + "cat": func() core.Command { return cat.New() }, + "chmod": func() core.Command { return chmod.New() }, + "cp": func() core.Command { return cp.New() }, + "find": func() core.Command { return find.New() }, + "ls": func() core.Command { return ls.New() }, + "mkdir": func() core.Command { return mkdir.New() }, + "mv": func() core.Command { return mv.New() }, + "rm": func() core.Command { return rm.New() }, + "touch": func() core.Command { return touch.New() }, + "xargs": func() core.Command { return xargs.New() }, +} + +func (s *Shell) coreUtilsHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + program, programArgs := args[0], args[1:] + + newCoreUtil, ok := coreUtils[program] + if !ok { + return next(ctx, args) + } + + c := interp.HandlerCtx(ctx) + + cmd := newCoreUtil() + cmd.SetIO(c.Stdin, c.Stdout, c.Stderr) + cmd.SetWorkingDir(c.Dir) + cmd.SetLookupEnv(func(key string) (string, bool) { + v := c.Env.Get(key) + return v.Str, v.Set + }) + return cmd.RunContext(ctx, programArgs...) + } + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index e6e6d47e644e4569c4dc04e927a66817a9fc1a28..d76f9bdcb355cc9314e570761147ab0bce1fd219 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -221,7 +221,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), - interp.ExecHandlers(s.blockHandler()), + interp.ExecHandlers(s.blockHandler(), s.coreUtilsHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) diff --git a/internal/tui/components/anim/anim.go b/internal/tui/components/anim/anim.go index 07d02483d0b470b6b4cadf36fbe5acd52e8857ba..241522c8989c89bf8eb877c69b9a72f01508c5f4 100644 --- a/internal/tui/components/anim/anim.go +++ b/internal/tui/components/anim/anim.go @@ -289,7 +289,7 @@ func (a Anim) View() string { var b strings.Builder for i := range a.width { switch { - case !a.initialized && time.Since(a.startTime) < a.birthOffsets[i]: + case !a.initialized && i < len(a.birthOffsets) && time.Since(a.startTime) < a.birthOffsets[i]: // Birth offset not reached: render initial character. b.WriteString(a.initialFrames[a.step][i]) case i < a.cyclingCharWidth: diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 242075d98e99da0117430a26df34357f58c18d10..55a5e7525a430039b314cd810cb94856185cf5af 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -187,9 +187,11 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { value = value[:m.completionsStartIndex] value += item.Path m.textarea.SetValue(value) - m.isCompletionsOpen = false - m.currentQuery = "" - m.completionsStartIndex = 0 + if !msg.Insert { + m.isCompletionsOpen = false + m.currentQuery = "" + m.completionsStartIndex = 0 + } return m, nil } case openEditorMsg: diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index 2ffa1601f84fcb9028faf67bd94d70920a193864..d5aca88108cad83115cad5bd046c72e146935f78 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -8,11 +8,11 @@ import ( "github.com/charmbracelet/bubbles/v2/viewport" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -369,11 +369,11 @@ func (m *assistantSectionModel) View() string { model := config.Get().GetModel(m.message.Provider, m.message.Model) if model == nil { // This means the model is not configured anymore - model = &provider.Model{ - Model: "Unknown Model", + model = &catwalk.Model{ + Name: "Unknown Model", } } - modelFormatted := t.S().Muted.Render(model.Model) + modelFormatted := t.S().Muted.Render(model.Name) assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg) return t.S().Base.PaddingLeft(2).Render( core.Section(assistant, m.width-2), diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index 4483b4ed2a9d2bf0b87a6eeae4565049edfa303e..ace42420a26a47854313029e48ca4b3f495525c4 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/ansiext" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/llm/tools" @@ -212,10 +213,19 @@ func (br bashRenderer) Render(v *toolCallCmp) string { args := newParamBuilder().addMain(cmd).build() return br.renderWithParams(v, "Bash", args, func() string { - if v.result.Content == tools.BashNoOutput { + var meta tools.BashResponseMetadata + if err := br.unmarshalParams(v.result.Metadata, &meta); err != nil { + return renderPlainContent(v, v.result.Content) + } + // for backwards compatibility with older tool calls. + if meta.Output == "" && v.result.Content != tools.BashNoOutput { + meta.Output = v.result.Content + } + + if meta.Output == "" { return "" } - return renderPlainContent(v, v.result.Content) + return renderPlainContent(v, meta.Output) }) } @@ -693,7 +703,7 @@ func renderPlainContent(v *toolCallCmp, content string) string { if i >= responseContextHeight { break } - ln = escapeContent(ln) + ln = ansiext.Escape(ln) ln = " " + ln // left padding if len(ln) > width { ln = v.fit(ln, width) @@ -731,7 +741,7 @@ func renderCodeContent(v *toolCallCmp, path, content string, offset int) string lines := strings.Split(truncated, "\n") for i, ln := range lines { - lines[i] = escapeContent(ln) + lines[i] = ansiext.Escape(ln) } highlighted, _ := highlight.SyntaxHighlight(strings.Join(lines, "\n"), path, t.BgBase) @@ -807,20 +817,3 @@ func prettifyToolName(name string) string { return name } } - -// escapeContent replaces control characters with their Unicode Control Picture -// representations to ensure they are displayed correctly in the UI. -func escapeContent(content string) string { - var sb strings.Builder - for _, r := range content { - switch { - case r >= 0 && r <= 0x1f: // Control characters 0x00-0x1F - sb.WriteRune('\u2400' + r) - case r == ansi.DEL: - sb.WriteRune('\u2421') - default: - sb.WriteRune(r) - } - } - return sb.String() -} diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 3d9e572b5192354bd97fd6274c482057646ad41c..1aa239bdc15cec6898a4cba1e4dc7a867b5e4ce0 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -9,10 +9,10 @@ import ( "sync" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -897,7 +897,7 @@ func (s *sidebarCmp) currentModelBlock() string { t := styles.CurrentTheme() modelIcon := t.S().Base.Foreground(t.FgSubtle).Render(styles.ModelIcon) - modelName := t.S().Text.Render(model.Model) + modelName := t.S().Text.Render(model.Name) modelInfo := fmt.Sprintf("%s %s", modelIcon, modelName) parts := []string{ modelInfo, @@ -905,14 +905,14 @@ func (s *sidebarCmp) currentModelBlock() string { if model.CanReason { reasoningInfoStyle := t.S().Subtle.PaddingLeft(2) switch modelProvider.Type { - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: reasoningEffort := model.DefaultReasoningEffort if selectedModel.ReasoningEffort != "" { reasoningEffort = selectedModel.ReasoningEffort } formatter := cases.Title(language.English, cases.NoLower) parts = append(parts, reasoningInfoStyle.Render(formatter.String(fmt.Sprintf("Reasoning %s", reasoningEffort)))) - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: formatter := cases.Title(language.English, cases.NoLower) if selectedModel.Think { parts = append(parts, reasoningInfoStyle.Render(formatter.String("Thinking on"))) diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 2a2c47a171a7ac685d644005e61e507a3964389f..b7291b3b59ae2bec879739e384e495776bb84f23 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -10,8 +10,8 @@ import ( "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/completions" @@ -109,7 +109,7 @@ func (s *splashCmp) SetOnboarding(onboarding bool) { if err != nil { return } - filteredProviders := []provider.Provider{} + filteredProviders := []catwalk.Provider{} simpleProviders := []string{ "anthropic", "openai", @@ -407,7 +407,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { return nil } -func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) { +func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { providers, err := config.Providers() if err != nil { return nil, err diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index bd30dc394d47ad80421e8c78d3a0f84730518a9c..6c63afd22e982e5ba40f5d175fc71449bcd0879e 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -36,7 +36,8 @@ type CompletionsOpenedMsg struct{} type CloseCompletionsMsg struct{} type SelectCompletionMsg struct { - Value any // The value of the selected completion item + Value any // The value of the selected completion item + Insert bool } type Completions interface { @@ -115,6 +116,30 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { d, cmd := c.list.Update(msg) c.list = d.(list.ListModel) return c, cmd + case key.Matches(msg, c.keyMap.UpInsert): + selectedItemInx := c.list.SelectedIndex() - 1 + items := c.list.Items() + if selectedItemInx == list.NoSelection || selectedItemInx < 0 { + return c, nil // No item selected, do nothing + } + selectedItem := items[selectedItemInx].(CompletionItem).Value() + c.list.SetSelected(selectedItemInx) + return c, util.CmdHandler(SelectCompletionMsg{ + Value: selectedItem, + Insert: true, + }) + case key.Matches(msg, c.keyMap.DownInsert): + selectedItemInx := c.list.SelectedIndex() + 1 + items := c.list.Items() + if selectedItemInx == list.NoSelection || selectedItemInx >= len(items) { + return c, nil // No item selected, do nothing + } + selectedItem := items[selectedItemInx].(CompletionItem).Value() + c.list.SetSelected(selectedItemInx) + return c, util.CmdHandler(SelectCompletionMsg{ + Value: selectedItem, + Insert: true, + }) case key.Matches(msg, c.keyMap.Select): selectedItemInx := c.list.SelectedIndex() if selectedItemInx == list.NoSelection { diff --git a/internal/tui/components/completions/keys.go b/internal/tui/components/completions/keys.go index 530b429fe32ffd89d73c6cec1723c27de1ddd459..82372358028aec2b1384f1b4b6bff90be4a05eb8 100644 --- a/internal/tui/components/completions/keys.go +++ b/internal/tui/components/completions/keys.go @@ -9,6 +9,8 @@ type KeyMap struct { Up, Select, Cancel key.Binding + DownInsert, + UpInsert key.Binding } func DefaultKeyMap() KeyMap { @@ -29,6 +31,14 @@ func DefaultKeyMap() KeyMap { key.WithKeys("esc"), key.WithHelp("esc", "cancel"), ), + DownInsert: key.NewBinding( + key.WithKeys("ctrl+n"), + key.WithHelp("ctrl+n", "insert next"), + ), + UpInsert: key.NewBinding( + key.WithKeys("ctrl+p"), + key.WithHelp("ctrl+p", "insert previous"), + ), } } diff --git a/internal/tui/components/core/status/status.go b/internal/tui/components/core/status/status.go index 59d873a94b0fc6713951b82caebd75a3a79a9623..4cbe8727f41f2a8c0f866b635573036735434781 100644 --- a/internal/tui/components/core/status/status.go +++ b/internal/tui/components/core/status/status.go @@ -1,14 +1,12 @@ package status import ( - "strings" "time" "github.com/charmbracelet/bubbles/v2/help" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" - "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/ansi" ) @@ -74,18 +72,15 @@ func (m *statusCmp) infoMsg() string { switch m.info.Type { case util.InfoTypeError: infoType = t.S().Base.Background(t.Red).Padding(0, 1).Render("ERROR") - width := m.width - lipgloss.Width(infoType) - message = t.S().Base.Background(t.Error).Foreground(t.White).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…")) + message = t.S().Base.Background(t.Error).Width(m.width).Foreground(t.White).Padding(0, 1).Render(m.info.Msg) case util.InfoTypeWarn: infoType = t.S().Base.Foreground(t.BgOverlay).Background(t.Yellow).Padding(0, 1).Render("WARNING") - width := m.width - lipgloss.Width(infoType) - message = t.S().Base.Foreground(t.BgOverlay).Background(t.Warning).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…")) + message = t.S().Base.Foreground(t.BgOverlay).Width(m.width).Background(t.Warning).Padding(0, 1).Render(m.info.Msg) default: infoType = t.S().Base.Foreground(t.BgOverlay).Background(t.Green).Padding(0, 1).Render("OKAY!") - width := m.width - lipgloss.Width(infoType) - message = t.S().Base.Background(t.Success).Foreground(t.White).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…")) + message = t.S().Base.Background(t.Success).Width(m.width).Foreground(t.White).Padding(0, 1).Render(m.info.Msg) } - return strings.Join([]string{infoType, message}, "") + return ansi.Truncate(infoType+message, m.width, "…") } func (m *statusCmp) ToggleFullHelp() { diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index a14138ff51ecf8164cf0fc595c758b0247aa3277..c1b96f0bac7d0b665aad77794392b7417d60457a 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -4,10 +4,10 @@ import ( "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/completions" @@ -270,7 +270,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { providerCfg := cfg.GetProviderForModel(agentCfg.Model) model := cfg.GetModelByType(agentCfg.Model) if providerCfg != nil && model != nil && - providerCfg.Type == provider.TypeAnthropic && model.CanReason { + providerCfg.Type == catwalk.TypeAnthropic && model.CanReason { selectedModel := cfg.Models[agentCfg.Model] status := "Enable" if selectedModel.Think { diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 5f558364eec801d77a250c891a80110e0c9a3b86..5a36ab736351f2c92154da997f01ba7360470d8a 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -5,8 +5,8 @@ import ( "slices" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core/list" "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" @@ -18,7 +18,7 @@ import ( type ModelListComponent struct { list list.ListModel modelType int - providers []provider.Provider + providers []catwalk.Provider } func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent { @@ -109,19 +109,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is not in the known providers list - if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) { + if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) { // Convert config provider to provider.Provider format - configProvider := provider.Provider{ + configProvider := catwalk.Provider{ Name: providerConfig.Name, - ID: provider.InferenceProvider(providerID), - Models: make([]provider.Model, len(providerConfig.Models)), + ID: catwalk.InferenceProvider(providerID), + Models: make([]catwalk.Model, len(providerConfig.Models)), } // Convert models for i, model := range providerConfig.Models { - configProvider.Models[i] = provider.Model{ + configProvider.Models[i] = catwalk.Model{ ID: model.ID, - Model: model.Model, + Name: model.Name, CostPer1MIn: model.CostPer1MIn, CostPer1MOut: model.CostPer1MOut, CostPer1MInCached: model.CostPer1MInCached, @@ -144,7 +144,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { section.SetInfo(configured) modelItems = append(modelItems, section) for _, model := range configProvider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{ + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ Provider: configProvider, Model: model, })) @@ -179,7 +179,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } modelItems = append(modelItems, section) for _, model := range provider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{ + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ Provider: provider, Model: model, })) @@ -201,6 +201,6 @@ func (m *ModelListComponent) SetInputPlaceholder(placeholder string) { m.list.SetFilterPlaceholder(placeholder) } -func (m *ModelListComponent) SetProviders(providers []provider.Provider) { +func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) { m.providers = providers } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b53388d16f17bbae8612cc66d1525e3e0e616db5..795e2585760391bcd711491533a156f9b2c810ba 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -8,8 +8,8 @@ import ( "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/list" @@ -48,8 +48,8 @@ type ModelDialog interface { } type ModelOption struct { - Provider provider.Provider - Model provider.Model + Provider catwalk.Provider + Model catwalk.Model } type modelDialogCmp struct { @@ -363,7 +363,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { return false } -func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) { +func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { providers, err := config.Providers() if err != nil { return nil, err diff --git a/internal/tui/exp/diffview/chroma.go b/internal/tui/exp/diffview/chroma.go index e4d6b2dbaa12651b28ace04e2e051c7a64522899..72e286c6cbab0a2080bcb54043083bf253171158 100644 --- a/internal/tui/exp/diffview/chroma.go +++ b/internal/tui/exp/diffview/chroma.go @@ -4,8 +4,10 @@ import ( "fmt" "image/color" "io" + "strings" "github.com/alecthomas/chroma/v2" + "github.com/charmbracelet/crush/internal/ansiext" "github.com/charmbracelet/lipgloss/v2" ) @@ -20,9 +22,12 @@ type chromaFormatter struct { // Format implements the chroma.Formatter interface. func (c chromaFormatter) Format(w io.Writer, style *chroma.Style, it chroma.Iterator) error { for token := it(); token != chroma.EOF; token = it() { + value := strings.TrimRight(token.Value, "\n") + value = ansiext.Escape(value) + entry := style.Get(token.Type) if entry.IsZero() { - if _, err := fmt.Fprint(w, token.Value); err != nil { + if _, err := fmt.Fprint(w, value); err != nil { return err } continue @@ -44,7 +49,7 @@ func (c chromaFormatter) Format(w io.Writer, style *chroma.Style, it chroma.Iter s = s.Foreground(lipgloss.Color(entry.Colour.String())) } - if _, err := fmt.Fprint(w, s.Render(token.Value)); err != nil { + if _, err := fmt.Fprint(w, s.Render(value)); err != nil { return err } } diff --git a/internal/tui/exp/diffview/diffview.go b/internal/tui/exp/diffview/diffview.go index ddac14296984cb31ce7f0b179950b2832280d3d1..eaea2837fcaa7522294143f0385bcbb0879316bd 100644 --- a/internal/tui/exp/diffview/diffview.go +++ b/internal/tui/exp/diffview/diffview.go @@ -193,6 +193,7 @@ func (dv *DiffView) clearSyntaxCache() { // String returns the string representation of the DiffView. func (dv *DiffView) String() string { + dv.normalizeLineEndings() dv.replaceTabs() if err := dv.computeDiff(); err != nil { return err.Error() @@ -227,6 +228,12 @@ func (dv *DiffView) String() string { } } +// normalizeLineEndings ensures the file contents use Unix-style line endings. +func (dv *DiffView) normalizeLineEndings() { + dv.before.content = strings.ReplaceAll(dv.before.content, "\r\n", "\n") + dv.after.content = strings.ReplaceAll(dv.after.content, "\r\n", "\n") +} + // replaceTabs replaces tabs in the before and after file contents with spaces // according to the specified tab width. func (dv *DiffView) replaceTabs() { @@ -396,8 +403,7 @@ func (dv *DiffView) renderUnified() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.ReplaceAll(in, "\r\n", "\n") - content = strings.TrimSuffix(content, "\n") + content = strings.TrimSuffix(in, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") @@ -520,8 +526,7 @@ func (dv *DiffView) renderSplit() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.ReplaceAll(in, "\r\n", "\n") - content = strings.TrimSuffix(content, "\n") + content = strings.TrimSuffix(in, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") diff --git a/internal/tui/exp/diffview/diffview_test.go b/internal/tui/exp/diffview/diffview_test.go index 595e0fa83260fbc981998aad0171fd4a1dcb25b8..c77c9e0e945072862aecdfbe5802e8093c5def83 100644 --- a/internal/tui/exp/diffview/diffview_test.go +++ b/internal/tui/exp/diffview/diffview_test.go @@ -36,6 +36,12 @@ var TestTabsBefore string //go:embed testdata/TestTabs.after var TestTabsAfter string +//go:embed testdata/TestLineBreakIssue.before +var TestLineBreakIssueBefore string + +//go:embed testdata/TestLineBreakIssue.after +var TestLineBreakIssueAfter string + type ( TestFunc func(dv *diffview.DiffView) *diffview.DiffView TestFuncs map[string]TestFunc @@ -177,6 +183,26 @@ func TestDiffViewTabs(t *testing.T) { } } +func TestDiffViewLineBreakIssue(t *testing.T) { + t.Parallel() + + for layoutName, layoutFunc := range LayoutFuncs { + t.Run(layoutName, func(t *testing.T) { + t.Parallel() + + dv := diffview.New(). + Before("index.js", TestLineBreakIssueBefore). + After("index.js", TestLineBreakIssueAfter). + Style(diffview.DefaultLightStyle()). + ChromaStyle(styles.Get("catppuccin-latte")) + dv = layoutFunc(dv) + + output := dv.String() + golden.RequireEqual(t, []byte(output)) + }) + } +} + func TestDiffViewWidth(t *testing.T) { for layoutName, layoutFunc := range LayoutFuncs { t.Run(layoutName, func(t *testing.T) { diff --git a/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Split.golden b/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Split.golden new file mode 100644 index 0000000000000000000000000000000000000000..41ec3a82928b33ab5e27b100d71abded6fd8b305 --- /dev/null +++ b/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Split.golden @@ -0,0 +1,9 @@ +  …  @@ -1,6 +1,8 @@    …    +  1 - // this is   1 + /**  +       2 +  * this is  +  2 - // a regular   3 +  * a block  +  3 - // comment   4 +  * comment  +       5 +  */  +  4  $(function() {   6  $(function() {  +  5   console.log("Hello, world!");   7   console.log("Hello, world!");  +  6  });   8  });  \ No newline at end of file diff --git a/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Unified.golden b/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Unified.golden new file mode 100644 index 0000000000000000000000000000000000000000..c3bded237b15f4207dc76d99ec52869124759506 --- /dev/null +++ b/internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Unified.golden @@ -0,0 +1,12 @@ +  …   …  @@ -1,6 +1,8 @@   +  1    - // this is  +     1 + /**  +     2 +  * this is  +  2    - // a regular  +     3 +  * a block  +  3    - // comment  +     4 +  * comment  +     5 +  */  +  4   6  $(function() {  +  5   7   console.log("Hello, world!");  +  6   8  });  \ No newline at end of file diff --git a/internal/tui/exp/diffview/testdata/TestLineBreakIssue.after b/internal/tui/exp/diffview/testdata/TestLineBreakIssue.after new file mode 100644 index 0000000000000000000000000000000000000000..b26198ffbc2f9f7b3817d4aa486b1d3d56c752e0 --- /dev/null +++ b/internal/tui/exp/diffview/testdata/TestLineBreakIssue.after @@ -0,0 +1,8 @@ +/** + * this is + * a block + * comment + */ +$(function() { + console.log("Hello, world!"); +}); diff --git a/internal/tui/exp/diffview/testdata/TestLineBreakIssue.before b/internal/tui/exp/diffview/testdata/TestLineBreakIssue.before new file mode 100644 index 0000000000000000000000000000000000000000..7dea269c551b41906d6eb2b7a83d652250476e47 --- /dev/null +++ b/internal/tui/exp/diffview/testdata/TestLineBreakIssue.before @@ -0,0 +1,6 @@ +// this is +// a regular +// comment +$(function() { + console.log("Hello, world!"); +}); diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 0d28f13f3ca0a42c9ae15612f21678cdeb8f4bf2..9deac1e9e48c1cff576e84746d3976b4b670a700 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -279,7 +279,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { - return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Model) + return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name) } case key.Matches(msg, p.keyMap.Tab): if p.session.ID == "" { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dda0ce2b9a5190953cf2bc288001a74c8c763b09..0e2587666f5a8c58be1466149a6b6f7a9dfb2a59 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -3,6 +3,7 @@ package tui import ( "context" "fmt" + "strings" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" @@ -112,6 +113,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return a, tea.Batch(cmds...) case tea.WindowSizeMsg: + a.wWidth, a.wHeight = msg.Width, msg.Height a.completions.Update(msg) return a, a.handleWindowResize(msg.Width, msg.Height) @@ -290,7 +292,6 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // handleWindowResize processes window resize events and updates all components. func (a *appModel) handleWindowResize(width, height int) tea.Cmd { var cmds []tea.Cmd - a.wWidth, a.wHeight = width, height if a.showingFullHelp { height -= 5 } else { @@ -319,26 +320,20 @@ func (a *appModel) handleWindowResize(width, height int) tea.Cmd { // handleKeyPressMsg processes keyboard input and routes to appropriate handlers. func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { + if a.completions.Open() { + // completions + keyMap := a.completions.KeyMap() + switch { + case key.Matches(msg, keyMap.Up), key.Matches(msg, keyMap.Down), + key.Matches(msg, keyMap.Select), key.Matches(msg, keyMap.Cancel), + key.Matches(msg, keyMap.UpInsert), key.Matches(msg, keyMap.DownInsert): + u, cmd := a.completions.Update(msg) + a.completions = u.(completions.Completions) + return cmd + } + } switch { - // completions - case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Up): - u, cmd := a.completions.Update(msg) - a.completions = u.(completions.Completions) - return cmd - - case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Down): - u, cmd := a.completions.Update(msg) - a.completions = u.(completions.Completions) - return cmd - case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Select): - u, cmd := a.completions.Update(msg) - a.completions = u.(completions.Completions) - return cmd - case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Cancel): - u, cmd := a.completions.Update(msg) - a.completions = u.(completions.Completions) - return cmd - // help + // help case key.Matches(msg, a.keyMap.Help): a.status.ToggleFullHelp() a.showingFullHelp = !a.showingFullHelp @@ -429,6 +424,27 @@ func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { // View renders the complete application interface including pages, dialogs, and overlays. func (a *appModel) View() tea.View { + var view tea.View + t := styles.CurrentTheme() + view.BackgroundColor = t.BgBase + if a.wWidth < 25 || a.wHeight < 15 { + view.Layer = lipgloss.NewCanvas( + lipgloss.NewLayer( + t.S().Base.Width(a.wWidth).Height(a.wHeight). + Align(lipgloss.Center, lipgloss.Center). + Render( + t.S().Base. + Padding(1, 4). + Foreground(t.White). + BorderStyle(lipgloss.RoundedBorder()). + BorderForeground(t.Primary). + Render("Window too small!"), + ), + ), + ) + return view + } + page := a.pages[a.currentPage] if withHelp, ok := page.(core.KeyMapHelp); ok { a.status.SetKeyMap(withHelp.Help()) @@ -453,6 +469,11 @@ func (a *appModel) View() tea.View { var cursor *tea.Cursor if v, ok := page.(util.Cursor); ok { cursor = v.Cursor() + // Hide the cursor if it's positioned outside the textarea + statusHeight := a.height - strings.Count(pageView, "\n") + 1 + if cursor != nil && cursor.Y+statusHeight+chat.EditorHeight-2 <= a.height { // 2 for the top and bottom app padding + cursor = nil + } } activeView := a.dialog.ActiveModel() if activeView != nil { @@ -475,10 +496,7 @@ func (a *appModel) View() tea.View { layers..., ) - var view tea.View - t := styles.CurrentTheme() view.Layer = canvas - view.BackgroundColor = t.BgBase view.Cursor = cursor return view } diff --git a/main.go b/main.go index ba1fdcd5e443e354d8ee288648055ccced01a653..072e3b35d2a2f408d8ed6a09423712b324df8b96 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,9 @@ package main import ( - "fmt" "log/slog" "net/http" "os" - "runtime" - "strings" _ "net/http/pprof" // profiling @@ -14,14 +11,9 @@ import ( "github.com/charmbracelet/crush/internal/cmd" "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/lipgloss/v2" ) func main() { - if runtime.GOOS == "windows" { - showWindowsWarning() - } - defer log.RecoverPanic("main", func() { slog.Error("Application terminated due to unhandled panic") }) @@ -30,22 +22,10 @@ func main() { go func() { slog.Info("Serving pprof at localhost:6060") if httpErr := http.ListenAndServe("localhost:6060", nil); httpErr != nil { - slog.Error(fmt.Sprintf("Failed to pprof listen: %v", httpErr)) + slog.Error("Failed to pprof listen", "error", httpErr) } }() } cmd.Execute() } - -func showWindowsWarning() { - content := strings.Join([]string{ - lipgloss.NewStyle().Bold(true).Render("WARNING:") + " Crush is experimental on Windows!", - "While we work on it, we recommend WSL2 for a better experience.", - lipgloss.NewStyle().Italic(true).Render("Press Enter to continue..."), - }, "\n") - fmt.Print(content) - - var input string - fmt.Scanln(&input) -}