From 3c7ee39258d0b47f7602160096103baad73bf2b4 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Wed, 10 Dec 2025 12:08:14 +0100 Subject: [PATCH 1/5] fix(noninteractive): cancel on signal (#1584) --- internal/cmd/run.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index d3748d0f327d0db0913384100a5f885eb19a8c2e..3cebabb1c78a356c55636ec11f1645db072c6e54 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -1,9 +1,11 @@ package cmd import ( + "context" "fmt" "log/slog" "os" + "os/signal" "strings" "github.com/spf13/cobra" @@ -30,6 +32,10 @@ crush run --quiet "Generate a README for this project" RunE: func(cmd *cobra.Command, args []string) error { quiet, _ := cmd.Flags().GetBool("quiet") + // Cancel on SIGINT or SIGTERM. + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() + app, err := setupApp(cmd) if err != nil { return err @@ -57,8 +63,7 @@ crush run --quiet "Generate a README for this project" // crush run "Do something fancy" > output.txt // echo "Do something fancy" | crush run > output.txt // - // TODO: We currently need to press ^c twice to cancel. Fix that. - return app.RunNonInteractive(cmd.Context(), os.Stdout, prompt, quiet) + return app.RunNonInteractive(ctx, os.Stdout, prompt, quiet) }, } From 37c8e3fbae77159f4588f4552a1f8db6465647c6 Mon Sep 17 00:00:00 2001 From: Charm <124303983+charmcli@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:04:59 -0300 Subject: [PATCH 2/5] chore(legal): @mengwong has signed the CLA --- .github/cla-signatures.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json index ce031c351d2ce89248df538665ba5f1880f05c65..5679ea22271cf769f15fe3f4871fb27858010250 100644 --- a/.github/cla-signatures.json +++ b/.github/cla-signatures.json @@ -935,6 +935,14 @@ "created_at": "2025-12-06T18:13:11Z", "repoId": 987670088, "pullRequestNo": 1560 + }, + { + "name": "mengwong", + "id": 1480631, + "comment_id": 3636765109, + "created_at": "2025-12-10T12:04:50Z", + "repoId": 987670088, + "pullRequestNo": 1592 } ] } \ No newline at end of file From 409773299ad33b436e10f93d28c27e0a3c25d612 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Wed, 10 Dec 2025 09:10:40 -0300 Subject: [PATCH 3/5] perf: improve startup, specifically provider updates (#1577) Signed-off-by: Carlos Alexandro Becker --- go.mod | 4 +- go.sum | 4 +- internal/cmd/update_providers.go | 6 +- internal/config/provider.go | 109 ++++++++++++------------ internal/config/provider_empty_test.go | 30 ++----- internal/config/provider_test.go | 111 ++++++++++++++++++++----- 6 files changed, 161 insertions(+), 103 deletions(-) diff --git a/go.mod b/go.mod index cbfaa3077007fce91f5c7602c77c7f2353a4682a..b419bebcbb13d92ead60b4e7b444cfd53845715a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/charmbracelet/crush -go 1.25.0 +go 1.25.5 require ( charm.land/bubbles/v2 v2.0.0-rc.1 @@ -16,7 +16,7 @@ require ( github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.1 github.com/charlievieth/fastwalk v1.0.14 - github.com/charmbracelet/catwalk v0.9.5 + github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74 github.com/charmbracelet/colorprofile v0.3.3 github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/glamour/v2 v2.0.0-20251106195642-800eb8175930 diff --git a/go.sum b/go.sum index 238b46788328c5ba0ada41bc882fd0d40949f09e..c406f84cbc6d629b765101bfaa27ba71f5ed7c30 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4= -github.com/charmbracelet/catwalk v0.9.5 h1:QLqajLJfjGTVh2MIVIdAhww2XvPklxmu+p0Z4wrT7KU= -github.com/charmbracelet/catwalk v0.9.5/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY= +github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74 h1:pJI8ivIgSWOeCNnZFXeZzL6px1vZVS3LU4Cqk3Gx37I= +github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY= github.com/charmbracelet/colorprofile v0.3.3 h1:DjJzJtLP6/NZ8p7Cgjno0CKGr7wwRJGxWUwh2IyhfAI= github.com/charmbracelet/colorprofile v0.3.3/go.mod h1:nB1FugsAbzq284eJcjfah2nhdSLppN2NqvfotkfRYP4= github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY= diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 4949c31e2e8b87f212d8ac0ed94e2416f363a53b..599d2c90954ca43888197961ec3bea4372285071 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -31,12 +31,12 @@ crush update-providers embedded // NOTE(@andreynering): We want to skip logging output do stdout here. slog.SetDefault(slog.New(slog.DiscardHandler)) - var pathOrUrl string + var pathOrURL string if len(args) > 0 { - pathOrUrl = args[0] + pathOrURL = args[0] } - if err := config.UpdateProviders(pathOrUrl); err != nil { + if err := config.UpdateProviders(pathOrURL); err != nil { return err } diff --git a/internal/config/provider.go b/internal/config/provider.go index 7a98caaaa92a1c0401adccd79c3e945f813fbfc7..e9d4dfcc9d0947eebe51d5ae303106404ab47292 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -2,7 +2,9 @@ package config import ( "cmp" + "context" "encoding/json" + "errors" "fmt" "log/slog" "os" @@ -17,7 +19,7 @@ import ( ) type ProviderClient interface { - GetProviders() ([]catwalk.Provider, error) + GetProviders(context.Context, string) ([]catwalk.Provider, error) } var ( @@ -53,7 +55,7 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error { return fmt.Errorf("failed to create directory for provider cache: %w", err) } - data, err := json.MarshalIndent(providers, "", " ") + data, err := json.Marshal(providers) if err != nil { return fmt.Errorf("failed to marshal provider data: %w", err) } @@ -64,34 +66,35 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error { return nil } -func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { +func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) { data, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("failed to read provider cache file: %w", err) + return nil, "", fmt.Errorf("failed to read provider cache file: %w", err) } var providers []catwalk.Provider if err := json.Unmarshal(data, &providers); err != nil { - return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) + return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } - return providers, nil + + return providers, catwalk.Etag(data), nil } -func UpdateProviders(pathOrUrl string) error { +func UpdateProviders(pathOrURL string) error { var providers []catwalk.Provider - pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL) + pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL) switch { - case pathOrUrl == "embedded": + case pathOrURL == "embedded": providers = embedded.GetAll() - case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"): + case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"): var err error - providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders() + providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "") if err != nil { return fmt.Errorf("failed to fetch providers from Catwalk: %w", err) } default: - content, err := os.ReadFile(pathOrUrl) + content, err := os.ReadFile(pathOrURL) if err != nil { return fmt.Errorf("failed to read file: %w", err) } @@ -108,61 +111,61 @@ func UpdateProviders(pathOrUrl string) error { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath) return nil } +// Providers returns the list of providers, taking into account cached results +// and whether or not auto update is enabled. +// +// It will: +// 1. if auto update is disabled, it'll return the embedded providers at the +// time of release. +// 2. load the cached providers +// 3. try to get the fresh list of providers, and return either this new list, +// the cached list, or the embedded list if all others fail. func Providers(cfg *Config) ([]catwalk.Provider, error) { providerOnce.Do(func() { catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) client := catwalk.NewWithURL(catwalkURL) path := providerCacheFileData() - autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate - providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) - }) - return providerList, providerErr -} - -func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { - catwalkGetAndSave := func() ([]catwalk.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + if cfg.Options.DisableProviderAutoUpdate { + slog.Info("Using embedded Catwalk providers") + providerList, providerErr = embedded.GetAll(), nil + return } - if len(providers) == 0 { - return nil, fmt.Errorf("empty providers list from catwalk") - } - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err - } - return providers, nil - } - - switch { - case autoUpdateDisabled: - slog.Warn("Providers auto-update is disabled") - if _, err := os.Stat(path); err == nil { - slog.Warn("Using locally cached providers") - return loadProvidersFromCache(path) + cached, etag, cachedErr := loadProvidersFromCache(path) + if len(cached) == 0 || cachedErr != nil { + // if cached file is empty, default to embedded providers + cached = embedded.GetAll() } - slog.Warn("Saving embedded providers to cache") - providers := embedded.GetAll() - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err + providerList, providerErr = loadProviders(client, etag, path) + if errors.Is(providerErr, catwalk.ErrNotModified) { + slog.Info("Catwalk providers not modified") + providerList, providerErr = cached, nil } - return providers, nil - - default: - slog.Info("Fetching providers from Catwalk.", "path", path) + }) + if providerErr != nil { + catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) + return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck + } + return providerList, nil +} - providers, err := catwalkGetAndSave() - if err != nil { - catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) - return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck - } - return providers, nil +func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) { + slog.Info("Fetching providers from Catwalk.", "path", path) + providers, err := client.GetProviders(context.Background(), etag) + if err != nil { + return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + } + if len(providers) == 0 { + return nil, errors.New("empty providers list from catwalk") } + if err := saveProvidersInCache(path, providers); err != nil { + return nil, err + } + return providers, nil } diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index f3691c320ad4e3509b327374c8ce7f5285c39590..5e889f82a4997d9f5761e6fc4a01ec6d54a0623a 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -1,8 +1,7 @@ package config import ( - "encoding/json" - "os" + "context" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -11,37 +10,22 @@ import ( type emptyProviderClient struct{} -func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) { +func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { return []catwalk.Provider{}, nil } +// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an +// error when the client returns an empty list. This ensures we don't cache +// empty provider lists. func TestProvider_loadProvidersEmptyResult(t *testing.T) { client := &emptyProviderClient{} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) - require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers") + providers, err := loadProviders(client, "", tmpPath) + require.Contains(t, err.Error(), "empty providers list from catwalk") require.Empty(t, providers) require.Len(t, providers, 0) // Check that no cache file was created for empty results require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") } - -func TestProvider_loadProvidersEmptyCache(t *testing.T) { - client := &mockProviderClient{shouldFail: false} - tmpPath := t.TempDir() + "/providers.json" - - // Create an empty cache file - emptyProviders := []catwalk.Provider{} - data, err := json.Marshal(emptyProviders) - require.NoError(t, err) - require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) - - // Should refresh and get real providers instead of using empty cache - providers, err := loadProviders(false, client, tmpPath) - require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) - require.Equal(t, "Mock", providers[0].Name) -} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 1262b60ef42050b9061c9f7c6be4dc431efe3548..f101dd8de8ef624ed041ff88115c6c286f902659 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,9 +1,11 @@ package config import ( + "context" "encoding/json" "errors" "os" + "sync" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -11,10 +13,14 @@ import ( ) type mockProviderClient struct { - shouldFail bool + shouldFail bool + shouldReturnErr error } -func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { +func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { + if m.shouldReturnErr != nil { + return nil, m.shouldReturnErr + } if m.shouldFail { return nil, errors.New("failed to load providers") } @@ -25,10 +31,16 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { }, nil } +func resetProviderState() { + providerOnce = sync.Once{} + providerList = nil + providerErr = nil +} + func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + providers, err := loadProviders(client, "", tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -39,35 +51,94 @@ func TestProvider_loadProvidersNoIssues(t *testing.T) { require.False(t, fileInfo.IsDir(), "Expected a file, not a directory") } -func TestProvider_loadProvidersWithIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - // store providers to a temporary file - oldProviders := []catwalk.Provider{ - { - Name: "OldProvider", +func TestProvider_DisableAutoUpdate(t *testing.T) { + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + resetProviderState() + defer resetProviderState() + + cfg := &Config{ + Options: &Options{ + DisableProviderAutoUpdate: true, }, } - data, err := json.Marshal(oldProviders) - if err != nil { - t.Fatalf("Failed to marshal old providers: %v", err) - } - err = os.WriteFile(tmpPath, data, 0o644) - if err != nil { - t.Fatalf("Failed to write old providers to file: %v", err) + providers, err := Providers(cfg) + require.NoError(t, err) + require.NotNil(t, providers) + require.Greater(t, len(providers), 5, "Expected embedded providers") +} + +func TestProvider_WithValidCache(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + cachedProviders := []catwalk.Provider{ + {Name: "Cached"}, } - providers, err := loadProviders(true, client, tmpPath) + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + mockClient := &mockProviderClient{shouldFail: false} + + providers, err := loadProviders(mockClient, "", cachePath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) - require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") + require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch") +} + +func TestProvider_NotModifiedUsesCached(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + cachedProviders := []catwalk.Provider{ + {Name: "Cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified} + providers, err := loadProviders(mockClient, "", cachePath) + require.ErrorIs(t, err, catwalk.ErrNotModified) + require.Nil(t, providers) +} + +func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + emptyProviders := []catwalk.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + cached, _, err := loadProvidersFromCache(cachePath) + require.NoError(t, err) + require.Empty(t, cached, "Expected empty cache") } func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + providers, err := loadProviders(client, "", tmpPath) require.Error(t, err) require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") } From ab55cb6ce2636a7d6fd2dfce1cd6aba7c2e6c23e Mon Sep 17 00:00:00 2001 From: Kieran Klukas Date: Wed, 10 Dec 2025 02:32:08 -1000 Subject: [PATCH 4/5] fix(claude): add authentication refresh on 401 errors (#1581) Co-authored-by: Andrey Nering --- internal/agent/coordinator.go | 91 ++++++++++++++++++++++++++--------- internal/config/config.go | 7 ++- internal/config/load.go | 23 +-------- 3 files changed, 73 insertions(+), 48 deletions(-) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 91463fe4c24be90b743bcdb654f865ce60ecf2af..436aa27d95e4b86f83c20c3f46b2e1434986e89d 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -10,6 +10,7 @@ import ( "io" "log/slog" "maps" + "net/http" "os" "slices" "strings" @@ -130,32 +131,42 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg) - if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() { - slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID) - if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil { - slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr) - return nil, refreshErr - } - - // Rebuild models with refreshed token - if updateErr := c.UpdateModels(ctx); updateErr != nil { - slog.Error("Failed to update models after token refresh", "error", updateErr) - return nil, updateErr + run := func() (*fantasy.AgentResult, error) { + return c.currentAgent.Run(ctx, SessionAgentCall{ + SessionID: sessionID, + Prompt: prompt, + Attachments: attachments, + MaxOutputTokens: maxTokens, + ProviderOptions: mergedOptions, + Temperature: temp, + TopP: topP, + TopK: topK, + FrequencyPenalty: freqPenalty, + PresencePenalty: presPenalty, + }) + } + result, originalErr := run() + + if c.isUnauthorized(originalErr) { + switch { + case providerCfg.OAuthToken != nil: + slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID) + if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil { + return nil, originalErr + } + slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID) + return run() + case strings.Contains(providerCfg.APIKeyTemplate, "$"): + slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID) + if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil { + return nil, originalErr + } + slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID) + return run() } } - result, err := c.currentAgent.Run(ctx, SessionAgentCall{ - SessionID: sessionID, - Prompt: prompt, - Attachments: attachments, - MaxOutputTokens: maxTokens, - ProviderOptions: mergedOptions, - Temperature: temp, - TopP: topP, - TopK: topK, - FrequencyPenalty: freqPenalty, - PresencePenalty: presPenalty, - }) - return result, err + + return result, originalErr } func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions { @@ -773,3 +784,35 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { } return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg)) } + +func (c *coordinator) isUnauthorized(err error) bool { + var providerErr *fantasy.ProviderError + return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized +} + +func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { + if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) + return err + } + if err := c.UpdateModels(ctx); err != nil { + return err + } + return nil +} + +func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error { + newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate) + if err != nil { + slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err) + return err + } + + providerCfg.APIKey = newAPIKey + c.cfg.Providers.Set(providerCfg.ID, providerCfg) + + if err := c.UpdateModels(ctx); err != nil { + return err + } + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 464dc14bc8c6d12cdf1db17c681c4faa68a59339..4c9dc7bafe83ff0b75b0a0238fcd71ba9e63a3bf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -95,6 +95,8 @@ type ProviderConfig struct { Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"` // The provider's API key. APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"` + // The original API key template before resolution (for re-resolution on auth errors). + APIKeyTemplate string `json:"-"` // OAuthToken for providers that use OAuth2 authentication. OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"` // Marks the provider as disabled. @@ -469,6 +471,7 @@ func (c *Config) SetConfigField(key string, value any) error { return nil } +// RefreshOAuthToken refreshes the OAuth token for the given provider. func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { providerConfig, exists := c.Providers.Get(providerID) if !exists { @@ -479,7 +482,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - // Only Anthropic provider uses OAuth for now + // Only Anthropic provider uses OAuth for now. if providerID != string(catwalk.InferenceProviderAnthropic) { return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } @@ -489,7 +492,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err) } - slog.Info("Successfully refreshed OAuth token in background", "provider", providerID) + slog.Info("Successfully refreshed OAuth token", "provider", providerID) providerConfig.OAuthToken = newToken providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken) providerConfig.SetupClaudeCode() diff --git a/internal/config/load.go b/internal/config/load.go index 7645861198eefbceb1e283ee7815d3f130b0b868..8f3ad171d2ae2e196584223e55d24d42b200e073 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -1,7 +1,6 @@ package config import ( - "cmp" "context" "encoding/json" "fmt" @@ -19,11 +18,9 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/oauth/claude" powernapConfig "github.com/charmbracelet/x/powernap/pkg/config" ) @@ -189,6 +186,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Name: p.Name, BaseURL: p.APIEndpoint, APIKey: p.APIKey, + APIKeyTemplate: p.APIKey, // Store original template for re-resolution OAuthToken: config.OAuthToken, Type: p.Type, Disable: config.Disable, @@ -200,25 +198,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil { - if config.OAuthToken.IsExpired() { - newToken, err := claude.RefreshToken(context.TODO(), config.OAuthToken.RefreshToken) - if err == nil { - slog.Info("Successfully refreshed Anthropic OAuth token") - config.OAuthToken = newToken - prepared.OAuthToken = newToken - if err := cmp.Or( - c.SetConfigField("providers.anthropic.api_key", newToken.AccessToken), - c.SetConfigField("providers.anthropic.oauth", newToken), - ); err != nil { - return err - } - } else { - slog.Error("Failed to refresh Anthropic OAuth token", "error", err) - event.Error(err) - } - } else { - slog.Info("Using existing non-expired Anthropic OAuth token") - } prepared.SetupClaudeCode() }