copilot.rs

   1pub mod copilot_chat;
   2mod copilot_completion_provider;
   3pub mod request;
   4mod sign_in;
   5
   6use crate::sign_in::initiate_sign_in_within_workspace;
   7use ::fs::Fs;
   8use anyhow::{Context as _, Result, anyhow};
   9use collections::{HashMap, HashSet};
  10use command_palette_hooks::CommandPaletteFilter;
  11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
  12use gpui::{
  13    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
  14    WeakEntity, actions,
  15};
  16use http_client::HttpClient;
  17use language::language_settings::CopilotSettings;
  18use language::{
  19    Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
  20    language_settings::{EditPredictionProvider, all_language_settings, language_settings},
  21    point_from_lsp, point_to_lsp,
  22};
  23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
  24use node_runtime::NodeRuntime;
  25use parking_lot::Mutex;
  26use request::StatusNotification;
  27use settings::SettingsStore;
  28use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
  29use std::{
  30    any::TypeId,
  31    env,
  32    ffi::OsString,
  33    mem,
  34    ops::Range,
  35    path::{Path, PathBuf},
  36    sync::Arc,
  37};
  38use util::{ResultExt, fs::remove_matching};
  39use workspace::Workspace;
  40
  41pub use crate::copilot_completion_provider::CopilotCompletionProvider;
  42pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
  43
  44actions!(
  45    copilot,
  46    [
  47        Suggest,
  48        NextSuggestion,
  49        PreviousSuggestion,
  50        Reinstall,
  51        SignIn,
  52        SignOut
  53    ]
  54);
  55
  56pub fn init(
  57    new_server_id: LanguageServerId,
  58    fs: Arc<dyn Fs>,
  59    http: Arc<dyn HttpClient>,
  60    node_runtime: NodeRuntime,
  61    cx: &mut App,
  62) {
  63    copilot_chat::init(fs.clone(), http.clone(), cx);
  64
  65    let copilot = cx.new({
  66        let node_runtime = node_runtime.clone();
  67        move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
  68    });
  69    Copilot::set_global(copilot.clone(), cx);
  70    cx.observe(&copilot, |handle, cx| {
  71        let copilot_action_types = [
  72            TypeId::of::<Suggest>(),
  73            TypeId::of::<NextSuggestion>(),
  74            TypeId::of::<PreviousSuggestion>(),
  75            TypeId::of::<Reinstall>(),
  76        ];
  77        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  78        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  79        let status = handle.read(cx).status();
  80        let filter = CommandPaletteFilter::global_mut(cx);
  81
  82        match status {
  83            Status::Disabled => {
  84                filter.hide_action_types(&copilot_action_types);
  85                filter.hide_action_types(&copilot_auth_action_types);
  86                filter.hide_action_types(&copilot_no_auth_action_types);
  87            }
  88            Status::Authorized => {
  89                filter.hide_action_types(&copilot_no_auth_action_types);
  90                filter.show_action_types(
  91                    copilot_action_types
  92                        .iter()
  93                        .chain(&copilot_auth_action_types),
  94                );
  95            }
  96            _ => {
  97                filter.hide_action_types(&copilot_action_types);
  98                filter.hide_action_types(&copilot_auth_action_types);
  99                filter.show_action_types(copilot_no_auth_action_types.iter());
 100            }
 101        }
 102    })
 103    .detach();
 104
 105    cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
 106        workspace.register_action(|workspace, _: &SignIn, window, cx| {
 107            if let Some(copilot) = Copilot::global(cx) {
 108                let is_reinstall = false;
 109                initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
 110            }
 111        });
 112        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
 113            if let Some(copilot) = Copilot::global(cx) {
 114                reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
 115            }
 116        });
 117        workspace.register_action(|workspace, _: &SignOut, _window, cx| {
 118            if let Some(copilot) = Copilot::global(cx) {
 119                sign_out_within_workspace(workspace, copilot, cx);
 120            }
 121        });
 122    })
 123    .detach();
 124}
 125
 126enum CopilotServer {
 127    Disabled,
 128    Starting { task: Shared<Task<()>> },
 129    Error(Arc<str>),
 130    Running(RunningCopilotServer),
 131}
 132
 133impl CopilotServer {
 134    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 135        let server = self.as_running()?;
 136        anyhow::ensure!(
 137            matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
 138            "must sign in before using copilot"
 139        );
 140        Ok(server)
 141    }
 142
 143    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 144        match self {
 145            CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
 146            CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
 147            CopilotServer::Error(error) => {
 148                anyhow::bail!("copilot was not started because of an error: {error}")
 149            }
 150            CopilotServer::Running(server) => Ok(server),
 151        }
 152    }
 153}
 154
 155struct RunningCopilotServer {
 156    lsp: Arc<LanguageServer>,
 157    sign_in_status: SignInStatus,
 158    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 159}
 160
 161#[derive(Clone, Debug)]
 162enum SignInStatus {
 163    Authorized,
 164    Unauthorized,
 165    SigningIn {
 166        prompt: Option<request::PromptUserDeviceFlow>,
 167        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 168    },
 169    SignedOut {
 170        awaiting_signing_in: bool,
 171    },
 172}
 173
 174#[derive(Debug, Clone)]
 175pub enum Status {
 176    Starting {
 177        task: Shared<Task<()>>,
 178    },
 179    Error(Arc<str>),
 180    Disabled,
 181    SignedOut {
 182        awaiting_signing_in: bool,
 183    },
 184    SigningIn {
 185        prompt: Option<request::PromptUserDeviceFlow>,
 186    },
 187    Unauthorized,
 188    Authorized,
 189}
 190
 191impl Status {
 192    pub fn is_authorized(&self) -> bool {
 193        matches!(self, Status::Authorized)
 194    }
 195
 196    pub fn is_disabled(&self) -> bool {
 197        matches!(self, Status::Disabled)
 198    }
 199}
 200
 201struct RegisteredBuffer {
 202    uri: lsp::Url,
 203    language_id: String,
 204    snapshot: BufferSnapshot,
 205    snapshot_version: i32,
 206    _subscriptions: [gpui::Subscription; 2],
 207    pending_buffer_change: Task<Option<()>>,
 208}
 209
 210impl RegisteredBuffer {
 211    fn report_changes(
 212        &mut self,
 213        buffer: &Entity<Buffer>,
 214        cx: &mut Context<Copilot>,
 215    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 216        let (done_tx, done_rx) = oneshot::channel();
 217
 218        if buffer.read(cx).version() == self.snapshot.version {
 219            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 220        } else {
 221            let buffer = buffer.downgrade();
 222            let id = buffer.entity_id();
 223            let prev_pending_change =
 224                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 225            self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
 226                prev_pending_change.await;
 227
 228                let old_version = copilot
 229                    .update(cx, |copilot, _| {
 230                        let server = copilot.server.as_authenticated().log_err()?;
 231                        let buffer = server.registered_buffers.get_mut(&id)?;
 232                        Some(buffer.snapshot.version.clone())
 233                    })
 234                    .ok()??;
 235                let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
 236
 237                let content_changes = cx
 238                    .background_spawn({
 239                        let new_snapshot = new_snapshot.clone();
 240                        async move {
 241                            new_snapshot
 242                                .edits_since::<(PointUtf16, usize)>(&old_version)
 243                                .map(|edit| {
 244                                    let edit_start = edit.new.start.0;
 245                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 246                                    let new_text = new_snapshot
 247                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 248                                        .collect();
 249                                    lsp::TextDocumentContentChangeEvent {
 250                                        range: Some(lsp::Range::new(
 251                                            point_to_lsp(edit_start),
 252                                            point_to_lsp(edit_end),
 253                                        )),
 254                                        range_length: None,
 255                                        text: new_text,
 256                                    }
 257                                })
 258                                .collect::<Vec<_>>()
 259                        }
 260                    })
 261                    .await;
 262
 263                copilot
 264                    .update(cx, |copilot, _| {
 265                        let server = copilot.server.as_authenticated().log_err()?;
 266                        let buffer = server.registered_buffers.get_mut(&id)?;
 267                        if !content_changes.is_empty() {
 268                            buffer.snapshot_version += 1;
 269                            buffer.snapshot = new_snapshot;
 270                            server
 271                                .lsp
 272                                .notify::<lsp::notification::DidChangeTextDocument>(
 273                                    &lsp::DidChangeTextDocumentParams {
 274                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 275                                            buffer.uri.clone(),
 276                                            buffer.snapshot_version,
 277                                        ),
 278                                        content_changes,
 279                                    },
 280                                )
 281                                .ok();
 282                        }
 283                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 284                        Some(())
 285                    })
 286                    .ok()?;
 287
 288                Some(())
 289            });
 290        }
 291
 292        done_rx
 293    }
 294}
 295
 296#[derive(Debug)]
 297pub struct Completion {
 298    pub uuid: String,
 299    pub range: Range<Anchor>,
 300    pub text: String,
 301}
 302
 303pub struct Copilot {
 304    fs: Arc<dyn Fs>,
 305    node_runtime: NodeRuntime,
 306    server: CopilotServer,
 307    buffers: HashSet<WeakEntity<Buffer>>,
 308    server_id: LanguageServerId,
 309    _subscription: gpui::Subscription,
 310}
 311
 312pub enum Event {
 313    CopilotLanguageServerStarted,
 314    CopilotAuthSignedIn,
 315    CopilotAuthSignedOut,
 316}
 317
 318impl EventEmitter<Event> for Copilot {}
 319
 320struct GlobalCopilot(Entity<Copilot>);
 321
 322impl Global for GlobalCopilot {}
 323
 324impl Copilot {
 325    pub fn global(cx: &App) -> Option<Entity<Self>> {
 326        cx.try_global::<GlobalCopilot>()
 327            .map(|model| model.0.clone())
 328    }
 329
 330    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
 331        cx.set_global(GlobalCopilot(copilot));
 332    }
 333
 334    fn start(
 335        new_server_id: LanguageServerId,
 336        fs: Arc<dyn Fs>,
 337        node_runtime: NodeRuntime,
 338        cx: &mut Context<Self>,
 339    ) -> Self {
 340        let mut this = Self {
 341            server_id: new_server_id,
 342            fs,
 343            node_runtime,
 344            server: CopilotServer::Disabled,
 345            buffers: Default::default(),
 346            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 347        };
 348        this.start_copilot(true, false, cx);
 349        cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
 350            .detach();
 351        this
 352    }
 353
 354    fn shutdown_language_server(
 355        &mut self,
 356        _cx: &mut Context<Self>,
 357    ) -> impl Future<Output = ()> + use<> {
 358        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 359            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 360            _ => None,
 361        };
 362
 363        async move {
 364            if let Some(shutdown) = shutdown {
 365                shutdown.await;
 366            }
 367        }
 368    }
 369
 370    fn start_copilot(
 371        &mut self,
 372        check_edit_prediction_provider: bool,
 373        awaiting_sign_in_after_start: bool,
 374        cx: &mut Context<Self>,
 375    ) {
 376        if !matches!(self.server, CopilotServer::Disabled) {
 377            return;
 378        }
 379        let language_settings = all_language_settings(None, cx);
 380        if check_edit_prediction_provider
 381            && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
 382        {
 383            return;
 384        }
 385        let server_id = self.server_id;
 386        let fs = self.fs.clone();
 387        let node_runtime = self.node_runtime.clone();
 388        let env = self.build_env(&language_settings.edit_predictions.copilot);
 389        let start_task = cx
 390            .spawn(async move |this, cx| {
 391                Self::start_language_server(
 392                    server_id,
 393                    fs,
 394                    node_runtime,
 395                    env,
 396                    this,
 397                    awaiting_sign_in_after_start,
 398                    cx,
 399                )
 400                .await
 401            })
 402            .shared();
 403        self.server = CopilotServer::Starting { task: start_task };
 404        cx.notify();
 405    }
 406
 407    fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
 408        let proxy_url = copilot_settings.proxy.clone()?;
 409        let no_verify = copilot_settings.proxy_no_verify;
 410        let http_or_https_proxy = if proxy_url.starts_with("http:") {
 411            "HTTP_PROXY"
 412        } else if proxy_url.starts_with("https:") {
 413            "HTTPS_PROXY"
 414        } else {
 415            log::error!(
 416                "Unsupported protocol scheme for language server proxy (must be http or https)"
 417            );
 418            return None;
 419        };
 420
 421        let mut env = HashMap::default();
 422        env.insert(http_or_https_proxy.to_string(), proxy_url);
 423
 424        if let Some(true) = no_verify {
 425            env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
 426        };
 427
 428        Some(env)
 429    }
 430
 431    #[cfg(any(test, feature = "test-support"))]
 432    pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
 433        use fs::FakeFs;
 434        use lsp::FakeLanguageServer;
 435        use node_runtime::NodeRuntime;
 436
 437        let (server, fake_server) = FakeLanguageServer::new(
 438            LanguageServerId(0),
 439            LanguageServerBinary {
 440                path: "path/to/copilot".into(),
 441                arguments: vec![],
 442                env: None,
 443            },
 444            "copilot".into(),
 445            Default::default(),
 446            &mut cx.to_async(),
 447        );
 448        let node_runtime = NodeRuntime::unavailable();
 449        let this = cx.new(|cx| Self {
 450            server_id: LanguageServerId(0),
 451            fs: FakeFs::new(cx.background_executor().clone()),
 452            node_runtime,
 453            server: CopilotServer::Running(RunningCopilotServer {
 454                lsp: Arc::new(server),
 455                sign_in_status: SignInStatus::Authorized,
 456                registered_buffers: Default::default(),
 457            }),
 458            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 459            buffers: Default::default(),
 460        });
 461        (this, fake_server)
 462    }
 463
 464    async fn start_language_server(
 465        new_server_id: LanguageServerId,
 466        fs: Arc<dyn Fs>,
 467        node_runtime: NodeRuntime,
 468        env: Option<HashMap<String, String>>,
 469        this: WeakEntity<Self>,
 470        awaiting_sign_in_after_start: bool,
 471        cx: &mut AsyncApp,
 472    ) {
 473        let start_language_server = async {
 474            let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
 475            let node_path = node_runtime.binary_path().await?;
 476            let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 477            let binary = LanguageServerBinary {
 478                path: node_path,
 479                arguments,
 480                env,
 481            };
 482
 483            let root_path = if cfg!(target_os = "windows") {
 484                Path::new("C:/")
 485            } else {
 486                Path::new("/")
 487            };
 488
 489            let server_name = LanguageServerName("copilot".into());
 490            let server = LanguageServer::new(
 491                Arc::new(Mutex::new(None)),
 492                new_server_id,
 493                server_name,
 494                binary,
 495                root_path,
 496                None,
 497                Default::default(),
 498                cx,
 499            )?;
 500
 501            server
 502                .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
 503                .detach();
 504
 505            let configuration = lsp::DidChangeConfigurationParams {
 506                settings: Default::default(),
 507            };
 508
 509            let editor_info = request::SetEditorInfoParams {
 510                editor_info: request::EditorInfo {
 511                    name: "zed".into(),
 512                    version: env!("CARGO_PKG_VERSION").into(),
 513                },
 514                editor_plugin_info: request::EditorPluginInfo {
 515                    name: "zed-copilot".into(),
 516                    version: "0.0.1".into(),
 517                },
 518            };
 519            let editor_info_json = serde_json::to_value(&editor_info)?;
 520
 521            let server = cx
 522                .update(|cx| {
 523                    let mut params = server.default_initialize_params(cx);
 524                    params.initialization_options = Some(editor_info_json);
 525                    server.initialize(params, configuration.into(), cx)
 526                })?
 527                .await?;
 528
 529            let status = server
 530                .request::<request::CheckStatus>(request::CheckStatusParams {
 531                    local_checks_only: false,
 532                })
 533                .await
 534                .into_response()
 535                .context("copilot: check status")?;
 536
 537            server
 538                .request::<request::SetEditorInfo>(editor_info)
 539                .await
 540                .into_response()
 541                .context("copilot: set editor info")?;
 542
 543            anyhow::Ok((server, status))
 544        };
 545
 546        let server = start_language_server.await;
 547        this.update(cx, |this, cx| {
 548            cx.notify();
 549            match server {
 550                Ok((server, status)) => {
 551                    this.server = CopilotServer::Running(RunningCopilotServer {
 552                        lsp: server,
 553                        sign_in_status: SignInStatus::SignedOut {
 554                            awaiting_signing_in: awaiting_sign_in_after_start,
 555                        },
 556                        registered_buffers: Default::default(),
 557                    });
 558                    cx.emit(Event::CopilotLanguageServerStarted);
 559                    this.update_sign_in_status(status, cx);
 560                }
 561                Err(error) => {
 562                    this.server = CopilotServer::Error(error.to_string().into());
 563                    cx.notify()
 564                }
 565            }
 566        })
 567        .ok();
 568    }
 569
 570    pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 571        if let CopilotServer::Running(server) = &mut self.server {
 572            let task = match &server.sign_in_status {
 573                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 574                SignInStatus::SigningIn { task, .. } => {
 575                    cx.notify();
 576                    task.clone()
 577                }
 578                SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
 579                    let lsp = server.lsp.clone();
 580                    let task = cx
 581                        .spawn(async move |this, cx| {
 582                            let sign_in = async {
 583                                let sign_in = lsp
 584                                    .request::<request::SignInInitiate>(
 585                                        request::SignInInitiateParams {},
 586                                    )
 587                                    .await
 588                                    .into_response()
 589                                    .context("copilot sign-in")?;
 590                                match sign_in {
 591                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 592                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 593                                    }
 594                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 595                                        this.update(cx, |this, cx| {
 596                                            if let CopilotServer::Running(RunningCopilotServer {
 597                                                sign_in_status: status,
 598                                                ..
 599                                            }) = &mut this.server
 600                                            {
 601                                                if let SignInStatus::SigningIn {
 602                                                    prompt: prompt_flow,
 603                                                    ..
 604                                                } = status
 605                                                {
 606                                                    *prompt_flow = Some(flow.clone());
 607                                                    cx.notify();
 608                                                }
 609                                            }
 610                                        })?;
 611                                        let response = lsp
 612                                            .request::<request::SignInConfirm>(
 613                                                request::SignInConfirmParams {
 614                                                    user_code: flow.user_code,
 615                                                },
 616                                            )
 617                                            .await
 618                                            .into_response()
 619                                            .context("copilot: sign in confirm")?;
 620                                        Ok(response)
 621                                    }
 622                                }
 623                            };
 624
 625                            let sign_in = sign_in.await;
 626                            this.update(cx, |this, cx| match sign_in {
 627                                Ok(status) => {
 628                                    this.update_sign_in_status(status, cx);
 629                                    Ok(())
 630                                }
 631                                Err(error) => {
 632                                    this.update_sign_in_status(
 633                                        request::SignInStatus::NotSignedIn,
 634                                        cx,
 635                                    );
 636                                    Err(Arc::new(error))
 637                                }
 638                            })?
 639                        })
 640                        .shared();
 641                    server.sign_in_status = SignInStatus::SigningIn {
 642                        prompt: None,
 643                        task: task.clone(),
 644                    };
 645                    cx.notify();
 646                    task
 647                }
 648            };
 649
 650            cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
 651        } else {
 652            // If we're downloading, wait until download is finished
 653            // If we're in a stuck state, display to the user
 654            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 655        }
 656    }
 657
 658    pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 659        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 660        match &self.server {
 661            CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
 662                let server = server.clone();
 663                cx.background_spawn(async move {
 664                    server
 665                        .request::<request::SignOut>(request::SignOutParams {})
 666                        .await
 667                        .into_response()
 668                        .context("copilot: sign in confirm")?;
 669                    anyhow::Ok(())
 670                })
 671            }
 672            CopilotServer::Disabled => cx.background_spawn(async {
 673                clear_copilot_config_dir().await;
 674                anyhow::Ok(())
 675            }),
 676            _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
 677        }
 678    }
 679
 680    pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
 681        let language_settings = all_language_settings(None, cx);
 682        let env = self.build_env(&language_settings.edit_predictions.copilot);
 683        let start_task = cx
 684            .spawn({
 685                let fs = self.fs.clone();
 686                let node_runtime = self.node_runtime.clone();
 687                let server_id = self.server_id;
 688                async move |this, cx| {
 689                    clear_copilot_dir().await;
 690                    Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
 691                        .await
 692                }
 693            })
 694            .shared();
 695
 696        self.server = CopilotServer::Starting {
 697            task: start_task.clone(),
 698        };
 699
 700        cx.notify();
 701
 702        start_task
 703    }
 704
 705    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 706        if let CopilotServer::Running(server) = &self.server {
 707            Some(&server.lsp)
 708        } else {
 709            None
 710        }
 711    }
 712
 713    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 714        let weak_buffer = buffer.downgrade();
 715        self.buffers.insert(weak_buffer.clone());
 716
 717        if let CopilotServer::Running(RunningCopilotServer {
 718            lsp: server,
 719            sign_in_status: status,
 720            registered_buffers,
 721            ..
 722        }) = &mut self.server
 723        {
 724            if !matches!(status, SignInStatus::Authorized { .. }) {
 725                return;
 726            }
 727
 728            registered_buffers
 729                .entry(buffer.entity_id())
 730                .or_insert_with(|| {
 731                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 732                    let language_id = id_for_language(buffer.read(cx).language());
 733                    let snapshot = buffer.read(cx).snapshot();
 734                    server
 735                        .notify::<lsp::notification::DidOpenTextDocument>(
 736                            &lsp::DidOpenTextDocumentParams {
 737                                text_document: lsp::TextDocumentItem {
 738                                    uri: uri.clone(),
 739                                    language_id: language_id.clone(),
 740                                    version: 0,
 741                                    text: snapshot.text(),
 742                                },
 743                            },
 744                        )
 745                        .ok();
 746
 747                    RegisteredBuffer {
 748                        uri,
 749                        language_id,
 750                        snapshot,
 751                        snapshot_version: 0,
 752                        pending_buffer_change: Task::ready(Some(())),
 753                        _subscriptions: [
 754                            cx.subscribe(buffer, |this, buffer, event, cx| {
 755                                this.handle_buffer_event(buffer, event, cx).log_err();
 756                            }),
 757                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 758                                this.buffers.remove(&weak_buffer);
 759                                this.unregister_buffer(&weak_buffer);
 760                            }),
 761                        ],
 762                    }
 763                });
 764        }
 765    }
 766
 767    fn handle_buffer_event(
 768        &mut self,
 769        buffer: Entity<Buffer>,
 770        event: &language::BufferEvent,
 771        cx: &mut Context<Self>,
 772    ) -> Result<()> {
 773        if let Ok(server) = self.server.as_running() {
 774            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 775            {
 776                match event {
 777                    language::BufferEvent::Edited => {
 778                        drop(registered_buffer.report_changes(&buffer, cx));
 779                    }
 780                    language::BufferEvent::Saved => {
 781                        server
 782                            .lsp
 783                            .notify::<lsp::notification::DidSaveTextDocument>(
 784                                &lsp::DidSaveTextDocumentParams {
 785                                    text_document: lsp::TextDocumentIdentifier::new(
 786                                        registered_buffer.uri.clone(),
 787                                    ),
 788                                    text: None,
 789                                },
 790                            )?;
 791                    }
 792                    language::BufferEvent::FileHandleChanged
 793                    | language::BufferEvent::LanguageChanged => {
 794                        let new_language_id = id_for_language(buffer.read(cx).language());
 795                        let new_uri = uri_for_buffer(&buffer, cx);
 796                        if new_uri != registered_buffer.uri
 797                            || new_language_id != registered_buffer.language_id
 798                        {
 799                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 800                            registered_buffer.language_id = new_language_id;
 801                            server
 802                                .lsp
 803                                .notify::<lsp::notification::DidCloseTextDocument>(
 804                                    &lsp::DidCloseTextDocumentParams {
 805                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 806                                    },
 807                                )?;
 808                            server
 809                                .lsp
 810                                .notify::<lsp::notification::DidOpenTextDocument>(
 811                                    &lsp::DidOpenTextDocumentParams {
 812                                        text_document: lsp::TextDocumentItem::new(
 813                                            registered_buffer.uri.clone(),
 814                                            registered_buffer.language_id.clone(),
 815                                            registered_buffer.snapshot_version,
 816                                            registered_buffer.snapshot.text(),
 817                                        ),
 818                                    },
 819                                )?;
 820                        }
 821                    }
 822                    _ => {}
 823                }
 824            }
 825        }
 826
 827        Ok(())
 828    }
 829
 830    fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
 831        if let Ok(server) = self.server.as_running() {
 832            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 833                server
 834                    .lsp
 835                    .notify::<lsp::notification::DidCloseTextDocument>(
 836                        &lsp::DidCloseTextDocumentParams {
 837                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 838                        },
 839                    )
 840                    .ok();
 841            }
 842        }
 843    }
 844
 845    pub fn completions<T>(
 846        &mut self,
 847        buffer: &Entity<Buffer>,
 848        position: T,
 849        cx: &mut Context<Self>,
 850    ) -> Task<Result<Vec<Completion>>>
 851    where
 852        T: ToPointUtf16,
 853    {
 854        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 855    }
 856
 857    pub fn completions_cycling<T>(
 858        &mut self,
 859        buffer: &Entity<Buffer>,
 860        position: T,
 861        cx: &mut Context<Self>,
 862    ) -> Task<Result<Vec<Completion>>>
 863    where
 864        T: ToPointUtf16,
 865    {
 866        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 867    }
 868
 869    pub fn accept_completion(
 870        &mut self,
 871        completion: &Completion,
 872        cx: &mut Context<Self>,
 873    ) -> Task<Result<()>> {
 874        let server = match self.server.as_authenticated() {
 875            Ok(server) => server,
 876            Err(error) => return Task::ready(Err(error)),
 877        };
 878        let request =
 879            server
 880                .lsp
 881                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 882                    uuid: completion.uuid.clone(),
 883                });
 884        cx.background_spawn(async move {
 885            request
 886                .await
 887                .into_response()
 888                .context("copilot: notify accepted")?;
 889            Ok(())
 890        })
 891    }
 892
 893    pub fn discard_completions(
 894        &mut self,
 895        completions: &[Completion],
 896        cx: &mut Context<Self>,
 897    ) -> Task<Result<()>> {
 898        let server = match self.server.as_authenticated() {
 899            Ok(server) => server,
 900            Err(_) => return Task::ready(Ok(())),
 901        };
 902        let request =
 903            server
 904                .lsp
 905                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 906                    uuids: completions
 907                        .iter()
 908                        .map(|completion| completion.uuid.clone())
 909                        .collect(),
 910                });
 911        cx.background_spawn(async move {
 912            request
 913                .await
 914                .into_response()
 915                .context("copilot: notify rejected")?;
 916            Ok(())
 917        })
 918    }
 919
 920    fn request_completions<R, T>(
 921        &mut self,
 922        buffer: &Entity<Buffer>,
 923        position: T,
 924        cx: &mut Context<Self>,
 925    ) -> Task<Result<Vec<Completion>>>
 926    where
 927        R: 'static
 928            + lsp::request::Request<
 929                Params = request::GetCompletionsParams,
 930                Result = request::GetCompletionsResult,
 931            >,
 932        T: ToPointUtf16,
 933    {
 934        self.register_buffer(buffer, cx);
 935
 936        let server = match self.server.as_authenticated() {
 937            Ok(server) => server,
 938            Err(error) => return Task::ready(Err(error)),
 939        };
 940        let lsp = server.lsp.clone();
 941        let registered_buffer = server
 942            .registered_buffers
 943            .get_mut(&buffer.entity_id())
 944            .unwrap();
 945        let snapshot = registered_buffer.report_changes(buffer, cx);
 946        let buffer = buffer.read(cx);
 947        let uri = registered_buffer.uri.clone();
 948        let position = position.to_point_utf16(buffer);
 949        let settings = language_settings(
 950            buffer.language_at(position).map(|l| l.name()),
 951            buffer.file(),
 952            cx,
 953        );
 954        let tab_size = settings.tab_size;
 955        let hard_tabs = settings.hard_tabs;
 956        let relative_path = buffer
 957            .file()
 958            .map(|file| file.path().to_path_buf())
 959            .unwrap_or_default();
 960
 961        cx.background_spawn(async move {
 962            let (version, snapshot) = snapshot.await?;
 963            let result = lsp
 964                .request::<R>(request::GetCompletionsParams {
 965                    doc: request::GetCompletionsDocument {
 966                        uri,
 967                        tab_size: tab_size.into(),
 968                        indent_size: 1,
 969                        insert_spaces: !hard_tabs,
 970                        relative_path: relative_path.to_string_lossy().into(),
 971                        position: point_to_lsp(position),
 972                        version: version.try_into().unwrap(),
 973                    },
 974                })
 975                .await
 976                .into_response()
 977                .context("copilot: get completions")?;
 978            let completions = result
 979                .completions
 980                .into_iter()
 981                .map(|completion| {
 982                    let start = snapshot
 983                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 984                    let end =
 985                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 986                    Completion {
 987                        uuid: completion.uuid,
 988                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 989                        text: completion.text,
 990                    }
 991                })
 992                .collect();
 993            anyhow::Ok(completions)
 994        })
 995    }
 996
 997    pub fn status(&self) -> Status {
 998        match &self.server {
 999            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1000            CopilotServer::Disabled => Status::Disabled,
1001            CopilotServer::Error(error) => Status::Error(error.clone()),
1002            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1003                match sign_in_status {
1004                    SignInStatus::Authorized { .. } => Status::Authorized,
1005                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1006                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1007                        prompt: prompt.clone(),
1008                    },
1009                    SignInStatus::SignedOut {
1010                        awaiting_signing_in,
1011                    } => Status::SignedOut {
1012                        awaiting_signing_in: *awaiting_signing_in,
1013                    },
1014                }
1015            }
1016        }
1017    }
1018
1019    fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1020        self.buffers.retain(|buffer| buffer.is_upgradable());
1021
1022        if let Ok(server) = self.server.as_running() {
1023            match lsp_status {
1024                request::SignInStatus::Ok { user: Some(_) }
1025                | request::SignInStatus::MaybeOk { .. }
1026                | request::SignInStatus::AlreadySignedIn { .. } => {
1027                    server.sign_in_status = SignInStatus::Authorized;
1028                    cx.emit(Event::CopilotAuthSignedIn);
1029                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1030                        if let Some(buffer) = buffer.upgrade() {
1031                            self.register_buffer(&buffer, cx);
1032                        }
1033                    }
1034                }
1035                request::SignInStatus::NotAuthorized { .. } => {
1036                    server.sign_in_status = SignInStatus::Unauthorized;
1037                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1038                        self.unregister_buffer(&buffer);
1039                    }
1040                }
1041                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1042                    if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1043                        server.sign_in_status = SignInStatus::SignedOut {
1044                            awaiting_signing_in: false,
1045                        };
1046                    }
1047                    cx.emit(Event::CopilotAuthSignedOut);
1048                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1049                        self.unregister_buffer(&buffer);
1050                    }
1051                }
1052            }
1053
1054            cx.notify();
1055        }
1056    }
1057}
1058
1059fn id_for_language(language: Option<&Arc<Language>>) -> String {
1060    language
1061        .map(|language| language.lsp_id())
1062        .unwrap_or_else(|| "plaintext".to_string())
1063}
1064
1065fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> lsp::Url {
1066    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1067        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
1068    } else {
1069        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
1070    }
1071}
1072
1073async fn clear_copilot_dir() {
1074    remove_matching(paths::copilot_dir(), |_| true).await
1075}
1076
1077async fn clear_copilot_config_dir() {
1078    remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1079}
1080
1081async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1082    const PACKAGE_NAME: &str = "@github/copilot-language-server";
1083    const SERVER_PATH: &str =
1084        "node_modules/@github/copilot-language-server/dist/language-server.js";
1085
1086    let latest_version = node_runtime
1087        .npm_package_latest_version(PACKAGE_NAME)
1088        .await?;
1089    let server_path = paths::copilot_dir().join(SERVER_PATH);
1090
1091    fs.create_dir(paths::copilot_dir()).await?;
1092
1093    let should_install = node_runtime
1094        .should_install_npm_package(
1095            PACKAGE_NAME,
1096            &server_path,
1097            paths::copilot_dir(),
1098            &latest_version,
1099        )
1100        .await;
1101    if should_install {
1102        node_runtime
1103            .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1104            .await?;
1105    }
1106
1107    Ok(server_path)
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113    use gpui::TestAppContext;
1114    use util::path;
1115
1116    #[gpui::test(iterations = 10)]
1117    async fn test_buffer_management(cx: &mut TestAppContext) {
1118        let (copilot, mut lsp) = Copilot::fake(cx);
1119
1120        let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1121        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1122            .parse()
1123            .unwrap();
1124        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1125        assert_eq!(
1126            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1127                .await,
1128            lsp::DidOpenTextDocumentParams {
1129                text_document: lsp::TextDocumentItem::new(
1130                    buffer_1_uri.clone(),
1131                    "plaintext".into(),
1132                    0,
1133                    "Hello".into()
1134                ),
1135            }
1136        );
1137
1138        let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1139        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1140            .parse()
1141            .unwrap();
1142        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1143        assert_eq!(
1144            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1145                .await,
1146            lsp::DidOpenTextDocumentParams {
1147                text_document: lsp::TextDocumentItem::new(
1148                    buffer_2_uri.clone(),
1149                    "plaintext".into(),
1150                    0,
1151                    "Goodbye".into()
1152                ),
1153            }
1154        );
1155
1156        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1157        assert_eq!(
1158            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1159                .await,
1160            lsp::DidChangeTextDocumentParams {
1161                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1162                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1163                    range: Some(lsp::Range::new(
1164                        lsp::Position::new(0, 5),
1165                        lsp::Position::new(0, 5)
1166                    )),
1167                    range_length: None,
1168                    text: " world".into(),
1169                }],
1170            }
1171        );
1172
1173        // Ensure updates to the file are reflected in the LSP.
1174        buffer_1.update(cx, |buffer, cx| {
1175            buffer.file_updated(
1176                Arc::new(File {
1177                    abs_path: path!("/root/child/buffer-1").into(),
1178                    path: Path::new("child/buffer-1").into(),
1179                }),
1180                cx,
1181            )
1182        });
1183        assert_eq!(
1184            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1185                .await,
1186            lsp::DidCloseTextDocumentParams {
1187                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1188            }
1189        );
1190        let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1191        assert_eq!(
1192            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1193                .await,
1194            lsp::DidOpenTextDocumentParams {
1195                text_document: lsp::TextDocumentItem::new(
1196                    buffer_1_uri.clone(),
1197                    "plaintext".into(),
1198                    1,
1199                    "Hello world".into()
1200                ),
1201            }
1202        );
1203
1204        // Ensure all previously-registered buffers are closed when signing out.
1205        lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1206            Ok(request::SignOutResult {})
1207        });
1208        copilot
1209            .update(cx, |copilot, cx| copilot.sign_out(cx))
1210            .await
1211            .unwrap();
1212        assert_eq!(
1213            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1214                .await,
1215            lsp::DidCloseTextDocumentParams {
1216                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1217            }
1218        );
1219        assert_eq!(
1220            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1221                .await,
1222            lsp::DidCloseTextDocumentParams {
1223                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1224            }
1225        );
1226
1227        // Ensure all previously-registered buffers are re-opened when signing in.
1228        lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1229            Ok(request::SignInInitiateResult::AlreadySignedIn {
1230                user: "user-1".into(),
1231            })
1232        });
1233        copilot
1234            .update(cx, |copilot, cx| copilot.sign_in(cx))
1235            .await
1236            .unwrap();
1237
1238        assert_eq!(
1239            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1240                .await,
1241            lsp::DidOpenTextDocumentParams {
1242                text_document: lsp::TextDocumentItem::new(
1243                    buffer_1_uri.clone(),
1244                    "plaintext".into(),
1245                    0,
1246                    "Hello world".into()
1247                ),
1248            }
1249        );
1250        assert_eq!(
1251            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1252                .await,
1253            lsp::DidOpenTextDocumentParams {
1254                text_document: lsp::TextDocumentItem::new(
1255                    buffer_2_uri.clone(),
1256                    "plaintext".into(),
1257                    0,
1258                    "Goodbye".into()
1259                ),
1260            }
1261        );
1262        // Dropping a buffer causes it to be closed on the LSP side as well.
1263        cx.update(|_| drop(buffer_2));
1264        assert_eq!(
1265            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1266                .await,
1267            lsp::DidCloseTextDocumentParams {
1268                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1269            }
1270        );
1271    }
1272
1273    struct File {
1274        abs_path: PathBuf,
1275        path: Arc<Path>,
1276    }
1277
1278    impl language::File for File {
1279        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1280            Some(self)
1281        }
1282
1283        fn disk_state(&self) -> language::DiskState {
1284            language::DiskState::Present {
1285                mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1286            }
1287        }
1288
1289        fn path(&self) -> &Arc<Path> {
1290            &self.path
1291        }
1292
1293        fn full_path(&self, _: &App) -> PathBuf {
1294            unimplemented!()
1295        }
1296
1297        fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1298            unimplemented!()
1299        }
1300
1301        fn to_proto(&self, _: &App) -> rpc::proto::File {
1302            unimplemented!()
1303        }
1304
1305        fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1306            settings::WorktreeId::from_usize(0)
1307        }
1308
1309        fn is_private(&self) -> bool {
1310            false
1311        }
1312    }
1313
1314    impl language::LocalFile for File {
1315        fn abs_path(&self, _: &App) -> PathBuf {
1316            self.abs_path.clone()
1317        }
1318
1319        fn load(&self, _: &App) -> Task<Result<String>> {
1320            unimplemented!()
1321        }
1322
1323        fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1324            unimplemented!()
1325        }
1326    }
1327}
1328
1329#[cfg(test)]
1330#[ctor::ctor]
1331fn init_logger() {
1332    zlog::init_test();
1333}