copilot.rs

   1pub mod copilot_chat;
   2mod copilot_completion_provider;
   3pub mod request;
   4mod sign_in;
   5
   6use crate::sign_in::initiate_sign_in_within_workspace;
   7use ::fs::Fs;
   8use anyhow::{Context as _, Result, anyhow};
   9use collections::{HashMap, HashSet};
  10use command_palette_hooks::CommandPaletteFilter;
  11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
  12use gpui::{
  13    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
  14    WeakEntity, actions,
  15};
  16use http_client::HttpClient;
  17use language::language_settings::CopilotSettings;
  18use language::{
  19    Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
  20    language_settings::{EditPredictionProvider, all_language_settings, language_settings},
  21    point_from_lsp, point_to_lsp,
  22};
  23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
  24use node_runtime::NodeRuntime;
  25use parking_lot::Mutex;
  26use request::StatusNotification;
  27use settings::SettingsStore;
  28use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
  29use std::{
  30    any::TypeId,
  31    env,
  32    ffi::OsString,
  33    mem,
  34    ops::Range,
  35    path::{Path, PathBuf},
  36    sync::Arc,
  37};
  38use util::{ResultExt, fs::remove_matching};
  39use workspace::Workspace;
  40
  41pub use crate::copilot_completion_provider::CopilotCompletionProvider;
  42pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
  43
  44actions!(
  45    copilot,
  46    [
  47        Suggest,
  48        NextSuggestion,
  49        PreviousSuggestion,
  50        Reinstall,
  51        SignIn,
  52        SignOut
  53    ]
  54);
  55
  56pub fn init(
  57    new_server_id: LanguageServerId,
  58    fs: Arc<dyn Fs>,
  59    http: Arc<dyn HttpClient>,
  60    node_runtime: NodeRuntime,
  61    cx: &mut App,
  62) {
  63    copilot_chat::init(fs.clone(), http.clone(), cx);
  64
  65    let copilot = cx.new({
  66        let node_runtime = node_runtime.clone();
  67        move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
  68    });
  69    Copilot::set_global(copilot.clone(), cx);
  70    cx.observe(&copilot, |handle, cx| {
  71        let copilot_action_types = [
  72            TypeId::of::<Suggest>(),
  73            TypeId::of::<NextSuggestion>(),
  74            TypeId::of::<PreviousSuggestion>(),
  75            TypeId::of::<Reinstall>(),
  76        ];
  77        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  78        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  79        let status = handle.read(cx).status();
  80        let filter = CommandPaletteFilter::global_mut(cx);
  81
  82        match status {
  83            Status::Disabled => {
  84                filter.hide_action_types(&copilot_action_types);
  85                filter.hide_action_types(&copilot_auth_action_types);
  86                filter.hide_action_types(&copilot_no_auth_action_types);
  87            }
  88            Status::Authorized => {
  89                filter.hide_action_types(&copilot_no_auth_action_types);
  90                filter.show_action_types(
  91                    copilot_action_types
  92                        .iter()
  93                        .chain(&copilot_auth_action_types),
  94                );
  95            }
  96            _ => {
  97                filter.hide_action_types(&copilot_action_types);
  98                filter.hide_action_types(&copilot_auth_action_types);
  99                filter.show_action_types(copilot_no_auth_action_types.iter());
 100            }
 101        }
 102    })
 103    .detach();
 104
 105    cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
 106        workspace.register_action(|workspace, _: &SignIn, window, cx| {
 107            if let Some(copilot) = Copilot::global(cx) {
 108                let is_reinstall = false;
 109                initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
 110            }
 111        });
 112        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
 113            if let Some(copilot) = Copilot::global(cx) {
 114                reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
 115            }
 116        });
 117        workspace.register_action(|workspace, _: &SignOut, _window, cx| {
 118            if let Some(copilot) = Copilot::global(cx) {
 119                sign_out_within_workspace(workspace, copilot, cx);
 120            }
 121        });
 122    })
 123    .detach();
 124}
 125
 126enum CopilotServer {
 127    Disabled,
 128    Starting { task: Shared<Task<()>> },
 129    Error(Arc<str>),
 130    Running(RunningCopilotServer),
 131}
 132
 133impl CopilotServer {
 134    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 135        let server = self.as_running()?;
 136        anyhow::ensure!(
 137            matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
 138            "must sign in before using copilot"
 139        );
 140        Ok(server)
 141    }
 142
 143    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 144        match self {
 145            CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
 146            CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
 147            CopilotServer::Error(error) => {
 148                anyhow::bail!("copilot was not started because of an error: {error}")
 149            }
 150            CopilotServer::Running(server) => Ok(server),
 151        }
 152    }
 153}
 154
 155struct RunningCopilotServer {
 156    lsp: Arc<LanguageServer>,
 157    sign_in_status: SignInStatus,
 158    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 159}
 160
 161#[derive(Clone, Debug)]
 162enum SignInStatus {
 163    Authorized,
 164    Unauthorized,
 165    SigningIn {
 166        prompt: Option<request::PromptUserDeviceFlow>,
 167        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 168    },
 169    SignedOut {
 170        awaiting_signing_in: bool,
 171    },
 172}
 173
 174#[derive(Debug, Clone)]
 175pub enum Status {
 176    Starting {
 177        task: Shared<Task<()>>,
 178    },
 179    Error(Arc<str>),
 180    Disabled,
 181    SignedOut {
 182        awaiting_signing_in: bool,
 183    },
 184    SigningIn {
 185        prompt: Option<request::PromptUserDeviceFlow>,
 186    },
 187    Unauthorized,
 188    Authorized,
 189}
 190
 191impl Status {
 192    pub fn is_authorized(&self) -> bool {
 193        matches!(self, Status::Authorized)
 194    }
 195
 196    pub fn is_disabled(&self) -> bool {
 197        matches!(self, Status::Disabled)
 198    }
 199}
 200
 201struct RegisteredBuffer {
 202    uri: lsp::Url,
 203    language_id: String,
 204    snapshot: BufferSnapshot,
 205    snapshot_version: i32,
 206    _subscriptions: [gpui::Subscription; 2],
 207    pending_buffer_change: Task<Option<()>>,
 208}
 209
 210impl RegisteredBuffer {
 211    fn report_changes(
 212        &mut self,
 213        buffer: &Entity<Buffer>,
 214        cx: &mut Context<Copilot>,
 215    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 216        let (done_tx, done_rx) = oneshot::channel();
 217
 218        if buffer.read(cx).version() == self.snapshot.version {
 219            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 220        } else {
 221            let buffer = buffer.downgrade();
 222            let id = buffer.entity_id();
 223            let prev_pending_change =
 224                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 225            self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
 226                prev_pending_change.await;
 227
 228                let old_version = copilot
 229                    .update(cx, |copilot, _| {
 230                        let server = copilot.server.as_authenticated().log_err()?;
 231                        let buffer = server.registered_buffers.get_mut(&id)?;
 232                        Some(buffer.snapshot.version.clone())
 233                    })
 234                    .ok()??;
 235                let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
 236
 237                let content_changes = cx
 238                    .background_spawn({
 239                        let new_snapshot = new_snapshot.clone();
 240                        async move {
 241                            new_snapshot
 242                                .edits_since::<(PointUtf16, usize)>(&old_version)
 243                                .map(|edit| {
 244                                    let edit_start = edit.new.start.0;
 245                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 246                                    let new_text = new_snapshot
 247                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 248                                        .collect();
 249                                    lsp::TextDocumentContentChangeEvent {
 250                                        range: Some(lsp::Range::new(
 251                                            point_to_lsp(edit_start),
 252                                            point_to_lsp(edit_end),
 253                                        )),
 254                                        range_length: None,
 255                                        text: new_text,
 256                                    }
 257                                })
 258                                .collect::<Vec<_>>()
 259                        }
 260                    })
 261                    .await;
 262
 263                copilot
 264                    .update(cx, |copilot, _| {
 265                        let server = copilot.server.as_authenticated().log_err()?;
 266                        let buffer = server.registered_buffers.get_mut(&id)?;
 267                        if !content_changes.is_empty() {
 268                            buffer.snapshot_version += 1;
 269                            buffer.snapshot = new_snapshot;
 270                            server
 271                                .lsp
 272                                .notify::<lsp::notification::DidChangeTextDocument>(
 273                                    &lsp::DidChangeTextDocumentParams {
 274                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 275                                            buffer.uri.clone(),
 276                                            buffer.snapshot_version,
 277                                        ),
 278                                        content_changes,
 279                                    },
 280                                )
 281                                .ok();
 282                        }
 283                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 284                        Some(())
 285                    })
 286                    .ok()?;
 287
 288                Some(())
 289            });
 290        }
 291
 292        done_rx
 293    }
 294}
 295
 296#[derive(Debug)]
 297pub struct Completion {
 298    pub uuid: String,
 299    pub range: Range<Anchor>,
 300    pub text: String,
 301}
 302
 303pub struct Copilot {
 304    fs: Arc<dyn Fs>,
 305    node_runtime: NodeRuntime,
 306    server: CopilotServer,
 307    buffers: HashSet<WeakEntity<Buffer>>,
 308    server_id: LanguageServerId,
 309    _subscription: gpui::Subscription,
 310}
 311
 312pub enum Event {
 313    CopilotLanguageServerStarted,
 314    CopilotAuthSignedIn,
 315    CopilotAuthSignedOut,
 316}
 317
 318impl EventEmitter<Event> for Copilot {}
 319
 320struct GlobalCopilot(Entity<Copilot>);
 321
 322impl Global for GlobalCopilot {}
 323
 324impl Copilot {
 325    pub fn global(cx: &App) -> Option<Entity<Self>> {
 326        cx.try_global::<GlobalCopilot>()
 327            .map(|model| model.0.clone())
 328    }
 329
 330    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
 331        cx.set_global(GlobalCopilot(copilot));
 332    }
 333
 334    fn start(
 335        new_server_id: LanguageServerId,
 336        fs: Arc<dyn Fs>,
 337        node_runtime: NodeRuntime,
 338        cx: &mut Context<Self>,
 339    ) -> Self {
 340        let mut this = Self {
 341            server_id: new_server_id,
 342            fs,
 343            node_runtime,
 344            server: CopilotServer::Disabled,
 345            buffers: Default::default(),
 346            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 347        };
 348        this.start_copilot(true, false, cx);
 349        cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
 350            .detach();
 351        this
 352    }
 353
 354    fn shutdown_language_server(
 355        &mut self,
 356        _cx: &mut Context<Self>,
 357    ) -> impl Future<Output = ()> + use<> {
 358        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 359            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 360            _ => None,
 361        };
 362
 363        async move {
 364            if let Some(shutdown) = shutdown {
 365                shutdown.await;
 366            }
 367        }
 368    }
 369
 370    fn start_copilot(
 371        &mut self,
 372        check_edit_prediction_provider: bool,
 373        awaiting_sign_in_after_start: bool,
 374        cx: &mut Context<Self>,
 375    ) {
 376        if !matches!(self.server, CopilotServer::Disabled) {
 377            return;
 378        }
 379        let language_settings = all_language_settings(None, cx);
 380        if check_edit_prediction_provider
 381            && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
 382        {
 383            return;
 384        }
 385        let server_id = self.server_id;
 386        let fs = self.fs.clone();
 387        let node_runtime = self.node_runtime.clone();
 388        let env = self.build_env(&language_settings.edit_predictions.copilot);
 389        let start_task = cx
 390            .spawn(async move |this, cx| {
 391                Self::start_language_server(
 392                    server_id,
 393                    fs,
 394                    node_runtime,
 395                    env,
 396                    this,
 397                    awaiting_sign_in_after_start,
 398                    cx,
 399                )
 400                .await
 401            })
 402            .shared();
 403        self.server = CopilotServer::Starting { task: start_task };
 404        cx.notify();
 405    }
 406
 407    fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
 408        let proxy_url = copilot_settings.proxy.clone()?;
 409        let no_verify = copilot_settings.proxy_no_verify;
 410        let http_or_https_proxy = if proxy_url.starts_with("http:") {
 411            Some("HTTP_PROXY")
 412        } else if proxy_url.starts_with("https:") {
 413            Some("HTTPS_PROXY")
 414        } else {
 415            log::error!(
 416                "Unsupported protocol scheme for language server proxy (must be http or https)"
 417            );
 418            None
 419        };
 420
 421        let mut env = HashMap::default();
 422
 423        if let Some(proxy_type) = http_or_https_proxy {
 424            env.insert(proxy_type.to_string(), proxy_url);
 425            if let Some(true) = no_verify {
 426                env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
 427            };
 428        }
 429
 430        if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
 431            env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
 432        }
 433
 434        if env.is_empty() { None } else { Some(env) }
 435    }
 436
 437    #[cfg(any(test, feature = "test-support"))]
 438    pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
 439        use fs::FakeFs;
 440        use lsp::FakeLanguageServer;
 441        use node_runtime::NodeRuntime;
 442
 443        let (server, fake_server) = FakeLanguageServer::new(
 444            LanguageServerId(0),
 445            LanguageServerBinary {
 446                path: "path/to/copilot".into(),
 447                arguments: vec![],
 448                env: None,
 449            },
 450            "copilot".into(),
 451            Default::default(),
 452            &mut cx.to_async(),
 453        );
 454        let node_runtime = NodeRuntime::unavailable();
 455        let this = cx.new(|cx| Self {
 456            server_id: LanguageServerId(0),
 457            fs: FakeFs::new(cx.background_executor().clone()),
 458            node_runtime,
 459            server: CopilotServer::Running(RunningCopilotServer {
 460                lsp: Arc::new(server),
 461                sign_in_status: SignInStatus::Authorized,
 462                registered_buffers: Default::default(),
 463            }),
 464            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 465            buffers: Default::default(),
 466        });
 467        (this, fake_server)
 468    }
 469
 470    async fn start_language_server(
 471        new_server_id: LanguageServerId,
 472        fs: Arc<dyn Fs>,
 473        node_runtime: NodeRuntime,
 474        env: Option<HashMap<String, String>>,
 475        this: WeakEntity<Self>,
 476        awaiting_sign_in_after_start: bool,
 477        cx: &mut AsyncApp,
 478    ) {
 479        let start_language_server = async {
 480            let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
 481            let node_path = node_runtime.binary_path().await?;
 482            let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 483            let binary = LanguageServerBinary {
 484                path: node_path,
 485                arguments,
 486                env,
 487            };
 488
 489            let root_path = if cfg!(target_os = "windows") {
 490                Path::new("C:/")
 491            } else {
 492                Path::new("/")
 493            };
 494
 495            let server_name = LanguageServerName("copilot".into());
 496            let server = LanguageServer::new(
 497                Arc::new(Mutex::new(None)),
 498                new_server_id,
 499                server_name,
 500                binary,
 501                root_path,
 502                None,
 503                Default::default(),
 504                cx,
 505            )?;
 506
 507            server
 508                .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
 509                .detach();
 510
 511            let configuration = lsp::DidChangeConfigurationParams {
 512                settings: Default::default(),
 513            };
 514
 515            let editor_info = request::SetEditorInfoParams {
 516                editor_info: request::EditorInfo {
 517                    name: "zed".into(),
 518                    version: env!("CARGO_PKG_VERSION").into(),
 519                },
 520                editor_plugin_info: request::EditorPluginInfo {
 521                    name: "zed-copilot".into(),
 522                    version: "0.0.1".into(),
 523                },
 524            };
 525            let editor_info_json = serde_json::to_value(&editor_info)?;
 526
 527            let server = cx
 528                .update(|cx| {
 529                    let mut params = server.default_initialize_params(false, cx);
 530                    params.initialization_options = Some(editor_info_json);
 531                    server.initialize(params, configuration.into(), cx)
 532                })?
 533                .await?;
 534
 535            let status = server
 536                .request::<request::CheckStatus>(request::CheckStatusParams {
 537                    local_checks_only: false,
 538                })
 539                .await
 540                .into_response()
 541                .context("copilot: check status")?;
 542
 543            server
 544                .request::<request::SetEditorInfo>(editor_info)
 545                .await
 546                .into_response()
 547                .context("copilot: set editor info")?;
 548
 549            anyhow::Ok((server, status))
 550        };
 551
 552        let server = start_language_server.await;
 553        this.update(cx, |this, cx| {
 554            cx.notify();
 555            match server {
 556                Ok((server, status)) => {
 557                    this.server = CopilotServer::Running(RunningCopilotServer {
 558                        lsp: server,
 559                        sign_in_status: SignInStatus::SignedOut {
 560                            awaiting_signing_in: awaiting_sign_in_after_start,
 561                        },
 562                        registered_buffers: Default::default(),
 563                    });
 564                    cx.emit(Event::CopilotLanguageServerStarted);
 565                    this.update_sign_in_status(status, cx);
 566                }
 567                Err(error) => {
 568                    this.server = CopilotServer::Error(error.to_string().into());
 569                    cx.notify()
 570                }
 571            }
 572        })
 573        .ok();
 574    }
 575
 576    pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 577        if let CopilotServer::Running(server) = &mut self.server {
 578            let task = match &server.sign_in_status {
 579                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 580                SignInStatus::SigningIn { task, .. } => {
 581                    cx.notify();
 582                    task.clone()
 583                }
 584                SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
 585                    let lsp = server.lsp.clone();
 586                    let task = cx
 587                        .spawn(async move |this, cx| {
 588                            let sign_in = async {
 589                                let sign_in = lsp
 590                                    .request::<request::SignInInitiate>(
 591                                        request::SignInInitiateParams {},
 592                                    )
 593                                    .await
 594                                    .into_response()
 595                                    .context("copilot sign-in")?;
 596                                match sign_in {
 597                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 598                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 599                                    }
 600                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 601                                        this.update(cx, |this, cx| {
 602                                            if let CopilotServer::Running(RunningCopilotServer {
 603                                                sign_in_status: status,
 604                                                ..
 605                                            }) = &mut this.server
 606                                            {
 607                                                if let SignInStatus::SigningIn {
 608                                                    prompt: prompt_flow,
 609                                                    ..
 610                                                } = status
 611                                                {
 612                                                    *prompt_flow = Some(flow.clone());
 613                                                    cx.notify();
 614                                                }
 615                                            }
 616                                        })?;
 617                                        let response = lsp
 618                                            .request::<request::SignInConfirm>(
 619                                                request::SignInConfirmParams {
 620                                                    user_code: flow.user_code,
 621                                                },
 622                                            )
 623                                            .await
 624                                            .into_response()
 625                                            .context("copilot: sign in confirm")?;
 626                                        Ok(response)
 627                                    }
 628                                }
 629                            };
 630
 631                            let sign_in = sign_in.await;
 632                            this.update(cx, |this, cx| match sign_in {
 633                                Ok(status) => {
 634                                    this.update_sign_in_status(status, cx);
 635                                    Ok(())
 636                                }
 637                                Err(error) => {
 638                                    this.update_sign_in_status(
 639                                        request::SignInStatus::NotSignedIn,
 640                                        cx,
 641                                    );
 642                                    Err(Arc::new(error))
 643                                }
 644                            })?
 645                        })
 646                        .shared();
 647                    server.sign_in_status = SignInStatus::SigningIn {
 648                        prompt: None,
 649                        task: task.clone(),
 650                    };
 651                    cx.notify();
 652                    task
 653                }
 654            };
 655
 656            cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
 657        } else {
 658            // If we're downloading, wait until download is finished
 659            // If we're in a stuck state, display to the user
 660            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 661        }
 662    }
 663
 664    pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 665        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 666        match &self.server {
 667            CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
 668                let server = server.clone();
 669                cx.background_spawn(async move {
 670                    server
 671                        .request::<request::SignOut>(request::SignOutParams {})
 672                        .await
 673                        .into_response()
 674                        .context("copilot: sign in confirm")?;
 675                    anyhow::Ok(())
 676                })
 677            }
 678            CopilotServer::Disabled => cx.background_spawn(async {
 679                clear_copilot_config_dir().await;
 680                anyhow::Ok(())
 681            }),
 682            _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
 683        }
 684    }
 685
 686    pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
 687        let language_settings = all_language_settings(None, cx);
 688        let env = self.build_env(&language_settings.edit_predictions.copilot);
 689        let start_task = cx
 690            .spawn({
 691                let fs = self.fs.clone();
 692                let node_runtime = self.node_runtime.clone();
 693                let server_id = self.server_id;
 694                async move |this, cx| {
 695                    clear_copilot_dir().await;
 696                    Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
 697                        .await
 698                }
 699            })
 700            .shared();
 701
 702        self.server = CopilotServer::Starting {
 703            task: start_task.clone(),
 704        };
 705
 706        cx.notify();
 707
 708        start_task
 709    }
 710
 711    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 712        if let CopilotServer::Running(server) = &self.server {
 713            Some(&server.lsp)
 714        } else {
 715            None
 716        }
 717    }
 718
 719    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 720        let weak_buffer = buffer.downgrade();
 721        self.buffers.insert(weak_buffer.clone());
 722
 723        if let CopilotServer::Running(RunningCopilotServer {
 724            lsp: server,
 725            sign_in_status: status,
 726            registered_buffers,
 727            ..
 728        }) = &mut self.server
 729        {
 730            if !matches!(status, SignInStatus::Authorized { .. }) {
 731                return;
 732            }
 733
 734            registered_buffers
 735                .entry(buffer.entity_id())
 736                .or_insert_with(|| {
 737                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 738                    let language_id = id_for_language(buffer.read(cx).language());
 739                    let snapshot = buffer.read(cx).snapshot();
 740                    server
 741                        .notify::<lsp::notification::DidOpenTextDocument>(
 742                            &lsp::DidOpenTextDocumentParams {
 743                                text_document: lsp::TextDocumentItem {
 744                                    uri: uri.clone(),
 745                                    language_id: language_id.clone(),
 746                                    version: 0,
 747                                    text: snapshot.text(),
 748                                },
 749                            },
 750                        )
 751                        .ok();
 752
 753                    RegisteredBuffer {
 754                        uri,
 755                        language_id,
 756                        snapshot,
 757                        snapshot_version: 0,
 758                        pending_buffer_change: Task::ready(Some(())),
 759                        _subscriptions: [
 760                            cx.subscribe(buffer, |this, buffer, event, cx| {
 761                                this.handle_buffer_event(buffer, event, cx).log_err();
 762                            }),
 763                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 764                                this.buffers.remove(&weak_buffer);
 765                                this.unregister_buffer(&weak_buffer);
 766                            }),
 767                        ],
 768                    }
 769                });
 770        }
 771    }
 772
 773    fn handle_buffer_event(
 774        &mut self,
 775        buffer: Entity<Buffer>,
 776        event: &language::BufferEvent,
 777        cx: &mut Context<Self>,
 778    ) -> Result<()> {
 779        if let Ok(server) = self.server.as_running() {
 780            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 781            {
 782                match event {
 783                    language::BufferEvent::Edited => {
 784                        drop(registered_buffer.report_changes(&buffer, cx));
 785                    }
 786                    language::BufferEvent::Saved => {
 787                        server
 788                            .lsp
 789                            .notify::<lsp::notification::DidSaveTextDocument>(
 790                                &lsp::DidSaveTextDocumentParams {
 791                                    text_document: lsp::TextDocumentIdentifier::new(
 792                                        registered_buffer.uri.clone(),
 793                                    ),
 794                                    text: None,
 795                                },
 796                            )?;
 797                    }
 798                    language::BufferEvent::FileHandleChanged
 799                    | language::BufferEvent::LanguageChanged => {
 800                        let new_language_id = id_for_language(buffer.read(cx).language());
 801                        let new_uri = uri_for_buffer(&buffer, cx);
 802                        if new_uri != registered_buffer.uri
 803                            || new_language_id != registered_buffer.language_id
 804                        {
 805                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 806                            registered_buffer.language_id = new_language_id;
 807                            server
 808                                .lsp
 809                                .notify::<lsp::notification::DidCloseTextDocument>(
 810                                    &lsp::DidCloseTextDocumentParams {
 811                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 812                                    },
 813                                )?;
 814                            server
 815                                .lsp
 816                                .notify::<lsp::notification::DidOpenTextDocument>(
 817                                    &lsp::DidOpenTextDocumentParams {
 818                                        text_document: lsp::TextDocumentItem::new(
 819                                            registered_buffer.uri.clone(),
 820                                            registered_buffer.language_id.clone(),
 821                                            registered_buffer.snapshot_version,
 822                                            registered_buffer.snapshot.text(),
 823                                        ),
 824                                    },
 825                                )?;
 826                        }
 827                    }
 828                    _ => {}
 829                }
 830            }
 831        }
 832
 833        Ok(())
 834    }
 835
 836    fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
 837        if let Ok(server) = self.server.as_running() {
 838            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 839                server
 840                    .lsp
 841                    .notify::<lsp::notification::DidCloseTextDocument>(
 842                        &lsp::DidCloseTextDocumentParams {
 843                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 844                        },
 845                    )
 846                    .ok();
 847            }
 848        }
 849    }
 850
 851    pub fn completions<T>(
 852        &mut self,
 853        buffer: &Entity<Buffer>,
 854        position: T,
 855        cx: &mut Context<Self>,
 856    ) -> Task<Result<Vec<Completion>>>
 857    where
 858        T: ToPointUtf16,
 859    {
 860        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 861    }
 862
 863    pub fn completions_cycling<T>(
 864        &mut self,
 865        buffer: &Entity<Buffer>,
 866        position: T,
 867        cx: &mut Context<Self>,
 868    ) -> Task<Result<Vec<Completion>>>
 869    where
 870        T: ToPointUtf16,
 871    {
 872        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 873    }
 874
 875    pub fn accept_completion(
 876        &mut self,
 877        completion: &Completion,
 878        cx: &mut Context<Self>,
 879    ) -> Task<Result<()>> {
 880        let server = match self.server.as_authenticated() {
 881            Ok(server) => server,
 882            Err(error) => return Task::ready(Err(error)),
 883        };
 884        let request =
 885            server
 886                .lsp
 887                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 888                    uuid: completion.uuid.clone(),
 889                });
 890        cx.background_spawn(async move {
 891            request
 892                .await
 893                .into_response()
 894                .context("copilot: notify accepted")?;
 895            Ok(())
 896        })
 897    }
 898
 899    pub fn discard_completions(
 900        &mut self,
 901        completions: &[Completion],
 902        cx: &mut Context<Self>,
 903    ) -> Task<Result<()>> {
 904        let server = match self.server.as_authenticated() {
 905            Ok(server) => server,
 906            Err(_) => return Task::ready(Ok(())),
 907        };
 908        let request =
 909            server
 910                .lsp
 911                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 912                    uuids: completions
 913                        .iter()
 914                        .map(|completion| completion.uuid.clone())
 915                        .collect(),
 916                });
 917        cx.background_spawn(async move {
 918            request
 919                .await
 920                .into_response()
 921                .context("copilot: notify rejected")?;
 922            Ok(())
 923        })
 924    }
 925
 926    fn request_completions<R, T>(
 927        &mut self,
 928        buffer: &Entity<Buffer>,
 929        position: T,
 930        cx: &mut Context<Self>,
 931    ) -> Task<Result<Vec<Completion>>>
 932    where
 933        R: 'static
 934            + lsp::request::Request<
 935                Params = request::GetCompletionsParams,
 936                Result = request::GetCompletionsResult,
 937            >,
 938        T: ToPointUtf16,
 939    {
 940        self.register_buffer(buffer, cx);
 941
 942        let server = match self.server.as_authenticated() {
 943            Ok(server) => server,
 944            Err(error) => return Task::ready(Err(error)),
 945        };
 946        let lsp = server.lsp.clone();
 947        let registered_buffer = server
 948            .registered_buffers
 949            .get_mut(&buffer.entity_id())
 950            .unwrap();
 951        let snapshot = registered_buffer.report_changes(buffer, cx);
 952        let buffer = buffer.read(cx);
 953        let uri = registered_buffer.uri.clone();
 954        let position = position.to_point_utf16(buffer);
 955        let settings = language_settings(
 956            buffer.language_at(position).map(|l| l.name()),
 957            buffer.file(),
 958            cx,
 959        );
 960        let tab_size = settings.tab_size;
 961        let hard_tabs = settings.hard_tabs;
 962        let relative_path = buffer
 963            .file()
 964            .map(|file| file.path().to_path_buf())
 965            .unwrap_or_default();
 966
 967        cx.background_spawn(async move {
 968            let (version, snapshot) = snapshot.await?;
 969            let result = lsp
 970                .request::<R>(request::GetCompletionsParams {
 971                    doc: request::GetCompletionsDocument {
 972                        uri,
 973                        tab_size: tab_size.into(),
 974                        indent_size: 1,
 975                        insert_spaces: !hard_tabs,
 976                        relative_path: relative_path.to_string_lossy().into(),
 977                        position: point_to_lsp(position),
 978                        version: version.try_into().unwrap(),
 979                    },
 980                })
 981                .await
 982                .into_response()
 983                .context("copilot: get completions")?;
 984            let completions = result
 985                .completions
 986                .into_iter()
 987                .map(|completion| {
 988                    let start = snapshot
 989                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 990                    let end =
 991                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 992                    Completion {
 993                        uuid: completion.uuid,
 994                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 995                        text: completion.text,
 996                    }
 997                })
 998                .collect();
 999            anyhow::Ok(completions)
1000        })
1001    }
1002
1003    pub fn status(&self) -> Status {
1004        match &self.server {
1005            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1006            CopilotServer::Disabled => Status::Disabled,
1007            CopilotServer::Error(error) => Status::Error(error.clone()),
1008            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1009                match sign_in_status {
1010                    SignInStatus::Authorized { .. } => Status::Authorized,
1011                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1012                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1013                        prompt: prompt.clone(),
1014                    },
1015                    SignInStatus::SignedOut {
1016                        awaiting_signing_in,
1017                    } => Status::SignedOut {
1018                        awaiting_signing_in: *awaiting_signing_in,
1019                    },
1020                }
1021            }
1022        }
1023    }
1024
1025    fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1026        self.buffers.retain(|buffer| buffer.is_upgradable());
1027
1028        if let Ok(server) = self.server.as_running() {
1029            match lsp_status {
1030                request::SignInStatus::Ok { user: Some(_) }
1031                | request::SignInStatus::MaybeOk { .. }
1032                | request::SignInStatus::AlreadySignedIn { .. } => {
1033                    server.sign_in_status = SignInStatus::Authorized;
1034                    cx.emit(Event::CopilotAuthSignedIn);
1035                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1036                        if let Some(buffer) = buffer.upgrade() {
1037                            self.register_buffer(&buffer, cx);
1038                        }
1039                    }
1040                }
1041                request::SignInStatus::NotAuthorized { .. } => {
1042                    server.sign_in_status = SignInStatus::Unauthorized;
1043                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1044                        self.unregister_buffer(&buffer);
1045                    }
1046                }
1047                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1048                    if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1049                        server.sign_in_status = SignInStatus::SignedOut {
1050                            awaiting_signing_in: false,
1051                        };
1052                    }
1053                    cx.emit(Event::CopilotAuthSignedOut);
1054                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1055                        self.unregister_buffer(&buffer);
1056                    }
1057                }
1058            }
1059
1060            cx.notify();
1061        }
1062    }
1063}
1064
1065fn id_for_language(language: Option<&Arc<Language>>) -> String {
1066    language
1067        .map(|language| language.lsp_id())
1068        .unwrap_or_else(|| "plaintext".to_string())
1069}
1070
1071fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> lsp::Url {
1072    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1073        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
1074    } else {
1075        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
1076    }
1077}
1078
1079async fn clear_copilot_dir() {
1080    remove_matching(paths::copilot_dir(), |_| true).await
1081}
1082
1083async fn clear_copilot_config_dir() {
1084    remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1085}
1086
1087async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1088    const PACKAGE_NAME: &str = "@github/copilot-language-server";
1089    const SERVER_PATH: &str =
1090        "node_modules/@github/copilot-language-server/dist/language-server.js";
1091
1092    let latest_version = node_runtime
1093        .npm_package_latest_version(PACKAGE_NAME)
1094        .await?;
1095    let server_path = paths::copilot_dir().join(SERVER_PATH);
1096
1097    fs.create_dir(paths::copilot_dir()).await?;
1098
1099    let should_install = node_runtime
1100        .should_install_npm_package(
1101            PACKAGE_NAME,
1102            &server_path,
1103            paths::copilot_dir(),
1104            &latest_version,
1105        )
1106        .await;
1107    if should_install {
1108        node_runtime
1109            .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1110            .await?;
1111    }
1112
1113    Ok(server_path)
1114}
1115
1116#[cfg(test)]
1117mod tests {
1118    use super::*;
1119    use gpui::TestAppContext;
1120    use util::path;
1121
1122    #[gpui::test(iterations = 10)]
1123    async fn test_buffer_management(cx: &mut TestAppContext) {
1124        let (copilot, mut lsp) = Copilot::fake(cx);
1125
1126        let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1127        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1128            .parse()
1129            .unwrap();
1130        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1131        assert_eq!(
1132            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1133                .await,
1134            lsp::DidOpenTextDocumentParams {
1135                text_document: lsp::TextDocumentItem::new(
1136                    buffer_1_uri.clone(),
1137                    "plaintext".into(),
1138                    0,
1139                    "Hello".into()
1140                ),
1141            }
1142        );
1143
1144        let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1145        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1146            .parse()
1147            .unwrap();
1148        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1149        assert_eq!(
1150            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1151                .await,
1152            lsp::DidOpenTextDocumentParams {
1153                text_document: lsp::TextDocumentItem::new(
1154                    buffer_2_uri.clone(),
1155                    "plaintext".into(),
1156                    0,
1157                    "Goodbye".into()
1158                ),
1159            }
1160        );
1161
1162        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1163        assert_eq!(
1164            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1165                .await,
1166            lsp::DidChangeTextDocumentParams {
1167                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1168                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1169                    range: Some(lsp::Range::new(
1170                        lsp::Position::new(0, 5),
1171                        lsp::Position::new(0, 5)
1172                    )),
1173                    range_length: None,
1174                    text: " world".into(),
1175                }],
1176            }
1177        );
1178
1179        // Ensure updates to the file are reflected in the LSP.
1180        buffer_1.update(cx, |buffer, cx| {
1181            buffer.file_updated(
1182                Arc::new(File {
1183                    abs_path: path!("/root/child/buffer-1").into(),
1184                    path: Path::new("child/buffer-1").into(),
1185                }),
1186                cx,
1187            )
1188        });
1189        assert_eq!(
1190            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1191                .await,
1192            lsp::DidCloseTextDocumentParams {
1193                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1194            }
1195        );
1196        let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1197        assert_eq!(
1198            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1199                .await,
1200            lsp::DidOpenTextDocumentParams {
1201                text_document: lsp::TextDocumentItem::new(
1202                    buffer_1_uri.clone(),
1203                    "plaintext".into(),
1204                    1,
1205                    "Hello world".into()
1206                ),
1207            }
1208        );
1209
1210        // Ensure all previously-registered buffers are closed when signing out.
1211        lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1212            Ok(request::SignOutResult {})
1213        });
1214        copilot
1215            .update(cx, |copilot, cx| copilot.sign_out(cx))
1216            .await
1217            .unwrap();
1218        assert_eq!(
1219            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1220                .await,
1221            lsp::DidCloseTextDocumentParams {
1222                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1223            }
1224        );
1225        assert_eq!(
1226            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1227                .await,
1228            lsp::DidCloseTextDocumentParams {
1229                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1230            }
1231        );
1232
1233        // Ensure all previously-registered buffers are re-opened when signing in.
1234        lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1235            Ok(request::SignInInitiateResult::AlreadySignedIn {
1236                user: "user-1".into(),
1237            })
1238        });
1239        copilot
1240            .update(cx, |copilot, cx| copilot.sign_in(cx))
1241            .await
1242            .unwrap();
1243
1244        assert_eq!(
1245            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1246                .await,
1247            lsp::DidOpenTextDocumentParams {
1248                text_document: lsp::TextDocumentItem::new(
1249                    buffer_1_uri.clone(),
1250                    "plaintext".into(),
1251                    0,
1252                    "Hello world".into()
1253                ),
1254            }
1255        );
1256        assert_eq!(
1257            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1258                .await,
1259            lsp::DidOpenTextDocumentParams {
1260                text_document: lsp::TextDocumentItem::new(
1261                    buffer_2_uri.clone(),
1262                    "plaintext".into(),
1263                    0,
1264                    "Goodbye".into()
1265                ),
1266            }
1267        );
1268        // Dropping a buffer causes it to be closed on the LSP side as well.
1269        cx.update(|_| drop(buffer_2));
1270        assert_eq!(
1271            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1272                .await,
1273            lsp::DidCloseTextDocumentParams {
1274                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1275            }
1276        );
1277    }
1278
1279    struct File {
1280        abs_path: PathBuf,
1281        path: Arc<Path>,
1282    }
1283
1284    impl language::File for File {
1285        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1286            Some(self)
1287        }
1288
1289        fn disk_state(&self) -> language::DiskState {
1290            language::DiskState::Present {
1291                mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1292            }
1293        }
1294
1295        fn path(&self) -> &Arc<Path> {
1296            &self.path
1297        }
1298
1299        fn full_path(&self, _: &App) -> PathBuf {
1300            unimplemented!()
1301        }
1302
1303        fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1304            unimplemented!()
1305        }
1306
1307        fn to_proto(&self, _: &App) -> rpc::proto::File {
1308            unimplemented!()
1309        }
1310
1311        fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1312            settings::WorktreeId::from_usize(0)
1313        }
1314
1315        fn is_private(&self) -> bool {
1316            false
1317        }
1318    }
1319
1320    impl language::LocalFile for File {
1321        fn abs_path(&self, _: &App) -> PathBuf {
1322            self.abs_path.clone()
1323        }
1324
1325        fn load(&self, _: &App) -> Task<Result<String>> {
1326            unimplemented!()
1327        }
1328
1329        fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1330            unimplemented!()
1331        }
1332    }
1333}
1334
1335#[cfg(test)]
1336#[ctor::ctor]
1337fn init_logger() {
1338    zlog::init_test();
1339}