copilot.rs

   1pub mod copilot_chat;
   2mod copilot_completion_provider;
   3pub mod request;
   4mod sign_in;
   5
   6use ::fs::Fs;
   7use anyhow::{Result, anyhow};
   8use collections::{HashMap, HashSet};
   9use command_palette_hooks::CommandPaletteFilter;
  10use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
  11use gpui::{
  12    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
  13    WeakEntity, actions,
  14};
  15use http_client::HttpClient;
  16use language::language_settings::CopilotSettings;
  17use language::{
  18    Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
  19    language_settings::{EditPredictionProvider, all_language_settings, language_settings},
  20    point_from_lsp, point_to_lsp,
  21};
  22use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
  23use node_runtime::NodeRuntime;
  24use parking_lot::Mutex;
  25use request::StatusNotification;
  26use settings::SettingsStore;
  27use std::{
  28    any::TypeId,
  29    env,
  30    ffi::OsString,
  31    mem,
  32    ops::Range,
  33    path::{Path, PathBuf},
  34    sync::Arc,
  35};
  36use util::{ResultExt, fs::remove_matching};
  37
  38pub use crate::copilot_completion_provider::CopilotCompletionProvider;
  39pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in};
  40
  41actions!(
  42    copilot,
  43    [
  44        Suggest,
  45        NextSuggestion,
  46        PreviousSuggestion,
  47        Reinstall,
  48        SignIn,
  49        SignOut
  50    ]
  51);
  52
  53pub fn init(
  54    new_server_id: LanguageServerId,
  55    fs: Arc<dyn Fs>,
  56    http: Arc<dyn HttpClient>,
  57    node_runtime: NodeRuntime,
  58    cx: &mut App,
  59) {
  60    copilot_chat::init(fs.clone(), http.clone(), cx);
  61
  62    let copilot = cx.new({
  63        let node_runtime = node_runtime.clone();
  64        move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
  65    });
  66    Copilot::set_global(copilot.clone(), cx);
  67    cx.observe(&copilot, |handle, cx| {
  68        let copilot_action_types = [
  69            TypeId::of::<Suggest>(),
  70            TypeId::of::<NextSuggestion>(),
  71            TypeId::of::<PreviousSuggestion>(),
  72            TypeId::of::<Reinstall>(),
  73        ];
  74        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  75        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  76        let status = handle.read(cx).status();
  77        let filter = CommandPaletteFilter::global_mut(cx);
  78
  79        match status {
  80            Status::Disabled => {
  81                filter.hide_action_types(&copilot_action_types);
  82                filter.hide_action_types(&copilot_auth_action_types);
  83                filter.hide_action_types(&copilot_no_auth_action_types);
  84            }
  85            Status::Authorized => {
  86                filter.hide_action_types(&copilot_no_auth_action_types);
  87                filter.show_action_types(
  88                    copilot_action_types
  89                        .iter()
  90                        .chain(&copilot_auth_action_types),
  91                );
  92            }
  93            _ => {
  94                filter.hide_action_types(&copilot_action_types);
  95                filter.hide_action_types(&copilot_auth_action_types);
  96                filter.show_action_types(copilot_no_auth_action_types.iter());
  97            }
  98        }
  99    })
 100    .detach();
 101
 102    cx.on_action(|_: &SignIn, cx| {
 103        if let Some(copilot) = Copilot::global(cx) {
 104            copilot
 105                .update(cx, |copilot, cx| copilot.sign_in(cx))
 106                .detach_and_log_err(cx);
 107        }
 108    });
 109    cx.on_action(|_: &SignOut, cx| {
 110        if let Some(copilot) = Copilot::global(cx) {
 111            copilot
 112                .update(cx, |copilot, cx| copilot.sign_out(cx))
 113                .detach_and_log_err(cx);
 114        }
 115    });
 116    cx.on_action(|_: &Reinstall, cx| {
 117        if let Some(copilot) = Copilot::global(cx) {
 118            copilot
 119                .update(cx, |copilot, cx| copilot.reinstall(cx))
 120                .detach();
 121        }
 122    });
 123}
 124
 125enum CopilotServer {
 126    Disabled,
 127    Starting { task: Shared<Task<()>> },
 128    Error(Arc<str>),
 129    Running(RunningCopilotServer),
 130}
 131
 132impl CopilotServer {
 133    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 134        let server = self.as_running()?;
 135        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 136            Ok(server)
 137        } else {
 138            Err(anyhow!("must sign in before using copilot"))
 139        }
 140    }
 141
 142    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 143        match self {
 144            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 145            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 146            CopilotServer::Error(error) => Err(anyhow!(
 147                "copilot was not started because of an error: {}",
 148                error
 149            )),
 150            CopilotServer::Running(server) => Ok(server),
 151        }
 152    }
 153}
 154
 155struct RunningCopilotServer {
 156    lsp: Arc<LanguageServer>,
 157    sign_in_status: SignInStatus,
 158    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 159}
 160
 161#[derive(Clone, Debug)]
 162enum SignInStatus {
 163    Authorized,
 164    Unauthorized,
 165    SigningIn {
 166        prompt: Option<request::PromptUserDeviceFlow>,
 167        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 168    },
 169    SignedOut {
 170        awaiting_signing_in: bool,
 171    },
 172}
 173
 174#[derive(Debug, Clone)]
 175pub enum Status {
 176    Starting {
 177        task: Shared<Task<()>>,
 178    },
 179    Error(Arc<str>),
 180    Disabled,
 181    SignedOut {
 182        awaiting_signing_in: bool,
 183    },
 184    SigningIn {
 185        prompt: Option<request::PromptUserDeviceFlow>,
 186    },
 187    Unauthorized,
 188    Authorized,
 189}
 190
 191impl Status {
 192    pub fn is_authorized(&self) -> bool {
 193        matches!(self, Status::Authorized)
 194    }
 195
 196    pub fn is_disabled(&self) -> bool {
 197        matches!(self, Status::Disabled)
 198    }
 199}
 200
 201struct RegisteredBuffer {
 202    uri: lsp::Url,
 203    language_id: String,
 204    snapshot: BufferSnapshot,
 205    snapshot_version: i32,
 206    _subscriptions: [gpui::Subscription; 2],
 207    pending_buffer_change: Task<Option<()>>,
 208}
 209
 210impl RegisteredBuffer {
 211    fn report_changes(
 212        &mut self,
 213        buffer: &Entity<Buffer>,
 214        cx: &mut Context<Copilot>,
 215    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 216        let (done_tx, done_rx) = oneshot::channel();
 217
 218        if buffer.read(cx).version() == self.snapshot.version {
 219            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 220        } else {
 221            let buffer = buffer.downgrade();
 222            let id = buffer.entity_id();
 223            let prev_pending_change =
 224                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 225            self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
 226                prev_pending_change.await;
 227
 228                let old_version = copilot
 229                    .update(cx, |copilot, _| {
 230                        let server = copilot.server.as_authenticated().log_err()?;
 231                        let buffer = server.registered_buffers.get_mut(&id)?;
 232                        Some(buffer.snapshot.version.clone())
 233                    })
 234                    .ok()??;
 235                let new_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()).ok()?;
 236
 237                let content_changes = cx
 238                    .background_spawn({
 239                        let new_snapshot = new_snapshot.clone();
 240                        async move {
 241                            new_snapshot
 242                                .edits_since::<(PointUtf16, usize)>(&old_version)
 243                                .map(|edit| {
 244                                    let edit_start = edit.new.start.0;
 245                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 246                                    let new_text = new_snapshot
 247                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 248                                        .collect();
 249                                    lsp::TextDocumentContentChangeEvent {
 250                                        range: Some(lsp::Range::new(
 251                                            point_to_lsp(edit_start),
 252                                            point_to_lsp(edit_end),
 253                                        )),
 254                                        range_length: None,
 255                                        text: new_text,
 256                                    }
 257                                })
 258                                .collect::<Vec<_>>()
 259                        }
 260                    })
 261                    .await;
 262
 263                copilot
 264                    .update(cx, |copilot, _| {
 265                        let server = copilot.server.as_authenticated().log_err()?;
 266                        let buffer = server.registered_buffers.get_mut(&id)?;
 267                        if !content_changes.is_empty() {
 268                            buffer.snapshot_version += 1;
 269                            buffer.snapshot = new_snapshot;
 270                            server
 271                                .lsp
 272                                .notify::<lsp::notification::DidChangeTextDocument>(
 273                                    &lsp::DidChangeTextDocumentParams {
 274                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 275                                            buffer.uri.clone(),
 276                                            buffer.snapshot_version,
 277                                        ),
 278                                        content_changes,
 279                                    },
 280                                )
 281                                .ok();
 282                        }
 283                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 284                        Some(())
 285                    })
 286                    .ok()?;
 287
 288                Some(())
 289            });
 290        }
 291
 292        done_rx
 293    }
 294}
 295
 296#[derive(Debug)]
 297pub struct Completion {
 298    pub uuid: String,
 299    pub range: Range<Anchor>,
 300    pub text: String,
 301}
 302
 303pub struct Copilot {
 304    fs: Arc<dyn Fs>,
 305    node_runtime: NodeRuntime,
 306    server: CopilotServer,
 307    buffers: HashSet<WeakEntity<Buffer>>,
 308    server_id: LanguageServerId,
 309    _subscription: gpui::Subscription,
 310}
 311
 312pub enum Event {
 313    CopilotLanguageServerStarted,
 314    CopilotAuthSignedIn,
 315    CopilotAuthSignedOut,
 316}
 317
 318impl EventEmitter<Event> for Copilot {}
 319
 320struct GlobalCopilot(Entity<Copilot>);
 321
 322impl Global for GlobalCopilot {}
 323
 324impl Copilot {
 325    pub fn global(cx: &App) -> Option<Entity<Self>> {
 326        cx.try_global::<GlobalCopilot>()
 327            .map(|model| model.0.clone())
 328    }
 329
 330    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
 331        cx.set_global(GlobalCopilot(copilot));
 332    }
 333
 334    fn start(
 335        new_server_id: LanguageServerId,
 336        fs: Arc<dyn Fs>,
 337        node_runtime: NodeRuntime,
 338        cx: &mut Context<Self>,
 339    ) -> Self {
 340        let mut this = Self {
 341            server_id: new_server_id,
 342            fs,
 343            node_runtime,
 344            server: CopilotServer::Disabled,
 345            buffers: Default::default(),
 346            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 347        };
 348        this.start_copilot(true, false, cx);
 349        cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
 350            .detach();
 351        this
 352    }
 353
 354    fn shutdown_language_server(
 355        &mut self,
 356        _cx: &mut Context<Self>,
 357    ) -> impl Future<Output = ()> + use<> {
 358        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 359            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 360            _ => None,
 361        };
 362
 363        async move {
 364            if let Some(shutdown) = shutdown {
 365                shutdown.await;
 366            }
 367        }
 368    }
 369
 370    fn start_copilot(
 371        &mut self,
 372        check_edit_prediction_provider: bool,
 373        awaiting_sign_in_after_start: bool,
 374        cx: &mut Context<Self>,
 375    ) {
 376        if !matches!(self.server, CopilotServer::Disabled) {
 377            return;
 378        }
 379        let language_settings = all_language_settings(None, cx);
 380        if check_edit_prediction_provider
 381            && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
 382        {
 383            return;
 384        }
 385        let server_id = self.server_id;
 386        let fs = self.fs.clone();
 387        let node_runtime = self.node_runtime.clone();
 388        let env = self.build_env(&language_settings.edit_predictions.copilot);
 389        let start_task = cx
 390            .spawn(async move |this, cx| {
 391                Self::start_language_server(
 392                    server_id,
 393                    fs,
 394                    node_runtime,
 395                    env,
 396                    this,
 397                    awaiting_sign_in_after_start,
 398                    cx,
 399                )
 400                .await
 401            })
 402            .shared();
 403        self.server = CopilotServer::Starting { task: start_task };
 404        cx.notify();
 405    }
 406
 407    fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
 408        let proxy_url = copilot_settings.proxy.clone()?;
 409        let no_verify = copilot_settings.proxy_no_verify;
 410        let http_or_https_proxy = if proxy_url.starts_with("http:") {
 411            "HTTP_PROXY"
 412        } else if proxy_url.starts_with("https:") {
 413            "HTTPS_PROXY"
 414        } else {
 415            log::error!(
 416                "Unsupported protocol scheme for language server proxy (must be http or https)"
 417            );
 418            return None;
 419        };
 420
 421        let mut env = HashMap::default();
 422        env.insert(http_or_https_proxy.to_string(), proxy_url);
 423
 424        if let Some(true) = no_verify {
 425            env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
 426        };
 427
 428        Some(env)
 429    }
 430
 431    #[cfg(any(test, feature = "test-support"))]
 432    pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
 433        use fs::FakeFs;
 434        use lsp::FakeLanguageServer;
 435        use node_runtime::NodeRuntime;
 436
 437        let (server, fake_server) = FakeLanguageServer::new(
 438            LanguageServerId(0),
 439            LanguageServerBinary {
 440                path: "path/to/copilot".into(),
 441                arguments: vec![],
 442                env: None,
 443            },
 444            "copilot".into(),
 445            Default::default(),
 446            &mut cx.to_async(),
 447        );
 448        let node_runtime = NodeRuntime::unavailable();
 449        let this = cx.new(|cx| Self {
 450            server_id: LanguageServerId(0),
 451            fs: FakeFs::new(cx.background_executor().clone()),
 452            node_runtime,
 453            server: CopilotServer::Running(RunningCopilotServer {
 454                lsp: Arc::new(server),
 455                sign_in_status: SignInStatus::Authorized,
 456                registered_buffers: Default::default(),
 457            }),
 458            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 459            buffers: Default::default(),
 460        });
 461        (this, fake_server)
 462    }
 463
 464    async fn start_language_server(
 465        new_server_id: LanguageServerId,
 466        fs: Arc<dyn Fs>,
 467        node_runtime: NodeRuntime,
 468        env: Option<HashMap<String, String>>,
 469        this: WeakEntity<Self>,
 470        awaiting_sign_in_after_start: bool,
 471        cx: &mut AsyncApp,
 472    ) {
 473        let start_language_server = async {
 474            let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
 475            let node_path = node_runtime.binary_path().await?;
 476            let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 477            let binary = LanguageServerBinary {
 478                path: node_path,
 479                arguments,
 480                env,
 481            };
 482
 483            let root_path = if cfg!(target_os = "windows") {
 484                Path::new("C:/")
 485            } else {
 486                Path::new("/")
 487            };
 488
 489            let server_name = LanguageServerName("copilot".into());
 490            let server = LanguageServer::new(
 491                Arc::new(Mutex::new(None)),
 492                new_server_id,
 493                server_name,
 494                binary,
 495                root_path,
 496                None,
 497                Default::default(),
 498                cx,
 499            )?;
 500
 501            server
 502                .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
 503                .detach();
 504
 505            let configuration = lsp::DidChangeConfigurationParams {
 506                settings: Default::default(),
 507            };
 508
 509            let editor_info = request::SetEditorInfoParams {
 510                editor_info: request::EditorInfo {
 511                    name: "zed".into(),
 512                    version: env!("CARGO_PKG_VERSION").into(),
 513                },
 514                editor_plugin_info: request::EditorPluginInfo {
 515                    name: "zed-copilot".into(),
 516                    version: "0.0.1".into(),
 517                },
 518            };
 519            let editor_info_json = serde_json::to_value(&editor_info)?;
 520
 521            let server = cx
 522                .update(|cx| {
 523                    let mut params = server.default_initialize_params(cx);
 524                    params.initialization_options = Some(editor_info_json);
 525                    server.initialize(params, configuration.into(), cx)
 526                })?
 527                .await?;
 528
 529            let status = server
 530                .request::<request::CheckStatus>(request::CheckStatusParams {
 531                    local_checks_only: false,
 532                })
 533                .await?;
 534
 535            server
 536                .request::<request::SetEditorInfo>(editor_info)
 537                .await?;
 538
 539            anyhow::Ok((server, status))
 540        };
 541
 542        let server = start_language_server.await;
 543        this.update(cx, |this, cx| {
 544            cx.notify();
 545            match server {
 546                Ok((server, status)) => {
 547                    this.server = CopilotServer::Running(RunningCopilotServer {
 548                        lsp: server,
 549                        sign_in_status: SignInStatus::SignedOut {
 550                            awaiting_signing_in: awaiting_sign_in_after_start,
 551                        },
 552                        registered_buffers: Default::default(),
 553                    });
 554                    cx.emit(Event::CopilotLanguageServerStarted);
 555                    this.update_sign_in_status(status, cx);
 556                }
 557                Err(error) => {
 558                    this.server = CopilotServer::Error(error.to_string().into());
 559                    cx.notify()
 560                }
 561            }
 562        })
 563        .ok();
 564    }
 565
 566    pub fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 567        if let CopilotServer::Running(server) = &mut self.server {
 568            let task = match &server.sign_in_status {
 569                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 570                SignInStatus::SigningIn { task, .. } => {
 571                    cx.notify();
 572                    task.clone()
 573                }
 574                SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
 575                    let lsp = server.lsp.clone();
 576                    let task = cx
 577                        .spawn(async move |this, cx| {
 578                            let sign_in = async {
 579                                let sign_in = lsp
 580                                    .request::<request::SignInInitiate>(
 581                                        request::SignInInitiateParams {},
 582                                    )
 583                                    .await?;
 584                                match sign_in {
 585                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 586                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 587                                    }
 588                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 589                                        this.update(cx, |this, cx| {
 590                                            if let CopilotServer::Running(RunningCopilotServer {
 591                                                sign_in_status: status,
 592                                                ..
 593                                            }) = &mut this.server
 594                                            {
 595                                                if let SignInStatus::SigningIn {
 596                                                    prompt: prompt_flow,
 597                                                    ..
 598                                                } = status
 599                                                {
 600                                                    *prompt_flow = Some(flow.clone());
 601                                                    cx.notify();
 602                                                }
 603                                            }
 604                                        })?;
 605                                        let response = lsp
 606                                            .request::<request::SignInConfirm>(
 607                                                request::SignInConfirmParams {
 608                                                    user_code: flow.user_code,
 609                                                },
 610                                            )
 611                                            .await?;
 612                                        Ok(response)
 613                                    }
 614                                }
 615                            };
 616
 617                            let sign_in = sign_in.await;
 618                            this.update(cx, |this, cx| match sign_in {
 619                                Ok(status) => {
 620                                    this.update_sign_in_status(status, cx);
 621                                    Ok(())
 622                                }
 623                                Err(error) => {
 624                                    this.update_sign_in_status(
 625                                        request::SignInStatus::NotSignedIn,
 626                                        cx,
 627                                    );
 628                                    Err(Arc::new(error))
 629                                }
 630                            })?
 631                        })
 632                        .shared();
 633                    server.sign_in_status = SignInStatus::SigningIn {
 634                        prompt: None,
 635                        task: task.clone(),
 636                    };
 637                    cx.notify();
 638                    task
 639                }
 640            };
 641
 642            cx.background_spawn(task.map_err(|err| anyhow!("{:?}", err)))
 643        } else {
 644            // If we're downloading, wait until download is finished
 645            // If we're in a stuck state, display to the user
 646            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 647        }
 648    }
 649
 650    pub fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 651        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 652        match &self.server {
 653            CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
 654                let server = server.clone();
 655                cx.background_spawn(async move {
 656                    server
 657                        .request::<request::SignOut>(request::SignOutParams {})
 658                        .await?;
 659                    anyhow::Ok(())
 660                })
 661            }
 662            CopilotServer::Disabled => cx.background_spawn(async {
 663                clear_copilot_config_dir().await;
 664                anyhow::Ok(())
 665            }),
 666            _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
 667        }
 668    }
 669
 670    pub fn reinstall(&mut self, cx: &mut Context<Self>) -> Task<()> {
 671        let language_settings = all_language_settings(None, cx);
 672        let env = self.build_env(&language_settings.edit_predictions.copilot);
 673        let start_task = cx
 674            .spawn({
 675                let fs = self.fs.clone();
 676                let node_runtime = self.node_runtime.clone();
 677                let server_id = self.server_id;
 678                async move |this, cx| {
 679                    clear_copilot_dir().await;
 680                    Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
 681                        .await
 682                }
 683            })
 684            .shared();
 685
 686        self.server = CopilotServer::Starting {
 687            task: start_task.clone(),
 688        };
 689
 690        cx.notify();
 691
 692        cx.background_spawn(start_task)
 693    }
 694
 695    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 696        if let CopilotServer::Running(server) = &self.server {
 697            Some(&server.lsp)
 698        } else {
 699            None
 700        }
 701    }
 702
 703    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 704        let weak_buffer = buffer.downgrade();
 705        self.buffers.insert(weak_buffer.clone());
 706
 707        if let CopilotServer::Running(RunningCopilotServer {
 708            lsp: server,
 709            sign_in_status: status,
 710            registered_buffers,
 711            ..
 712        }) = &mut self.server
 713        {
 714            if !matches!(status, SignInStatus::Authorized { .. }) {
 715                return;
 716            }
 717
 718            registered_buffers
 719                .entry(buffer.entity_id())
 720                .or_insert_with(|| {
 721                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 722                    let language_id = id_for_language(buffer.read(cx).language());
 723                    let snapshot = buffer.read(cx).snapshot();
 724                    server
 725                        .notify::<lsp::notification::DidOpenTextDocument>(
 726                            &lsp::DidOpenTextDocumentParams {
 727                                text_document: lsp::TextDocumentItem {
 728                                    uri: uri.clone(),
 729                                    language_id: language_id.clone(),
 730                                    version: 0,
 731                                    text: snapshot.text(),
 732                                },
 733                            },
 734                        )
 735                        .ok();
 736
 737                    RegisteredBuffer {
 738                        uri,
 739                        language_id,
 740                        snapshot,
 741                        snapshot_version: 0,
 742                        pending_buffer_change: Task::ready(Some(())),
 743                        _subscriptions: [
 744                            cx.subscribe(buffer, |this, buffer, event, cx| {
 745                                this.handle_buffer_event(buffer, event, cx).log_err();
 746                            }),
 747                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 748                                this.buffers.remove(&weak_buffer);
 749                                this.unregister_buffer(&weak_buffer);
 750                            }),
 751                        ],
 752                    }
 753                });
 754        }
 755    }
 756
 757    fn handle_buffer_event(
 758        &mut self,
 759        buffer: Entity<Buffer>,
 760        event: &language::BufferEvent,
 761        cx: &mut Context<Self>,
 762    ) -> Result<()> {
 763        if let Ok(server) = self.server.as_running() {
 764            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 765            {
 766                match event {
 767                    language::BufferEvent::Edited => {
 768                        drop(registered_buffer.report_changes(&buffer, cx));
 769                    }
 770                    language::BufferEvent::Saved => {
 771                        server
 772                            .lsp
 773                            .notify::<lsp::notification::DidSaveTextDocument>(
 774                                &lsp::DidSaveTextDocumentParams {
 775                                    text_document: lsp::TextDocumentIdentifier::new(
 776                                        registered_buffer.uri.clone(),
 777                                    ),
 778                                    text: None,
 779                                },
 780                            )?;
 781                    }
 782                    language::BufferEvent::FileHandleChanged
 783                    | language::BufferEvent::LanguageChanged => {
 784                        let new_language_id = id_for_language(buffer.read(cx).language());
 785                        let new_uri = uri_for_buffer(&buffer, cx);
 786                        if new_uri != registered_buffer.uri
 787                            || new_language_id != registered_buffer.language_id
 788                        {
 789                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 790                            registered_buffer.language_id = new_language_id;
 791                            server
 792                                .lsp
 793                                .notify::<lsp::notification::DidCloseTextDocument>(
 794                                    &lsp::DidCloseTextDocumentParams {
 795                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 796                                    },
 797                                )?;
 798                            server
 799                                .lsp
 800                                .notify::<lsp::notification::DidOpenTextDocument>(
 801                                    &lsp::DidOpenTextDocumentParams {
 802                                        text_document: lsp::TextDocumentItem::new(
 803                                            registered_buffer.uri.clone(),
 804                                            registered_buffer.language_id.clone(),
 805                                            registered_buffer.snapshot_version,
 806                                            registered_buffer.snapshot.text(),
 807                                        ),
 808                                    },
 809                                )?;
 810                        }
 811                    }
 812                    _ => {}
 813                }
 814            }
 815        }
 816
 817        Ok(())
 818    }
 819
 820    fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
 821        if let Ok(server) = self.server.as_running() {
 822            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 823                server
 824                    .lsp
 825                    .notify::<lsp::notification::DidCloseTextDocument>(
 826                        &lsp::DidCloseTextDocumentParams {
 827                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 828                        },
 829                    )
 830                    .ok();
 831            }
 832        }
 833    }
 834
 835    pub fn completions<T>(
 836        &mut self,
 837        buffer: &Entity<Buffer>,
 838        position: T,
 839        cx: &mut Context<Self>,
 840    ) -> Task<Result<Vec<Completion>>>
 841    where
 842        T: ToPointUtf16,
 843    {
 844        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 845    }
 846
 847    pub fn completions_cycling<T>(
 848        &mut self,
 849        buffer: &Entity<Buffer>,
 850        position: T,
 851        cx: &mut Context<Self>,
 852    ) -> Task<Result<Vec<Completion>>>
 853    where
 854        T: ToPointUtf16,
 855    {
 856        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 857    }
 858
 859    pub fn accept_completion(
 860        &mut self,
 861        completion: &Completion,
 862        cx: &mut Context<Self>,
 863    ) -> Task<Result<()>> {
 864        let server = match self.server.as_authenticated() {
 865            Ok(server) => server,
 866            Err(error) => return Task::ready(Err(error)),
 867        };
 868        let request =
 869            server
 870                .lsp
 871                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 872                    uuid: completion.uuid.clone(),
 873                });
 874        cx.background_spawn(async move {
 875            request.await?;
 876            Ok(())
 877        })
 878    }
 879
 880    pub fn discard_completions(
 881        &mut self,
 882        completions: &[Completion],
 883        cx: &mut Context<Self>,
 884    ) -> Task<Result<()>> {
 885        let server = match self.server.as_authenticated() {
 886            Ok(server) => server,
 887            Err(_) => return Task::ready(Ok(())),
 888        };
 889        let request =
 890            server
 891                .lsp
 892                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 893                    uuids: completions
 894                        .iter()
 895                        .map(|completion| completion.uuid.clone())
 896                        .collect(),
 897                });
 898        cx.background_spawn(async move {
 899            request.await?;
 900            Ok(())
 901        })
 902    }
 903
 904    fn request_completions<R, T>(
 905        &mut self,
 906        buffer: &Entity<Buffer>,
 907        position: T,
 908        cx: &mut Context<Self>,
 909    ) -> Task<Result<Vec<Completion>>>
 910    where
 911        R: 'static
 912            + lsp::request::Request<
 913                Params = request::GetCompletionsParams,
 914                Result = request::GetCompletionsResult,
 915            >,
 916        T: ToPointUtf16,
 917    {
 918        self.register_buffer(buffer, cx);
 919
 920        let server = match self.server.as_authenticated() {
 921            Ok(server) => server,
 922            Err(error) => return Task::ready(Err(error)),
 923        };
 924        let lsp = server.lsp.clone();
 925        let registered_buffer = server
 926            .registered_buffers
 927            .get_mut(&buffer.entity_id())
 928            .unwrap();
 929        let snapshot = registered_buffer.report_changes(buffer, cx);
 930        let buffer = buffer.read(cx);
 931        let uri = registered_buffer.uri.clone();
 932        let position = position.to_point_utf16(buffer);
 933        let settings = language_settings(
 934            buffer.language_at(position).map(|l| l.name()),
 935            buffer.file(),
 936            cx,
 937        );
 938        let tab_size = settings.tab_size;
 939        let hard_tabs = settings.hard_tabs;
 940        let relative_path = buffer
 941            .file()
 942            .map(|file| file.path().to_path_buf())
 943            .unwrap_or_default();
 944
 945        cx.background_spawn(async move {
 946            let (version, snapshot) = snapshot.await?;
 947            let result = lsp
 948                .request::<R>(request::GetCompletionsParams {
 949                    doc: request::GetCompletionsDocument {
 950                        uri,
 951                        tab_size: tab_size.into(),
 952                        indent_size: 1,
 953                        insert_spaces: !hard_tabs,
 954                        relative_path: relative_path.to_string_lossy().into(),
 955                        position: point_to_lsp(position),
 956                        version: version.try_into().unwrap(),
 957                    },
 958                })
 959                .await?;
 960            let completions = result
 961                .completions
 962                .into_iter()
 963                .map(|completion| {
 964                    let start = snapshot
 965                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 966                    let end =
 967                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 968                    Completion {
 969                        uuid: completion.uuid,
 970                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 971                        text: completion.text,
 972                    }
 973                })
 974                .collect();
 975            anyhow::Ok(completions)
 976        })
 977    }
 978
 979    pub fn status(&self) -> Status {
 980        match &self.server {
 981            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 982            CopilotServer::Disabled => Status::Disabled,
 983            CopilotServer::Error(error) => Status::Error(error.clone()),
 984            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 985                match sign_in_status {
 986                    SignInStatus::Authorized { .. } => Status::Authorized,
 987                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 988                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 989                        prompt: prompt.clone(),
 990                    },
 991                    SignInStatus::SignedOut {
 992                        awaiting_signing_in,
 993                    } => Status::SignedOut {
 994                        awaiting_signing_in: *awaiting_signing_in,
 995                    },
 996                }
 997            }
 998        }
 999    }
1000
1001    fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1002        self.buffers.retain(|buffer| buffer.is_upgradable());
1003
1004        if let Ok(server) = self.server.as_running() {
1005            match lsp_status {
1006                request::SignInStatus::Ok { user: Some(_) }
1007                | request::SignInStatus::MaybeOk { .. }
1008                | request::SignInStatus::AlreadySignedIn { .. } => {
1009                    server.sign_in_status = SignInStatus::Authorized;
1010                    cx.emit(Event::CopilotAuthSignedIn);
1011                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1012                        if let Some(buffer) = buffer.upgrade() {
1013                            self.register_buffer(&buffer, cx);
1014                        }
1015                    }
1016                }
1017                request::SignInStatus::NotAuthorized { .. } => {
1018                    server.sign_in_status = SignInStatus::Unauthorized;
1019                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1020                        self.unregister_buffer(&buffer);
1021                    }
1022                }
1023                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1024                    if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1025                        server.sign_in_status = SignInStatus::SignedOut {
1026                            awaiting_signing_in: false,
1027                        };
1028                    }
1029                    cx.emit(Event::CopilotAuthSignedOut);
1030                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1031                        self.unregister_buffer(&buffer);
1032                    }
1033                }
1034            }
1035
1036            cx.notify();
1037        }
1038    }
1039}
1040
1041fn id_for_language(language: Option<&Arc<Language>>) -> String {
1042    language
1043        .map(|language| language.lsp_id())
1044        .unwrap_or_else(|| "plaintext".to_string())
1045}
1046
1047fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> lsp::Url {
1048    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1049        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
1050    } else {
1051        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
1052    }
1053}
1054
1055async fn clear_copilot_dir() {
1056    remove_matching(paths::copilot_dir(), |_| true).await
1057}
1058
1059async fn clear_copilot_config_dir() {
1060    remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1061}
1062
1063async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1064    const PACKAGE_NAME: &str = "@github/copilot-language-server";
1065    const SERVER_PATH: &str =
1066        "node_modules/@github/copilot-language-server/dist/language-server.js";
1067
1068    let latest_version = node_runtime
1069        .npm_package_latest_version(PACKAGE_NAME)
1070        .await?;
1071    let server_path = paths::copilot_dir().join(SERVER_PATH);
1072
1073    fs.create_dir(paths::copilot_dir()).await?;
1074
1075    let should_install = node_runtime
1076        .should_install_npm_package(
1077            PACKAGE_NAME,
1078            &server_path,
1079            paths::copilot_dir(),
1080            &latest_version,
1081        )
1082        .await;
1083    if should_install {
1084        node_runtime
1085            .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1086            .await?;
1087    }
1088
1089    Ok(server_path)
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094    use super::*;
1095    use gpui::TestAppContext;
1096    use util::path;
1097
1098    #[gpui::test(iterations = 10)]
1099    async fn test_buffer_management(cx: &mut TestAppContext) {
1100        let (copilot, mut lsp) = Copilot::fake(cx);
1101
1102        let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1103        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1104            .parse()
1105            .unwrap();
1106        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1107        assert_eq!(
1108            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1109                .await,
1110            lsp::DidOpenTextDocumentParams {
1111                text_document: lsp::TextDocumentItem::new(
1112                    buffer_1_uri.clone(),
1113                    "plaintext".into(),
1114                    0,
1115                    "Hello".into()
1116                ),
1117            }
1118        );
1119
1120        let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1121        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1122            .parse()
1123            .unwrap();
1124        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1125        assert_eq!(
1126            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1127                .await,
1128            lsp::DidOpenTextDocumentParams {
1129                text_document: lsp::TextDocumentItem::new(
1130                    buffer_2_uri.clone(),
1131                    "plaintext".into(),
1132                    0,
1133                    "Goodbye".into()
1134                ),
1135            }
1136        );
1137
1138        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1139        assert_eq!(
1140            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1141                .await,
1142            lsp::DidChangeTextDocumentParams {
1143                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1144                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1145                    range: Some(lsp::Range::new(
1146                        lsp::Position::new(0, 5),
1147                        lsp::Position::new(0, 5)
1148                    )),
1149                    range_length: None,
1150                    text: " world".into(),
1151                }],
1152            }
1153        );
1154
1155        // Ensure updates to the file are reflected in the LSP.
1156        buffer_1.update(cx, |buffer, cx| {
1157            buffer.file_updated(
1158                Arc::new(File {
1159                    abs_path: path!("/root/child/buffer-1").into(),
1160                    path: Path::new("child/buffer-1").into(),
1161                }),
1162                cx,
1163            )
1164        });
1165        assert_eq!(
1166            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1167                .await,
1168            lsp::DidCloseTextDocumentParams {
1169                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1170            }
1171        );
1172        let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1173        assert_eq!(
1174            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1175                .await,
1176            lsp::DidOpenTextDocumentParams {
1177                text_document: lsp::TextDocumentItem::new(
1178                    buffer_1_uri.clone(),
1179                    "plaintext".into(),
1180                    1,
1181                    "Hello world".into()
1182                ),
1183            }
1184        );
1185
1186        // Ensure all previously-registered buffers are closed when signing out.
1187        lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1188            Ok(request::SignOutResult {})
1189        });
1190        copilot
1191            .update(cx, |copilot, cx| copilot.sign_out(cx))
1192            .await
1193            .unwrap();
1194        assert_eq!(
1195            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1196                .await,
1197            lsp::DidCloseTextDocumentParams {
1198                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1199            }
1200        );
1201        assert_eq!(
1202            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1203                .await,
1204            lsp::DidCloseTextDocumentParams {
1205                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1206            }
1207        );
1208
1209        // Ensure all previously-registered buffers are re-opened when signing in.
1210        lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1211            Ok(request::SignInInitiateResult::AlreadySignedIn {
1212                user: "user-1".into(),
1213            })
1214        });
1215        copilot
1216            .update(cx, |copilot, cx| copilot.sign_in(cx))
1217            .await
1218            .unwrap();
1219
1220        assert_eq!(
1221            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1222                .await,
1223            lsp::DidOpenTextDocumentParams {
1224                text_document: lsp::TextDocumentItem::new(
1225                    buffer_1_uri.clone(),
1226                    "plaintext".into(),
1227                    0,
1228                    "Hello world".into()
1229                ),
1230            }
1231        );
1232        assert_eq!(
1233            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1234                .await,
1235            lsp::DidOpenTextDocumentParams {
1236                text_document: lsp::TextDocumentItem::new(
1237                    buffer_2_uri.clone(),
1238                    "plaintext".into(),
1239                    0,
1240                    "Goodbye".into()
1241                ),
1242            }
1243        );
1244        // Dropping a buffer causes it to be closed on the LSP side as well.
1245        cx.update(|_| drop(buffer_2));
1246        assert_eq!(
1247            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1248                .await,
1249            lsp::DidCloseTextDocumentParams {
1250                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1251            }
1252        );
1253    }
1254
1255    struct File {
1256        abs_path: PathBuf,
1257        path: Arc<Path>,
1258    }
1259
1260    impl language::File for File {
1261        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1262            Some(self)
1263        }
1264
1265        fn disk_state(&self) -> language::DiskState {
1266            language::DiskState::Present {
1267                mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1268            }
1269        }
1270
1271        fn path(&self) -> &Arc<Path> {
1272            &self.path
1273        }
1274
1275        fn full_path(&self, _: &App) -> PathBuf {
1276            unimplemented!()
1277        }
1278
1279        fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1280            unimplemented!()
1281        }
1282
1283        fn to_proto(&self, _: &App) -> rpc::proto::File {
1284            unimplemented!()
1285        }
1286
1287        fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1288            settings::WorktreeId::from_usize(0)
1289        }
1290
1291        fn is_private(&self) -> bool {
1292            false
1293        }
1294    }
1295
1296    impl language::LocalFile for File {
1297        fn abs_path(&self, _: &App) -> PathBuf {
1298            self.abs_path.clone()
1299        }
1300
1301        fn load(&self, _: &App) -> Task<Result<String>> {
1302            unimplemented!()
1303        }
1304
1305        fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1306            unimplemented!()
1307        }
1308    }
1309}
1310
1311#[cfg(test)]
1312#[ctor::ctor]
1313fn init_logger() {
1314    if std::env::var("RUST_LOG").is_ok() {
1315        env_logger::init();
1316    }
1317}