copilot: Decouple authentication from the lifetime of any single Copilot instance (#47473)

Piotr Osiewicz , dino , and Zed Zippy created

Users had trouble signing in due to us relying on the Copilot::global
being set, which was never the case. We've decided to use a dedicated
LSP instance just for handling auth of Copilot Chat and other goodies.
That instance is subscribed to by local Copilot instances for projects.
When the Auth instance changes it's state, local instances are prompted
to re-check their own sign in status.

Closes #47352

Co-authored-by: dino <dinojoaocosta@gmail.com>

Release Notes:

- Fixed authentication issues with Copilot.

---------

Co-authored-by: dino <dinojoaocosta@gmail.com>
Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

Cargo.lock                                              |   3 
crates/copilot/Cargo.toml                               |   1 
crates/copilot/src/copilot.rs                           |  96 ++++-
crates/copilot_ui/Cargo.toml                            |   2 
crates/copilot_ui/src/copilot_ui.rs                     |  40 +-
crates/copilot_ui/src/sign_in.rs                        | 172 +++++-----
crates/edit_prediction/src/edit_prediction.rs           |  25 +
crates/edit_prediction/src/onboarding_modal.rs          |   8 
crates/edit_prediction_ui/src/edit_prediction_button.rs |  49 ++-
crates/language_models/src/provider/copilot_chat.rs     |   6 
crates/zed/src/main.rs                                  |   2 
11 files changed, 258 insertions(+), 146 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3703,6 +3703,7 @@ dependencies = [
  "sum_tree",
  "theme",
  "util",
+ "workspace",
  "zlog",
 ]
 
@@ -3733,10 +3734,12 @@ dependencies = [
  "anyhow",
  "copilot",
  "gpui",
+ "language",
  "log",
  "lsp",
  "menu",
  "serde_json",
+ "settings",
  "ui",
  "util",
  "workspace",

crates/copilot/Cargo.toml 🔗

@@ -45,6 +45,7 @@ serde_json.workspace = true
 settings.workspace = true
 sum_tree.workspace = true
 util.workspace = true
+workspace.workspace = true
 
 [target.'cfg(windows)'.dependencies]
 async-std = { version = "1.12.0", features = ["unstable"] }

crates/copilot/src/copilot.rs 🔗

@@ -12,8 +12,8 @@ use command_palette_hooks::CommandPaletteFilter;
 use futures::future;
 use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared, select_biased};
 use gpui::{
-    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
-    WeakEntity, actions,
+    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Subscription,
+    Task, WeakEntity, actions,
 };
 use language::language_settings::{AllLanguageSettings, CopilotSettings};
 use language::{
@@ -41,6 +41,7 @@ use std::{
 };
 use sum_tree::Dimensions;
 use util::{ResultExt, fs::remove_matching};
+use workspace::AppState;
 
 pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
 
@@ -251,20 +252,50 @@ pub struct Copilot {
     server: CopilotServer,
     buffers: HashSet<WeakEntity<Buffer>>,
     server_id: LanguageServerId,
-    _subscriptions: [gpui::Subscription; 2],
+    _subscriptions: Vec<Subscription>,
 }
 
 pub enum Event {
-    CopilotLanguageServerStarted,
     CopilotAuthSignedIn,
     CopilotAuthSignedOut,
 }
 
 impl EventEmitter<Event> for Copilot {}
 
-struct GlobalCopilot(Entity<Copilot>);
+#[derive(Clone)]
+pub struct GlobalCopilotAuth(pub Entity<Copilot>);
 
-impl Global for GlobalCopilot {}
+impl GlobalCopilotAuth {
+    pub fn set_global(
+        server_id: LanguageServerId,
+        fs: Arc<dyn Fs>,
+        node_runtime: NodeRuntime,
+        cx: &mut App,
+    ) {
+        let auth =
+            GlobalCopilotAuth(cx.new(|cx| Copilot::new(None, server_id, fs, node_runtime, cx)));
+        cx.set_global(auth);
+    }
+    pub fn try_global(cx: &mut App) -> Option<&GlobalCopilotAuth> {
+        cx.try_global()
+    }
+
+    pub fn get_or_init(cx: &mut App) -> Option<GlobalCopilotAuth> {
+        if let Some(copilot) = cx.try_global::<Self>() {
+            Some(copilot.clone())
+        } else {
+            let app_state = AppState::global(cx).upgrade()?;
+            Self::set_global(
+                app_state.languages.next_language_server_id(),
+                app_state.fs.clone(),
+                app_state.node_runtime.clone(),
+                cx,
+            );
+            cx.try_global::<Self>().cloned()
+        }
+    }
+}
+impl Global for GlobalCopilotAuth {}
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
 pub(crate) enum CompletionSource {
@@ -284,23 +315,14 @@ pub(crate) struct CopilotEditPrediction {
 }
 
 impl Copilot {
-    pub fn global(cx: &App) -> Option<Entity<Self>> {
-        cx.try_global::<GlobalCopilot>()
-            .map(|model| model.0.clone())
-    }
-
-    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
-        cx.set_global(GlobalCopilot(copilot));
-    }
-
     pub fn new(
-        project: Entity<Project>,
+        project: Option<Entity<Project>>,
         new_server_id: LanguageServerId,
         fs: Arc<dyn Fs>,
         node_runtime: NodeRuntime,
         cx: &mut Context<Self>,
     ) -> Self {
-        let send_focus_notification =
+        let send_focus_notification = project.map(|project| {
             cx.subscribe(&project, |this, project, e: &project::Event, cx| {
                 if let project::Event::ActiveEntryChanged(new_entry) = e
                     && let Ok(running) = this.server.as_authenticated()
@@ -312,11 +334,40 @@ impl Copilot {
 
                     _ = running.lsp.notify::<DidFocus>(DidFocusParams { uri });
                 }
+            })
+        });
+        let global_authentication_events =
+            cx.try_global::<GlobalCopilotAuth>().cloned().map(|auth| {
+                cx.subscribe(&auth.0, |_, _, _: &Event, cx| {
+                    cx.spawn(async move |this, cx| {
+                        let Some(server) = this
+                            .update(cx, |this, _| this.language_server().cloned())
+                            .ok()
+                            .flatten()
+                        else {
+                            return;
+                        };
+                        let status = server
+                            .request::<request::CheckStatus>(request::CheckStatusParams {
+                                local_checks_only: false,
+                            })
+                            .await
+                            .into_response()
+                            .ok();
+                        if let Some(status) = status {
+                            this.update(cx, |copilot, cx| {
+                                copilot.update_sign_in_status(status, cx);
+                            })
+                            .ok();
+                        }
+                    })
+                    .detach()
+                })
             });
-        let _subscriptions = [
-            cx.on_app_quit(Self::shutdown_language_server),
-            send_focus_notification,
-        ];
+        let _subscriptions = std::iter::once(cx.on_app_quit(Self::shutdown_language_server))
+            .chain(send_focus_notification)
+            .chain(global_authentication_events)
+            .collect();
         let mut this = Self {
             server_id: new_server_id,
             fs,
@@ -455,7 +506,7 @@ impl Copilot {
                 sign_in_status: SignInStatus::Authorized,
                 registered_buffers: Default::default(),
             }),
-            _subscriptions: [
+            _subscriptions: vec![
                 send_focus_notification,
                 cx.on_app_quit(Self::shutdown_language_server),
             ],
@@ -619,7 +670,6 @@ impl Copilot {
                         },
                         registered_buffers: Default::default(),
                     });
-                    cx.emit(Event::CopilotLanguageServerStarted);
                     this.update_sign_in_status(status, cx);
                 }
                 Err(error) => {

crates/copilot_ui/Cargo.toml 🔗

@@ -23,10 +23,12 @@ test-support = [
 anyhow.workspace = true
 copilot.workspace = true
 gpui.workspace = true
+language.workspace = true
 log.workspace = true
 lsp.workspace = true
 menu.workspace = true
 serde_json.workspace = true
+settings.workspace = true
 ui.workspace = true
 util.workspace = true
 workspace.workspace = true

crates/copilot_ui/src/copilot_ui.rs 🔗

@@ -1,25 +1,31 @@
 mod sign_in;
 
-use copilot::{Reinstall, SignIn, SignOut};
-use gpui::App;
-use workspace::Workspace;
+use std::sync::Arc;
 
+use copilot::GlobalCopilotAuth;
+use gpui::AppContext;
+use language::language_settings::AllLanguageSettings;
+use settings::SettingsStore;
 pub use sign_in::{
     ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in,
-    reinstall_and_sign_in,
+    initiate_sign_out, reinstall_and_sign_in,
 };
+use ui::App;
+use workspace::AppState;
 
-pub fn init(cx: &mut App) {
-    cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
-        workspace.register_action(|_, _: &SignIn, window, cx| {
-            sign_in::initiate_sign_in(window, cx);
-        });
-        workspace.register_action(|_, _: &Reinstall, window, cx| {
-            sign_in::reinstall_and_sign_in(window, cx);
-        });
-        workspace.register_action(|_, _: &SignOut, window, cx| {
-            sign_in::initiate_sign_out(window, cx);
-        });
-    })
-    .detach();
+pub fn init(app_state: &Arc<AppState>, cx: &mut App) {
+    let provider = cx.read_global(|settings: &SettingsStore, _| {
+        settings
+            .get::<AllLanguageSettings>(None)
+            .edit_predictions
+            .provider
+    });
+    if provider == settings::EditPredictionProvider::Copilot {
+        GlobalCopilotAuth::set_global(
+            app_state.languages.next_language_server_id(),
+            app_state.fs.clone(),
+            app_state.node_runtime.clone(),
+            cx,
+        );
+    }
 }

crates/copilot_ui/src/sign_in.rs 🔗

@@ -1,5 +1,8 @@
 use anyhow::Context as _;
-use copilot::{Copilot, Status, request, request::PromptUserDeviceFlow};
+use copilot::{
+    Copilot, GlobalCopilotAuth, Status,
+    request::{self, PromptUserDeviceFlow},
+};
 use gpui::{
     App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle,
     Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled,
@@ -15,16 +18,12 @@ const ERROR_LABEL: &str =
 
 struct CopilotStatusToast;
 
-pub fn initiate_sign_in(window: &mut Window, cx: &mut App) {
+pub fn initiate_sign_in(copilot: Entity<Copilot>, window: &mut Window, cx: &mut App) {
     let is_reinstall = false;
-    initiate_sign_in_impl(is_reinstall, window, cx)
+    initiate_sign_in_impl(copilot, is_reinstall, window, cx)
 }
 
-pub fn initiate_sign_out(window: &mut Window, cx: &mut App) {
-    let Some(copilot) = Copilot::global(cx) else {
-        return;
-    };
-
+pub fn initiate_sign_out(copilot: Entity<Copilot>, window: &mut Window, cx: &mut App) {
     copilot_toast(Some("Signing out of Copilot…"), window, cx);
 
     let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx));
@@ -46,13 +45,10 @@ pub fn initiate_sign_out(window: &mut Window, cx: &mut App) {
         .detach();
 }
 
-pub fn reinstall_and_sign_in(window: &mut Window, cx: &mut App) {
-    let Some(copilot) = Copilot::global(cx) else {
-        return;
-    };
+pub fn reinstall_and_sign_in(copilot: Entity<Copilot>, window: &mut Window, cx: &mut App) {
     let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx));
     let is_reinstall = true;
-    initiate_sign_in_impl(is_reinstall, window, cx);
+    initiate_sign_in_impl(copilot, is_reinstall, window, cx);
 }
 
 fn open_copilot_code_verification_window(copilot: &Entity<Copilot>, window: &Window, cx: &mut App) {
@@ -96,10 +92,12 @@ fn copilot_toast(message: Option<&'static str>, window: &Window, cx: &mut App) {
     })
 }
 
-pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut App) {
-    let Some(copilot) = Copilot::global(cx) else {
-        return;
-    };
+pub fn initiate_sign_in_impl(
+    copilot: Entity<Copilot>,
+    is_reinstall: bool,
+    window: &mut Window,
+    cx: &mut App,
+) {
     if matches!(copilot.read(cx).status(), Status::Disabled) {
         copilot.update(cx, |copilot, cx| copilot.start_copilot(false, true, cx));
     }
@@ -118,21 +116,16 @@ pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut A
             window
                 .spawn(cx, async move |cx| {
                     task.await;
-                    cx.update(|window, cx| {
-                        let Some(copilot) = Copilot::global(cx) else {
-                            return;
-                        };
-                        match copilot.read(cx).status() {
-                            Status::Authorized => {
-                                copilot_toast(Some("Copilot has started."), window, cx)
-                            }
-                            _ => {
-                                copilot_toast(None, window, cx);
-                                copilot
-                                    .update(cx, |copilot, cx| copilot.sign_in(cx))
-                                    .detach_and_log_err(cx);
-                                open_copilot_code_verification_window(&copilot, window, cx);
-                            }
+                    cx.update(|window, cx| match copilot.read(cx).status() {
+                        Status::Authorized => {
+                            copilot_toast(Some("Copilot has started."), window, cx)
+                        }
+                        _ => {
+                            copilot_toast(None, window, cx);
+                            copilot
+                                .update(cx, |copilot, cx| copilot.sign_in(cx))
+                                .detach_and_log_err(cx);
+                            open_copilot_code_verification_window(&copilot, window, cx);
                         }
                     })
                     .log_err();
@@ -237,6 +230,7 @@ impl CopilotCodeVerification {
     }
 
     fn render_prompting_modal(
+        copilot: Entity<Copilot>,
         connect_clicked: bool,
         data: &PromptUserDeviceFlow,
         cx: &mut Context<Self>,
@@ -274,47 +268,44 @@ impl CopilotCodeVerification {
                             .on_click({
                                 let command = data.command.clone();
                                 cx.listener(move |this, _, _window, cx| {
-                                    if let Some(copilot) = Copilot::global(cx) {
-                                        let command = command.clone();
-                                        let copilot_clone = copilot.clone();
-                                        copilot.update(cx, |copilot, cx| {
-                                            if let Some(server) = copilot.language_server() {
-                                                let server = server.clone();
-                                                cx.spawn(async move |_, cx| {
-                                                    let result = server
-                                                        .request::<lsp::request::ExecuteCommand>(
-                                                            lsp::ExecuteCommandParams {
-                                                                command: command.command.clone(),
-                                                                arguments: command
-                                                                    .arguments
-                                                                    .clone()
-                                                                    .unwrap_or_default(),
-                                                                ..Default::default()
-                                                            },
-                                                        )
-                                                        .await
-                                                        .into_response()
-                                                        .ok()
-                                                        .flatten();
-                                                    if let Some(value) = result {
-                                                        if let Ok(status) =
-                                                            serde_json::from_value::<
-                                                                request::SignInStatus,
-                                                            >(value)
-                                                        {
-                                                            copilot_clone
-                                                                .update(cx, |copilot, cx| {
-                                                                    copilot.update_sign_in_status(
-                                                                        status, cx,
-                                                                    );
-                                                                });
-                                                        }
+                                    let command = command.clone();
+                                    let copilot_clone = copilot.clone();
+                                    copilot.update(cx, |copilot, cx| {
+                                        if let Some(server) = copilot.language_server() {
+                                            let server = server.clone();
+                                            cx.spawn(async move |_, cx| {
+                                                let result = server
+                                                    .request::<lsp::request::ExecuteCommand>(
+                                                        lsp::ExecuteCommandParams {
+                                                            command: command.command.clone(),
+                                                            arguments: command
+                                                                .arguments
+                                                                .clone()
+                                                                .unwrap_or_default(),
+                                                            ..Default::default()
+                                                        },
+                                                    )
+                                                    .await
+                                                    .into_response()
+                                                    .ok()
+                                                    .flatten();
+                                                if let Some(value) = result {
+                                                    if let Ok(status) = serde_json::from_value::<
+                                                        request::SignInStatus,
+                                                    >(
+                                                        value
+                                                    ) {
+                                                        copilot_clone.update(cx, |copilot, cx| {
+                                                            copilot
+                                                                .update_sign_in_status(status, cx);
+                                                        });
                                                     }
-                                                })
-                                                .detach();
-                                            }
-                                        });
-                                    }
+                                                }
+                                            })
+                                            .detach();
+                                        }
+                                    });
+
                                     this.connect_clicked = true;
                                 })
                             }),
@@ -378,7 +369,7 @@ impl CopilotCodeVerification {
             )
     }
 
-    fn render_error_modal(_cx: &mut Context<Self>) -> impl Element {
+    fn render_error_modal(copilot: Entity<Copilot>, _cx: &mut Context<Self>) -> impl Element {
         v_flex()
             .gap_2()
             .text_center()
@@ -394,7 +385,9 @@ impl CopilotCodeVerification {
                     .icon_color(Color::Muted)
                     .icon_position(IconPosition::Start)
                     .icon_size(IconSize::Small)
-                    .on_click(|_, window, cx| reinstall_and_sign_in(window, cx)),
+                    .on_click(move |_, window, cx| {
+                        reinstall_and_sign_in(copilot.clone(), window, cx)
+                    }),
             )
     }
 
@@ -420,7 +413,10 @@ impl Render for CopilotCodeVerification {
                 .into_any_element(),
             Status::SigningIn {
                 prompt: Some(prompt),
-            } => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(),
+            } => {
+                Self::render_prompting_modal(self.copilot.clone(), self.connect_clicked, prompt, cx)
+                    .into_any_element()
+            }
             Status::Unauthorized => {
                 self.connect_clicked = false;
                 self.render_unauthorized_modal(cx).into_any_element()
@@ -429,7 +425,9 @@ impl Render for CopilotCodeVerification {
                 self.connect_clicked = false;
                 Self::render_enabled_modal(cx).into_any_element()
             }
-            Status::Error(..) => Self::render_error_modal(cx).into_any_element(),
+            Status::Error(..) => {
+                Self::render_error_modal(self.copilot.clone(), cx).into_any_element()
+            }
             _ => div().into_any_element(),
         };
 
@@ -475,14 +473,14 @@ impl ConfigurationView {
         mode: ConfigurationMode,
         cx: &mut Context<Self>,
     ) -> Self {
-        let copilot = Copilot::global(cx);
+        let copilot = GlobalCopilotAuth::try_global(cx).cloned();
 
         Self {
-            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
+            copilot_status: copilot.as_ref().map(|copilot| copilot.0.read(cx).status()),
             is_authenticated: Box::new(is_authenticated),
             edit_prediction: matches!(mode, ConfigurationMode::EditPrediction),
             _subscription: copilot.as_ref().map(|copilot| {
-                cx.observe(copilot, |this, model, cx| {
+                cx.observe(&copilot.0, |this, model, cx| {
                     this.copilot_status = Some(model.read(cx).status());
                     cx.notify();
                 })
@@ -568,7 +566,11 @@ impl ConfigurationView {
             .icon_color(Color::Muted)
             .icon_position(IconPosition::Start)
             .icon_size(IconSize::Small)
-            .on_click(|_, window, cx| initiate_sign_in(window, cx))
+            .on_click(|_, window, cx| {
+                if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) {
+                    initiate_sign_in(copilot.0, window, cx)
+                }
+            })
     }
 
     fn render_reinstall_button(&self, edit_prediction: bool) -> impl IntoElement {
@@ -591,7 +593,11 @@ impl ConfigurationView {
             .icon_color(Color::Muted)
             .icon_position(IconPosition::Start)
             .icon_size(IconSize::Small)
-            .on_click(|_, window, cx| reinstall_and_sign_in(window, cx))
+            .on_click(|_, window, cx| {
+                if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) {
+                    reinstall_and_sign_in(copilot.0, window, cx)
+                }
+            })
     }
 
     fn render_for_edit_prediction(&self) -> impl IntoElement {
@@ -684,7 +690,9 @@ impl Render for ConfigurationView {
             return ConfiguredApiCard::new("Authorized")
                 .button_label("Sign Out")
                 .on_click(|_, window, cx| {
-                    initiate_sign_out(window, cx);
+                    if let Some(auth) = GlobalCopilotAuth::try_global(cx) {
+                        initiate_sign_out(auth.0.clone(), window, cx);
+                    }
                 })
                 .into_any_element();
         }

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -10,7 +10,7 @@ use cloud_llm_client::{
     PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
 };
 use collections::{HashMap, HashSet};
-use copilot::Copilot;
+use copilot::{Copilot, Reinstall, SignIn, SignOut};
 use db::kvp::{Dismissable, KEY_VALUE_STORE};
 use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
 use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
@@ -754,7 +754,7 @@ impl EditPredictionStore {
             let next_id = project.languages().next_language_server_id();
             let fs = project.fs().clone();
 
-            let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx));
+            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
             state.copilot = Some(copilot.clone());
             Some(copilot)
         } else {
@@ -2332,6 +2332,27 @@ pub fn init(cx: &mut App) {
                     .edit_prediction_provider = Some(EditPredictionProvider::None)
             });
         });
+        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
+            EditPredictionStore::try_global(cx).and_then(|store| {
+                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
+            })
+        }
+
+        workspace.register_action(|workspace, _: &SignIn, window, cx| {
+            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
+                copilot_ui::initiate_sign_in(copilot, window, cx);
+            }
+        });
+        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
+            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
+                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
+            }
+        });
+        workspace.register_action(|workspace, _: &SignOut, window, cx| {
+            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
+                copilot_ui::initiate_sign_out(copilot, window, cx);
+            }
+        });
     })
     .detach();
 }

crates/edit_prediction/src/onboarding_modal.rs 🔗

@@ -60,7 +60,9 @@ impl ZedPredictModal {
                     EditPredictionOnboarding::new(
                         user_store.clone(),
                         client.clone(),
-                        copilot.is_some_and(|copilot| copilot.read(cx).status().is_configured()),
+                        copilot
+                            .as_ref()
+                            .is_some_and(|copilot| copilot.read(cx).status().is_configured()),
                         Arc::new({
                             let this = weak_entity.clone();
                             move |_window, cx| {
@@ -75,7 +77,9 @@ impl ZedPredictModal {
                                 ZedPredictUpsell::set_dismissed(true, cx);
                                 set_edit_prediction_provider(EditPredictionProvider::Copilot, cx);
                                 this.update(cx, |_, cx| cx.emit(DismissEvent)).ok();
-                                copilot_ui::initiate_sign_in(window, cx);
+                                if let Some(copilot) = copilot.clone() {
+                                    copilot_ui::initiate_sign_in(copilot, window, cx);
+                                }
                             }
                         }),
                         cx,

crates/edit_prediction_ui/src/edit_prediction_button.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::Result;
 use client::{Client, UserStore, zed_urls};
 use cloud_llm_client::UsageLimit;
 use codestral::CodestralEditPredictionDelegate;
-use copilot::{Copilot, Status};
+use copilot::Status;
 use edit_prediction::{
     EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag,
 };
@@ -124,6 +124,7 @@ impl Render for EditPredictionButton {
                             .on_click(cx.listener(move |_, _, window, cx| {
                                 if let Some(workspace) = window.root::<Workspace>().flatten() {
                                     workspace.update(cx, |workspace, cx| {
+                                        let copilot = copilot.clone();
                                         workspace.show_toast(
                                             Toast::new(
                                                 NotificationId::unique::<CopilotErrorToast>(),
@@ -131,8 +132,12 @@ impl Render for EditPredictionButton {
                                             )
                                             .on_click(
                                                 "Reinstall Copilot",
-                                                |window, cx| {
-                                                    copilot_ui::reinstall_and_sign_in(window, cx)
+                                                move |window, cx| {
+                                                    copilot_ui::reinstall_and_sign_in(
+                                                        copilot.clone(),
+                                                        window,
+                                                        cx,
+                                                    )
                                                 },
                                             ),
                                             cx,
@@ -489,7 +494,10 @@ impl EditPredictionButton {
         project: Entity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
-        if let Some(copilot) = Copilot::global(cx) {
+        let copilot = EditPredictionStore::try_global(cx).and_then(|store| {
+            store.update(cx, |this, cx| this.start_copilot_for_project(&project, cx))
+        });
+        if let Some(copilot) = copilot {
             cx.observe(&copilot, |_, _, cx| cx.notify()).detach()
         }
 
@@ -638,19 +646,28 @@ impl EditPredictionButton {
         cx: &mut Context<Self>,
     ) -> Entity<ContextMenu> {
         let fs = self.fs.clone();
+        let project = self.project.clone();
         ContextMenu::build(window, cx, |menu, _, _| {
-            menu.entry("Sign In to Copilot", None, copilot_ui::initiate_sign_in)
-                .entry("Disable Copilot", None, {
-                    let fs = fs.clone();
-                    move |_window, cx| hide_copilot(fs.clone(), cx)
-                })
-                .separator()
-                .entry("Use Zed AI", None, {
-                    let fs = fs.clone();
-                    move |_window, cx| {
-                        set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed)
-                    }
-                })
+            menu.entry("Sign In to Copilot", None, move |window, cx| {
+                if let Some(copilot) = EditPredictionStore::try_global(cx).and_then(|store| {
+                    store.update(cx, |this, cx| {
+                        this.start_copilot_for_project(&project.upgrade()?, cx)
+                    })
+                }) {
+                    copilot_ui::initiate_sign_in(copilot, window, cx);
+                }
+            })
+            .entry("Disable Copilot", None, {
+                let fs = fs.clone();
+                move |_window, cx| hide_copilot(fs.clone(), cx)
+            })
+            .separator()
+            .entry("Use Zed AI", None, {
+                let fs = fs.clone();
+                move |_window, cx| {
+                    set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed)
+                }
+            })
         })
     }
 

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -5,7 +5,7 @@ use std::sync::Arc;
 use anyhow::{Result, anyhow};
 use cloud_llm_client::CompletionIntent;
 use collections::HashMap;
-use copilot::{Copilot, Status};
+use copilot::{GlobalCopilotAuth, Status};
 use copilot_chat::responses as copilot_responses;
 use copilot_chat::{
     ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, CopilotChatConfiguration,
@@ -141,7 +141,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
             return Task::ready(Ok(()));
         };
 
-        let Some(copilot) = Copilot::global(cx) else {
+        let Some(copilot) = GlobalCopilotAuth::try_global(cx).cloned() else {
             return Task::ready(Err(anyhow!(concat!(
                 "Copilot must be enabled for Copilot Chat to work. ",
                 "Please enable Copilot and try again."
@@ -149,7 +149,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
             .into()));
         };
 
-        let err = match copilot.read(cx).status() {
+        let err = match copilot.0.read(cx).status() {
             Status::Authorized => return Task::ready(Ok(())),
             Status::Disabled => anyhow!(
                 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."

crates/zed/src/main.rs 🔗

@@ -610,7 +610,7 @@ fn main() {
             cx,
         );
 
-        copilot_ui::init(cx);
+        copilot_ui::init(&app_state, cx);
         supermaven::init(app_state.client.clone(), cx);
         language_model::init(app_state.client.clone(), cx);
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);