diff --git a/Cargo.lock b/Cargo.lock index b139b679dcc981f2c4a19de935cdeaa55f2a3ccd..f8a3a2bbfde31dbe2cf8cd0993915ee6f8cbf7d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3669,13 +3669,12 @@ version = "0.1.0" dependencies = [ "anyhow", "async-std", - "chrono", "client", "clock", "collections", "command_palette_hooks", + "copilot_chat", "ctor", - "dirs 4.0.0", "edit_prediction_types", "editor", "fs", @@ -3683,11 +3682,9 @@ dependencies = [ "gpui", "http_client", "indoc", - "itertools 0.14.0", "language", "log", "lsp", - "menu", "node_runtime", "parking_lot", "paths", @@ -3698,13 +3695,45 @@ dependencies = [ "serde_json", "settings", "sum_tree", - "task", "theme", + "util", + "zlog", +] + +[[package]] +name = "copilot_chat" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "collections", + "dirs 4.0.0", + "fs", + "futures 0.3.31", + "gpui", + "http_client", + "itertools 0.14.0", + "log", + "paths", + "serde", + "serde_json", + "settings", +] + +[[package]] +name = "copilot_ui" +version = "0.1.0" +dependencies = [ + "anyhow", + "copilot", + "gpui", + "log", + "lsp", + "menu", + "serde_json", "ui", - "url", "util", "workspace", - "zlog", ] [[package]] @@ -5199,6 +5228,7 @@ dependencies = [ "cloud_llm_client", "collections", "copilot", + "copilot_ui", "ctor", "db", "edit_prediction_context", @@ -5349,6 +5379,8 @@ dependencies = [ "collections", "command_palette_hooks", "copilot", + "copilot_chat", + "copilot_ui", "edit_prediction", "edit_prediction_types", "editor", @@ -8960,6 +8992,8 @@ dependencies = [ "component", "convert_case 0.8.0", "copilot", + "copilot_chat", + "copilot_ui", "credentials_provider", "deepseek", "editor", @@ -9041,7 +9075,7 @@ dependencies = [ "client", "collections", "command_palette_hooks", - "copilot", + "edit_prediction", "editor", "futures 0.3.31", "gpui", @@ -14939,7 +14973,7 @@ dependencies = [ "assets", "bm25", "client", - "copilot", + "copilot_ui", "edit_prediction", "editor", "feature_flags", @@ -20732,6 +20766,8 @@ dependencies = [ "component", "component_preview", "copilot", + "copilot_chat", + "copilot_ui", "crashes", "dap", "dap_adapters", diff --git a/Cargo.toml b/Cargo.toml index 1b76ca469c70c8bf5fd6d272d32cd550a5908d0f..0331b9aa71d1a01e97ca31f823d3af8bae6f015c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ members = [ "crates/component_preview", "crates/context_server", "crates/copilot", + "crates/copilot_chat", "crates/crashes", "crates/credentials_provider", "crates/dap", @@ -280,6 +281,8 @@ component = { path = "crates/component" } component_preview = { path = "crates/component_preview" } context_server = { path = "crates/context_server" } copilot = { path = "crates/copilot" } +copilot_chat = { path = "crates/copilot_chat" } +copilot_ui = { path = "crates/copilot_ui" } crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } crossbeam = "0.8.4" diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 3a1706a7a679fbc14eafbeac953d842cda9f65c8..1402bb8d6ffd82d5cd8a2225c8336ee30de3e49e 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -25,19 +25,16 @@ test-support = [ [dependencies] anyhow.workspace = true -chrono.workspace = true collections.workspace = true command_palette_hooks.workspace = true -dirs.workspace = true +copilot_chat.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true -http_client.workspace = true edit_prediction_types.workspace = true language.workspace = true log.workspace = true lsp.workspace = true -menu.workspace = true node_runtime.workspace = true parking_lot.workspace = true paths.workspace = true @@ -47,12 +44,7 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true sum_tree.workspace = true -task.workspace = true -ui.workspace = true util.workspace = true -workspace.workspace = true -itertools.workspace = true -url.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } @@ -76,5 +68,4 @@ serde_json.workspace = true settings = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } -workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index c86c249a6788027ef2550390a60e91529a222a20..759c2d9d53496b1f9a2313c5a9ea9e2ef8acbb1b 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,21 +1,19 @@ -pub mod copilot_chat; mod copilot_edit_prediction_delegate; -pub mod copilot_responses; pub mod request; -mod sign_in; -use crate::request::NextEditSuggestions; -use crate::sign_in::initiate_sign_out; +use crate::request::{ + DidFocus, DidFocusParams, FormattingOptions, InlineCompletionContext, + InlineCompletionTriggerKind, InlineCompletions, NextEditSuggestions, +}; use ::fs::Fs; use anyhow::{Context as _, Result, anyhow}; use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; -use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared}; +use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared, select_biased}; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task, WeakEntity, actions, }; -use http_client::HttpClient; use language::language_settings::CopilotSettings; use language::{ Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16, @@ -25,8 +23,8 @@ use language::{ use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; use node_runtime::{NodeRuntime, VersionStrategy}; use parking_lot::Mutex; -use project::DisableAiSettings; -use request::StatusNotification; +use project::{DisableAiSettings, Project}; +use request::DidChangeStatus; use semver::Version; use serde_json::json; use settings::{Settings, SettingsStore}; @@ -42,13 +40,8 @@ use std::{ }; use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; -use workspace::Workspace; pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate; -pub use crate::sign_in::{ - ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in, - reinstall_and_sign_in, -}; actions!( copilot, @@ -68,50 +61,6 @@ actions!( ] ); -pub fn init( - new_server_id: LanguageServerId, - fs: Arc, - http: Arc, - node_runtime: NodeRuntime, - cx: &mut App, -) { - let language_settings = all_language_settings(None, cx); - let configuration = copilot_chat::CopilotChatConfiguration { - enterprise_uri: language_settings - .edit_predictions - .copilot - .enterprise_uri - .clone(), - }; - copilot_chat::init(fs.clone(), http.clone(), configuration, cx); - - let copilot = cx.new(move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)); - Copilot::set_global(copilot.clone(), cx); - cx.observe(&copilot, |copilot, cx| { - copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); - }) - .detach(); - cx.observe_global::(|cx| { - if let Some(copilot) = Copilot::global(cx) { - copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); - } - }) - .detach(); - - cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_, _: &SignIn, window, cx| { - initiate_sign_in(window, cx); - }); - workspace.register_action(|_, _: &Reinstall, window, cx| { - reinstall_and_sign_in(window, cx); - }); - workspace.register_action(|_, _: &SignOut, window, cx| { - initiate_sign_out(window, cx); - }); - }) - .detach(); -} - enum CopilotServer { Disabled, Starting { task: Shared> }, @@ -301,7 +250,7 @@ pub struct Copilot { server: CopilotServer, buffers: HashSet>, server_id: LanguageServerId, - _subscription: gpui::Subscription, + _subscriptions: [gpui::Subscription; 2], } pub enum Event { @@ -316,13 +265,21 @@ struct GlobalCopilot(Entity); impl Global for GlobalCopilot {} +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum CompletionSource { + NextEditSuggestion, + InlineCompletion, +} + /// Copilot's NextEditSuggestion response, with coordinates converted to Anchors. -struct CopilotEditPrediction { - buffer: Entity, - range: Range, - text: String, - command: Option, - snapshot: BufferSnapshot, +#[derive(Clone)] +pub(crate) struct CopilotEditPrediction { + pub(crate) buffer: Entity, + pub(crate) range: Range, + pub(crate) text: String, + pub(crate) command: Option, + pub(crate) snapshot: BufferSnapshot, + pub(crate) source: CompletionSource, } impl Copilot { @@ -335,19 +292,37 @@ impl Copilot { cx.set_global(GlobalCopilot(copilot)); } - fn start( + pub fn new( + project: Entity, new_server_id: LanguageServerId, fs: Arc, node_runtime: NodeRuntime, cx: &mut Context, ) -> Self { + let send_focus_notification = + cx.subscribe(&project, |this, project, e: &project::Event, cx| { + if let project::Event::ActiveEntryChanged(new_entry) = e + && let Ok(running) = this.server.as_authenticated() + { + let uri = new_entry + .and_then(|id| project.read(cx).path_for_entry(id, cx)) + .and_then(|entry| project.read(cx).absolute_path(&entry, cx)) + .and_then(|abs_path| lsp::Uri::from_file_path(abs_path).ok()); + + _ = running.lsp.notify::(DidFocusParams { uri }); + } + }); + let _subscriptions = [ + cx.on_app_quit(Self::shutdown_language_server), + send_focus_notification, + ]; let mut this = Self { server_id: new_server_id, fs, node_runtime, server: CopilotServer::Disabled, buffers: Default::default(), - _subscription: cx.on_app_quit(Self::shutdown_language_server), + _subscriptions, }; this.start_copilot(true, false, cx); cx.observe_global::(move |this, cx| { @@ -357,6 +332,11 @@ impl Copilot { .context("copilot setting change: did change configuration") .log_err(); } + this.update_action_visibilities(cx); + }) + .detach(); + cx.observe_self(|copilot, cx| { + copilot.update_action_visibilities(cx); }) .detach(); this @@ -448,6 +428,7 @@ impl Copilot { #[cfg(any(test, feature = "test-support"))] pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity, lsp::FakeLanguageServer) { use fs::FakeFs; + use gpui::Subscription; use lsp::FakeLanguageServer; use node_runtime::NodeRuntime; @@ -463,6 +444,7 @@ impl Copilot { &mut cx.to_async(), ); let node_runtime = NodeRuntime::unavailable(); + let send_focus_notification = Subscription::new(|| {}); let this = cx.new(|cx| Self { server_id: LanguageServerId(0), fs: FakeFs::new(cx.background_executor().clone()), @@ -472,7 +454,10 @@ impl Copilot { sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), }), - _subscription: cx.on_app_quit(Self::shutdown_language_server), + _subscriptions: [ + send_focus_notification, + cx.on_app_quit(Self::shutdown_language_server), + ], buffers: Default::default(), }); (this, fake_server) @@ -522,7 +507,51 @@ impl Copilot { )?; server - .on_notification::(|_, _| { /* Silence the notification */ }) + .on_notification::({ + let this = this.clone(); + move |params, cx| { + if params.kind == request::StatusKind::Normal { + let this = this.clone(); + cx.spawn(async move |cx| { + let lsp = this + .read_with(cx, |copilot, _| { + if let CopilotServer::Running(server) = &copilot.server { + Some(server.lsp.clone()) + } else { + None + } + }) + .ok() + .flatten(); + let Some(lsp) = lsp else { return }; + let status = lsp + .request::(request::CheckStatusParams { + local_checks_only: false, + }) + .await + .into_response() + .ok(); + if let Some(status) = status { + this.update(cx, |copilot, cx| { + copilot.update_sign_in_status(status, cx); + }) + .ok(); + } + }) + .detach(); + } + } + }) + .detach(); + + server + .on_request::(move |params, cx| { + if params.external.unwrap_or(false) { + let url = params.uri.to_string(); + cx.update(|cx| cx.open_url(&url)); + } + async move { Ok(lsp::ShowDocumentResult { success: true }) } + }) .detach(); let configuration = lsp::DidChangeConfigurationParams { @@ -545,6 +574,12 @@ impl Copilot { .update(|cx| { let mut params = server.default_initialize_params(false, cx); params.initialization_options = Some(editor_info_json); + params + .capabilities + .window + .get_or_insert_with(Default::default) + .show_document = + Some(lsp::ShowDocumentClientCapabilities { support: true }); server.initialize(params, configuration.into(), cx) }) .await?; @@ -615,55 +650,37 @@ impl Copilot { } SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized => { let lsp = server.lsp.clone(); + let task = cx .spawn(async move |this, cx| { let sign_in = async { - let sign_in = lsp - .request::( - request::SignInInitiateParams {}, - ) + let flow = lsp + .request::(request::SignInParams {}) .await .into_response() .context("copilot sign-in")?; - match sign_in { - request::SignInInitiateResult::AlreadySignedIn { user } => { - Ok(request::SignInStatus::Ok { user: Some(user) }) - } - request::SignInInitiateResult::PromptUserDeviceFlow(flow) => { - this.update(cx, |this, cx| { - if let CopilotServer::Running(RunningCopilotServer { - sign_in_status: status, - .. - }) = &mut this.server - && let SignInStatus::SigningIn { - prompt: prompt_flow, - .. - } = status - { - *prompt_flow = Some(flow.clone()); - cx.notify(); - } - })?; - let response = lsp - .request::( - request::SignInConfirmParams { - user_code: flow.user_code, - }, - ) - .await - .into_response() - .context("copilot: sign in confirm")?; - Ok(response) + + this.update(cx, |this, cx| { + if let CopilotServer::Running(RunningCopilotServer { + sign_in_status: status, + .. + }) = &mut this.server + && let SignInStatus::SigningIn { + prompt: prompt_flow, + .. + } = status + { + *prompt_flow = Some(flow.clone()); + cx.notify(); } - } + })?; + + anyhow::Ok(()) }; let sign_in = sign_in.await; this.update(cx, |this, cx| match sign_in { - Ok(status) => { - this.update_sign_in_status(status, cx); - Ok(()) - } + Ok(()) => Ok(()), Err(error) => { this.update_sign_in_status( request::SignInStatus::NotSignedIn, @@ -691,7 +708,7 @@ impl Copilot { } } - pub(crate) fn sign_out(&mut self, cx: &mut Context) -> Task> { + pub fn sign_out(&mut self, cx: &mut Context) -> Task> { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); match &self.server { CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => { @@ -713,7 +730,7 @@ impl Copilot { } } - pub(crate) fn reinstall(&mut self, cx: &mut Context) -> Shared> { + pub fn reinstall(&mut self, cx: &mut Context) -> Shared> { let language_settings = all_language_settings(None, cx); let env = self.build_env(&language_settings.edit_predictions.copilot); let start_task = cx @@ -901,39 +918,127 @@ impl Copilot { .registered_buffers .get_mut(&buffer.entity_id()) .unwrap(); - let snapshot = registered_buffer.report_changes(buffer, cx); + let pending_snapshot = registered_buffer.report_changes(buffer, cx); let buffer = buffer.read(cx); let uri = registered_buffer.uri.clone(); let position = position.to_point_utf16(buffer); + let snapshot = buffer.snapshot(); + let settings = snapshot.settings_at(0, cx); + let tab_size = settings.tab_size.get(); + let hard_tabs = settings.hard_tabs; + drop(settings); cx.background_spawn(async move { - let (version, snapshot) = snapshot.await?; - let result = lsp + let (version, snapshot) = pending_snapshot.await?; + let lsp_position = point_to_lsp(position); + + let nes_request = lsp .request::(request::NextEditSuggestionsParams { - text_document: lsp::VersionedTextDocumentIdentifier { uri, version }, - position: point_to_lsp(position), + text_document: lsp::VersionedTextDocumentIdentifier { + uri: uri.clone(), + version, + }, + position: lsp_position, }) - .await - .into_response() - .context("copilot: get completions")?; - let completions = result - .edits - .into_iter() - .map(|completion| { - let start = snapshot - .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left); - let end = - snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); - CopilotEditPrediction { - buffer: buffer_entity.clone(), - range: snapshot.anchor_before(start)..snapshot.anchor_after(end), - text: completion.text, - command: completion.command, - snapshot: snapshot.clone(), - } + .fuse(); + + let inline_request = lsp + .request::(request::InlineCompletionsParams { + text_document: lsp::VersionedTextDocumentIdentifier { + uri: uri.clone(), + version, + }, + position: lsp_position, + context: InlineCompletionContext { + trigger_kind: InlineCompletionTriggerKind::Automatic, + }, + formatting_options: Some(FormattingOptions { + tab_size, + insert_spaces: !hard_tabs, + }), }) - .collect(); - anyhow::Ok(completions) + .fuse(); + + futures::pin_mut!(nes_request, inline_request); + + let convert_nes = + |result: request::NextEditSuggestionsResult| -> Vec { + result + .edits + .into_iter() + .map(|completion| { + let start = snapshot.clip_point_utf16( + point_from_lsp(completion.range.start), + Bias::Left, + ); + let end = snapshot + .clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); + CopilotEditPrediction { + buffer: buffer_entity.clone(), + range: snapshot.anchor_before(start)..snapshot.anchor_after(end), + text: completion.text, + command: completion.command, + snapshot: snapshot.clone(), + source: CompletionSource::NextEditSuggestion, + } + }) + .collect() + }; + + let convert_inline = + |result: request::InlineCompletionsResult| -> Vec { + result + .items + .into_iter() + .map(|item| { + let start = snapshot + .clip_point_utf16(point_from_lsp(item.range.start), Bias::Left); + let end = snapshot + .clip_point_utf16(point_from_lsp(item.range.end), Bias::Left); + CopilotEditPrediction { + buffer: buffer_entity.clone(), + range: snapshot.anchor_before(start)..snapshot.anchor_after(end), + text: item.insert_text, + command: item.command, + snapshot: snapshot.clone(), + source: CompletionSource::InlineCompletion, + } + }) + .collect() + }; + + let mut nes_result: Option> = None; + let mut inline_result: Option> = None; + + loop { + select_biased! { + nes = nes_request => { + let completions = nes.into_response().ok().map(convert_nes).unwrap_or_default(); + if !completions.is_empty() { + return Ok(completions); + } + nes_result = Some(completions); + } + inline = inline_request => { + let completions = inline.into_response().ok().map(convert_inline).unwrap_or_default(); + if !completions.is_empty() && nes_result.is_some() { + return Ok(completions); + } + inline_result = Some(completions); + } + complete => break, + } + + if let (Some(nes), Some(inline)) = (&nes_result, &inline_result) { + return if !nes.is_empty() { + Ok(nes.clone()) + } else { + Ok(inline.clone()) + }; + } + } + + Ok(nes_result.or(inline_result).unwrap_or_default()) }) } @@ -988,7 +1093,11 @@ impl Copilot { } } - fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context) { + pub fn update_sign_in_status( + &mut self, + lsp_status: request::SignInStatus, + cx: &mut Context, + ) { self.buffers.retain(|buffer| buffer.is_upgradable()); if let Ok(server) = self.server.as_running() { @@ -1320,9 +1429,14 @@ mod tests { ); // Ensure all previously-registered buffers are re-opened when signing in. - lsp.set_request_handler::(|_, _| async { - Ok(request::SignInInitiateResult::AlreadySignedIn { - user: "user-1".into(), + lsp.set_request_handler::(|_, _| async { + Ok(request::PromptUserDeviceFlow { + user_code: "test-code".into(), + command: lsp::Command { + title: "Sign in".into(), + command: "github.copilot.finishDeviceFlow".into(), + arguments: None, + }, }) }); copilot @@ -1330,6 +1444,16 @@ mod tests { .await .unwrap(); + // Simulate auth completion by directly updating sign-in status + copilot.update(cx, |copilot, cx| { + copilot.update_sign_in_status( + request::SignInStatus::Ok { + user: Some("user-1".into()), + }, + cx, + ); + }); + assert_eq!( lsp.receive_notification::() .await, diff --git a/crates/copilot/src/copilot_edit_prediction_delegate.rs b/crates/copilot/src/copilot_edit_prediction_delegate.rs index fe26979f655418f74efc29c6c0ad0757895261ef..ffd4414a49066175cc58ad1c59dacb8d31a94bff 100644 --- a/crates/copilot/src/copilot_edit_prediction_delegate.rs +++ b/crates/copilot/src/copilot_edit_prediction_delegate.rs @@ -1,8 +1,14 @@ -use crate::{Copilot, CopilotEditPrediction}; +use crate::{ + CompletionSource, Copilot, CopilotEditPrediction, + request::{ + DidShowCompletion, DidShowCompletionParams, DidShowInlineEdit, DidShowInlineEditParams, + InlineCompletionItem, + }, +}; use anyhow::Result; use edit_prediction_types::{EditPrediction, EditPredictionDelegate, interpolate_edits}; use gpui::{App, Context, Entity, Task}; -use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt}; +use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToPointUtf16}; use std::{ops::Range, sync::Arc, time::Duration}; pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); @@ -137,7 +143,37 @@ impl EditPredictionDelegate for CopilotEditPredictionDelegate { )]; let edits = interpolate_edits(&completion.snapshot, &buffer.snapshot(), &edits) .filter(|edits| !edits.is_empty())?; - + self.copilot.update(cx, |this, _| { + if let Ok(server) = this.server.as_authenticated() { + match completion.source { + CompletionSource::NextEditSuggestion => { + if let Some(cmd) = completion.command.as_ref() { + _ = server + .lsp + .notify::(DidShowInlineEditParams { + item: serde_json::json!({"command": {"arguments": cmd.arguments}}), + }); + } + } + CompletionSource::InlineCompletion => { + _ = server.lsp.notify::(DidShowCompletionParams { + item: InlineCompletionItem { + insert_text: completion.text.clone(), + range: lsp::Range::new( + language::point_to_lsp( + completion.range.start.to_point_utf16(&completion.snapshot), + ), + language::point_to_lsp( + completion.range.end.to_point_utf16(&completion.snapshot), + ), + ), + command: completion.command.clone(), + }, + }); + } + } + } + }); Some(EditPrediction::Local { id: None, edits, diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 2f97fb72a42904b1fefdd3999f680fca12559ecd..a8739438f3708e4bd003a0893f22430547ff7884 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -1,4 +1,4 @@ -use lsp::VersionedTextDocumentIdentifier; +use lsp::{Uri, VersionedTextDocumentIdentifier}; use serde::{Deserialize, Serialize}; pub enum CheckStatus {} @@ -15,37 +15,22 @@ impl lsp::request::Request for CheckStatus { const METHOD: &'static str = "checkStatus"; } -pub enum SignInInitiate {} +pub enum SignIn {} #[derive(Debug, Serialize, Deserialize)] -pub struct SignInInitiateParams {} +pub struct SignInParams {} -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "status")] -pub enum SignInInitiateResult { - AlreadySignedIn { user: String }, - PromptUserDeviceFlow(PromptUserDeviceFlow), -} - -#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptUserDeviceFlow { pub user_code: String, - pub verification_uri: String, + pub command: lsp::Command, } -impl lsp::request::Request for SignInInitiate { - type Params = SignInInitiateParams; - type Result = SignInInitiateResult; - const METHOD: &'static str = "signInInitiate"; -} - -pub enum SignInConfirm {} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SignInConfirmParams { - pub user_code: String, +impl lsp::request::Request for SignIn { + type Params = SignInParams; + type Result = PromptUserDeviceFlow; + const METHOD: &'static str = "signIn"; } #[derive(Debug, Serialize, Deserialize)] @@ -67,12 +52,6 @@ pub enum SignInStatus { NotSignedIn, } -impl lsp::request::Request for SignInConfirm { - type Params = SignInConfirmParams; - type Result = SignInStatus; - const METHOD: &'static str = "signInConfirm"; -} - pub enum SignOut {} #[derive(Debug, Serialize, Deserialize)] @@ -89,17 +68,26 @@ impl lsp::request::Request for SignOut { const METHOD: &'static str = "signOut"; } -pub enum StatusNotification {} +pub enum DidChangeStatus {} #[derive(Debug, Serialize, Deserialize)] -pub struct StatusNotificationParams { - pub message: String, - pub status: String, // One of Normal/InProgress +pub struct DidChangeStatusParams { + #[serde(default)] + pub message: Option, + pub kind: StatusKind, } -impl lsp::notification::Notification for StatusNotification { - type Params = StatusNotificationParams; - const METHOD: &'static str = "statusNotification"; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum StatusKind { + Normal, + Error, + Warning, + Inactive, +} + +impl lsp::notification::Notification for DidChangeStatus { + type Params = DidChangeStatusParams; + const METHOD: &'static str = "didChangeStatus"; } pub enum SetEditorInfo {} @@ -191,3 +179,121 @@ impl lsp::request::Request for NextEditSuggestions { const METHOD: &'static str = "textDocument/copilotInlineEdit"; } + +pub(crate) struct DidFocus; + +#[derive(Serialize, Deserialize)] +pub(crate) struct DidFocusParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) uri: Option, +} + +impl lsp::notification::Notification for DidFocus { + type Params = DidFocusParams; + const METHOD: &'static str = "textDocument/didFocus"; +} + +pub(crate) struct DidShowInlineEdit; + +#[derive(Serialize, Deserialize)] +pub(crate) struct DidShowInlineEditParams { + pub(crate) item: serde_json::Value, +} + +impl lsp::notification::Notification for DidShowInlineEdit { + type Params = DidShowInlineEditParams; + const METHOD: &'static str = "textDocument/didShowInlineEdit"; +} + +// Inline Completions (non-NES) - textDocument/inlineCompletion + +pub enum InlineCompletions {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineCompletionsParams { + pub text_document: VersionedTextDocumentIdentifier, + pub position: lsp::Position, + pub context: InlineCompletionContext, + #[serde(skip_serializing_if = "Option::is_none")] + pub formatting_options: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineCompletionContext { + pub trigger_kind: InlineCompletionTriggerKind, +} + +#[derive(Debug, Clone, Copy)] +pub enum InlineCompletionTriggerKind { + Invoked = 1, + Automatic = 2, +} + +impl Serialize for InlineCompletionTriggerKind { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_u8(*self as u8) + } +} + +impl<'de> Deserialize<'de> for InlineCompletionTriggerKind { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = u8::deserialize(deserializer)?; + match value { + 1 => Ok(InlineCompletionTriggerKind::Invoked), + 2 => Ok(InlineCompletionTriggerKind::Automatic), + _ => Err(serde::de::Error::custom("invalid trigger kind")), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FormattingOptions { + pub tab_size: u32, + pub insert_spaces: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineCompletionsResult { + pub items: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineCompletionItem { + pub insert_text: String, + pub range: lsp::Range, + #[serde(skip_serializing_if = "Option::is_none")] + pub command: Option, +} + +impl lsp::request::Request for InlineCompletions { + type Params = InlineCompletionsParams; + type Result = InlineCompletionsResult; + + const METHOD: &'static str = "textDocument/inlineCompletion"; +} + +// Telemetry notifications for inline completions + +pub(crate) struct DidShowCompletion; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct DidShowCompletionParams { + pub(crate) item: InlineCompletionItem, +} + +impl lsp::notification::Notification for DidShowCompletion { + type Params = DidShowCompletionParams; + const METHOD: &'static str = "textDocument/didShowCompletion"; +} diff --git a/crates/copilot_chat/Cargo.toml b/crates/copilot_chat/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ac200f05b330823f20c578ea5dd94b2a3bfc3429 --- /dev/null +++ b/crates/copilot_chat/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "copilot_chat" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/copilot_chat.rs" +doctest = false + +[features] +default = [] +test-support = [ + "collections/test-support", + "gpui/test-support", + "settings/test-support", +] + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +collections.workspace = true +dirs.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +http_client.workspace = true +itertools.workspace = true +log.workspace = true +paths.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } +serde_json.workspace = true \ No newline at end of file diff --git a/crates/copilot_chat/LICENSE-GPL b/crates/copilot_chat/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/copilot_chat/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot_chat/src/copilot_chat.rs similarity index 99% rename from crates/copilot/src/copilot_chat.rs rename to crates/copilot_chat/src/copilot_chat.rs index 085959b59f97c0c17f4a4044b71c158b703cc515..922401a87837164fa9863afbe16b561426068da6 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot_chat/src/copilot_chat.rs @@ -1,3 +1,5 @@ +pub mod responses; + use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; @@ -16,7 +18,6 @@ use itertools::Itertools; use paths::home_dir; use serde::{Deserialize, Serialize}; -use crate::copilot_responses as responses; use settings::watch_config_dir; pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN"; diff --git a/crates/copilot/src/copilot_responses.rs b/crates/copilot_chat/src/responses.rs similarity index 99% rename from crates/copilot/src/copilot_responses.rs rename to crates/copilot_chat/src/responses.rs index 2da2eb394b5fc5ba88c8dd3007df394a2dbc15bf..8262d8e4c370a66a44fc65a2b4de05da23dc5f18 100644 --- a/crates/copilot/src/copilot_responses.rs +++ b/crates/copilot_chat/src/responses.rs @@ -1,4 +1,5 @@ -use super::*; +use std::sync::Arc; + use anyhow::{Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; diff --git a/crates/copilot_ui/Cargo.toml b/crates/copilot_ui/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..9f2668d9fb40d12631bff6af3291bdb3a40dea15 --- /dev/null +++ b/crates/copilot_ui/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "copilot_ui" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/copilot_ui.rs" +doctest = false + +[features] +default = [] +test-support = [ + "copilot/test-support", + "gpui/test-support", +] + +[dependencies] +anyhow.workspace = true +copilot.workspace = true +gpui.workspace = true +log.workspace = true +lsp.workspace = true +menu.workspace = true +serde_json.workspace = true +ui.workspace = true +util.workspace = true +workspace.workspace = true diff --git a/crates/copilot_ui/LICENSE-GPL b/crates/copilot_ui/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/copilot_ui/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs new file mode 100644 index 0000000000000000000000000000000000000000..e22c2800c4beff1debb31aea9ce4ddca811f2bf0 --- /dev/null +++ b/crates/copilot_ui/src/copilot_ui.rs @@ -0,0 +1,25 @@ +mod sign_in; + +use copilot::{Reinstall, SignIn, SignOut}; +use gpui::App; +use workspace::Workspace; + +pub use sign_in::{ + ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in, + reinstall_and_sign_in, +}; + +pub fn init(cx: &mut App) { + cx.observe_new(|workspace: &mut Workspace, _window, _cx| { + workspace.register_action(|_, _: &SignIn, window, cx| { + sign_in::initiate_sign_in(window, cx); + }); + workspace.register_action(|_, _: &Reinstall, window, cx| { + sign_in::reinstall_and_sign_in(window, cx); + }); + workspace.register_action(|_, _: &SignOut, window, cx| { + sign_in::initiate_sign_out(window, cx); + }); + }) + .detach(); +} diff --git a/crates/copilot/src/sign_in.rs b/crates/copilot_ui/src/sign_in.rs similarity index 87% rename from crates/copilot/src/sign_in.rs rename to crates/copilot_ui/src/sign_in.rs index ed633fe9306abf9b060027e562909edd05dad8fb..a9eda9d3c9b182fcb880c0d4d9812063578b4d1e 100644 --- a/crates/copilot/src/sign_in.rs +++ b/crates/copilot_ui/src/sign_in.rs @@ -1,12 +1,11 @@ -use crate::{Copilot, Status, request::PromptUserDeviceFlow}; use anyhow::Context as _; +use copilot::{Copilot, Status, request, request::PromptUserDeviceFlow}; use gpui::{ App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled, Subscription, Window, WindowBounds, WindowOptions, div, point, }; use ui::{ButtonLike, CommonAnimationExt, ConfiguredApiCard, Vector, VectorName, prelude::*}; -use url::Url; use util::ResultExt as _; use workspace::{Toast, Workspace, notifications::NotificationId}; @@ -187,22 +186,12 @@ impl CopilotCodeVerification { .detach(); let status = copilot.read(cx).status(); - // Determine sign-up URL based on verification_uri domain if available - let sign_up_url = if let Status::SigningIn { - prompt: Some(ref prompt), - } = status - { - // Extract domain from verification_uri to construct sign-up URL - Self::get_sign_up_url_from_verification(&prompt.verification_uri) - } else { - None - }; Self { status, connect_clicked: false, focus_handle: cx.focus_handle(), copilot: copilot.clone(), - sign_up_url, + sign_up_url: None, _subscription: cx.observe(copilot, |this, copilot, cx| { let status = copilot.read(cx).status(); match status { @@ -216,30 +205,10 @@ impl CopilotCodeVerification { } pub fn set_status(&mut self, status: Status, cx: &mut Context) { - // Update sign-up URL if we have a new verification URI - if let Status::SigningIn { - prompt: Some(ref prompt), - } = status - { - self.sign_up_url = Self::get_sign_up_url_from_verification(&prompt.verification_uri); - } self.status = status; cx.notify(); } - fn get_sign_up_url_from_verification(verification_uri: &str) -> Option { - // Extract domain from verification URI using url crate - if let Ok(url) = Url::parse(verification_uri) - && let Some(host) = url.host_str() - && !host.contains("github.com") - { - // For GHE, construct URL from domain - Some(format!("https://{}/features/copilot", host)) - } else { - None - } - } - fn render_device_code(data: &PromptUserDeviceFlow, cx: &mut Context) -> impl IntoElement { let copied = cx .read_from_clipboard() @@ -303,9 +272,49 @@ impl CopilotCodeVerification { .style(ButtonStyle::Outlined) .size(ButtonSize::Medium) .on_click({ - let verification_uri = data.verification_uri.clone(); + let command = data.command.clone(); cx.listener(move |this, _, _window, cx| { - cx.open_url(&verification_uri); + if let Some(copilot) = Copilot::global(cx) { + let command = command.clone(); + let copilot_clone = copilot.clone(); + copilot.update(cx, |copilot, cx| { + if let Some(server) = copilot.language_server() { + let server = server.clone(); + cx.spawn(async move |_, cx| { + let result = server + .request::( + lsp::ExecuteCommandParams { + command: command.command.clone(), + arguments: command + .arguments + .clone() + .unwrap_or_default(), + ..Default::default() + }, + ) + .await + .into_response() + .ok() + .flatten(); + if let Some(value) = result { + if let Ok(status) = + serde_json::from_value::< + request::SignInStatus, + >(value) + { + copilot_clone + .update(cx, |copilot, cx| { + copilot.update_sign_in_status( + status, cx, + ); + }); + } + } + }) + .detach(); + } + }); + } this.connect_clicked = true; }) }), @@ -450,7 +459,7 @@ impl Render for CopilotCodeVerification { pub struct ConfigurationView { copilot_status: Option, - is_authenticated: fn(cx: &App) -> bool, + is_authenticated: Box bool + 'static>, edit_prediction: bool, _subscription: Option, } @@ -462,7 +471,7 @@ pub enum ConfigurationMode { impl ConfigurationView { pub fn new( - is_authenticated: fn(cx: &App) -> bool, + is_authenticated: impl Fn(&App) -> bool + 'static, mode: ConfigurationMode, cx: &mut Context, ) -> Self { @@ -470,7 +479,7 @@ impl ConfigurationView { Self { copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), - is_authenticated, + is_authenticated: Box::new(is_authenticated), edit_prediction: matches!(mode, ConfigurationMode::EditPrediction), _subscription: copilot.as_ref().map(|copilot| { cx.observe(copilot, |this, model, cx| { @@ -669,7 +678,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let is_authenticated = self.is_authenticated; + let is_authenticated = &self.is_authenticated; if is_authenticated(cx) { return ConfiguredApiCard::new("Authorized") diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 394afc0eb701f2fd1fba6c5fa733fb76ca5781b1..b91b30a1bc29771866d0237b2f5a108196495b82 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -24,6 +24,7 @@ client.workspace = true cloud_llm_client.workspace = true collections.workspace = true copilot.workspace = true +copilot_ui.workspace = true db.workspace = true edit_prediction_types.workspace = true edit_prediction_context.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 4f9a67880b31f8c958fb9b922bfd5c102d365c1a..a3035a72ace5476c9c30106d2289b45ee352b89d 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -10,6 +10,7 @@ use cloud_llm_client::{ PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet}; +use copilot::Copilot; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::EditPredictionExcerptOptions; use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; @@ -291,6 +292,7 @@ struct ProjectState { license_detection_watchers: HashMap>, user_actions: VecDeque, _subscription: gpui::Subscription, + copilot: Option>, } impl ProjectState { @@ -662,6 +664,7 @@ impl EditPredictionStore { }, sweep_ai: SweepAi::new(cx), mercury: Mercury::new(cx), + data_collection_choice, reject_predictions_tx: reject_tx, rated_predictions: Default::default(), @@ -783,6 +786,38 @@ impl EditPredictionStore { .unwrap_or_default() } + pub fn copilot_for_project(&self, project: &Entity) -> Option> { + self.projects + .get(&project.entity_id()) + .and_then(|project| project.copilot.clone()) + } + + pub fn start_copilot_for_project( + &mut self, + project: &Entity, + cx: &mut Context, + ) -> Option> { + let state = self.get_or_init_project(project, cx); + + if state.copilot.is_some() { + return state.copilot.clone(); + } + let _project = project.clone(); + let project = project.read(cx); + + let node = project.node_runtime().cloned(); + if let Some(node) = node { + let next_id = project.languages().next_language_server_id(); + let fs = project.fs().clone(); + + let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx)); + state.copilot = Some(copilot.clone()); + Some(copilot) + } else { + None + } + } + pub fn context_for_project_with_buffers<'a>( &'a self, project: &Entity, @@ -853,6 +888,7 @@ impl EditPredictionStore { license_detection_watchers: HashMap::default(), user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE), _subscription: cx.subscribe(&project, Self::handle_project_event), + copilot: None, }) } diff --git a/crates/edit_prediction/src/onboarding_modal.rs b/crates/edit_prediction/src/onboarding_modal.rs index 97f529ae38df350ef21ffc04b79df6e8e6a7a501..14f3ce4e1daddc8a2be37a3a18729f8ae85572e0 100644 --- a/crates/edit_prediction/src/onboarding_modal.rs +++ b/crates/edit_prediction/src/onboarding_modal.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::ZedPredictUpsell; +use crate::{EditPredictionStore, ZedPredictUpsell}; use ai_onboarding::EditPredictionOnboarding; use client::{Client, UserStore}; use db::kvp::Dismissable; @@ -50,15 +50,17 @@ impl ZedPredictModal { window: &mut Window, cx: &mut Context, ) { + let project = workspace.project().clone(); workspace.toggle_modal(window, cx, |_window, cx| { let weak_entity = cx.weak_entity(); + let copilot = EditPredictionStore::try_global(cx) + .and_then(|store| store.read(cx).copilot_for_project(&project)); Self { onboarding: cx.new(|cx| { EditPredictionOnboarding::new( user_store.clone(), client.clone(), - copilot::Copilot::global(cx) - .is_some_and(|copilot| copilot.read(cx).status().is_configured()), + copilot.is_some_and(|copilot| copilot.read(cx).status().is_configured()), Arc::new({ let this = weak_entity.clone(); move |_window, cx| { @@ -73,7 +75,7 @@ impl ZedPredictModal { ZedPredictUpsell::set_dismissed(true, cx); set_edit_prediction_provider(EditPredictionProvider::Copilot, cx); this.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); - copilot::initiate_sign_in(window, cx); + copilot_ui::initiate_sign_in(window, cx); } }), cx, diff --git a/crates/edit_prediction_ui/Cargo.toml b/crates/edit_prediction_ui/Cargo.toml index 6c4bf735e7cce1b95666b0195d2da8caab57f703..d4a7c5d3ab800f54476a8e88914dcaaba3a26547 100644 --- a/crates/edit_prediction_ui/Cargo.toml +++ b/crates/edit_prediction_ui/Cargo.toml @@ -22,6 +22,8 @@ cloud_llm_client.workspace = true codestral.workspace = true command_palette_hooks.workspace = true copilot.workspace = true +copilot_chat.workspace = true +copilot_ui.workspace = true edit_prediction_types.workspace = true edit_prediction.workspace = true editor.workspace = true diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index 031e915a10fbe45781700c62364b220ae720e05b..b7b6d61edf59273e7ef72000004d116cab69309e 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -22,7 +22,7 @@ use language::{ EditPredictionsMode, File, Language, language_settings::{self, AllLanguageSettings, EditPredictionProvider, all_language_settings}, }; -use project::DisableAiSettings; +use project::{DisableAiSettings, Project}; use regex::Regex; use settings::{ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, @@ -75,6 +75,7 @@ pub struct EditPredictionButton { fs: Arc, user_store: Entity, popover_menu_handle: PopoverMenuHandle, + project: WeakEntity, } enum SupermavenButtonStatus { @@ -95,7 +96,9 @@ impl Render for EditPredictionButton { match all_language_settings.edit_predictions.provider { EditPredictionProvider::Copilot => { - let Some(copilot) = Copilot::global(cx) else { + let Some(copilot) = EditPredictionStore::try_global(cx) + .and_then(|store| store.read(cx).copilot_for_project(&self.project.upgrade()?)) + else { return div().hidden(); }; let status = copilot.read(cx).status(); @@ -129,7 +132,7 @@ impl Render for EditPredictionButton { .on_click( "Reinstall Copilot", |window, cx| { - copilot::reinstall_and_sign_in(window, cx) + copilot_ui::reinstall_and_sign_in(window, cx) }, ), cx, @@ -143,11 +146,16 @@ impl Render for EditPredictionButton { ); } let this = cx.weak_entity(); - + let project = self.project.clone(); div().child( PopoverMenu::new("copilot") .menu(move |window, cx| { - let current_status = Copilot::global(cx)?.read(cx).status(); + let current_status = EditPredictionStore::try_global(cx) + .and_then(|store| { + store.read(cx).copilot_for_project(&project.upgrade()?) + })? + .read(cx) + .status(); match current_status { Status::Authorized => this.update(cx, |this, cx| { this.build_copilot_context_menu(window, cx) @@ -478,6 +486,7 @@ impl EditPredictionButton { user_store: Entity, popover_menu_handle: PopoverMenuHandle, client: Arc, + project: Entity, cx: &mut Context, ) -> Self { if let Some(copilot) = Copilot::global(cx) { @@ -514,6 +523,7 @@ impl EditPredictionButton { edit_prediction_provider: None, user_store, popover_menu_handle, + project: project.downgrade(), fs, } } @@ -529,10 +539,10 @@ impl EditPredictionButton { )); } - if let Some(copilot) = Copilot::global(cx) { - if matches!(copilot.read(cx).status(), Status::Authorized) { - providers.push(EditPredictionProvider::Copilot); - } + if let Some(_) = EditPredictionStore::try_global(cx) + .and_then(|store| store.read(cx).copilot_for_project(&self.project.upgrade()?)) + { + providers.push(EditPredictionProvider::Copilot); } if let Some(supermaven) = Supermaven::global(cx) { @@ -629,7 +639,7 @@ impl EditPredictionButton { ) -> Entity { let fs = self.fs.clone(); ContextMenu::build(window, cx, |menu, _, _| { - menu.entry("Sign In to Copilot", None, copilot::initiate_sign_in) + menu.entry("Sign In to Copilot", None, copilot_ui::initiate_sign_in) .entry("Disable Copilot", None, { let fs = fs.clone(); move |_window, cx| hide_copilot(fs.clone(), cx) @@ -931,7 +941,7 @@ impl EditPredictionButton { cx: &mut Context, ) -> Entity { let all_language_settings = all_language_settings(None, cx); - let copilot_config = copilot::copilot_chat::CopilotChatConfiguration { + let copilot_config = copilot_chat::CopilotChatConfiguration { enterprise_uri: all_language_settings .edit_predictions .copilot diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 6f5ca58e221207b2732b4a0388351fa40826e296..3b0cf3a31d0718f89994c4cd3cb2bf82f4ec4408 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -26,6 +26,8 @@ collections.workspace = true component.workspace = true convert_case.workspace = true copilot.workspace = true +copilot_chat.workspace = true +copilot_ui.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } extension.workspace = true diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 68eaab1dbed33a8d983de6a919b75dc809410a70..43c4f08d6cdc1b4a33b971d29528c6c3e3812d32 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -5,12 +5,13 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use cloud_llm_client::CompletionIntent; use collections::HashMap; -use copilot::copilot_chat::{ - ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, - Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool, - ToolCall, -}; use copilot::{Copilot, Status}; +use copilot_chat::responses as copilot_responses; +use copilot_chat::{ + ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, CopilotChatConfiguration, + Function, FunctionContent, ImageUrl, Model as CopilotChatModel, ModelVendor, + Request as CopilotChatRequest, ResponseEvent, Tool, ToolCall, ToolCallContent, ToolChoice, +}; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, Stream, StreamExt}; @@ -60,7 +61,7 @@ impl CopilotChatLanguageModelProvider { _settings_subscription: cx.observe_global::(|_, cx| { if let Some(copilot_chat) = CopilotChat::global(cx) { let language_settings = all_language_settings(None, cx); - let configuration = copilot::copilot_chat::CopilotChatConfiguration { + let configuration = CopilotChatConfiguration { enterprise_uri: language_settings .edit_predictions .copilot @@ -178,13 +179,13 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { cx: &mut App, ) -> AnyView { cx.new(|cx| { - copilot::ConfigurationView::new( + copilot_ui::ConfigurationView::new( |cx| { CopilotChat::global(cx) .map(|m| m.read(cx).is_authenticated()) .unwrap_or(false) }, - copilot::ConfigurationMode::Chat, + copilot_ui::ConfigurationMode::Chat, cx, ) }) @@ -563,7 +564,7 @@ impl CopilotResponsesEventMapper { pub fn map_stream( mut self, - events: Pin>>>, + events: Pin>>>, ) -> impl Stream> { events.flat_map(move |event| { @@ -576,11 +577,11 @@ impl CopilotResponsesEventMapper { fn map_event( &mut self, - event: copilot::copilot_responses::StreamEvent, + event: copilot_responses::StreamEvent, ) -> Vec> { match event { - copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item { - copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => { + copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item { + copilot_responses::ResponseOutputItem::Message { id, .. } => { vec![Ok(LanguageModelCompletionEvent::StartMessage { message_id: id, })] @@ -588,7 +589,7 @@ impl CopilotResponsesEventMapper { _ => Vec::new(), }, - copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => { + copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => { if delta.is_empty() { Vec::new() } else { @@ -596,9 +597,9 @@ impl CopilotResponsesEventMapper { } } - copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item { - copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(), - copilot::copilot_responses::ResponseOutputItem::FunctionCall { + copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item { + copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(), + copilot_responses::ResponseOutputItem::FunctionCall { call_id, name, arguments, @@ -632,7 +633,7 @@ impl CopilotResponsesEventMapper { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); events } - copilot::copilot_responses::ResponseOutputItem::Reasoning { + copilot_responses::ResponseOutputItem::Reasoning { summary, encrypted_content, .. @@ -660,7 +661,7 @@ impl CopilotResponsesEventMapper { } }, - copilot::copilot_responses::StreamEvent::Completed { response } => { + copilot_responses::StreamEvent::Completed { response } => { let mut events = Vec::new(); if let Some(usage) = response.usage { events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { @@ -676,18 +677,16 @@ impl CopilotResponsesEventMapper { events } - copilot::copilot_responses::StreamEvent::Incomplete { response } => { + copilot_responses::StreamEvent::Incomplete { response } => { let reason = response .incomplete_details .as_ref() .and_then(|details| details.reason.as_ref()); let stop_reason = match reason { - Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => { + Some(copilot_responses::IncompleteReason::MaxOutputTokens) => { StopReason::MaxTokens } - Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => { - StopReason::Refusal - } + Some(copilot_responses::IncompleteReason::ContentFilter) => StopReason::Refusal, _ => self .pending_stop_reason .take() @@ -707,7 +706,7 @@ impl CopilotResponsesEventMapper { events } - copilot::copilot_responses::StreamEvent::Failed { response } => { + copilot_responses::StreamEvent::Failed { response } => { let provider = PROVIDER_NAME; let (status_code, message) = match response.error { Some(error) => { @@ -727,18 +726,18 @@ impl CopilotResponsesEventMapper { })] } - copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err( + copilot_responses::StreamEvent::GenericError { error } => vec![Err( LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))), )], - copilot::copilot_responses::StreamEvent::Created { .. } - | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(), + copilot_responses::StreamEvent::Created { .. } + | copilot_responses::StreamEvent::Unknown => Vec::new(), } } } fn into_copilot_chat( - model: &copilot::copilot_chat::Model, + model: &CopilotChatModel, request: LanguageModelRequest, ) -> Result { let mut request_messages: Vec = Vec::new(); @@ -825,8 +824,8 @@ fn into_copilot_chat( if let MessageContent::ToolUse(tool_use) = content { tool_calls.push(ToolCall { id: tool_use.id.to_string(), - content: copilot::copilot_chat::ToolCallContent::Function { - function: copilot::copilot_chat::FunctionContent { + content: ToolCallContent::Function { + function: FunctionContent { name: tool_use.name.to_string(), arguments: serde_json::to_string(&tool_use.input)?, thought_signature: tool_use.thought_signature.clone(), @@ -890,7 +889,7 @@ fn into_copilot_chat( .tools .iter() .map(|tool| Tool::Function { - function: copilot::copilot_chat::Function { + function: Function { name: tool.name.clone(), description: tool.description.clone(), parameters: tool.input_schema.clone(), @@ -907,18 +906,18 @@ fn into_copilot_chat( messages, tools, tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto, - LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any, - LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None, + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, }), }) } fn into_copilot_responses( - model: &copilot::copilot_chat::Model, + model: &CopilotChatModel, request: LanguageModelRequest, -) -> copilot::copilot_responses::Request { - use copilot::copilot_responses as responses; +) -> copilot_responses::Request { + use copilot_responses as responses; let LanguageModelRequest { thread_id: _, @@ -1109,7 +1108,7 @@ fn into_copilot_responses( tool_choice: mapped_tool_choice, reasoning: None, // We would need to add support for setting from user settings. include: Some(vec![ - copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent, + copilot_responses::ResponseIncludable::ReasoningEncryptedContent, ]), } } @@ -1117,7 +1116,7 @@ fn into_copilot_responses( #[cfg(test)] mod tests { use super::*; - use copilot::copilot_responses as responses; + use copilot_chat::responses; use futures::StreamExt; fn map_events(events: Vec) -> Vec { @@ -1384,20 +1383,22 @@ mod tests { #[test] fn chat_completions_stream_maps_reasoning_data() { - use copilot::copilot_chat::ResponseEvent; + use copilot_chat::{ + FunctionChunk, ResponseChoice, ResponseDelta, ResponseEvent, Role, ToolCallChunk, + }; let events = vec![ ResponseEvent { - choices: vec![copilot::copilot_chat::ResponseChoice { + choices: vec![ResponseChoice { index: Some(0), finish_reason: None, - delta: Some(copilot::copilot_chat::ResponseDelta { + delta: Some(ResponseDelta { content: None, - role: Some(copilot::copilot_chat::Role::Assistant), - tool_calls: vec![copilot::copilot_chat::ToolCallChunk { + role: Some(Role::Assistant), + tool_calls: vec![ToolCallChunk { index: Some(0), id: Some("call_abc123".to_string()), - function: Some(copilot::copilot_chat::FunctionChunk { + function: Some(FunctionChunk { name: Some("list_directory".to_string()), arguments: Some("{\"path\":\"test\"}".to_string()), thought_signature: None, @@ -1412,10 +1413,10 @@ mod tests { usage: None, }, ResponseEvent { - choices: vec![copilot::copilot_chat::ResponseChoice { + choices: vec![ResponseChoice { index: Some(0), finish_reason: Some("tool_calls".to_string()), - delta: Some(copilot::copilot_chat::ResponseDelta { + delta: Some(ResponseDelta { content: None, role: None, tool_calls: vec![], diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 6181975a9be8a23a92acb48f74a2c17b17a8d6ff..989e1769ec84c06ea41dadc7187dacf5894e7ca3 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -17,8 +17,8 @@ anyhow.workspace = true client.workspace = true collections.workspace = true command_palette_hooks.workspace = true -copilot.workspace = true editor.workspace = true +edit_prediction.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true diff --git a/crates/language_tools/src/lsp_log_view.rs b/crates/language_tools/src/lsp_log_view.rs index 43212cd63818ead409b180babfe1ebda2359001f..a83ebaabc219cc131ed6dd499cfbd3c663e5d9ae 100644 --- a/crates/language_tools/src/lsp_log_view.rs +++ b/crates/language_tools/src/lsp_log_view.rs @@ -1,5 +1,5 @@ use collections::VecDeque; -use copilot::Copilot; +use edit_prediction::EditPredictionStore; use editor::{Editor, EditorEvent, MultiBufferOffset, actions::MoveToEnd, scroll::Autoscroll}; use gpui::{ App, Context, Corner, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, @@ -115,46 +115,6 @@ actions!( pub fn init(on_headless_host: bool, cx: &mut App) { let log_store = log_store::init(on_headless_host, cx); - log_store.update(cx, |_, cx| { - Copilot::global(cx).map(|copilot| { - let copilot = &copilot; - cx.subscribe(copilot, |log_store, copilot, edit_prediction_event, cx| { - if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event - && let Some(server) = copilot.read(cx).language_server() - { - let server_id = server.server_id(); - let weak_lsp_store = cx.weak_entity(); - log_store.copilot_log_subscription = - Some(server.on_notification::( - move |params, cx| { - weak_lsp_store - .update(cx, |lsp_store, cx| { - lsp_store.add_language_server_log( - server_id, - MessageType::LOG, - ¶ms.message, - cx, - ); - }) - .ok(); - }, - )); - - let name = LanguageServerName::new_static("copilot"); - log_store.add_language_server( - LanguageServerKind::Global, - server.server_id(), - Some(name), - None, - Some(server.clone()), - cx, - ); - } - }) - .detach(); - }) - }); - cx.observe_new(move |workspace: &mut Workspace, _, cx| { log_store.update(cx, |store, cx| { store.add_project(workspace.project(), cx); @@ -381,8 +341,47 @@ impl LspLogView { ); (editor, vec![editor_subscription, search_subscription]) } + pub(crate) fn try_ensure_copilot_for_project(&self, cx: &mut App) { + self.log_store.update(cx, |this, cx| { + let copilot = EditPredictionStore::try_global(cx) + .and_then(|store| store.read(cx).copilot_for_project(&self.project))?; + let server = copilot.read(cx).language_server()?.clone(); + let log_subscription = this.copilot_state_for_project(&self.project.downgrade()); + if let Some(subscription_slot @ None) = log_subscription { + let weak_lsp_store = cx.weak_entity(); + let server_id = server.server_id(); + + let name = LanguageServerName::new_static("copilot"); + *subscription_slot = + Some(server.on_notification::( + move |params, cx| { + weak_lsp_store + .update(cx, |lsp_store, cx| { + lsp_store.add_language_server_log( + server_id, + MessageType::LOG, + ¶ms.message, + cx, + ); + }) + .ok(); + }, + )); + this.add_language_server( + LanguageServerKind::Global, + server.server_id(), + Some(name), + None, + Some(server.clone()), + cx, + ); + } - pub(crate) fn menu_items<'a>(&'a self, cx: &'a App) -> Option> { + Some(()) + }); + } + pub(crate) fn menu_items(&self, cx: &mut App) -> Option> { + self.try_ensure_copilot_for_project(cx); let log_store = self.log_store.read(cx); let unknown_server = LanguageServerName::new_static("unknown server"); diff --git a/crates/project/src/lsp_store/log_store.rs b/crates/project/src/lsp_store/log_store.rs index 877cf44c4a0b511c89172e6cf87f857d200ed178..ae6f9ec09d419232d8f1e5ea76e92e80708f4311 100644 --- a/crates/project/src/lsp_store/log_store.rs +++ b/crates/project/src/lsp_store/log_store.rs @@ -40,13 +40,13 @@ impl EventEmitter for LogStore {} pub struct LogStore { on_headless_host: bool, projects: HashMap, ProjectState>, - pub copilot_log_subscription: Option, pub language_servers: HashMap, io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, } struct ProjectState { _subscriptions: [Subscription; 2], + copilot_log_subscription: Option, } pub trait Message: AsRef { @@ -220,7 +220,7 @@ impl LogStore { let log_store = Self { projects: HashMap::default(), language_servers: HashMap::default(), - copilot_log_subscription: None, + on_headless_host, io_tx, }; @@ -350,6 +350,7 @@ impl LogStore { } }), ], + copilot_log_subscription: None, }, ); } @@ -713,4 +714,12 @@ impl LogStore { } } } + pub fn copilot_state_for_project( + &mut self, + project: &WeakEntity, + ) -> Option<&mut Option> { + self.projects + .get_mut(project) + .map(|project| &mut project.copilot_log_subscription) + } } diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 256ec2de557e903405d1c3431ef44e98d757d3c6..94630e860f31f9d8f0e624f32253be4515981760 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -18,7 +18,7 @@ test-support = [] [dependencies] anyhow.workspace = true bm25 = "2.3.2" -copilot.workspace = true +copilot_ui.workspace = true edit_prediction.workspace = true language_models.workspace = true editor.workspace = true diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 0bab03a852cb29b6af0bf667ad3375d74014abee..58094df30650b87d18ed771a9819f90ce8095630 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -1,11 +1,12 @@ use edit_prediction::{ - ApiKeyState, MercuryFeatureFlag, SweepFeatureFlag, + ApiKeyState, EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token}, sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token}, }; use feature_flags::FeatureFlagAppExt as _; use gpui::{Entity, ScrollHandle, prelude::*}; use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key}; +use project::Project; use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*}; use crate::{ @@ -30,9 +31,19 @@ impl EditPredictionSetupPage { impl Render for EditPredictionSetupPage { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings_window = self.settings_window.clone(); - + let project = settings_window + .read(cx) + .original_window + .as_ref() + .and_then(|window| { + window + .read_with(cx, |workspace, _| workspace.project().clone()) + .ok() + }); let providers = [ - Some(render_github_copilot_provider(window, cx).into_any_element()), + project.and_then(|project| { + Some(render_github_copilot_provider(project, window, cx)?.into_any_element()) + }), cx.has_flag::().then(|| { render_api_key_provider( IconName::Inception, @@ -337,29 +348,36 @@ fn codestral_settings() -> Box<[SettingsPageItem]> { ]) } -pub(crate) fn render_github_copilot_provider( +fn render_github_copilot_provider( + project: Entity, window: &mut Window, cx: &mut App, -) -> impl IntoElement { +) -> Option { + let copilot = EditPredictionStore::try_global(cx)? + .read(cx) + .copilot_for_project(&project); let configuration_view = window.use_state(cx, |_, cx| { - copilot::ConfigurationView::new( - |cx| { - copilot::Copilot::global(cx) + copilot_ui::ConfigurationView::new( + move |cx| { + copilot + .as_ref() .is_some_and(|copilot| copilot.read(cx).is_authenticated()) }, - copilot::ConfigurationMode::EditPrediction, + copilot_ui::ConfigurationMode::EditPrediction, cx, ) }); - v_flex() - .id("github-copilot") - .min_w_0() - .gap_1p5() - .child( - SettingsSectionHeader::new("GitHub Copilot") - .icon(IconName::Copilot) - .no_padding(true), - ) - .child(configuration_view) + Some( + v_flex() + .id("github-copilot") + .min_w_0() + .gap_1p5() + .child( + SettingsSectionHeader::new("GitHub Copilot") + .icon(IconName::Copilot) + .no_padding(true), + ) + .child(configuration_view), + ) } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 9c2fa63dd6b32f49a74bfd80cf430361a7f06128..7adaf376a2a148f7c34fa4a165de19877c93f8f8 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -87,6 +87,8 @@ command_palette.workspace = true component.workspace = true component_preview.workspace = true copilot.workspace = true +copilot_chat.workspace = true +copilot_ui.workspace = true crashes.workspace = true dap_adapters.workspace = true db.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 1f72a894a290f7fb7224e23558bef2f02c53313c..e56f9e236853eca1c10f18f76cd6e29b7e4b594b 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -590,14 +590,21 @@ fn main() { cx.background_executor().clone(), ); command_palette::init(cx); - let copilot_language_server_id = app_state.languages.next_language_server_id(); - copilot::init( - copilot_language_server_id, + let copilot_chat_configuration = copilot_chat::CopilotChatConfiguration { + enterprise_uri: language::language_settings::all_language_settings(None, cx) + .edit_predictions + .copilot + .enterprise_uri + .clone(), + }; + copilot_chat::init( app_state.fs.clone(), app_state.client.http_client(), - app_state.node_runtime.clone(), + copilot_chat_configuration, cx, ); + + copilot_ui::init(cx); supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index c2f427226bb7c1f60d5299d65dc1ddcf472bd6a4..9bb79b78646c5d35bdb2137ee28b1a3876994c1d 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -407,6 +407,7 @@ pub fn initialize_workspace( app_state.user_store.clone(), edit_prediction_menu_handle.clone(), app_state.client.clone(), + workspace.project().clone(), cx, ) }); @@ -4922,10 +4923,10 @@ mod tests { project_panel::init(cx); outline_panel::init(cx); terminal_view::init(cx); - copilot::copilot_chat::init( + copilot_chat::init( app_state.fs.clone(), app_state.client.http_client(), - copilot::copilot_chat::CopilotChatConfiguration::default(), + copilot_chat::CopilotChatConfiguration::default(), cx, ); image_viewer::init(cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 98d5fcaad848920bce47c119bfb046c74e6188c1..a59da22be5ff0f557d373b53c90fef5d05b31527 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -1,7 +1,7 @@ use client::{Client, UserStore}; use codestral::CodestralEditPredictionDelegate; use collections::HashMap; -use copilot::{Copilot, CopilotEditPredictionDelegate}; +use copilot::CopilotEditPredictionDelegate; use edit_prediction::{ MercuryFeatureFlag, SweepFeatureFlag, ZedEditPredictionDelegate, Zeta2FeatureFlag, }; @@ -165,7 +165,14 @@ fn assign_edit_prediction_provider( editor.set_edit_prediction_provider::(None, window, cx); } EditPredictionProvider::Copilot => { - if let Some(copilot) = Copilot::global(cx) { + let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx); + let Some(project) = editor.project().cloned() else { + return; + }; + let copilot = + ep_store.update(cx, |this, cx| this.start_copilot_for_project(&project, cx)); + + if let Some(copilot) = copilot { if let Some(buffer) = singleton_buffer && buffer.read(cx).file().is_some() {