fix(copilot): change import to happen on demand, and also on onboarding

Andrey Nering created

Change summary

internal/config/config.go                        | 5 ++++-
internal/config/copilot.go                       | 7 ++++++-
internal/config/load.go                          | 7 -------
internal/tui/components/chat/splash/splash.go    | 4 ++++
internal/tui/components/dialogs/models/models.go | 5 +++++
5 files changed, 19 insertions(+), 9 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -9,6 +9,7 @@ import (
 	"net/http"
 	"net/url"
 	"os"
+	"path/filepath"
 	"slices"
 	"strings"
 	"time"
@@ -498,7 +499,6 @@ func (c *Config) HasConfigField(key string) bool {
 }
 
 func (c *Config) SetConfigField(key string, value any) error {
-	// read the data
 	data, err := os.ReadFile(c.dataConfigDir)
 	if err != nil {
 		if os.IsNotExist(err) {
@@ -512,6 +512,9 @@ func (c *Config) SetConfigField(key string, value any) error {
 	if err != nil {
 		return fmt.Errorf("failed to set config field %s: %w", key, err)
 	}
+	if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil {
+		return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err)
+	}
 	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil {
 		return fmt.Errorf("failed to write config file: %w", err)
 	}

internal/config/copilot.go 🔗

@@ -6,11 +6,12 @@ import (
 	"log/slog"
 	"testing"
 
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/charmbracelet/crush/internal/oauth/copilot"
 )
 
-func (c *Config) importCopilot() (*oauth.Token, bool) {
+func (c *Config) ImportCopilot() (*oauth.Token, bool) {
 	if testing.Testing() {
 		return nil, false
 	}
@@ -31,6 +32,10 @@ func (c *Config) importCopilot() (*oauth.Token, bool) {
 		return nil, false
 	}
 
+	if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
+		return token, false
+	}
+
 	if err := cmp.Or(
 		c.SetConfigField("providers.copilot.api_key", token.AccessToken),
 		c.SetConfigField("providers.copilot.oauth", token),

internal/config/load.go 🔗

@@ -133,8 +133,6 @@ func PushPopCrushEnv() func() {
 }
 
 func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
-	c.importCopilot()
-
 	knownProviderNames := make(map[string]bool)
 	restore := PushPopCrushEnv()
 	defer restore()
@@ -206,11 +204,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
 		case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
 			prepared.SetupClaudeCode()
 		case p.ID == catwalk.InferenceProviderCopilot:
-			if config.OAuthToken != nil {
-				if token, ok := c.importCopilot(); ok {
-					prepared.OAuthToken = token
-				}
-			}
 			if config.OAuthToken != nil {
 				prepared.SetupGitHubCopilot()
 			}

internal/tui/components/chat/splash/splash.go 🔗

@@ -346,6 +346,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 						s.hyperDeviceFlow.SetWidth(min(s.width-2, 60))
 						return s, s.hyperDeviceFlow.Init()
 					case catwalk.InferenceProviderCopilot:
+						if token, ok := config.Get().ImportCopilot(); ok {
+							s.selectedModel = selectedItem
+							return s, s.saveAPIKeyAndContinue(token, true)
+						}
 						s.selectedModel = selectedItem
 						s.showCopilotDeviceFlow = true
 						s.copilotDeviceFlow = copilot.NewDeviceFlow()

internal/tui/components/dialogs/models/models.go 🔗

@@ -307,6 +307,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 				m.hyperDeviceFlow.SetWidth(m.width - 2)
 				return m, m.hyperDeviceFlow.Init()
 			case catwalk.InferenceProviderCopilot:
+				if token, ok := config.Get().ImportCopilot(); ok {
+					m.selectedModel = selectedItem
+					m.selectedModelType = modelType
+					return m, m.saveOauthTokenAndContinue(token, true)
+				}
 				m.showCopilotDeviceFlow = true
 				m.selectedModel = selectedItem
 				m.selectedModelType = modelType