From 68ad2981a0017a510f1df6ba89c7492ecb1b4338 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Thu, 18 Dec 2025 16:19:20 -0300 Subject: [PATCH] fix(copilot): change import to happen on demand, and also on onboarding --- internal/config/config.go | 5 ++++- internal/config/copilot.go | 7 ++++++- internal/config/load.go | 7 ------- internal/tui/components/chat/splash/splash.go | 4 ++++ internal/tui/components/dialogs/models/models.go | 5 +++++ 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index a91350b5bc894161bc5fdbd44720c27b46fc1063..e5878a5d0c999c612e3b5e2c5dd61ec4c4dc324d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "slices" "strings" "time" @@ -498,7 +499,6 @@ func (c *Config) HasConfigField(key string) bool { } func (c *Config) SetConfigField(key string, value any) error { - // read the data data, err := os.ReadFile(c.dataConfigDir) if err != nil { if os.IsNotExist(err) { @@ -512,6 +512,9 @@ func (c *Config) SetConfigField(key string, value any) error { if err != nil { return fmt.Errorf("failed to set config field %s: %w", key, err) } + if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) + } if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } diff --git a/internal/config/copilot.go b/internal/config/copilot.go index f9ebc2f4fbddf602c67ae6fc81f5e6ca02d57b27..ee50bec43d6ce5754799adf4bfe99ba9b357d690 100644 --- a/internal/config/copilot.go +++ b/internal/config/copilot.go @@ -6,11 +6,12 @@ import ( "log/slog" "testing" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" ) -func (c *Config) importCopilot() (*oauth.Token, bool) { +func (c *Config) ImportCopilot() (*oauth.Token, bool) { if testing.Testing() { return nil, false } @@ -31,6 +32,10 @@ func (c *Config) importCopilot() (*oauth.Token, bool) { return nil, false } + if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { + return token, false + } + if err := cmp.Or( c.SetConfigField("providers.copilot.api_key", token.AccessToken), c.SetConfigField("providers.copilot.oauth", token), diff --git a/internal/config/load.go b/internal/config/load.go index 0d16702dcdd35eb7d431ddfe4a0b35ab48e4debc..27866e5891afc8774cb5d5ac2c8fd4f979161e2f 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -133,8 +133,6 @@ func PushPopCrushEnv() func() { } func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { - c.importCopilot() - knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -206,11 +204,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: prepared.SetupClaudeCode() case p.ID == catwalk.InferenceProviderCopilot: - if config.OAuthToken != nil { - if token, ok := c.importCopilot(); ok { - prepared.OAuthToken = token - } - } if config.OAuthToken != nil { prepared.SetupGitHubCopilot() } diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 32f30b5df7a32ff27dc011331ec4857ce36cddc8..ba512b5b3911f966b8760e6e3ccbce1ec20e92ca 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -346,6 +346,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { s.hyperDeviceFlow.SetWidth(min(s.width-2, 60)) return s, s.hyperDeviceFlow.Init() case catwalk.InferenceProviderCopilot: + if token, ok := config.Get().ImportCopilot(); ok { + s.selectedModel = selectedItem + return s, s.saveAPIKeyAndContinue(token, true) + } s.selectedModel = selectedItem s.showCopilotDeviceFlow = true s.copilotDeviceFlow = copilot.NewDeviceFlow() diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index ff7bd4f0409693454222b45567c23a151f8de077..06da780edd48cf689113575d39e2ed5805fa27e2 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -307,6 +307,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.hyperDeviceFlow.SetWidth(m.width - 2) return m, m.hyperDeviceFlow.Init() case catwalk.InferenceProviderCopilot: + if token, ok := config.Get().ImportCopilot(); ok { + m.selectedModel = selectedItem + m.selectedModelType = modelType + return m, m.saveOauthTokenAndContinue(token, true) + } m.showCopilotDeviceFlow = true m.selectedModel = selectedItem m.selectedModelType = modelType