copilot.rs

   1pub mod copilot_chat;
   2mod copilot_completion_provider;
   3pub mod request;
   4mod sign_in;
   5
   6use ::fs::Fs;
   7use anyhow::{anyhow, Result};
   8use collections::{HashMap, HashSet};
   9use command_palette_hooks::CommandPaletteFilter;
  10use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
  11use gpui::{
  12    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
  13    WeakEntity,
  14};
  15use http_client::HttpClient;
  16use language::language_settings::CopilotSettings;
  17use language::{
  18    language_settings::{all_language_settings, language_settings, EditPredictionProvider},
  19    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  20    ToPointUtf16,
  21};
  22use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
  23use node_runtime::NodeRuntime;
  24use parking_lot::Mutex;
  25use request::StatusNotification;
  26use settings::SettingsStore;
  27use std::{
  28    any::TypeId,
  29    env,
  30    ffi::OsString,
  31    mem,
  32    ops::Range,
  33    path::{Path, PathBuf},
  34    sync::Arc,
  35};
  36use util::{fs::remove_matching, ResultExt};
  37
  38pub use crate::copilot_completion_provider::CopilotCompletionProvider;
  39pub use crate::sign_in::{initiate_sign_in, CopilotCodeVerification};
  40
  41actions!(
  42    copilot,
  43    [
  44        Suggest,
  45        NextSuggestion,
  46        PreviousSuggestion,
  47        Reinstall,
  48        SignIn,
  49        SignOut
  50    ]
  51);
  52
  53pub fn init(
  54    new_server_id: LanguageServerId,
  55    fs: Arc<dyn Fs>,
  56    http: Arc<dyn HttpClient>,
  57    node_runtime: NodeRuntime,
  58    cx: &mut App,
  59) {
  60    copilot_chat::init(fs, http.clone(), cx);
  61
  62    let copilot = cx.new({
  63        let node_runtime = node_runtime.clone();
  64        move |cx| Copilot::start(new_server_id, node_runtime, cx)
  65    });
  66    Copilot::set_global(copilot.clone(), cx);
  67    cx.observe(&copilot, |handle, cx| {
  68        let copilot_action_types = [
  69            TypeId::of::<Suggest>(),
  70            TypeId::of::<NextSuggestion>(),
  71            TypeId::of::<PreviousSuggestion>(),
  72            TypeId::of::<Reinstall>(),
  73        ];
  74        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  75        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  76        let status = handle.read(cx).status();
  77        let filter = CommandPaletteFilter::global_mut(cx);
  78
  79        match status {
  80            Status::Disabled => {
  81                filter.hide_action_types(&copilot_action_types);
  82                filter.hide_action_types(&copilot_auth_action_types);
  83                filter.hide_action_types(&copilot_no_auth_action_types);
  84            }
  85            Status::Authorized => {
  86                filter.hide_action_types(&copilot_no_auth_action_types);
  87                filter.show_action_types(
  88                    copilot_action_types
  89                        .iter()
  90                        .chain(&copilot_auth_action_types),
  91                );
  92            }
  93            _ => {
  94                filter.hide_action_types(&copilot_action_types);
  95                filter.hide_action_types(&copilot_auth_action_types);
  96                filter.show_action_types(copilot_no_auth_action_types.iter());
  97            }
  98        }
  99    })
 100    .detach();
 101
 102    cx.on_action(|_: &SignIn, cx| {
 103        if let Some(copilot) = Copilot::global(cx) {
 104            copilot
 105                .update(cx, |copilot, cx| copilot.sign_in(cx))
 106                .detach_and_log_err(cx);
 107        }
 108    });
 109    cx.on_action(|_: &SignOut, cx| {
 110        if let Some(copilot) = Copilot::global(cx) {
 111            copilot
 112                .update(cx, |copilot, cx| copilot.sign_out(cx))
 113                .detach_and_log_err(cx);
 114        }
 115    });
 116    cx.on_action(|_: &Reinstall, cx| {
 117        if let Some(copilot) = Copilot::global(cx) {
 118            copilot
 119                .update(cx, |copilot, cx| copilot.reinstall(cx))
 120                .detach();
 121        }
 122    });
 123}
 124
 125enum CopilotServer {
 126    Disabled,
 127    Starting { task: Shared<Task<()>> },
 128    Error(Arc<str>),
 129    Running(RunningCopilotServer),
 130}
 131
 132impl CopilotServer {
 133    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 134        let server = self.as_running()?;
 135        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 136            Ok(server)
 137        } else {
 138            Err(anyhow!("must sign in before using copilot"))
 139        }
 140    }
 141
 142    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 143        match self {
 144            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 145            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 146            CopilotServer::Error(error) => Err(anyhow!(
 147                "copilot was not started because of an error: {}",
 148                error
 149            )),
 150            CopilotServer::Running(server) => Ok(server),
 151        }
 152    }
 153}
 154
 155struct RunningCopilotServer {
 156    lsp: Arc<LanguageServer>,
 157    sign_in_status: SignInStatus,
 158    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 159}
 160
 161#[derive(Clone, Debug)]
 162enum SignInStatus {
 163    Authorized,
 164    Unauthorized,
 165    SigningIn {
 166        prompt: Option<request::PromptUserDeviceFlow>,
 167        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 168    },
 169    SignedOut {
 170        awaiting_signing_in: bool,
 171    },
 172}
 173
 174#[derive(Debug, Clone)]
 175pub enum Status {
 176    Starting {
 177        task: Shared<Task<()>>,
 178    },
 179    Error(Arc<str>),
 180    Disabled,
 181    SignedOut {
 182        awaiting_signing_in: bool,
 183    },
 184    SigningIn {
 185        prompt: Option<request::PromptUserDeviceFlow>,
 186    },
 187    Unauthorized,
 188    Authorized,
 189}
 190
 191impl Status {
 192    pub fn is_authorized(&self) -> bool {
 193        matches!(self, Status::Authorized)
 194    }
 195
 196    pub fn is_disabled(&self) -> bool {
 197        matches!(self, Status::Disabled)
 198    }
 199}
 200
 201struct RegisteredBuffer {
 202    uri: lsp::Url,
 203    language_id: String,
 204    snapshot: BufferSnapshot,
 205    snapshot_version: i32,
 206    _subscriptions: [gpui::Subscription; 2],
 207    pending_buffer_change: Task<Option<()>>,
 208}
 209
 210impl RegisteredBuffer {
 211    fn report_changes(
 212        &mut self,
 213        buffer: &Entity<Buffer>,
 214        cx: &mut Context<Copilot>,
 215    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 216        let (done_tx, done_rx) = oneshot::channel();
 217
 218        if buffer.read(cx).version() == self.snapshot.version {
 219            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 220        } else {
 221            let buffer = buffer.downgrade();
 222            let id = buffer.entity_id();
 223            let prev_pending_change =
 224                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 225            self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
 226                prev_pending_change.await;
 227
 228                let old_version = copilot
 229                    .update(cx, |copilot, _| {
 230                        let server = copilot.server.as_authenticated().log_err()?;
 231                        let buffer = server.registered_buffers.get_mut(&id)?;
 232                        Some(buffer.snapshot.version.clone())
 233                    })
 234                    .ok()??;
 235                let new_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()).ok()?;
 236
 237                let content_changes = cx
 238                    .background_spawn({
 239                        let new_snapshot = new_snapshot.clone();
 240                        async move {
 241                            new_snapshot
 242                                .edits_since::<(PointUtf16, usize)>(&old_version)
 243                                .map(|edit| {
 244                                    let edit_start = edit.new.start.0;
 245                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 246                                    let new_text = new_snapshot
 247                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 248                                        .collect();
 249                                    lsp::TextDocumentContentChangeEvent {
 250                                        range: Some(lsp::Range::new(
 251                                            point_to_lsp(edit_start),
 252                                            point_to_lsp(edit_end),
 253                                        )),
 254                                        range_length: None,
 255                                        text: new_text,
 256                                    }
 257                                })
 258                                .collect::<Vec<_>>()
 259                        }
 260                    })
 261                    .await;
 262
 263                copilot
 264                    .update(cx, |copilot, _| {
 265                        let server = copilot.server.as_authenticated().log_err()?;
 266                        let buffer = server.registered_buffers.get_mut(&id)?;
 267                        if !content_changes.is_empty() {
 268                            buffer.snapshot_version += 1;
 269                            buffer.snapshot = new_snapshot;
 270                            server
 271                                .lsp
 272                                .notify::<lsp::notification::DidChangeTextDocument>(
 273                                    &lsp::DidChangeTextDocumentParams {
 274                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 275                                            buffer.uri.clone(),
 276                                            buffer.snapshot_version,
 277                                        ),
 278                                        content_changes,
 279                                    },
 280                                )
 281                                .log_err();
 282                        }
 283                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 284                        Some(())
 285                    })
 286                    .ok()?;
 287
 288                Some(())
 289            });
 290        }
 291
 292        done_rx
 293    }
 294}
 295
 296#[derive(Debug)]
 297pub struct Completion {
 298    pub uuid: String,
 299    pub range: Range<Anchor>,
 300    pub text: String,
 301}
 302
 303pub struct Copilot {
 304    node_runtime: NodeRuntime,
 305    server: CopilotServer,
 306    buffers: HashSet<WeakEntity<Buffer>>,
 307    server_id: LanguageServerId,
 308    _subscription: gpui::Subscription,
 309}
 310
 311pub enum Event {
 312    CopilotLanguageServerStarted,
 313    CopilotAuthSignedIn,
 314    CopilotAuthSignedOut,
 315}
 316
 317impl EventEmitter<Event> for Copilot {}
 318
 319struct GlobalCopilot(Entity<Copilot>);
 320
 321impl Global for GlobalCopilot {}
 322
 323impl Copilot {
 324    pub fn global(cx: &App) -> Option<Entity<Self>> {
 325        cx.try_global::<GlobalCopilot>()
 326            .map(|model| model.0.clone())
 327    }
 328
 329    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
 330        cx.set_global(GlobalCopilot(copilot));
 331    }
 332
 333    fn start(
 334        new_server_id: LanguageServerId,
 335        node_runtime: NodeRuntime,
 336        cx: &mut Context<Self>,
 337    ) -> Self {
 338        let mut this = Self {
 339            server_id: new_server_id,
 340            node_runtime,
 341            server: CopilotServer::Disabled,
 342            buffers: Default::default(),
 343            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 344        };
 345        this.start_copilot(true, false, cx);
 346        cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
 347            .detach();
 348        this
 349    }
 350
 351    fn shutdown_language_server(&mut self, _cx: &mut Context<Self>) -> impl Future<Output = ()> {
 352        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 353            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 354            _ => None,
 355        };
 356
 357        async move {
 358            if let Some(shutdown) = shutdown {
 359                shutdown.await;
 360            }
 361        }
 362    }
 363
 364    fn start_copilot(
 365        &mut self,
 366        check_edit_prediction_provider: bool,
 367        awaiting_sign_in_after_start: bool,
 368        cx: &mut Context<Self>,
 369    ) {
 370        if !matches!(self.server, CopilotServer::Disabled) {
 371            return;
 372        }
 373        let language_settings = all_language_settings(None, cx);
 374        if check_edit_prediction_provider
 375            && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
 376        {
 377            return;
 378        }
 379        let server_id = self.server_id;
 380        let node_runtime = self.node_runtime.clone();
 381        let env = self.build_env(&language_settings.edit_predictions.copilot);
 382        let start_task = cx
 383            .spawn(async move |this, cx| {
 384                Self::start_language_server(
 385                    server_id,
 386                    node_runtime,
 387                    env,
 388                    this,
 389                    awaiting_sign_in_after_start,
 390                    cx,
 391                )
 392                .await
 393            })
 394            .shared();
 395        self.server = CopilotServer::Starting { task: start_task };
 396        cx.notify();
 397    }
 398
 399    fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
 400        let proxy_url = copilot_settings.proxy.clone()?;
 401        let no_verify = copilot_settings.proxy_no_verify;
 402        let http_or_https_proxy = if proxy_url.starts_with("http:") {
 403            "HTTP_PROXY"
 404        } else if proxy_url.starts_with("https:") {
 405            "HTTPS_PROXY"
 406        } else {
 407            log::error!(
 408                "Unsupported protocol scheme for language server proxy (must be http or https)"
 409            );
 410            return None;
 411        };
 412
 413        let mut env = HashMap::default();
 414        env.insert(http_or_https_proxy.to_string(), proxy_url);
 415
 416        if let Some(true) = no_verify {
 417            env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
 418        };
 419
 420        Some(env)
 421    }
 422
 423    #[cfg(any(test, feature = "test-support"))]
 424    pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
 425        use lsp::FakeLanguageServer;
 426        use node_runtime::NodeRuntime;
 427
 428        let (server, fake_server) = FakeLanguageServer::new(
 429            LanguageServerId(0),
 430            LanguageServerBinary {
 431                path: "path/to/copilot".into(),
 432                arguments: vec![],
 433                env: None,
 434            },
 435            "copilot".into(),
 436            Default::default(),
 437            &mut cx.to_async(),
 438        );
 439        let node_runtime = NodeRuntime::unavailable();
 440        let this = cx.new(|cx| Self {
 441            server_id: LanguageServerId(0),
 442            node_runtime,
 443            server: CopilotServer::Running(RunningCopilotServer {
 444                lsp: Arc::new(server),
 445                sign_in_status: SignInStatus::Authorized,
 446                registered_buffers: Default::default(),
 447            }),
 448            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 449            buffers: Default::default(),
 450        });
 451        (this, fake_server)
 452    }
 453
 454    async fn start_language_server(
 455        new_server_id: LanguageServerId,
 456        node_runtime: NodeRuntime,
 457        env: Option<HashMap<String, String>>,
 458        this: WeakEntity<Self>,
 459        awaiting_sign_in_after_start: bool,
 460        cx: &mut AsyncApp,
 461    ) {
 462        let start_language_server = async {
 463            let server_path = get_copilot_lsp(node_runtime.clone()).await?;
 464            let node_path = node_runtime.binary_path().await?;
 465            let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 466            let binary = LanguageServerBinary {
 467                path: node_path,
 468                arguments,
 469                env,
 470            };
 471
 472            let root_path = if cfg!(target_os = "windows") {
 473                Path::new("C:/")
 474            } else {
 475                Path::new("/")
 476            };
 477
 478            let server_name = LanguageServerName("copilot".into());
 479            let server = LanguageServer::new(
 480                Arc::new(Mutex::new(None)),
 481                new_server_id,
 482                server_name,
 483                binary,
 484                root_path,
 485                None,
 486                Default::default(),
 487                cx,
 488            )?;
 489
 490            server
 491                .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
 492                .detach();
 493
 494            let configuration = lsp::DidChangeConfigurationParams {
 495                settings: Default::default(),
 496            };
 497
 498            let editor_info = request::SetEditorInfoParams {
 499                editor_info: request::EditorInfo {
 500                    name: "zed".into(),
 501                    version: env!("CARGO_PKG_VERSION").into(),
 502                },
 503                editor_plugin_info: request::EditorPluginInfo {
 504                    name: "zed-copilot".into(),
 505                    version: "0.0.1".into(),
 506                },
 507            };
 508            let editor_info_json = serde_json::to_value(&editor_info)?;
 509
 510            let server = cx
 511                .update(|cx| {
 512                    let mut params = server.default_initialize_params(cx);
 513                    params.initialization_options = Some(editor_info_json);
 514                    server.initialize(params, configuration.into(), cx)
 515                })?
 516                .await?;
 517
 518            let status = server
 519                .request::<request::CheckStatus>(request::CheckStatusParams {
 520                    local_checks_only: false,
 521                })
 522                .await?;
 523
 524            server
 525                .request::<request::SetEditorInfo>(editor_info)
 526                .await?;
 527
 528            anyhow::Ok((server, status))
 529        };
 530
 531        let server = start_language_server.await;
 532        this.update(cx, |this, cx| {
 533            cx.notify();
 534            match server {
 535                Ok((server, status)) => {
 536                    this.server = CopilotServer::Running(RunningCopilotServer {
 537                        lsp: server,
 538                        sign_in_status: SignInStatus::SignedOut {
 539                            awaiting_signing_in: awaiting_sign_in_after_start,
 540                        },
 541                        registered_buffers: Default::default(),
 542                    });
 543                    cx.emit(Event::CopilotLanguageServerStarted);
 544                    this.update_sign_in_status(status, cx);
 545                }
 546                Err(error) => {
 547                    this.server = CopilotServer::Error(error.to_string().into());
 548                    cx.notify()
 549                }
 550            }
 551        })
 552        .ok();
 553    }
 554
 555    pub fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 556        if let CopilotServer::Running(server) = &mut self.server {
 557            let task = match &server.sign_in_status {
 558                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 559                SignInStatus::SigningIn { task, .. } => {
 560                    cx.notify();
 561                    task.clone()
 562                }
 563                SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
 564                    let lsp = server.lsp.clone();
 565                    let task = cx
 566                        .spawn(async move |this, cx| {
 567                            let sign_in = async {
 568                                let sign_in = lsp
 569                                    .request::<request::SignInInitiate>(
 570                                        request::SignInInitiateParams {},
 571                                    )
 572                                    .await?;
 573                                match sign_in {
 574                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 575                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 576                                    }
 577                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 578                                        this.update(cx, |this, cx| {
 579                                            if let CopilotServer::Running(RunningCopilotServer {
 580                                                sign_in_status: status,
 581                                                ..
 582                                            }) = &mut this.server
 583                                            {
 584                                                if let SignInStatus::SigningIn {
 585                                                    prompt: prompt_flow,
 586                                                    ..
 587                                                } = status
 588                                                {
 589                                                    *prompt_flow = Some(flow.clone());
 590                                                    cx.notify();
 591                                                }
 592                                            }
 593                                        })?;
 594                                        let response = lsp
 595                                            .request::<request::SignInConfirm>(
 596                                                request::SignInConfirmParams {
 597                                                    user_code: flow.user_code,
 598                                                },
 599                                            )
 600                                            .await?;
 601                                        Ok(response)
 602                                    }
 603                                }
 604                            };
 605
 606                            let sign_in = sign_in.await;
 607                            this.update(cx, |this, cx| match sign_in {
 608                                Ok(status) => {
 609                                    this.update_sign_in_status(status, cx);
 610                                    Ok(())
 611                                }
 612                                Err(error) => {
 613                                    this.update_sign_in_status(
 614                                        request::SignInStatus::NotSignedIn,
 615                                        cx,
 616                                    );
 617                                    Err(Arc::new(error))
 618                                }
 619                            })?
 620                        })
 621                        .shared();
 622                    server.sign_in_status = SignInStatus::SigningIn {
 623                        prompt: None,
 624                        task: task.clone(),
 625                    };
 626                    cx.notify();
 627                    task
 628                }
 629            };
 630
 631            cx.background_spawn(task.map_err(|err| anyhow!("{:?}", err)))
 632        } else {
 633            // If we're downloading, wait until download is finished
 634            // If we're in a stuck state, display to the user
 635            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 636        }
 637    }
 638
 639    pub fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 640        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 641        match &self.server {
 642            CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
 643                let server = server.clone();
 644                cx.background_spawn(async move {
 645                    server
 646                        .request::<request::SignOut>(request::SignOutParams {})
 647                        .await?;
 648                    anyhow::Ok(())
 649                })
 650            }
 651            CopilotServer::Disabled => cx.background_spawn(async {
 652                clear_copilot_config_dir().await;
 653                anyhow::Ok(())
 654            }),
 655            _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
 656        }
 657    }
 658
 659    pub fn reinstall(&mut self, cx: &mut Context<Self>) -> Task<()> {
 660        let language_settings = all_language_settings(None, cx);
 661        let env = self.build_env(&language_settings.edit_predictions.copilot);
 662        let start_task = cx
 663            .spawn({
 664                let node_runtime = self.node_runtime.clone();
 665                let server_id = self.server_id;
 666                async move |this, cx| {
 667                    clear_copilot_dir().await;
 668                    Self::start_language_server(server_id, node_runtime, env, this, false, cx).await
 669                }
 670            })
 671            .shared();
 672
 673        self.server = CopilotServer::Starting {
 674            task: start_task.clone(),
 675        };
 676
 677        cx.notify();
 678
 679        cx.background_spawn(start_task)
 680    }
 681
 682    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 683        if let CopilotServer::Running(server) = &self.server {
 684            Some(&server.lsp)
 685        } else {
 686            None
 687        }
 688    }
 689
 690    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 691        let weak_buffer = buffer.downgrade();
 692        self.buffers.insert(weak_buffer.clone());
 693
 694        if let CopilotServer::Running(RunningCopilotServer {
 695            lsp: server,
 696            sign_in_status: status,
 697            registered_buffers,
 698            ..
 699        }) = &mut self.server
 700        {
 701            if !matches!(status, SignInStatus::Authorized { .. }) {
 702                return;
 703            }
 704
 705            registered_buffers
 706                .entry(buffer.entity_id())
 707                .or_insert_with(|| {
 708                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 709                    let language_id = id_for_language(buffer.read(cx).language());
 710                    let snapshot = buffer.read(cx).snapshot();
 711                    server
 712                        .notify::<lsp::notification::DidOpenTextDocument>(
 713                            &lsp::DidOpenTextDocumentParams {
 714                                text_document: lsp::TextDocumentItem {
 715                                    uri: uri.clone(),
 716                                    language_id: language_id.clone(),
 717                                    version: 0,
 718                                    text: snapshot.text(),
 719                                },
 720                            },
 721                        )
 722                        .log_err();
 723
 724                    RegisteredBuffer {
 725                        uri,
 726                        language_id,
 727                        snapshot,
 728                        snapshot_version: 0,
 729                        pending_buffer_change: Task::ready(Some(())),
 730                        _subscriptions: [
 731                            cx.subscribe(buffer, |this, buffer, event, cx| {
 732                                this.handle_buffer_event(buffer, event, cx).log_err();
 733                            }),
 734                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 735                                this.buffers.remove(&weak_buffer);
 736                                this.unregister_buffer(&weak_buffer);
 737                            }),
 738                        ],
 739                    }
 740                });
 741        }
 742    }
 743
 744    fn handle_buffer_event(
 745        &mut self,
 746        buffer: Entity<Buffer>,
 747        event: &language::BufferEvent,
 748        cx: &mut Context<Self>,
 749    ) -> Result<()> {
 750        if let Ok(server) = self.server.as_running() {
 751            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 752            {
 753                match event {
 754                    language::BufferEvent::Edited => {
 755                        drop(registered_buffer.report_changes(&buffer, cx));
 756                    }
 757                    language::BufferEvent::Saved => {
 758                        server
 759                            .lsp
 760                            .notify::<lsp::notification::DidSaveTextDocument>(
 761                                &lsp::DidSaveTextDocumentParams {
 762                                    text_document: lsp::TextDocumentIdentifier::new(
 763                                        registered_buffer.uri.clone(),
 764                                    ),
 765                                    text: None,
 766                                },
 767                            )?;
 768                    }
 769                    language::BufferEvent::FileHandleChanged
 770                    | language::BufferEvent::LanguageChanged => {
 771                        let new_language_id = id_for_language(buffer.read(cx).language());
 772                        let new_uri = uri_for_buffer(&buffer, cx);
 773                        if new_uri != registered_buffer.uri
 774                            || new_language_id != registered_buffer.language_id
 775                        {
 776                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 777                            registered_buffer.language_id = new_language_id;
 778                            server
 779                                .lsp
 780                                .notify::<lsp::notification::DidCloseTextDocument>(
 781                                    &lsp::DidCloseTextDocumentParams {
 782                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 783                                    },
 784                                )?;
 785                            server
 786                                .lsp
 787                                .notify::<lsp::notification::DidOpenTextDocument>(
 788                                    &lsp::DidOpenTextDocumentParams {
 789                                        text_document: lsp::TextDocumentItem::new(
 790                                            registered_buffer.uri.clone(),
 791                                            registered_buffer.language_id.clone(),
 792                                            registered_buffer.snapshot_version,
 793                                            registered_buffer.snapshot.text(),
 794                                        ),
 795                                    },
 796                                )?;
 797                        }
 798                    }
 799                    _ => {}
 800                }
 801            }
 802        }
 803
 804        Ok(())
 805    }
 806
 807    fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
 808        if let Ok(server) = self.server.as_running() {
 809            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 810                server
 811                    .lsp
 812                    .notify::<lsp::notification::DidCloseTextDocument>(
 813                        &lsp::DidCloseTextDocumentParams {
 814                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 815                        },
 816                    )
 817                    .log_err();
 818            }
 819        }
 820    }
 821
 822    pub fn completions<T>(
 823        &mut self,
 824        buffer: &Entity<Buffer>,
 825        position: T,
 826        cx: &mut Context<Self>,
 827    ) -> Task<Result<Vec<Completion>>>
 828    where
 829        T: ToPointUtf16,
 830    {
 831        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 832    }
 833
 834    pub fn completions_cycling<T>(
 835        &mut self,
 836        buffer: &Entity<Buffer>,
 837        position: T,
 838        cx: &mut Context<Self>,
 839    ) -> Task<Result<Vec<Completion>>>
 840    where
 841        T: ToPointUtf16,
 842    {
 843        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 844    }
 845
 846    pub fn accept_completion(
 847        &mut self,
 848        completion: &Completion,
 849        cx: &mut Context<Self>,
 850    ) -> Task<Result<()>> {
 851        let server = match self.server.as_authenticated() {
 852            Ok(server) => server,
 853            Err(error) => return Task::ready(Err(error)),
 854        };
 855        let request =
 856            server
 857                .lsp
 858                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 859                    uuid: completion.uuid.clone(),
 860                });
 861        cx.background_spawn(async move {
 862            request.await?;
 863            Ok(())
 864        })
 865    }
 866
 867    pub fn discard_completions(
 868        &mut self,
 869        completions: &[Completion],
 870        cx: &mut Context<Self>,
 871    ) -> Task<Result<()>> {
 872        let server = match self.server.as_authenticated() {
 873            Ok(server) => server,
 874            Err(_) => return Task::ready(Ok(())),
 875        };
 876        let request =
 877            server
 878                .lsp
 879                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 880                    uuids: completions
 881                        .iter()
 882                        .map(|completion| completion.uuid.clone())
 883                        .collect(),
 884                });
 885        cx.background_spawn(async move {
 886            request.await?;
 887            Ok(())
 888        })
 889    }
 890
 891    fn request_completions<R, T>(
 892        &mut self,
 893        buffer: &Entity<Buffer>,
 894        position: T,
 895        cx: &mut Context<Self>,
 896    ) -> Task<Result<Vec<Completion>>>
 897    where
 898        R: 'static
 899            + lsp::request::Request<
 900                Params = request::GetCompletionsParams,
 901                Result = request::GetCompletionsResult,
 902            >,
 903        T: ToPointUtf16,
 904    {
 905        self.register_buffer(buffer, cx);
 906
 907        let server = match self.server.as_authenticated() {
 908            Ok(server) => server,
 909            Err(error) => return Task::ready(Err(error)),
 910        };
 911        let lsp = server.lsp.clone();
 912        let registered_buffer = server
 913            .registered_buffers
 914            .get_mut(&buffer.entity_id())
 915            .unwrap();
 916        let snapshot = registered_buffer.report_changes(buffer, cx);
 917        let buffer = buffer.read(cx);
 918        let uri = registered_buffer.uri.clone();
 919        let position = position.to_point_utf16(buffer);
 920        let settings = language_settings(
 921            buffer.language_at(position).map(|l| l.name()),
 922            buffer.file(),
 923            cx,
 924        );
 925        let tab_size = settings.tab_size;
 926        let hard_tabs = settings.hard_tabs;
 927        let relative_path = buffer
 928            .file()
 929            .map(|file| file.path().to_path_buf())
 930            .unwrap_or_default();
 931
 932        cx.background_spawn(async move {
 933            let (version, snapshot) = snapshot.await?;
 934            let result = lsp
 935                .request::<R>(request::GetCompletionsParams {
 936                    doc: request::GetCompletionsDocument {
 937                        uri,
 938                        tab_size: tab_size.into(),
 939                        indent_size: 1,
 940                        insert_spaces: !hard_tabs,
 941                        relative_path: relative_path.to_string_lossy().into(),
 942                        position: point_to_lsp(position),
 943                        version: version.try_into().unwrap(),
 944                    },
 945                })
 946                .await?;
 947            let completions = result
 948                .completions
 949                .into_iter()
 950                .map(|completion| {
 951                    let start = snapshot
 952                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 953                    let end =
 954                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 955                    Completion {
 956                        uuid: completion.uuid,
 957                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 958                        text: completion.text,
 959                    }
 960                })
 961                .collect();
 962            anyhow::Ok(completions)
 963        })
 964    }
 965
 966    pub fn status(&self) -> Status {
 967        match &self.server {
 968            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 969            CopilotServer::Disabled => Status::Disabled,
 970            CopilotServer::Error(error) => Status::Error(error.clone()),
 971            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 972                match sign_in_status {
 973                    SignInStatus::Authorized { .. } => Status::Authorized,
 974                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 975                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 976                        prompt: prompt.clone(),
 977                    },
 978                    SignInStatus::SignedOut {
 979                        awaiting_signing_in,
 980                    } => Status::SignedOut {
 981                        awaiting_signing_in: *awaiting_signing_in,
 982                    },
 983                }
 984            }
 985        }
 986    }
 987
 988    fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
 989        self.buffers.retain(|buffer| buffer.is_upgradable());
 990
 991        if let Ok(server) = self.server.as_running() {
 992            match lsp_status {
 993                request::SignInStatus::Ok { user: Some(_) }
 994                | request::SignInStatus::MaybeOk { .. }
 995                | request::SignInStatus::AlreadySignedIn { .. } => {
 996                    server.sign_in_status = SignInStatus::Authorized;
 997                    cx.emit(Event::CopilotAuthSignedIn);
 998                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 999                        if let Some(buffer) = buffer.upgrade() {
1000                            self.register_buffer(&buffer, cx);
1001                        }
1002                    }
1003                }
1004                request::SignInStatus::NotAuthorized { .. } => {
1005                    server.sign_in_status = SignInStatus::Unauthorized;
1006                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1007                        self.unregister_buffer(&buffer);
1008                    }
1009                }
1010                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1011                    if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1012                        server.sign_in_status = SignInStatus::SignedOut {
1013                            awaiting_signing_in: false,
1014                        };
1015                    }
1016                    cx.emit(Event::CopilotAuthSignedOut);
1017                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1018                        self.unregister_buffer(&buffer);
1019                    }
1020                }
1021            }
1022
1023            cx.notify();
1024        }
1025    }
1026}
1027
1028fn id_for_language(language: Option<&Arc<Language>>) -> String {
1029    language
1030        .map(|language| language.lsp_id())
1031        .unwrap_or_else(|| "plaintext".to_string())
1032}
1033
1034fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> lsp::Url {
1035    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1036        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
1037    } else {
1038        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
1039    }
1040}
1041
1042async fn clear_copilot_dir() {
1043    remove_matching(paths::copilot_dir(), |_| true).await
1044}
1045
1046async fn clear_copilot_config_dir() {
1047    remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1048}
1049
1050async fn get_copilot_lsp(node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1051    const PACKAGE_NAME: &str = "@github/copilot-language-server";
1052    const SERVER_PATH: &str =
1053        "node_modules/@github/copilot-language-server/dist/language-server.js";
1054
1055    let latest_version = node_runtime
1056        .npm_package_latest_version(PACKAGE_NAME)
1057        .await?;
1058    let server_path = paths::copilot_dir().join(SERVER_PATH);
1059
1060    let should_install = node_runtime
1061        .should_install_npm_package(
1062            PACKAGE_NAME,
1063            &server_path,
1064            paths::copilot_dir(),
1065            &latest_version,
1066        )
1067        .await;
1068    if should_install {
1069        node_runtime
1070            .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1071            .await?;
1072    }
1073
1074    Ok(server_path)
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080    use gpui::TestAppContext;
1081    use util::path;
1082
1083    #[gpui::test(iterations = 10)]
1084    async fn test_buffer_management(cx: &mut TestAppContext) {
1085        let (copilot, mut lsp) = Copilot::fake(cx);
1086
1087        let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1088        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1089            .parse()
1090            .unwrap();
1091        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1092        assert_eq!(
1093            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1094                .await,
1095            lsp::DidOpenTextDocumentParams {
1096                text_document: lsp::TextDocumentItem::new(
1097                    buffer_1_uri.clone(),
1098                    "plaintext".into(),
1099                    0,
1100                    "Hello".into()
1101                ),
1102            }
1103        );
1104
1105        let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1106        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1107            .parse()
1108            .unwrap();
1109        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1110        assert_eq!(
1111            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1112                .await,
1113            lsp::DidOpenTextDocumentParams {
1114                text_document: lsp::TextDocumentItem::new(
1115                    buffer_2_uri.clone(),
1116                    "plaintext".into(),
1117                    0,
1118                    "Goodbye".into()
1119                ),
1120            }
1121        );
1122
1123        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1124        assert_eq!(
1125            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1126                .await,
1127            lsp::DidChangeTextDocumentParams {
1128                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1129                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1130                    range: Some(lsp::Range::new(
1131                        lsp::Position::new(0, 5),
1132                        lsp::Position::new(0, 5)
1133                    )),
1134                    range_length: None,
1135                    text: " world".into(),
1136                }],
1137            }
1138        );
1139
1140        // Ensure updates to the file are reflected in the LSP.
1141        buffer_1.update(cx, |buffer, cx| {
1142            buffer.file_updated(
1143                Arc::new(File {
1144                    abs_path: path!("/root/child/buffer-1").into(),
1145                    path: Path::new("child/buffer-1").into(),
1146                }),
1147                cx,
1148            )
1149        });
1150        assert_eq!(
1151            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1152                .await,
1153            lsp::DidCloseTextDocumentParams {
1154                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1155            }
1156        );
1157        let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1158        assert_eq!(
1159            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1160                .await,
1161            lsp::DidOpenTextDocumentParams {
1162                text_document: lsp::TextDocumentItem::new(
1163                    buffer_1_uri.clone(),
1164                    "plaintext".into(),
1165                    1,
1166                    "Hello world".into()
1167                ),
1168            }
1169        );
1170
1171        // Ensure all previously-registered buffers are closed when signing out.
1172        lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1173            Ok(request::SignOutResult {})
1174        });
1175        copilot
1176            .update(cx, |copilot, cx| copilot.sign_out(cx))
1177            .await
1178            .unwrap();
1179        assert_eq!(
1180            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1181                .await,
1182            lsp::DidCloseTextDocumentParams {
1183                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1184            }
1185        );
1186        assert_eq!(
1187            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1188                .await,
1189            lsp::DidCloseTextDocumentParams {
1190                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1191            }
1192        );
1193
1194        // Ensure all previously-registered buffers are re-opened when signing in.
1195        lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1196            Ok(request::SignInInitiateResult::AlreadySignedIn {
1197                user: "user-1".into(),
1198            })
1199        });
1200        copilot
1201            .update(cx, |copilot, cx| copilot.sign_in(cx))
1202            .await
1203            .unwrap();
1204
1205        assert_eq!(
1206            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1207                .await,
1208            lsp::DidOpenTextDocumentParams {
1209                text_document: lsp::TextDocumentItem::new(
1210                    buffer_1_uri.clone(),
1211                    "plaintext".into(),
1212                    0,
1213                    "Hello world".into()
1214                ),
1215            }
1216        );
1217        assert_eq!(
1218            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1219                .await,
1220            lsp::DidOpenTextDocumentParams {
1221                text_document: lsp::TextDocumentItem::new(
1222                    buffer_2_uri.clone(),
1223                    "plaintext".into(),
1224                    0,
1225                    "Goodbye".into()
1226                ),
1227            }
1228        );
1229        // Dropping a buffer causes it to be closed on the LSP side as well.
1230        cx.update(|_| drop(buffer_2));
1231        assert_eq!(
1232            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1233                .await,
1234            lsp::DidCloseTextDocumentParams {
1235                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1236            }
1237        );
1238    }
1239
1240    struct File {
1241        abs_path: PathBuf,
1242        path: Arc<Path>,
1243    }
1244
1245    impl language::File for File {
1246        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1247            Some(self)
1248        }
1249
1250        fn disk_state(&self) -> language::DiskState {
1251            language::DiskState::Present {
1252                mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1253            }
1254        }
1255
1256        fn path(&self) -> &Arc<Path> {
1257            &self.path
1258        }
1259
1260        fn full_path(&self, _: &App) -> PathBuf {
1261            unimplemented!()
1262        }
1263
1264        fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1265            unimplemented!()
1266        }
1267
1268        fn as_any(&self) -> &dyn std::any::Any {
1269            unimplemented!()
1270        }
1271
1272        fn to_proto(&self, _: &App) -> rpc::proto::File {
1273            unimplemented!()
1274        }
1275
1276        fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1277            settings::WorktreeId::from_usize(0)
1278        }
1279
1280        fn is_private(&self) -> bool {
1281            false
1282        }
1283    }
1284
1285    impl language::LocalFile for File {
1286        fn abs_path(&self, _: &App) -> PathBuf {
1287            self.abs_path.clone()
1288        }
1289
1290        fn load(&self, _: &App) -> Task<Result<String>> {
1291            unimplemented!()
1292        }
1293
1294        fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1295            unimplemented!()
1296        }
1297    }
1298}
1299
1300#[cfg(test)]
1301#[ctor::ctor]
1302fn init_logger() {
1303    if std::env::var("RUST_LOG").is_ok() {
1304        env_logger::init();
1305    }
1306}