1pub mod copilot_chat;
2mod copilot_completion_provider;
3pub mod request;
4mod sign_in;
5
6use crate::sign_in::initiate_sign_in_within_workspace;
7use ::fs::Fs;
8use anyhow::{Context as _, Result, anyhow};
9use collections::{HashMap, HashSet};
10use command_palette_hooks::CommandPaletteFilter;
11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
12use gpui::{
13 App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
14 WeakEntity, actions,
15};
16use http_client::HttpClient;
17use language::language_settings::CopilotSettings;
18use language::{
19 Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
20 language_settings::{EditPredictionProvider, all_language_settings, language_settings},
21 point_from_lsp, point_to_lsp,
22};
23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
24use node_runtime::NodeRuntime;
25use parking_lot::Mutex;
26use request::StatusNotification;
27use serde_json::json;
28use settings::SettingsStore;
29use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
30use std::collections::hash_map::Entry;
31use std::{
32 any::TypeId,
33 env,
34 ffi::OsString,
35 mem,
36 ops::Range,
37 path::{Path, PathBuf},
38 sync::Arc,
39};
40use util::{ResultExt, fs::remove_matching};
41use workspace::Workspace;
42
43pub use crate::copilot_completion_provider::CopilotCompletionProvider;
44pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
45
46actions!(
47 copilot,
48 [
49 Suggest,
50 NextSuggestion,
51 PreviousSuggestion,
52 Reinstall,
53 SignIn,
54 SignOut
55 ]
56);
57
58pub fn init(
59 new_server_id: LanguageServerId,
60 fs: Arc<dyn Fs>,
61 http: Arc<dyn HttpClient>,
62 node_runtime: NodeRuntime,
63 cx: &mut App,
64) {
65 let language_settings = all_language_settings(None, cx);
66 let configuration = copilot_chat::CopilotChatConfiguration {
67 enterprise_uri: language_settings
68 .edit_predictions
69 .copilot
70 .enterprise_uri
71 .clone(),
72 };
73 copilot_chat::init(fs.clone(), http.clone(), configuration, cx);
74
75 let copilot = cx.new({
76 let node_runtime = node_runtime.clone();
77 move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
78 });
79 Copilot::set_global(copilot.clone(), cx);
80 cx.observe(&copilot, |handle, cx| {
81 let copilot_action_types = [
82 TypeId::of::<Suggest>(),
83 TypeId::of::<NextSuggestion>(),
84 TypeId::of::<PreviousSuggestion>(),
85 TypeId::of::<Reinstall>(),
86 ];
87 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
88 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
89 let status = handle.read(cx).status();
90 let filter = CommandPaletteFilter::global_mut(cx);
91
92 match status {
93 Status::Disabled => {
94 filter.hide_action_types(&copilot_action_types);
95 filter.hide_action_types(&copilot_auth_action_types);
96 filter.hide_action_types(&copilot_no_auth_action_types);
97 }
98 Status::Authorized => {
99 filter.hide_action_types(&copilot_no_auth_action_types);
100 filter.show_action_types(
101 copilot_action_types
102 .iter()
103 .chain(&copilot_auth_action_types),
104 );
105 }
106 _ => {
107 filter.hide_action_types(&copilot_action_types);
108 filter.hide_action_types(&copilot_auth_action_types);
109 filter.show_action_types(copilot_no_auth_action_types.iter());
110 }
111 }
112 })
113 .detach();
114
115 cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
116 workspace.register_action(|workspace, _: &SignIn, window, cx| {
117 if let Some(copilot) = Copilot::global(cx) {
118 let is_reinstall = false;
119 initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
120 }
121 });
122 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
123 if let Some(copilot) = Copilot::global(cx) {
124 reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
125 }
126 });
127 workspace.register_action(|workspace, _: &SignOut, _window, cx| {
128 if let Some(copilot) = Copilot::global(cx) {
129 sign_out_within_workspace(workspace, copilot, cx);
130 }
131 });
132 })
133 .detach();
134}
135
136enum CopilotServer {
137 Disabled,
138 Starting { task: Shared<Task<()>> },
139 Error(Arc<str>),
140 Running(RunningCopilotServer),
141}
142
143impl CopilotServer {
144 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
145 let server = self.as_running()?;
146 anyhow::ensure!(
147 matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
148 "must sign in before using copilot"
149 );
150 Ok(server)
151 }
152
153 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
154 match self {
155 CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
156 CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
157 CopilotServer::Error(error) => {
158 anyhow::bail!("copilot was not started because of an error: {error}")
159 }
160 CopilotServer::Running(server) => Ok(server),
161 }
162 }
163}
164
165struct RunningCopilotServer {
166 lsp: Arc<LanguageServer>,
167 sign_in_status: SignInStatus,
168 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
169}
170
171#[derive(Clone, Debug)]
172enum SignInStatus {
173 Authorized,
174 Unauthorized,
175 SigningIn {
176 prompt: Option<request::PromptUserDeviceFlow>,
177 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
178 },
179 SignedOut {
180 awaiting_signing_in: bool,
181 },
182}
183
184#[derive(Debug, Clone)]
185pub enum Status {
186 Starting {
187 task: Shared<Task<()>>,
188 },
189 Error(Arc<str>),
190 Disabled,
191 SignedOut {
192 awaiting_signing_in: bool,
193 },
194 SigningIn {
195 prompt: Option<request::PromptUserDeviceFlow>,
196 },
197 Unauthorized,
198 Authorized,
199}
200
201impl Status {
202 pub fn is_authorized(&self) -> bool {
203 matches!(self, Status::Authorized)
204 }
205
206 pub fn is_disabled(&self) -> bool {
207 matches!(self, Status::Disabled)
208 }
209}
210
211struct RegisteredBuffer {
212 uri: lsp::Url,
213 language_id: String,
214 snapshot: BufferSnapshot,
215 snapshot_version: i32,
216 _subscriptions: [gpui::Subscription; 2],
217 pending_buffer_change: Task<Option<()>>,
218}
219
220impl RegisteredBuffer {
221 fn report_changes(
222 &mut self,
223 buffer: &Entity<Buffer>,
224 cx: &mut Context<Copilot>,
225 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
226 let (done_tx, done_rx) = oneshot::channel();
227
228 if buffer.read(cx).version() == self.snapshot.version {
229 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
230 } else {
231 let buffer = buffer.downgrade();
232 let id = buffer.entity_id();
233 let prev_pending_change =
234 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
235 self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
236 prev_pending_change.await;
237
238 let old_version = copilot
239 .update(cx, |copilot, _| {
240 let server = copilot.server.as_authenticated().log_err()?;
241 let buffer = server.registered_buffers.get_mut(&id)?;
242 Some(buffer.snapshot.version.clone())
243 })
244 .ok()??;
245 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
246
247 let content_changes = cx
248 .background_spawn({
249 let new_snapshot = new_snapshot.clone();
250 async move {
251 new_snapshot
252 .edits_since::<(PointUtf16, usize)>(&old_version)
253 .map(|edit| {
254 let edit_start = edit.new.start.0;
255 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
256 let new_text = new_snapshot
257 .text_for_range(edit.new.start.1..edit.new.end.1)
258 .collect();
259 lsp::TextDocumentContentChangeEvent {
260 range: Some(lsp::Range::new(
261 point_to_lsp(edit_start),
262 point_to_lsp(edit_end),
263 )),
264 range_length: None,
265 text: new_text,
266 }
267 })
268 .collect::<Vec<_>>()
269 }
270 })
271 .await;
272
273 copilot
274 .update(cx, |copilot, _| {
275 let server = copilot.server.as_authenticated().log_err()?;
276 let buffer = server.registered_buffers.get_mut(&id)?;
277 if !content_changes.is_empty() {
278 buffer.snapshot_version += 1;
279 buffer.snapshot = new_snapshot;
280 server
281 .lsp
282 .notify::<lsp::notification::DidChangeTextDocument>(
283 &lsp::DidChangeTextDocumentParams {
284 text_document: lsp::VersionedTextDocumentIdentifier::new(
285 buffer.uri.clone(),
286 buffer.snapshot_version,
287 ),
288 content_changes,
289 },
290 )
291 .ok();
292 }
293 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
294 Some(())
295 })
296 .ok()?;
297
298 Some(())
299 });
300 }
301
302 done_rx
303 }
304}
305
306#[derive(Debug)]
307pub struct Completion {
308 pub uuid: String,
309 pub range: Range<Anchor>,
310 pub text: String,
311}
312
313pub struct Copilot {
314 fs: Arc<dyn Fs>,
315 node_runtime: NodeRuntime,
316 server: CopilotServer,
317 buffers: HashSet<WeakEntity<Buffer>>,
318 server_id: LanguageServerId,
319 _subscription: gpui::Subscription,
320}
321
322pub enum Event {
323 CopilotLanguageServerStarted,
324 CopilotAuthSignedIn,
325 CopilotAuthSignedOut,
326}
327
328impl EventEmitter<Event> for Copilot {}
329
330struct GlobalCopilot(Entity<Copilot>);
331
332impl Global for GlobalCopilot {}
333
334impl Copilot {
335 pub fn global(cx: &App) -> Option<Entity<Self>> {
336 cx.try_global::<GlobalCopilot>()
337 .map(|model| model.0.clone())
338 }
339
340 pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
341 cx.set_global(GlobalCopilot(copilot));
342 }
343
344 fn start(
345 new_server_id: LanguageServerId,
346 fs: Arc<dyn Fs>,
347 node_runtime: NodeRuntime,
348 cx: &mut Context<Self>,
349 ) -> Self {
350 let mut this = Self {
351 server_id: new_server_id,
352 fs,
353 node_runtime,
354 server: CopilotServer::Disabled,
355 buffers: Default::default(),
356 _subscription: cx.on_app_quit(Self::shutdown_language_server),
357 };
358 this.start_copilot(true, false, cx);
359 cx.observe_global::<SettingsStore>(move |this, cx| {
360 this.start_copilot(true, false, cx);
361 this.send_configuration_update(cx);
362 })
363 .detach();
364 this
365 }
366
367 fn shutdown_language_server(
368 &mut self,
369 _cx: &mut Context<Self>,
370 ) -> impl Future<Output = ()> + use<> {
371 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
372 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
373 _ => None,
374 };
375
376 async move {
377 if let Some(shutdown) = shutdown {
378 shutdown.await;
379 }
380 }
381 }
382
383 fn start_copilot(
384 &mut self,
385 check_edit_prediction_provider: bool,
386 awaiting_sign_in_after_start: bool,
387 cx: &mut Context<Self>,
388 ) {
389 if !matches!(self.server, CopilotServer::Disabled) {
390 return;
391 }
392 let language_settings = all_language_settings(None, cx);
393 if check_edit_prediction_provider
394 && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
395 {
396 return;
397 }
398 let server_id = self.server_id;
399 let fs = self.fs.clone();
400 let node_runtime = self.node_runtime.clone();
401 let env = self.build_env(&language_settings.edit_predictions.copilot);
402 let start_task = cx
403 .spawn(async move |this, cx| {
404 Self::start_language_server(
405 server_id,
406 fs,
407 node_runtime,
408 env,
409 this,
410 awaiting_sign_in_after_start,
411 cx,
412 )
413 .await
414 })
415 .shared();
416 self.server = CopilotServer::Starting { task: start_task };
417 cx.notify();
418 }
419
420 fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
421 let proxy_url = copilot_settings.proxy.clone()?;
422 let no_verify = copilot_settings.proxy_no_verify;
423 let http_or_https_proxy = if proxy_url.starts_with("http:") {
424 Some("HTTP_PROXY")
425 } else if proxy_url.starts_with("https:") {
426 Some("HTTPS_PROXY")
427 } else {
428 log::error!(
429 "Unsupported protocol scheme for language server proxy (must be http or https)"
430 );
431 None
432 };
433
434 let mut env = HashMap::default();
435
436 if let Some(proxy_type) = http_or_https_proxy {
437 env.insert(proxy_type.to_string(), proxy_url);
438 if let Some(true) = no_verify {
439 env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
440 };
441 }
442
443 if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
444 env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
445 }
446
447 if env.is_empty() { None } else { Some(env) }
448 }
449
450 fn send_configuration_update(&mut self, cx: &mut Context<Self>) {
451 let copilot_settings = all_language_settings(None, cx)
452 .edit_predictions
453 .copilot
454 .clone();
455
456 let settings = json!({
457 "http": {
458 "proxy": copilot_settings.proxy,
459 "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
460 },
461 "github-enterprise": {
462 "uri": copilot_settings.enterprise_uri
463 }
464 });
465
466 if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
467 copilot_chat.update(cx, |chat, cx| {
468 chat.set_configuration(
469 copilot_chat::CopilotChatConfiguration {
470 enterprise_uri: copilot_settings.enterprise_uri.clone(),
471 },
472 cx,
473 );
474 });
475 }
476
477 if let Ok(server) = self.server.as_running() {
478 server
479 .lsp
480 .notify::<lsp::notification::DidChangeConfiguration>(
481 &lsp::DidChangeConfigurationParams { settings },
482 )
483 .log_err();
484 }
485 }
486
487 #[cfg(any(test, feature = "test-support"))]
488 pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
489 use fs::FakeFs;
490 use lsp::FakeLanguageServer;
491 use node_runtime::NodeRuntime;
492
493 let (server, fake_server) = FakeLanguageServer::new(
494 LanguageServerId(0),
495 LanguageServerBinary {
496 path: "path/to/copilot".into(),
497 arguments: vec![],
498 env: None,
499 },
500 "copilot".into(),
501 Default::default(),
502 &mut cx.to_async(),
503 );
504 let node_runtime = NodeRuntime::unavailable();
505 let this = cx.new(|cx| Self {
506 server_id: LanguageServerId(0),
507 fs: FakeFs::new(cx.background_executor().clone()),
508 node_runtime,
509 server: CopilotServer::Running(RunningCopilotServer {
510 lsp: Arc::new(server),
511 sign_in_status: SignInStatus::Authorized,
512 registered_buffers: Default::default(),
513 }),
514 _subscription: cx.on_app_quit(Self::shutdown_language_server),
515 buffers: Default::default(),
516 });
517 (this, fake_server)
518 }
519
520 async fn start_language_server(
521 new_server_id: LanguageServerId,
522 fs: Arc<dyn Fs>,
523 node_runtime: NodeRuntime,
524 env: Option<HashMap<String, String>>,
525 this: WeakEntity<Self>,
526 awaiting_sign_in_after_start: bool,
527 cx: &mut AsyncApp,
528 ) {
529 let start_language_server = async {
530 let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
531 let node_path = node_runtime.binary_path().await?;
532 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
533 let binary = LanguageServerBinary {
534 path: node_path,
535 arguments,
536 env,
537 };
538
539 let root_path = if cfg!(target_os = "windows") {
540 Path::new("C:/")
541 } else {
542 Path::new("/")
543 };
544
545 let server_name = LanguageServerName("copilot".into());
546 let server = LanguageServer::new(
547 Arc::new(Mutex::new(None)),
548 new_server_id,
549 server_name,
550 binary,
551 root_path,
552 None,
553 Default::default(),
554 cx,
555 )?;
556
557 server
558 .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
559 .detach();
560
561 let configuration = lsp::DidChangeConfigurationParams {
562 settings: Default::default(),
563 };
564
565 let editor_info = request::SetEditorInfoParams {
566 editor_info: request::EditorInfo {
567 name: "zed".into(),
568 version: env!("CARGO_PKG_VERSION").into(),
569 },
570 editor_plugin_info: request::EditorPluginInfo {
571 name: "zed-copilot".into(),
572 version: "0.0.1".into(),
573 },
574 };
575 let editor_info_json = serde_json::to_value(&editor_info)?;
576
577 let server = cx
578 .update(|cx| {
579 let mut params = server.default_initialize_params(false, cx);
580 params.initialization_options = Some(editor_info_json);
581 server.initialize(params, configuration.into(), cx)
582 })?
583 .await?;
584
585 let status = server
586 .request::<request::CheckStatus>(request::CheckStatusParams {
587 local_checks_only: false,
588 })
589 .await
590 .into_response()
591 .context("copilot: check status")?;
592
593 anyhow::Ok((server, status))
594 };
595
596 let server = start_language_server.await;
597 this.update(cx, |this, cx| {
598 cx.notify();
599 match server {
600 Ok((server, status)) => {
601 this.server = CopilotServer::Running(RunningCopilotServer {
602 lsp: server,
603 sign_in_status: SignInStatus::SignedOut {
604 awaiting_signing_in: awaiting_sign_in_after_start,
605 },
606 registered_buffers: Default::default(),
607 });
608 cx.emit(Event::CopilotLanguageServerStarted);
609 this.update_sign_in_status(status, cx);
610 // Send configuration now that the LSP is fully started
611 this.send_configuration_update(cx);
612 }
613 Err(error) => {
614 this.server = CopilotServer::Error(error.to_string().into());
615 cx.notify()
616 }
617 }
618 })
619 .ok();
620 }
621
622 pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
623 if let CopilotServer::Running(server) = &mut self.server {
624 let task = match &server.sign_in_status {
625 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
626 SignInStatus::SigningIn { task, .. } => {
627 cx.notify();
628 task.clone()
629 }
630 SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
631 let lsp = server.lsp.clone();
632 let task = cx
633 .spawn(async move |this, cx| {
634 let sign_in = async {
635 let sign_in = lsp
636 .request::<request::SignInInitiate>(
637 request::SignInInitiateParams {},
638 )
639 .await
640 .into_response()
641 .context("copilot sign-in")?;
642 match sign_in {
643 request::SignInInitiateResult::AlreadySignedIn { user } => {
644 Ok(request::SignInStatus::Ok { user: Some(user) })
645 }
646 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
647 this.update(cx, |this, cx| {
648 if let CopilotServer::Running(RunningCopilotServer {
649 sign_in_status: status,
650 ..
651 }) = &mut this.server
652 {
653 if let SignInStatus::SigningIn {
654 prompt: prompt_flow,
655 ..
656 } = status
657 {
658 *prompt_flow = Some(flow.clone());
659 cx.notify();
660 }
661 }
662 })?;
663 let response = lsp
664 .request::<request::SignInConfirm>(
665 request::SignInConfirmParams {
666 user_code: flow.user_code,
667 },
668 )
669 .await
670 .into_response()
671 .context("copilot: sign in confirm")?;
672 Ok(response)
673 }
674 }
675 };
676
677 let sign_in = sign_in.await;
678 this.update(cx, |this, cx| match sign_in {
679 Ok(status) => {
680 this.update_sign_in_status(status, cx);
681 Ok(())
682 }
683 Err(error) => {
684 this.update_sign_in_status(
685 request::SignInStatus::NotSignedIn,
686 cx,
687 );
688 Err(Arc::new(error))
689 }
690 })?
691 })
692 .shared();
693 server.sign_in_status = SignInStatus::SigningIn {
694 prompt: None,
695 task: task.clone(),
696 };
697 cx.notify();
698 task
699 }
700 };
701
702 cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
703 } else {
704 // If we're downloading, wait until download is finished
705 // If we're in a stuck state, display to the user
706 Task::ready(Err(anyhow!("copilot hasn't started yet")))
707 }
708 }
709
710 pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
711 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
712 match &self.server {
713 CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
714 let server = server.clone();
715 cx.background_spawn(async move {
716 server
717 .request::<request::SignOut>(request::SignOutParams {})
718 .await
719 .into_response()
720 .context("copilot: sign in confirm")?;
721 anyhow::Ok(())
722 })
723 }
724 CopilotServer::Disabled => cx.background_spawn(async {
725 clear_copilot_config_dir().await;
726 anyhow::Ok(())
727 }),
728 _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
729 }
730 }
731
732 pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
733 let language_settings = all_language_settings(None, cx);
734 let env = self.build_env(&language_settings.edit_predictions.copilot);
735 let start_task = cx
736 .spawn({
737 let fs = self.fs.clone();
738 let node_runtime = self.node_runtime.clone();
739 let server_id = self.server_id;
740 async move |this, cx| {
741 clear_copilot_dir().await;
742 Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
743 .await
744 }
745 })
746 .shared();
747
748 self.server = CopilotServer::Starting {
749 task: start_task.clone(),
750 };
751
752 cx.notify();
753
754 start_task
755 }
756
757 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
758 if let CopilotServer::Running(server) = &self.server {
759 Some(&server.lsp)
760 } else {
761 None
762 }
763 }
764
765 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
766 let weak_buffer = buffer.downgrade();
767 self.buffers.insert(weak_buffer.clone());
768
769 if let CopilotServer::Running(RunningCopilotServer {
770 lsp: server,
771 sign_in_status: status,
772 registered_buffers,
773 ..
774 }) = &mut self.server
775 {
776 if !matches!(status, SignInStatus::Authorized { .. }) {
777 return;
778 }
779
780 let entry = registered_buffers.entry(buffer.entity_id());
781 if let Entry::Vacant(e) = entry {
782 let Ok(uri) = uri_for_buffer(buffer, cx) else {
783 return;
784 };
785 let language_id = id_for_language(buffer.read(cx).language());
786 let snapshot = buffer.read(cx).snapshot();
787 server
788 .notify::<lsp::notification::DidOpenTextDocument>(
789 &lsp::DidOpenTextDocumentParams {
790 text_document: lsp::TextDocumentItem {
791 uri: uri.clone(),
792 language_id: language_id.clone(),
793 version: 0,
794 text: snapshot.text(),
795 },
796 },
797 )
798 .ok();
799
800 e.insert(RegisteredBuffer {
801 uri,
802 language_id,
803 snapshot,
804 snapshot_version: 0,
805 pending_buffer_change: Task::ready(Some(())),
806 _subscriptions: [
807 cx.subscribe(buffer, |this, buffer, event, cx| {
808 this.handle_buffer_event(buffer, event, cx).log_err();
809 }),
810 cx.observe_release(buffer, move |this, _buffer, _cx| {
811 this.buffers.remove(&weak_buffer);
812 this.unregister_buffer(&weak_buffer);
813 }),
814 ],
815 });
816 }
817 }
818 }
819
820 fn handle_buffer_event(
821 &mut self,
822 buffer: Entity<Buffer>,
823 event: &language::BufferEvent,
824 cx: &mut Context<Self>,
825 ) -> Result<()> {
826 if let Ok(server) = self.server.as_running() {
827 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
828 {
829 match event {
830 language::BufferEvent::Edited => {
831 drop(registered_buffer.report_changes(&buffer, cx));
832 }
833 language::BufferEvent::Saved => {
834 server
835 .lsp
836 .notify::<lsp::notification::DidSaveTextDocument>(
837 &lsp::DidSaveTextDocumentParams {
838 text_document: lsp::TextDocumentIdentifier::new(
839 registered_buffer.uri.clone(),
840 ),
841 text: None,
842 },
843 )?;
844 }
845 language::BufferEvent::FileHandleChanged
846 | language::BufferEvent::LanguageChanged => {
847 let new_language_id = id_for_language(buffer.read(cx).language());
848 let Ok(new_uri) = uri_for_buffer(&buffer, cx) else {
849 return Ok(());
850 };
851 if new_uri != registered_buffer.uri
852 || new_language_id != registered_buffer.language_id
853 {
854 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
855 registered_buffer.language_id = new_language_id;
856 server
857 .lsp
858 .notify::<lsp::notification::DidCloseTextDocument>(
859 &lsp::DidCloseTextDocumentParams {
860 text_document: lsp::TextDocumentIdentifier::new(old_uri),
861 },
862 )?;
863 server
864 .lsp
865 .notify::<lsp::notification::DidOpenTextDocument>(
866 &lsp::DidOpenTextDocumentParams {
867 text_document: lsp::TextDocumentItem::new(
868 registered_buffer.uri.clone(),
869 registered_buffer.language_id.clone(),
870 registered_buffer.snapshot_version,
871 registered_buffer.snapshot.text(),
872 ),
873 },
874 )?;
875 }
876 }
877 _ => {}
878 }
879 }
880 }
881
882 Ok(())
883 }
884
885 fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
886 if let Ok(server) = self.server.as_running() {
887 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
888 server
889 .lsp
890 .notify::<lsp::notification::DidCloseTextDocument>(
891 &lsp::DidCloseTextDocumentParams {
892 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
893 },
894 )
895 .ok();
896 }
897 }
898 }
899
900 pub fn completions<T>(
901 &mut self,
902 buffer: &Entity<Buffer>,
903 position: T,
904 cx: &mut Context<Self>,
905 ) -> Task<Result<Vec<Completion>>>
906 where
907 T: ToPointUtf16,
908 {
909 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
910 }
911
912 pub fn completions_cycling<T>(
913 &mut self,
914 buffer: &Entity<Buffer>,
915 position: T,
916 cx: &mut Context<Self>,
917 ) -> Task<Result<Vec<Completion>>>
918 where
919 T: ToPointUtf16,
920 {
921 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
922 }
923
924 pub fn accept_completion(
925 &mut self,
926 completion: &Completion,
927 cx: &mut Context<Self>,
928 ) -> Task<Result<()>> {
929 let server = match self.server.as_authenticated() {
930 Ok(server) => server,
931 Err(error) => return Task::ready(Err(error)),
932 };
933 let request =
934 server
935 .lsp
936 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
937 uuid: completion.uuid.clone(),
938 });
939 cx.background_spawn(async move {
940 request
941 .await
942 .into_response()
943 .context("copilot: notify accepted")?;
944 Ok(())
945 })
946 }
947
948 pub fn discard_completions(
949 &mut self,
950 completions: &[Completion],
951 cx: &mut Context<Self>,
952 ) -> Task<Result<()>> {
953 let server = match self.server.as_authenticated() {
954 Ok(server) => server,
955 Err(_) => return Task::ready(Ok(())),
956 };
957 let request =
958 server
959 .lsp
960 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
961 uuids: completions
962 .iter()
963 .map(|completion| completion.uuid.clone())
964 .collect(),
965 });
966 cx.background_spawn(async move {
967 request
968 .await
969 .into_response()
970 .context("copilot: notify rejected")?;
971 Ok(())
972 })
973 }
974
975 fn request_completions<R, T>(
976 &mut self,
977 buffer: &Entity<Buffer>,
978 position: T,
979 cx: &mut Context<Self>,
980 ) -> Task<Result<Vec<Completion>>>
981 where
982 R: 'static
983 + lsp::request::Request<
984 Params = request::GetCompletionsParams,
985 Result = request::GetCompletionsResult,
986 >,
987 T: ToPointUtf16,
988 {
989 self.register_buffer(buffer, cx);
990
991 let server = match self.server.as_authenticated() {
992 Ok(server) => server,
993 Err(error) => return Task::ready(Err(error)),
994 };
995 let lsp = server.lsp.clone();
996 let registered_buffer = server
997 .registered_buffers
998 .get_mut(&buffer.entity_id())
999 .unwrap();
1000 let snapshot = registered_buffer.report_changes(buffer, cx);
1001 let buffer = buffer.read(cx);
1002 let uri = registered_buffer.uri.clone();
1003 let position = position.to_point_utf16(buffer);
1004 let settings = language_settings(
1005 buffer.language_at(position).map(|l| l.name()),
1006 buffer.file(),
1007 cx,
1008 );
1009 let tab_size = settings.tab_size;
1010 let hard_tabs = settings.hard_tabs;
1011 let relative_path = buffer
1012 .file()
1013 .map(|file| file.path().to_path_buf())
1014 .unwrap_or_default();
1015
1016 cx.background_spawn(async move {
1017 let (version, snapshot) = snapshot.await?;
1018 let result = lsp
1019 .request::<R>(request::GetCompletionsParams {
1020 doc: request::GetCompletionsDocument {
1021 uri,
1022 tab_size: tab_size.into(),
1023 indent_size: 1,
1024 insert_spaces: !hard_tabs,
1025 relative_path: relative_path.to_string_lossy().into(),
1026 position: point_to_lsp(position),
1027 version: version.try_into().unwrap(),
1028 },
1029 })
1030 .await
1031 .into_response()
1032 .context("copilot: get completions")?;
1033 let completions = result
1034 .completions
1035 .into_iter()
1036 .map(|completion| {
1037 let start = snapshot
1038 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
1039 let end =
1040 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
1041 Completion {
1042 uuid: completion.uuid,
1043 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
1044 text: completion.text,
1045 }
1046 })
1047 .collect();
1048 anyhow::Ok(completions)
1049 })
1050 }
1051
1052 pub fn status(&self) -> Status {
1053 match &self.server {
1054 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1055 CopilotServer::Disabled => Status::Disabled,
1056 CopilotServer::Error(error) => Status::Error(error.clone()),
1057 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1058 match sign_in_status {
1059 SignInStatus::Authorized { .. } => Status::Authorized,
1060 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1061 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1062 prompt: prompt.clone(),
1063 },
1064 SignInStatus::SignedOut {
1065 awaiting_signing_in,
1066 } => Status::SignedOut {
1067 awaiting_signing_in: *awaiting_signing_in,
1068 },
1069 }
1070 }
1071 }
1072 }
1073
1074 fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1075 self.buffers.retain(|buffer| buffer.is_upgradable());
1076
1077 if let Ok(server) = self.server.as_running() {
1078 match lsp_status {
1079 request::SignInStatus::Ok { user: Some(_) }
1080 | request::SignInStatus::MaybeOk { .. }
1081 | request::SignInStatus::AlreadySignedIn { .. } => {
1082 server.sign_in_status = SignInStatus::Authorized;
1083 cx.emit(Event::CopilotAuthSignedIn);
1084 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1085 if let Some(buffer) = buffer.upgrade() {
1086 self.register_buffer(&buffer, cx);
1087 }
1088 }
1089 }
1090 request::SignInStatus::NotAuthorized { .. } => {
1091 server.sign_in_status = SignInStatus::Unauthorized;
1092 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1093 self.unregister_buffer(&buffer);
1094 }
1095 }
1096 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1097 if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1098 server.sign_in_status = SignInStatus::SignedOut {
1099 awaiting_signing_in: false,
1100 };
1101 }
1102 cx.emit(Event::CopilotAuthSignedOut);
1103 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1104 self.unregister_buffer(&buffer);
1105 }
1106 }
1107 }
1108
1109 cx.notify();
1110 }
1111 }
1112}
1113
1114fn id_for_language(language: Option<&Arc<Language>>) -> String {
1115 language
1116 .map(|language| language.lsp_id())
1117 .unwrap_or_else(|| "plaintext".to_string())
1118}
1119
1120fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Result<lsp::Url, ()> {
1121 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1122 lsp::Url::from_file_path(file.abs_path(cx))
1123 } else {
1124 format!("buffer://{}", buffer.entity_id())
1125 .parse()
1126 .map_err(|_| ())
1127 }
1128}
1129
1130async fn clear_copilot_dir() {
1131 remove_matching(paths::copilot_dir(), |_| true).await
1132}
1133
1134async fn clear_copilot_config_dir() {
1135 remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1136}
1137
1138async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1139 const PACKAGE_NAME: &str = "@github/copilot-language-server";
1140 const SERVER_PATH: &str =
1141 "node_modules/@github/copilot-language-server/dist/language-server.js";
1142
1143 let latest_version = node_runtime
1144 .npm_package_latest_version(PACKAGE_NAME)
1145 .await?;
1146 let server_path = paths::copilot_dir().join(SERVER_PATH);
1147
1148 fs.create_dir(paths::copilot_dir()).await?;
1149
1150 let should_install = node_runtime
1151 .should_install_npm_package(
1152 PACKAGE_NAME,
1153 &server_path,
1154 paths::copilot_dir(),
1155 &latest_version,
1156 )
1157 .await;
1158 if should_install {
1159 node_runtime
1160 .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1161 .await?;
1162 }
1163
1164 Ok(server_path)
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170 use gpui::TestAppContext;
1171 use util::path;
1172
1173 #[gpui::test(iterations = 10)]
1174 async fn test_buffer_management(cx: &mut TestAppContext) {
1175 let (copilot, mut lsp) = Copilot::fake(cx);
1176
1177 let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1178 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1179 .parse()
1180 .unwrap();
1181 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1182 assert_eq!(
1183 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1184 .await,
1185 lsp::DidOpenTextDocumentParams {
1186 text_document: lsp::TextDocumentItem::new(
1187 buffer_1_uri.clone(),
1188 "plaintext".into(),
1189 0,
1190 "Hello".into()
1191 ),
1192 }
1193 );
1194
1195 let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1196 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1197 .parse()
1198 .unwrap();
1199 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1200 assert_eq!(
1201 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1202 .await,
1203 lsp::DidOpenTextDocumentParams {
1204 text_document: lsp::TextDocumentItem::new(
1205 buffer_2_uri.clone(),
1206 "plaintext".into(),
1207 0,
1208 "Goodbye".into()
1209 ),
1210 }
1211 );
1212
1213 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1214 assert_eq!(
1215 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1216 .await,
1217 lsp::DidChangeTextDocumentParams {
1218 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1219 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1220 range: Some(lsp::Range::new(
1221 lsp::Position::new(0, 5),
1222 lsp::Position::new(0, 5)
1223 )),
1224 range_length: None,
1225 text: " world".into(),
1226 }],
1227 }
1228 );
1229
1230 // Ensure updates to the file are reflected in the LSP.
1231 buffer_1.update(cx, |buffer, cx| {
1232 buffer.file_updated(
1233 Arc::new(File {
1234 abs_path: path!("/root/child/buffer-1").into(),
1235 path: Path::new("child/buffer-1").into(),
1236 }),
1237 cx,
1238 )
1239 });
1240 assert_eq!(
1241 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1242 .await,
1243 lsp::DidCloseTextDocumentParams {
1244 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1245 }
1246 );
1247 let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1248 assert_eq!(
1249 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1250 .await,
1251 lsp::DidOpenTextDocumentParams {
1252 text_document: lsp::TextDocumentItem::new(
1253 buffer_1_uri.clone(),
1254 "plaintext".into(),
1255 1,
1256 "Hello world".into()
1257 ),
1258 }
1259 );
1260
1261 // Ensure all previously-registered buffers are closed when signing out.
1262 lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1263 Ok(request::SignOutResult {})
1264 });
1265 copilot
1266 .update(cx, |copilot, cx| copilot.sign_out(cx))
1267 .await
1268 .unwrap();
1269 assert_eq!(
1270 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1271 .await,
1272 lsp::DidCloseTextDocumentParams {
1273 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1274 }
1275 );
1276 assert_eq!(
1277 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1278 .await,
1279 lsp::DidCloseTextDocumentParams {
1280 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1281 }
1282 );
1283
1284 // Ensure all previously-registered buffers are re-opened when signing in.
1285 lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1286 Ok(request::SignInInitiateResult::AlreadySignedIn {
1287 user: "user-1".into(),
1288 })
1289 });
1290 copilot
1291 .update(cx, |copilot, cx| copilot.sign_in(cx))
1292 .await
1293 .unwrap();
1294
1295 assert_eq!(
1296 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1297 .await,
1298 lsp::DidOpenTextDocumentParams {
1299 text_document: lsp::TextDocumentItem::new(
1300 buffer_1_uri.clone(),
1301 "plaintext".into(),
1302 0,
1303 "Hello world".into()
1304 ),
1305 }
1306 );
1307 assert_eq!(
1308 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1309 .await,
1310 lsp::DidOpenTextDocumentParams {
1311 text_document: lsp::TextDocumentItem::new(
1312 buffer_2_uri.clone(),
1313 "plaintext".into(),
1314 0,
1315 "Goodbye".into()
1316 ),
1317 }
1318 );
1319 // Dropping a buffer causes it to be closed on the LSP side as well.
1320 cx.update(|_| drop(buffer_2));
1321 assert_eq!(
1322 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1323 .await,
1324 lsp::DidCloseTextDocumentParams {
1325 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1326 }
1327 );
1328 }
1329
1330 struct File {
1331 abs_path: PathBuf,
1332 path: Arc<Path>,
1333 }
1334
1335 impl language::File for File {
1336 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1337 Some(self)
1338 }
1339
1340 fn disk_state(&self) -> language::DiskState {
1341 language::DiskState::Present {
1342 mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1343 }
1344 }
1345
1346 fn path(&self) -> &Arc<Path> {
1347 &self.path
1348 }
1349
1350 fn full_path(&self, _: &App) -> PathBuf {
1351 unimplemented!()
1352 }
1353
1354 fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1355 unimplemented!()
1356 }
1357
1358 fn to_proto(&self, _: &App) -> rpc::proto::File {
1359 unimplemented!()
1360 }
1361
1362 fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1363 settings::WorktreeId::from_usize(0)
1364 }
1365
1366 fn is_private(&self) -> bool {
1367 false
1368 }
1369 }
1370
1371 impl language::LocalFile for File {
1372 fn abs_path(&self, _: &App) -> PathBuf {
1373 self.abs_path.clone()
1374 }
1375
1376 fn load(&self, _: &App) -> Task<Result<String>> {
1377 unimplemented!()
1378 }
1379
1380 fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1381 unimplemented!()
1382 }
1383 }
1384}
1385
1386#[cfg(test)]
1387#[ctor::ctor]
1388fn init_logger() {
1389 zlog::init_test();
1390}