diff --git a/Cargo.lock b/Cargo.lock index f5abe2d4df9f346c35dcb952e231951cd503c5e9..1d43c57a302368647f4f5cb9ccc54a78d4cabaa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3703,6 +3703,7 @@ dependencies = [ "sum_tree", "theme", "util", + "workspace", "zlog", ] @@ -3733,10 +3734,12 @@ dependencies = [ "anyhow", "copilot", "gpui", + "language", "log", "lsp", "menu", "serde_json", + "settings", "ui", "util", "workspace", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 1402bb8d6ffd82d5cd8a2225c8336ee30de3e49e..a07811cd2e3a6b17a1bf23c24ae7af376ef9b37a 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -45,6 +45,7 @@ serde_json.workspace = true settings.workspace = true sum_tree.workspace = true util.workspace = true +workspace.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index c35546b5a9c24816dcb36a86bd73b977aa3f5d29..8177a509bc60ecb459515e02a0f0d9f75d09cdf9 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -12,8 +12,8 @@ use command_palette_hooks::CommandPaletteFilter; use futures::future; use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared, select_biased}; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task, - WeakEntity, actions, + App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Subscription, + Task, WeakEntity, actions, }; use language::language_settings::{AllLanguageSettings, CopilotSettings}; use language::{ @@ -41,6 +41,7 @@ use std::{ }; use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; +use workspace::AppState; pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate; @@ -251,20 +252,50 @@ pub struct Copilot { server: CopilotServer, buffers: HashSet>, server_id: LanguageServerId, - _subscriptions: [gpui::Subscription; 2], + _subscriptions: Vec, } pub enum Event { - CopilotLanguageServerStarted, CopilotAuthSignedIn, CopilotAuthSignedOut, } impl EventEmitter for Copilot {} -struct GlobalCopilot(Entity); +#[derive(Clone)] +pub struct GlobalCopilotAuth(pub Entity); -impl Global for GlobalCopilot {} +impl GlobalCopilotAuth { + pub fn set_global( + server_id: LanguageServerId, + fs: Arc, + node_runtime: NodeRuntime, + cx: &mut App, + ) { + let auth = + GlobalCopilotAuth(cx.new(|cx| Copilot::new(None, server_id, fs, node_runtime, cx))); + cx.set_global(auth); + } + pub fn try_global(cx: &mut App) -> Option<&GlobalCopilotAuth> { + cx.try_global() + } + + pub fn get_or_init(cx: &mut App) -> Option { + if let Some(copilot) = cx.try_global::() { + Some(copilot.clone()) + } else { + let app_state = AppState::global(cx).upgrade()?; + Self::set_global( + app_state.languages.next_language_server_id(), + app_state.fs.clone(), + app_state.node_runtime.clone(), + cx, + ); + cx.try_global::().cloned() + } + } +} +impl Global for GlobalCopilotAuth {} #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) enum CompletionSource { @@ -284,23 +315,14 @@ pub(crate) struct CopilotEditPrediction { } impl Copilot { - pub fn global(cx: &App) -> Option> { - cx.try_global::() - .map(|model| model.0.clone()) - } - - pub fn set_global(copilot: Entity, cx: &mut App) { - cx.set_global(GlobalCopilot(copilot)); - } - pub fn new( - project: Entity, + project: Option>, new_server_id: LanguageServerId, fs: Arc, node_runtime: NodeRuntime, cx: &mut Context, ) -> Self { - let send_focus_notification = + let send_focus_notification = project.map(|project| { cx.subscribe(&project, |this, project, e: &project::Event, cx| { if let project::Event::ActiveEntryChanged(new_entry) = e && let Ok(running) = this.server.as_authenticated() @@ -312,11 +334,40 @@ impl Copilot { _ = running.lsp.notify::(DidFocusParams { uri }); } + }) + }); + let global_authentication_events = + cx.try_global::().cloned().map(|auth| { + cx.subscribe(&auth.0, |_, _, _: &Event, cx| { + cx.spawn(async move |this, cx| { + let Some(server) = this + .update(cx, |this, _| this.language_server().cloned()) + .ok() + .flatten() + else { + return; + }; + let status = server + .request::(request::CheckStatusParams { + local_checks_only: false, + }) + .await + .into_response() + .ok(); + if let Some(status) = status { + this.update(cx, |copilot, cx| { + copilot.update_sign_in_status(status, cx); + }) + .ok(); + } + }) + .detach() + }) }); - let _subscriptions = [ - cx.on_app_quit(Self::shutdown_language_server), - send_focus_notification, - ]; + let _subscriptions = std::iter::once(cx.on_app_quit(Self::shutdown_language_server)) + .chain(send_focus_notification) + .chain(global_authentication_events) + .collect(); let mut this = Self { server_id: new_server_id, fs, @@ -455,7 +506,7 @@ impl Copilot { sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), }), - _subscriptions: [ + _subscriptions: vec![ send_focus_notification, cx.on_app_quit(Self::shutdown_language_server), ], @@ -619,7 +670,6 @@ impl Copilot { }, registered_buffers: Default::default(), }); - cx.emit(Event::CopilotLanguageServerStarted); this.update_sign_in_status(status, cx); } Err(error) => { diff --git a/crates/copilot_ui/Cargo.toml b/crates/copilot_ui/Cargo.toml index 9f2668d9fb40d12631bff6af3291bdb3a40dea15..fbd16c5db0fa936cf39de8b61170f1631335873c 100644 --- a/crates/copilot_ui/Cargo.toml +++ b/crates/copilot_ui/Cargo.toml @@ -23,10 +23,12 @@ test-support = [ anyhow.workspace = true copilot.workspace = true gpui.workspace = true +language.workspace = true log.workspace = true lsp.workspace = true menu.workspace = true serde_json.workspace = true +settings.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs index e22c2800c4beff1debb31aea9ce4ddca811f2bf0..f318acb5d547cd5652d6ad08acfad66a5056894b 100644 --- a/crates/copilot_ui/src/copilot_ui.rs +++ b/crates/copilot_ui/src/copilot_ui.rs @@ -1,25 +1,31 @@ mod sign_in; -use copilot::{Reinstall, SignIn, SignOut}; -use gpui::App; -use workspace::Workspace; +use std::sync::Arc; +use copilot::GlobalCopilotAuth; +use gpui::AppContext; +use language::language_settings::AllLanguageSettings; +use settings::SettingsStore; pub use sign_in::{ ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in, - reinstall_and_sign_in, + initiate_sign_out, reinstall_and_sign_in, }; +use ui::App; +use workspace::AppState; -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_, _: &SignIn, window, cx| { - sign_in::initiate_sign_in(window, cx); - }); - workspace.register_action(|_, _: &Reinstall, window, cx| { - sign_in::reinstall_and_sign_in(window, cx); - }); - workspace.register_action(|_, _: &SignOut, window, cx| { - sign_in::initiate_sign_out(window, cx); - }); - }) - .detach(); +pub fn init(app_state: &Arc, cx: &mut App) { + let provider = cx.read_global(|settings: &SettingsStore, _| { + settings + .get::(None) + .edit_predictions + .provider + }); + if provider == settings::EditPredictionProvider::Copilot { + GlobalCopilotAuth::set_global( + app_state.languages.next_language_server_id(), + app_state.fs.clone(), + app_state.node_runtime.clone(), + cx, + ); + } } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot_ui/src/sign_in.rs index a9eda9d3c9b182fcb880c0d4d9812063578b4d1e..cbc252270e118885f117cfd9667802147bcc9f26 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot_ui/src/sign_in.rs @@ -1,5 +1,8 @@ use anyhow::Context as _; -use copilot::{Copilot, Status, request, request::PromptUserDeviceFlow}; +use copilot::{ + Copilot, GlobalCopilotAuth, Status, + request::{self, PromptUserDeviceFlow}, +}; use gpui::{ App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled, @@ -15,16 +18,12 @@ const ERROR_LABEL: &str = struct CopilotStatusToast; -pub fn initiate_sign_in(window: &mut Window, cx: &mut App) { +pub fn initiate_sign_in(copilot: Entity, window: &mut Window, cx: &mut App) { let is_reinstall = false; - initiate_sign_in_impl(is_reinstall, window, cx) + initiate_sign_in_impl(copilot, is_reinstall, window, cx) } -pub fn initiate_sign_out(window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - +pub fn initiate_sign_out(copilot: Entity, window: &mut Window, cx: &mut App) { copilot_toast(Some("Signing out of Copilot…"), window, cx); let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx)); @@ -46,13 +45,10 @@ pub fn initiate_sign_out(window: &mut Window, cx: &mut App) { .detach(); } -pub fn reinstall_and_sign_in(window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; +pub fn reinstall_and_sign_in(copilot: Entity, window: &mut Window, cx: &mut App) { let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx)); let is_reinstall = true; - initiate_sign_in_impl(is_reinstall, window, cx); + initiate_sign_in_impl(copilot, is_reinstall, window, cx); } fn open_copilot_code_verification_window(copilot: &Entity, window: &Window, cx: &mut App) { @@ -96,10 +92,12 @@ fn copilot_toast(message: Option<&'static str>, window: &Window, cx: &mut App) { }) } -pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut App) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; +pub fn initiate_sign_in_impl( + copilot: Entity, + is_reinstall: bool, + window: &mut Window, + cx: &mut App, +) { if matches!(copilot.read(cx).status(), Status::Disabled) { copilot.update(cx, |copilot, cx| copilot.start_copilot(false, true, cx)); } @@ -118,21 +116,16 @@ pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut A window .spawn(cx, async move |cx| { task.await; - cx.update(|window, cx| { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - match copilot.read(cx).status() { - Status::Authorized => { - copilot_toast(Some("Copilot has started."), window, cx) - } - _ => { - copilot_toast(None, window, cx); - copilot - .update(cx, |copilot, cx| copilot.sign_in(cx)) - .detach_and_log_err(cx); - open_copilot_code_verification_window(&copilot, window, cx); - } + cx.update(|window, cx| match copilot.read(cx).status() { + Status::Authorized => { + copilot_toast(Some("Copilot has started."), window, cx) + } + _ => { + copilot_toast(None, window, cx); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + open_copilot_code_verification_window(&copilot, window, cx); } }) .log_err(); @@ -237,6 +230,7 @@ impl CopilotCodeVerification { } fn render_prompting_modal( + copilot: Entity, connect_clicked: bool, data: &PromptUserDeviceFlow, cx: &mut Context, @@ -274,47 +268,44 @@ impl CopilotCodeVerification { .on_click({ let command = data.command.clone(); cx.listener(move |this, _, _window, cx| { - if let Some(copilot) = Copilot::global(cx) { - let command = command.clone(); - let copilot_clone = copilot.clone(); - copilot.update(cx, |copilot, cx| { - if let Some(server) = copilot.language_server() { - let server = server.clone(); - cx.spawn(async move |_, cx| { - let result = server - .request::( - lsp::ExecuteCommandParams { - command: command.command.clone(), - arguments: command - .arguments - .clone() - .unwrap_or_default(), - ..Default::default() - }, - ) - .await - .into_response() - .ok() - .flatten(); - if let Some(value) = result { - if let Ok(status) = - serde_json::from_value::< - request::SignInStatus, - >(value) - { - copilot_clone - .update(cx, |copilot, cx| { - copilot.update_sign_in_status( - status, cx, - ); - }); - } + let command = command.clone(); + let copilot_clone = copilot.clone(); + copilot.update(cx, |copilot, cx| { + if let Some(server) = copilot.language_server() { + let server = server.clone(); + cx.spawn(async move |_, cx| { + let result = server + .request::( + lsp::ExecuteCommandParams { + command: command.command.clone(), + arguments: command + .arguments + .clone() + .unwrap_or_default(), + ..Default::default() + }, + ) + .await + .into_response() + .ok() + .flatten(); + if let Some(value) = result { + if let Ok(status) = serde_json::from_value::< + request::SignInStatus, + >( + value + ) { + copilot_clone.update(cx, |copilot, cx| { + copilot + .update_sign_in_status(status, cx); + }); } - }) - .detach(); - } - }); - } + } + }) + .detach(); + } + }); + this.connect_clicked = true; }) }), @@ -378,7 +369,7 @@ impl CopilotCodeVerification { ) } - fn render_error_modal(_cx: &mut Context) -> impl Element { + fn render_error_modal(copilot: Entity, _cx: &mut Context) -> impl Element { v_flex() .gap_2() .text_center() @@ -394,7 +385,9 @@ impl CopilotCodeVerification { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| reinstall_and_sign_in(window, cx)), + .on_click(move |_, window, cx| { + reinstall_and_sign_in(copilot.clone(), window, cx) + }), ) } @@ -420,7 +413,10 @@ impl Render for CopilotCodeVerification { .into_any_element(), Status::SigningIn { prompt: Some(prompt), - } => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(), + } => { + Self::render_prompting_modal(self.copilot.clone(), self.connect_clicked, prompt, cx) + .into_any_element() + } Status::Unauthorized => { self.connect_clicked = false; self.render_unauthorized_modal(cx).into_any_element() @@ -429,7 +425,9 @@ impl Render for CopilotCodeVerification { self.connect_clicked = false; Self::render_enabled_modal(cx).into_any_element() } - Status::Error(..) => Self::render_error_modal(cx).into_any_element(), + Status::Error(..) => { + Self::render_error_modal(self.copilot.clone(), cx).into_any_element() + } _ => div().into_any_element(), }; @@ -475,14 +473,14 @@ impl ConfigurationView { mode: ConfigurationMode, cx: &mut Context, ) -> Self { - let copilot = Copilot::global(cx); + let copilot = GlobalCopilotAuth::try_global(cx).cloned(); Self { - copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), + copilot_status: copilot.as_ref().map(|copilot| copilot.0.read(cx).status()), is_authenticated: Box::new(is_authenticated), edit_prediction: matches!(mode, ConfigurationMode::EditPrediction), _subscription: copilot.as_ref().map(|copilot| { - cx.observe(copilot, |this, model, cx| { + cx.observe(&copilot.0, |this, model, cx| { this.copilot_status = Some(model.read(cx).status()); cx.notify(); }) @@ -568,7 +566,11 @@ impl ConfigurationView { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| initiate_sign_in(window, cx)) + .on_click(|_, window, cx| { + if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { + initiate_sign_in(copilot.0, window, cx) + } + }) } fn render_reinstall_button(&self, edit_prediction: bool) -> impl IntoElement { @@ -591,7 +593,11 @@ impl ConfigurationView { .icon_color(Color::Muted) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) - .on_click(|_, window, cx| reinstall_and_sign_in(window, cx)) + .on_click(|_, window, cx| { + if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { + reinstall_and_sign_in(copilot.0, window, cx) + } + }) } fn render_for_edit_prediction(&self) -> impl IntoElement { @@ -684,7 +690,9 @@ impl Render for ConfigurationView { return ConfiguredApiCard::new("Authorized") .button_label("Sign Out") .on_click(|_, window, cx| { - initiate_sign_out(window, cx); + if let Some(auth) = GlobalCopilotAuth::try_global(cx) { + initiate_sign_out(auth.0.clone(), window, cx); + } }) .into_any_element(); } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 5813cd40e9bc32463c172f3ad8bfef9299be0d49..098136b42f2c92ddb80a43a46bb29ed7518aff34 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -10,7 +10,7 @@ use cloud_llm_client::{ PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet}; -use copilot::Copilot; +use copilot::{Copilot, Reinstall, SignIn, SignOut}; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; @@ -754,7 +754,7 @@ impl EditPredictionStore { let next_id = project.languages().next_language_server_id(); let fs = project.fs().clone(); - let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx)); + let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx)); state.copilot = Some(copilot.clone()); Some(copilot) } else { @@ -2332,6 +2332,27 @@ pub fn init(cx: &mut App) { .edit_prediction_provider = Some(EditPredictionProvider::None) }); }); + fn copilot_for_project(project: &Entity, cx: &mut App) -> Option> { + EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| this.start_copilot_for_project(project, cx)) + }) + } + + workspace.register_action(|workspace, _: &SignIn, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::initiate_sign_in(copilot, window, cx); + } + }); + workspace.register_action(|workspace, _: &Reinstall, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::reinstall_and_sign_in(copilot, window, cx); + } + }); + workspace.register_action(|workspace, _: &SignOut, window, cx| { + if let Some(copilot) = copilot_for_project(workspace.project(), cx) { + copilot_ui::initiate_sign_out(copilot, window, cx); + } + }); }) .detach(); } diff --git a/crates/edit_prediction/src/onboarding_modal.rs b/crates/edit_prediction/src/onboarding_modal.rs index 14f3ce4e1daddc8a2be37a3a18729f8ae85572e0..e4c4b2c973457665b2df5288ff71533d0aa1edef 100644 --- a/crates/edit_prediction/src/onboarding_modal.rs +++ b/crates/edit_prediction/src/onboarding_modal.rs @@ -60,7 +60,9 @@ impl ZedPredictModal { EditPredictionOnboarding::new( user_store.clone(), client.clone(), - copilot.is_some_and(|copilot| copilot.read(cx).status().is_configured()), + copilot + .as_ref() + .is_some_and(|copilot| copilot.read(cx).status().is_configured()), Arc::new({ let this = weak_entity.clone(); move |_window, cx| { @@ -75,7 +77,9 @@ impl ZedPredictModal { ZedPredictUpsell::set_dismissed(true, cx); set_edit_prediction_provider(EditPredictionProvider::Copilot, cx); this.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); - copilot_ui::initiate_sign_in(window, cx); + if let Some(copilot) = copilot.clone() { + copilot_ui::initiate_sign_in(copilot, window, cx); + } } }), cx, diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index dde6cc7192b121ff76c63568c5d070105121d6bc..c27492c982454e4db1f00d7227231bd783fd19da 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -2,7 +2,7 @@ use anyhow::Result; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use codestral::CodestralEditPredictionDelegate; -use copilot::{Copilot, Status}; +use copilot::Status; use edit_prediction::{ EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag, }; @@ -124,6 +124,7 @@ impl Render for EditPredictionButton { .on_click(cx.listener(move |_, _, window, cx| { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { + let copilot = copilot.clone(); workspace.show_toast( Toast::new( NotificationId::unique::(), @@ -131,8 +132,12 @@ impl Render for EditPredictionButton { ) .on_click( "Reinstall Copilot", - |window, cx| { - copilot_ui::reinstall_and_sign_in(window, cx) + move |window, cx| { + copilot_ui::reinstall_and_sign_in( + copilot.clone(), + window, + cx, + ) }, ), cx, @@ -489,7 +494,10 @@ impl EditPredictionButton { project: Entity, cx: &mut Context, ) -> Self { - if let Some(copilot) = Copilot::global(cx) { + let copilot = EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| this.start_copilot_for_project(&project, cx)) + }); + if let Some(copilot) = copilot { cx.observe(&copilot, |_, _, cx| cx.notify()).detach() } @@ -638,19 +646,28 @@ impl EditPredictionButton { cx: &mut Context, ) -> Entity { let fs = self.fs.clone(); + let project = self.project.clone(); ContextMenu::build(window, cx, |menu, _, _| { - menu.entry("Sign In to Copilot", None, copilot_ui::initiate_sign_in) - .entry("Disable Copilot", None, { - let fs = fs.clone(); - move |_window, cx| hide_copilot(fs.clone(), cx) - }) - .separator() - .entry("Use Zed AI", None, { - let fs = fs.clone(); - move |_window, cx| { - set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) - } - }) + menu.entry("Sign In to Copilot", None, move |window, cx| { + if let Some(copilot) = EditPredictionStore::try_global(cx).and_then(|store| { + store.update(cx, |this, cx| { + this.start_copilot_for_project(&project.upgrade()?, cx) + }) + }) { + copilot_ui::initiate_sign_in(copilot, window, cx); + } + }) + .entry("Disable Copilot", None, { + let fs = fs.clone(); + move |_window, cx| hide_copilot(fs.clone(), cx) + }) + .separator() + .entry("Use Zed AI", None, { + let fs = fs.clone(); + move |_window, cx| { + set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) + } + }) }) } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 8c9b84417d33f823da80d221072a766d48bc59ce..9ef9dc10f614b9b4f15d6489410e0a6fb5248789 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use cloud_llm_client::CompletionIntent; use collections::HashMap; -use copilot::{Copilot, Status}; +use copilot::{GlobalCopilotAuth, Status}; use copilot_chat::responses as copilot_responses; use copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, CopilotChatConfiguration, @@ -141,7 +141,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { return Task::ready(Ok(())); }; - let Some(copilot) = Copilot::global(cx) else { + let Some(copilot) = GlobalCopilotAuth::try_global(cx).cloned() else { return Task::ready(Err(anyhow!(concat!( "Copilot must be enabled for Copilot Chat to work. ", "Please enable Copilot and try again." @@ -149,7 +149,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { .into())); }; - let err = match copilot.read(cx).status() { + let err = match copilot.0.read(cx).status() { Status::Authorized => return Task::ready(Ok(())), Status::Disabled => anyhow!( "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index f58cfd3413b1f000f1fe88e0bf27d31fe980d59b..a248fd11c71b92893f8b5849e14286bb5627d924 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -610,7 +610,7 @@ fn main() { cx, ); - copilot_ui::init(cx); + copilot_ui::init(&app_state, cx); supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);