From 7c98f1732edc04458b8a96221d956fb8c50c13f0 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:20:16 +0100 Subject: [PATCH] copilot: Decouple authentication from the lifetime of any single Copilot instance (#47473) Users had trouble signing in due to us relying on the Copilot::global being set, which was never the case. We've decided to use a dedicated LSP instance just for handling auth of Copilot Chat and other goodies. That instance is subscribed to by local Copilot instances for projects. When the Auth instance changes it's state, local instances are prompted to re-check their own sign in status. Closes #47352 Co-authored-by: dino Release Notes: - Fixed authentication issues with Copilot. --------- Co-authored-by: dino Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> --- Cargo.lock | 3 + crates/copilot/Cargo.toml | 1 + crates/copilot/src/copilot.rs | 96 +++++++--- crates/copilot_ui/Cargo.toml | 2 + crates/copilot_ui/src/copilot_ui.rs | 40 ++-- crates/copilot_ui/src/sign_in.rs | 172 +++++++++--------- crates/edit_prediction/src/edit_prediction.rs | 25 ++- .../edit_prediction/src/onboarding_modal.rs | 8 +- .../src/edit_prediction_button.rs | 49 +++-- .../src/provider/copilot_chat.rs | 6 +- crates/zed/src/main.rs | 2 +- 11 files changed, 258 insertions(+), 146 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f5abe2d4df9f346c35dcb952e231951cd503c5e9..1d43c57a302368647f4f5cb9ccc54a78d4cabaa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3703,6 +3703,7 @@ dependencies = [ "sum_tree", "theme", "util", + "workspace", "zlog", ] @@ -3733,10 +3734,12 @@ dependencies = [ "anyhow", "copilot", "gpui", + "language", "log", "lsp", "menu", "serde_json", + "settings", "ui", "util", "workspace", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 1402bb8d6ffd82d5cd8a2225c8336ee30de3e49e..a07811cd2e3a6b17a1bf23c24ae7af376ef9b37a 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -45,6 +45,7 @@ serde_json.workspace = true settings.workspace = true sum_tree.workspace = true util.workspace = true +workspace.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index c35546b5a9c24816dcb36a86bd73b977aa3f5d29..8177a509bc60ecb459515e02a0f0d9f75d09cdf9 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -12,8 +12,8 @@ use command_palette_hooks::CommandPaletteFilter; use futures::future; use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared, select_biased}; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task, - WeakEntity, actions, + App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Subscription, + Task, WeakEntity, actions, }; use language::language_settings::{AllLanguageSettings, CopilotSettings}; use language::{ @@ -41,6 +41,7 @@ use std::{ }; use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; +use workspace::AppState; pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate; @@ -251,20 +252,50 @@ pub struct Copilot { server: CopilotServer, buffers: HashSet>, server_id: LanguageServerId, - _subscriptions: [gpui::Subscription; 2], + _subscriptions: Vec, } pub enum Event { - CopilotLanguageServerStarted, CopilotAuthSignedIn, CopilotAuthSignedOut, } impl EventEmitter for Copilot {} -struct GlobalCopilot(Entity); +#[derive(Clone)] +pub struct GlobalCopilotAuth(pub Entity); -impl Global for GlobalCopilot {} +impl GlobalCopilotAuth { + pub fn set_global( + server_id: LanguageServerId, + fs: Arc, + node_runtime: NodeRuntime, + cx: &mut App, + ) { + let auth = + GlobalCopilotAuth(cx.new(|cx| Copilot::new(None, server_id, fs, node_runtime, cx))); + cx.set_global(auth); + } + pub fn try_global(cx: &mut App) -> Option<&GlobalCopilotAuth> { + cx.try_global() + } + + pub fn get_or_init(cx: &mut App) -> Option { + if let Some(copilot) = cx.try_global::() { + Some(copilot.clone()) + } else { + let app_state = AppState::global(cx).upgrade()?; + Self::set_global( + app_state.languages.next_language_server_id(), + app_state.fs.clone(), + app_state.node_runtime.clone(), + cx, + ); + cx.try_global::().cloned() + } + } +} +impl Global for GlobalCopilotAuth {} #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) enum CompletionSource { @@ -284,23 +315,14 @@ pub(crate) struct CopilotEditPrediction { } impl Copilot { - pub fn global(cx: &App) -> Option> { - cx.try_global::() - .map(|model| model.0.clone()) - } - - pub fn set_global(copilot: Entity, cx: &mut App) { - cx.set_global(GlobalCopilot(copilot)); - } - pub fn new( - project: Entity, + project: Option>, new_server_id: LanguageServerId, fs: Arc, node_runtime: NodeRuntime, cx: &mut Context, ) -> Self { - let send_focus_notification = + let send_focus_notification = project.map(|project| { cx.subscribe(&project, |this, project, e: &project::Event, cx| { if let project::Event::ActiveEntryChanged(new_entry) = e && let Ok(running) = this.server.as_authenticated() @@ -312,11 +334,40 @@ impl Copilot { _ = running.lsp.notify::(DidFocusParams { uri }); } + }) + }); + let global_authentication_events = + cx.try_global::().cloned().map(|auth| { + cx.subscribe(&auth.0, |_, _, _: &Event, cx| { + cx.spawn(async move |this, cx| { + let Some(server) = this + .update(cx, |this, _| this.language_server().cloned()) + .ok() + .flatten() + else { + return; + }; + let status = server + .request::(request::CheckStatusParams { + local_checks_only: false, + }) + .await + .into_response() + .ok(); + if let Some(status) = status { + this.update(cx, |copilot, cx| { + copilot.update_sign_in_status(status, cx); + }) + .ok(); + } + }) + .detach() + }) }); - let _subscriptions = [ - cx.on_app_quit(Self::shutdown_language_server), - send_focus_notification, - ]; + let _subscriptions = std::iter::once(cx.on_app_quit(Self::shutdown_language_server)) + .chain(send_focus_notification) + .chain(global_authentication_events) + .collect(); let mut this = Self { server_id: new_server_id, fs, @@ -455,7 +506,7 @@ impl Copilot { sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), }), - _subscriptions: [ + _subscriptions: vec![ send_focus_notification, cx.on_app_quit(Self::shutdown_language_server), ], @@ -619,7 +670,6 @@ impl Copilot { }, registered_buffers: Default::default(), }); - cx.emit(Event::CopilotLanguageServerStarted); this.update_sign_in_status(status, cx); } Err(error) => { diff --git a/crates/copilot_ui/Cargo.toml b/crates/copilot_ui/Cargo.toml index 9f2668d9fb40d12631bff6af3291bdb3a40dea15..fbd16c5db0fa936cf39de8b61170f1631335873c 100644 --- a/crates/copilot_ui/Cargo.toml +++ b/crates/copilot_ui/Cargo.toml @@ -23,10 +23,12 @@ test-support = [ anyhow.workspace = true copilot.workspace = true gpui.workspace = true +language.workspace = true log.workspace = true lsp.workspace = true menu.workspace = true serde_json.workspace = true +settings.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs index e22c2800c4beff1debb31aea9ce4ddca811f2bf0..f318acb5d547cd5652d6ad08acfad66a5056894b 100644 --- a/crates/copilot_ui/src/copilot_ui.rs +++ b/crates/copilot_ui/src/copilot_ui.rs @@ -1,25 +1,31 @@ mod sign_in; -use copilot::{Reinstall, SignIn, SignOut}; -use gpui::App; -use workspace::Workspace; +use std::sync::Arc; +use copilot::GlobalCopilotAuth; +use gpui::AppContext; +use language::language_settings::AllLanguageSettings; +use settings::SettingsStore; pub use sign_in::{ ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in, - reinstall_and_sign_in, + initiate_sign_out, reinstall_and_sign_in, }; +use ui::App; +use workspace::AppState; -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_, _: &SignIn, window, cx| { - sign_in::initiate_sign_in(window, cx); - }); - workspace.register_action(|_, _: &Reinstall, window, cx| { - sign_in::reinstall_and_sign_in(window, cx); - }); - workspace.register_action(|_, _: &SignOut, window, cx| { - sign_in::initiate_sign_out(window, cx); - }); - }) - .detach(); +pub fn init(app_state: &Arc, cx: &mut App) { + let provider = cx.read_global(|settings: &SettingsStore, _| { + settings + .get::(None) + .edit_predictions + .provider + }); + if provider == settings::EditPredictionProvider::Copilot { + GlobalCopilotAuth::set_global( + app_state.languages.next_language_server_id(), + app_state.fs.clone(), + app_state.node_runtime.clone(), + cx, + ); + } } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot_ui/src/sign_in.rs index a9eda9d3c9b182fcb880c0d4d9812063578b4d1e..cbc252270e118885f117cfd9667802147bcc9f26 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot_ui/src/sign_in.rs @@ -1,5 +1,8 @@ use anyhow::Context as _; -use copilot::{Copilot, Status, request, request::PromptUserDeviceFlow}; +use copilot::{ + Copilot, GlobalCopilotAuth, Status, + request::{self, PromptUserDeviceFlow}, +}; use gpui::{ App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled, @@ -15,16 +18,12 @@ const ERROR_LABEL: &str = struct CopilotStatusToast; -pub fn initiate_sign_in(window: &mut Window, cx: &mut App) { +pub fn initiate_sign_in(copilot: Entity, window: &mut Window, cx: &mut App) { let is_reinstall = false; - initiate_sign_in_impl(is_reinstall, window, cx) + initiate_sign_in_impl(copilot, is_reinstall, window, cx) } -pub fn initiate_sign_out(window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - +pub fn initiate_sign_out(copilot: Entity, window: &mut Window, cx: &mut App) { copilot_toast(Some("Signing out of Copilot…"), window, cx); let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx)); @@ -46,13 +45,10 @@ pub fn initiate_sign_out(window: &mut Window, cx: &mut App) { .detach(); } -pub fn reinstall_and_sign_in(window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; +pub fn reinstall_and_sign_in(copilot: Entity, window: &mut Window, cx: &mut App) { let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx)); let is_reinstall = true; - initiate_sign_in_impl(is_reinstall, window, cx); + initiate_sign_in_impl(copilot, is_reinstall, window, cx); } fn open_copilot_code_verification_window(copilot: &Entity, window: &Window, cx: &mut App) { @@ -96,10 +92,12 @@ fn copilot_toast(message: Option<&'static str>, window: &Window, cx: &mut App) { }) } -pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; +pub fn initiate_sign_in_impl( + copilot: Entity, + is_reinstall: bool, + window: &mut Window, + cx: &mut App, +) { if matches!(copilot.read(cx).status(), Status::Disabled) { copilot.update(cx, |copilot, cx| copilot.start_copilot(false, true, cx)); } @@ -118,21 +116,16 @@ pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut A window .spawn(cx, async move |cx| { task.await; - cx.update(|window, cx| { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - match copilot.read(cx).status() { - Status::Authorized => { - copilot_toast(Some("Copilot has started."), window, cx) - } - _ => { - copilot_toast(None, window, cx); - copilot - .update(cx, |copilot, cx| copilot.sign_in(cx)) - .detach_and_log_err(cx); - open_copilot_code_verification_window(&copilot, window, cx); - } + cx.update(|window, cx| match copilot.read(cx).status() { + Status::Authorized => { + copilot_toast(Some("Copilot has started."), window, cx) + } + _ => { + copilot_toast(None, window, cx); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + open_copilot_code_verification_window(&copilot, window, cx); } }) .log_err(); @@ -237,6 +230,7 @@ impl CopilotCodeVerification { } fn render_prompting_modal( + copilot: Entity, connect_clicked: bool, data: &PromptUserDeviceFlow, cx: &mut Context, @@ -274,47 +268,44 @@ impl CopilotCodeVerification { .on_click({ let command = data.command.clone(); cx.listener(move |this, _, _window, cx| { - if let Some(copilot) = Copilot::global(cx) { - let command = command.clone(); - let copilot_clone = copilot.clone(); - copilot.update(cx, |copilot, cx| { - if let Some(server) = copilot.language_server() { - let server = server.clone(); - cx.spawn(async move |_, cx| { - let result = server - .request::( - lsp::ExecuteCommandParams { - command: command.command.clone(), - arguments: command - .arguments - .clone() - .unwrap_or_default(), - ..Default::default() - }, - ) - .await - .into_response() - .ok() - .flatten(); - if let Some(value) = result { - if let Ok(status) = - serde_json::from_value::< - request::SignInStatus, - >(value) - { - copilot_clone - .update(cx, |copilot, cx| { - copilot.update_sign_in_status( - status, cx, - ); - }); - } + let command = command.clone(); + let copilot_clone = copilot.clone(); + copilot.update(cx, |copilot, cx| { + if let Some(server) = copilot.language_server() { + let server = server.clone(); + cx.spawn(async move |_, cx| { + let result = server + .request::( + lsp::ExecuteCommandParams { + command: command.command.clone(), + arguments: command + .arguments + .clone() + .unwrap_or_default(), + ..Default::default() + }, + ) + .await + .into_response() + .ok() + .flatten(); + if let Some(value) = result { + if let Ok(status) = serde_json::from_value::< + request::SignInStatus, + >( + value + ) { + copilot_clone.update(cx, |copilot, cx| { + copilot + .update_sign_in_status(status, cx); + }); } - }) - .detach(); - } - }); - } + } + }) + .detach(); + } + }); + this.connect_clicked = true; }) }), @@ -378,7 +369,7 @@ impl CopilotCodeVerification { ) } - fn render_error_modal(_cx: &mut Context) -> impl Element { + fn render_error_modal(copilot: Entity, _cx: &mut Context) -> impl Element { v_flex() .gap_2() .text_center() @@ -394,7 +385,9 @@ impl CopilotCodeVerification { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| reinstall_and_sign_in(window, cx)), + .on_click(move |_, window, cx| { + reinstall_and_sign_in(copilot.clone(), window, cx) + }), ) } @@ -420,7 +413,10 @@ impl Render for CopilotCodeVerification { .into_any_element(), Status::SigningIn { prompt: Some(prompt), - } => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(), + } => { + Self::render_prompting_modal(self.copilot.clone(), self.connect_clicked, prompt, cx) + .into_any_element() + } Status::Unauthorized => { self.connect_clicked = false; self.render_unauthorized_modal(cx).into_any_element() @@ -429,7 +425,9 @@ impl Render for CopilotCodeVerification { self.connect_clicked = false; Self::render_enabled_modal(cx).into_any_element() } - Status::Error(..) => Self::render_error_modal(cx).into_any_element(), + Status::Error(..) => { + Self::render_error_modal(self.copilot.clone(), cx).into_any_element() + } _ => div().into_any_element(), }; @@ -475,14 +473,14 @@ impl ConfigurationView { mode: ConfigurationMode, cx: &mut Context, ) -> Self { - let copilot = Copilot::global(cx); + let copilot = GlobalCopilotAuth::try_global(cx).cloned(); Self { - copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), + copilot_status: copilot.as_ref().map(|copilot| copilot.0.read(cx).status()), is_authenticated: Box::new(is_authenticated), edit_prediction: matches!(mode, ConfigurationMode::EditPrediction), _subscription: copilot.as_ref().map(|copilot| { - cx.observe(copilot, |this, model, cx| { + cx.observe(&copilot.0, |this, model, cx| { this.copilot_status = Some(model.read(cx).status()); cx.notify(); }) @@ -568,7 +566,11 @@ impl ConfigurationView { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| initiate_sign_in(window, cx)) + .on_click(|_, window, cx| { + if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { + initiate_sign_in(copilot.0, window, cx) + } + }) } fn render_reinstall_button(&self, edit_prediction: bool) -> impl IntoElement { @@ -591,7 +593,11 @@ impl ConfigurationView { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| reinstall_and_sign_in(window, cx)) + .on_click(|_, window, cx| { + if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { + reinstall_and_sign_in(copilot.0, window, cx) + } + }) } fn render_for_edit_prediction(&self) -> impl IntoElement { @@ -684,7 +690,9 @@ impl Render for ConfigurationView { return ConfiguredApiCard::new("Authorized") .button_label("Sign Out") .on_click(|_, window, cx| { - initiate_sign_out(window, cx); + if let Some(auth) = GlobalCopilotAuth::try_global(cx) { + initiate_sign_out(auth.0.clone(), window, cx); + } }) .into_any_element(); } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 5813cd40e9bc32463c172f3ad8bfef9299be0d49..098136b42f2c92ddb80a43a46bb29ed7518aff34 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -10,7 +10,7 @@ use cloud_llm_client::{ PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet}; -use copilot::Copilot; +use copilot::{Copilot, Reinstall, SignIn, SignOut}; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; @@ -754,7 +754,7 @@ impl EditPredictionStore { let next_id = project.languages().next_language_server_id(); let fs = project.fs().clone(); - let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx)); + let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx)); state.copilot = Some(copilot.clone()); Some(copilot) } else { @@ -2332,6 +2332,27 @@ pub fn init(cx: &mut App) { .edit_prediction_provider = Some(EditPredictionProvider::None) }); }); + fn copilot_for_project(project: &Entity, cx: &mut App) -> Option> { + EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| this.start_copilot_for_project(project, cx)) + }) + } + + workspace.register_action(|workspace, _: &SignIn, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::initiate_sign_in(copilot, window, cx); + } + }); + workspace.register_action(|workspace, _: &Reinstall, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::reinstall_and_sign_in(copilot, window, cx); + } + }); + workspace.register_action(|workspace, _: &SignOut, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::initiate_sign_out(copilot, window, cx); + } + }); }) .detach(); } diff --git a/crates/edit_prediction/src/onboarding_modal.rs b/crates/edit_prediction/src/onboarding_modal.rs index 14f3ce4e1daddc8a2be37a3a18729f8ae85572e0..e4c4b2c973457665b2df5288ff71533d0aa1edef 100644 --- a/crates/edit_prediction/src/onboarding_modal.rs +++ b/crates/edit_prediction/src/onboarding_modal.rs @@ -60,7 +60,9 @@ impl ZedPredictModal { EditPredictionOnboarding::new( user_store.clone(), client.clone(), - copilot.is_some_and(|copilot| copilot.read(cx).status().is_configured()), + copilot + .as_ref() + .is_some_and(|copilot| copilot.read(cx).status().is_configured()), Arc::new({ let this = weak_entity.clone(); move |_window, cx| { @@ -75,7 +77,9 @@ impl ZedPredictModal { ZedPredictUpsell::set_dismissed(true, cx); set_edit_prediction_provider(EditPredictionProvider::Copilot, cx); this.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); - copilot_ui::initiate_sign_in(window, cx); + if let Some(copilot) = copilot.clone() { + copilot_ui::initiate_sign_in(copilot, window, cx); + } } }), cx, diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index dde6cc7192b121ff76c63568c5d070105121d6bc..c27492c982454e4db1f00d7227231bd783fd19da 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -2,7 +2,7 @@ use anyhow::Result; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use codestral::CodestralEditPredictionDelegate; -use copilot::{Copilot, Status}; +use copilot::Status; use edit_prediction::{ EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag, }; @@ -124,6 +124,7 @@ impl Render for EditPredictionButton { .on_click(cx.listener(move |_, _, window, cx| { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { + let copilot = copilot.clone(); workspace.show_toast( Toast::new( NotificationId::unique::(), @@ -131,8 +132,12 @@ impl Render for EditPredictionButton { ) .on_click( "Reinstall Copilot", - |window, cx| { - copilot_ui::reinstall_and_sign_in(window, cx) + move |window, cx| { + copilot_ui::reinstall_and_sign_in( + copilot.clone(), + window, + cx, + ) }, ), cx, @@ -489,7 +494,10 @@ impl EditPredictionButton { project: Entity, cx: &mut Context, ) -> Self { - if let Some(copilot) = Copilot::global(cx) { + let copilot = EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| this.start_copilot_for_project(&project, cx)) + }); + if let Some(copilot) = copilot { cx.observe(&copilot, |_, _, cx| cx.notify()).detach() } @@ -638,19 +646,28 @@ impl EditPredictionButton { cx: &mut Context, ) -> Entity { let fs = self.fs.clone(); + let project = self.project.clone(); ContextMenu::build(window, cx, |menu, _, _| { - menu.entry("Sign In to Copilot", None, copilot_ui::initiate_sign_in) - .entry("Disable Copilot", None, { - let fs = fs.clone(); - move |_window, cx| hide_copilot(fs.clone(), cx) - }) - .separator() - .entry("Use Zed AI", None, { - let fs = fs.clone(); - move |_window, cx| { - set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) - } - }) + menu.entry("Sign In to Copilot", None, move |window, cx| { + if let Some(copilot) = EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| { + this.start_copilot_for_project(&project.upgrade()?, cx) + }) + }) { + copilot_ui::initiate_sign_in(copilot, window, cx); + } + }) + .entry("Disable Copilot", None, { + let fs = fs.clone(); + move |_window, cx| hide_copilot(fs.clone(), cx) + }) + .separator() + .entry("Use Zed AI", None, { + let fs = fs.clone(); + move |_window, cx| { + set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) + } + }) }) } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 8c9b84417d33f823da80d221072a766d48bc59ce..9ef9dc10f614b9b4f15d6489410e0a6fb5248789 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use cloud_llm_client::CompletionIntent; use collections::HashMap; -use copilot::{Copilot, Status}; +use copilot::{GlobalCopilotAuth, Status}; use copilot_chat::responses as copilot_responses; use copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, CopilotChatConfiguration, @@ -141,7 +141,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { return Task::ready(Ok(())); }; - let Some(copilot) = Copilot::global(cx) else { + let Some(copilot) = GlobalCopilotAuth::try_global(cx).cloned() else { return Task::ready(Err(anyhow!(concat!( "Copilot must be enabled for Copilot Chat to work. ", "Please enable Copilot and try again." @@ -149,7 +149,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { .into())); }; - let err = match copilot.read(cx).status() { + let err = match copilot.0.read(cx).status() { Status::Authorized => return Task::ready(Ok(())), Status::Disabled => anyhow!( "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index f58cfd3413b1f000f1fe88e0bf27d31fe980d59b..a248fd11c71b92893f8b5849e14286bb5627d924 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -610,7 +610,7 @@ fn main() { cx, ); - copilot_ui::init(cx); + copilot_ui::init(&app_state, cx); supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);