copilot.rs

   1pub mod copilot_chat;
   2mod copilot_completion_provider;
   3pub mod request;
   4mod sign_in;
   5
   6use crate::sign_in::initiate_sign_in_within_workspace;
   7use ::fs::Fs;
   8use anyhow::{Context as _, Result, anyhow};
   9use collections::{HashMap, HashSet};
  10use command_palette_hooks::CommandPaletteFilter;
  11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
  12use gpui::{
  13    App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
  14    WeakEntity, actions,
  15};
  16use http_client::HttpClient;
  17use language::language_settings::CopilotSettings;
  18use language::{
  19    Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
  20    language_settings::{EditPredictionProvider, all_language_settings, language_settings},
  21    point_from_lsp, point_to_lsp,
  22};
  23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
  24use node_runtime::NodeRuntime;
  25use parking_lot::Mutex;
  26use request::StatusNotification;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
  30use std::collections::hash_map::Entry;
  31use std::{
  32    any::TypeId,
  33    env,
  34    ffi::OsString,
  35    mem,
  36    ops::Range,
  37    path::{Path, PathBuf},
  38    sync::Arc,
  39};
  40use util::{ResultExt, fs::remove_matching};
  41use workspace::Workspace;
  42
  43pub use crate::copilot_completion_provider::CopilotCompletionProvider;
  44pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
  45
  46actions!(
  47    copilot,
  48    [
  49        Suggest,
  50        NextSuggestion,
  51        PreviousSuggestion,
  52        Reinstall,
  53        SignIn,
  54        SignOut
  55    ]
  56);
  57
  58pub fn init(
  59    new_server_id: LanguageServerId,
  60    fs: Arc<dyn Fs>,
  61    http: Arc<dyn HttpClient>,
  62    node_runtime: NodeRuntime,
  63    cx: &mut App,
  64) {
  65    let language_settings = all_language_settings(None, cx);
  66    let configuration = copilot_chat::CopilotChatConfiguration {
  67        enterprise_uri: language_settings
  68            .edit_predictions
  69            .copilot
  70            .enterprise_uri
  71            .clone(),
  72    };
  73    copilot_chat::init(fs.clone(), http.clone(), configuration, cx);
  74
  75    let copilot = cx.new({
  76        let node_runtime = node_runtime.clone();
  77        move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
  78    });
  79    Copilot::set_global(copilot.clone(), cx);
  80    cx.observe(&copilot, |handle, cx| {
  81        let copilot_action_types = [
  82            TypeId::of::<Suggest>(),
  83            TypeId::of::<NextSuggestion>(),
  84            TypeId::of::<PreviousSuggestion>(),
  85            TypeId::of::<Reinstall>(),
  86        ];
  87        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  88        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  89        let status = handle.read(cx).status();
  90        let filter = CommandPaletteFilter::global_mut(cx);
  91
  92        match status {
  93            Status::Disabled => {
  94                filter.hide_action_types(&copilot_action_types);
  95                filter.hide_action_types(&copilot_auth_action_types);
  96                filter.hide_action_types(&copilot_no_auth_action_types);
  97            }
  98            Status::Authorized => {
  99                filter.hide_action_types(&copilot_no_auth_action_types);
 100                filter.show_action_types(
 101                    copilot_action_types
 102                        .iter()
 103                        .chain(&copilot_auth_action_types),
 104                );
 105            }
 106            _ => {
 107                filter.hide_action_types(&copilot_action_types);
 108                filter.hide_action_types(&copilot_auth_action_types);
 109                filter.show_action_types(copilot_no_auth_action_types.iter());
 110            }
 111        }
 112    })
 113    .detach();
 114
 115    cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
 116        workspace.register_action(|workspace, _: &SignIn, window, cx| {
 117            if let Some(copilot) = Copilot::global(cx) {
 118                let is_reinstall = false;
 119                initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
 120            }
 121        });
 122        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
 123            if let Some(copilot) = Copilot::global(cx) {
 124                reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
 125            }
 126        });
 127        workspace.register_action(|workspace, _: &SignOut, _window, cx| {
 128            if let Some(copilot) = Copilot::global(cx) {
 129                sign_out_within_workspace(workspace, copilot, cx);
 130            }
 131        });
 132    })
 133    .detach();
 134}
 135
 136enum CopilotServer {
 137    Disabled,
 138    Starting { task: Shared<Task<()>> },
 139    Error(Arc<str>),
 140    Running(RunningCopilotServer),
 141}
 142
 143impl CopilotServer {
 144    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 145        let server = self.as_running()?;
 146        anyhow::ensure!(
 147            matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
 148            "must sign in before using copilot"
 149        );
 150        Ok(server)
 151    }
 152
 153    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 154        match self {
 155            CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
 156            CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
 157            CopilotServer::Error(error) => {
 158                anyhow::bail!("copilot was not started because of an error: {error}")
 159            }
 160            CopilotServer::Running(server) => Ok(server),
 161        }
 162    }
 163}
 164
 165struct RunningCopilotServer {
 166    lsp: Arc<LanguageServer>,
 167    sign_in_status: SignInStatus,
 168    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 169}
 170
 171#[derive(Clone, Debug)]
 172enum SignInStatus {
 173    Authorized,
 174    Unauthorized,
 175    SigningIn {
 176        prompt: Option<request::PromptUserDeviceFlow>,
 177        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 178    },
 179    SignedOut {
 180        awaiting_signing_in: bool,
 181    },
 182}
 183
 184#[derive(Debug, Clone)]
 185pub enum Status {
 186    Starting {
 187        task: Shared<Task<()>>,
 188    },
 189    Error(Arc<str>),
 190    Disabled,
 191    SignedOut {
 192        awaiting_signing_in: bool,
 193    },
 194    SigningIn {
 195        prompt: Option<request::PromptUserDeviceFlow>,
 196    },
 197    Unauthorized,
 198    Authorized,
 199}
 200
 201impl Status {
 202    pub fn is_authorized(&self) -> bool {
 203        matches!(self, Status::Authorized)
 204    }
 205
 206    pub fn is_disabled(&self) -> bool {
 207        matches!(self, Status::Disabled)
 208    }
 209}
 210
 211struct RegisteredBuffer {
 212    uri: lsp::Url,
 213    language_id: String,
 214    snapshot: BufferSnapshot,
 215    snapshot_version: i32,
 216    _subscriptions: [gpui::Subscription; 2],
 217    pending_buffer_change: Task<Option<()>>,
 218}
 219
 220impl RegisteredBuffer {
 221    fn report_changes(
 222        &mut self,
 223        buffer: &Entity<Buffer>,
 224        cx: &mut Context<Copilot>,
 225    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 226        let (done_tx, done_rx) = oneshot::channel();
 227
 228        if buffer.read(cx).version() == self.snapshot.version {
 229            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 230        } else {
 231            let buffer = buffer.downgrade();
 232            let id = buffer.entity_id();
 233            let prev_pending_change =
 234                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 235            self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
 236                prev_pending_change.await;
 237
 238                let old_version = copilot
 239                    .update(cx, |copilot, _| {
 240                        let server = copilot.server.as_authenticated().log_err()?;
 241                        let buffer = server.registered_buffers.get_mut(&id)?;
 242                        Some(buffer.snapshot.version.clone())
 243                    })
 244                    .ok()??;
 245                let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
 246
 247                let content_changes = cx
 248                    .background_spawn({
 249                        let new_snapshot = new_snapshot.clone();
 250                        async move {
 251                            new_snapshot
 252                                .edits_since::<(PointUtf16, usize)>(&old_version)
 253                                .map(|edit| {
 254                                    let edit_start = edit.new.start.0;
 255                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 256                                    let new_text = new_snapshot
 257                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 258                                        .collect();
 259                                    lsp::TextDocumentContentChangeEvent {
 260                                        range: Some(lsp::Range::new(
 261                                            point_to_lsp(edit_start),
 262                                            point_to_lsp(edit_end),
 263                                        )),
 264                                        range_length: None,
 265                                        text: new_text,
 266                                    }
 267                                })
 268                                .collect::<Vec<_>>()
 269                        }
 270                    })
 271                    .await;
 272
 273                copilot
 274                    .update(cx, |copilot, _| {
 275                        let server = copilot.server.as_authenticated().log_err()?;
 276                        let buffer = server.registered_buffers.get_mut(&id)?;
 277                        if !content_changes.is_empty() {
 278                            buffer.snapshot_version += 1;
 279                            buffer.snapshot = new_snapshot;
 280                            server
 281                                .lsp
 282                                .notify::<lsp::notification::DidChangeTextDocument>(
 283                                    &lsp::DidChangeTextDocumentParams {
 284                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 285                                            buffer.uri.clone(),
 286                                            buffer.snapshot_version,
 287                                        ),
 288                                        content_changes,
 289                                    },
 290                                )
 291                                .ok();
 292                        }
 293                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 294                        Some(())
 295                    })
 296                    .ok()?;
 297
 298                Some(())
 299            });
 300        }
 301
 302        done_rx
 303    }
 304}
 305
 306#[derive(Debug)]
 307pub struct Completion {
 308    pub uuid: String,
 309    pub range: Range<Anchor>,
 310    pub text: String,
 311}
 312
 313pub struct Copilot {
 314    fs: Arc<dyn Fs>,
 315    node_runtime: NodeRuntime,
 316    server: CopilotServer,
 317    buffers: HashSet<WeakEntity<Buffer>>,
 318    server_id: LanguageServerId,
 319    _subscription: gpui::Subscription,
 320}
 321
 322pub enum Event {
 323    CopilotLanguageServerStarted,
 324    CopilotAuthSignedIn,
 325    CopilotAuthSignedOut,
 326}
 327
 328impl EventEmitter<Event> for Copilot {}
 329
 330struct GlobalCopilot(Entity<Copilot>);
 331
 332impl Global for GlobalCopilot {}
 333
 334impl Copilot {
 335    pub fn global(cx: &App) -> Option<Entity<Self>> {
 336        cx.try_global::<GlobalCopilot>()
 337            .map(|model| model.0.clone())
 338    }
 339
 340    pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
 341        cx.set_global(GlobalCopilot(copilot));
 342    }
 343
 344    fn start(
 345        new_server_id: LanguageServerId,
 346        fs: Arc<dyn Fs>,
 347        node_runtime: NodeRuntime,
 348        cx: &mut Context<Self>,
 349    ) -> Self {
 350        let mut this = Self {
 351            server_id: new_server_id,
 352            fs,
 353            node_runtime,
 354            server: CopilotServer::Disabled,
 355            buffers: Default::default(),
 356            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 357        };
 358        this.start_copilot(true, false, cx);
 359        cx.observe_global::<SettingsStore>(move |this, cx| {
 360            this.start_copilot(true, false, cx);
 361            this.send_configuration_update(cx);
 362        })
 363        .detach();
 364        this
 365    }
 366
 367    fn shutdown_language_server(
 368        &mut self,
 369        _cx: &mut Context<Self>,
 370    ) -> impl Future<Output = ()> + use<> {
 371        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 372            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 373            _ => None,
 374        };
 375
 376        async move {
 377            if let Some(shutdown) = shutdown {
 378                shutdown.await;
 379            }
 380        }
 381    }
 382
 383    fn start_copilot(
 384        &mut self,
 385        check_edit_prediction_provider: bool,
 386        awaiting_sign_in_after_start: bool,
 387        cx: &mut Context<Self>,
 388    ) {
 389        if !matches!(self.server, CopilotServer::Disabled) {
 390            return;
 391        }
 392        let language_settings = all_language_settings(None, cx);
 393        if check_edit_prediction_provider
 394            && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
 395        {
 396            return;
 397        }
 398        let server_id = self.server_id;
 399        let fs = self.fs.clone();
 400        let node_runtime = self.node_runtime.clone();
 401        let env = self.build_env(&language_settings.edit_predictions.copilot);
 402        let start_task = cx
 403            .spawn(async move |this, cx| {
 404                Self::start_language_server(
 405                    server_id,
 406                    fs,
 407                    node_runtime,
 408                    env,
 409                    this,
 410                    awaiting_sign_in_after_start,
 411                    cx,
 412                )
 413                .await
 414            })
 415            .shared();
 416        self.server = CopilotServer::Starting { task: start_task };
 417        cx.notify();
 418    }
 419
 420    fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
 421        let proxy_url = copilot_settings.proxy.clone()?;
 422        let no_verify = copilot_settings.proxy_no_verify;
 423        let http_or_https_proxy = if proxy_url.starts_with("http:") {
 424            Some("HTTP_PROXY")
 425        } else if proxy_url.starts_with("https:") {
 426            Some("HTTPS_PROXY")
 427        } else {
 428            log::error!(
 429                "Unsupported protocol scheme for language server proxy (must be http or https)"
 430            );
 431            None
 432        };
 433
 434        let mut env = HashMap::default();
 435
 436        if let Some(proxy_type) = http_or_https_proxy {
 437            env.insert(proxy_type.to_string(), proxy_url);
 438            if let Some(true) = no_verify {
 439                env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
 440            };
 441        }
 442
 443        if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
 444            env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
 445        }
 446
 447        if env.is_empty() { None } else { Some(env) }
 448    }
 449
 450    fn send_configuration_update(&mut self, cx: &mut Context<Self>) {
 451        let copilot_settings = all_language_settings(None, cx)
 452            .edit_predictions
 453            .copilot
 454            .clone();
 455
 456        let settings = json!({
 457            "http": {
 458                "proxy": copilot_settings.proxy,
 459                "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
 460            },
 461            "github-enterprise": {
 462                "uri": copilot_settings.enterprise_uri
 463            }
 464        });
 465
 466        if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
 467            copilot_chat.update(cx, |chat, cx| {
 468                chat.set_configuration(
 469                    copilot_chat::CopilotChatConfiguration {
 470                        enterprise_uri: copilot_settings.enterprise_uri.clone(),
 471                    },
 472                    cx,
 473                );
 474            });
 475        }
 476
 477        if let Ok(server) = self.server.as_running() {
 478            server
 479                .lsp
 480                .notify::<lsp::notification::DidChangeConfiguration>(
 481                    &lsp::DidChangeConfigurationParams { settings },
 482                )
 483                .log_err();
 484        }
 485    }
 486
 487    #[cfg(any(test, feature = "test-support"))]
 488    pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
 489        use fs::FakeFs;
 490        use lsp::FakeLanguageServer;
 491        use node_runtime::NodeRuntime;
 492
 493        let (server, fake_server) = FakeLanguageServer::new(
 494            LanguageServerId(0),
 495            LanguageServerBinary {
 496                path: "path/to/copilot".into(),
 497                arguments: vec![],
 498                env: None,
 499            },
 500            "copilot".into(),
 501            Default::default(),
 502            &mut cx.to_async(),
 503        );
 504        let node_runtime = NodeRuntime::unavailable();
 505        let this = cx.new(|cx| Self {
 506            server_id: LanguageServerId(0),
 507            fs: FakeFs::new(cx.background_executor().clone()),
 508            node_runtime,
 509            server: CopilotServer::Running(RunningCopilotServer {
 510                lsp: Arc::new(server),
 511                sign_in_status: SignInStatus::Authorized,
 512                registered_buffers: Default::default(),
 513            }),
 514            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 515            buffers: Default::default(),
 516        });
 517        (this, fake_server)
 518    }
 519
 520    async fn start_language_server(
 521        new_server_id: LanguageServerId,
 522        fs: Arc<dyn Fs>,
 523        node_runtime: NodeRuntime,
 524        env: Option<HashMap<String, String>>,
 525        this: WeakEntity<Self>,
 526        awaiting_sign_in_after_start: bool,
 527        cx: &mut AsyncApp,
 528    ) {
 529        let start_language_server = async {
 530            let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
 531            let node_path = node_runtime.binary_path().await?;
 532            let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 533            let binary = LanguageServerBinary {
 534                path: node_path,
 535                arguments,
 536                env,
 537            };
 538
 539            let root_path = if cfg!(target_os = "windows") {
 540                Path::new("C:/")
 541            } else {
 542                Path::new("/")
 543            };
 544
 545            let server_name = LanguageServerName("copilot".into());
 546            let server = LanguageServer::new(
 547                Arc::new(Mutex::new(None)),
 548                new_server_id,
 549                server_name,
 550                binary,
 551                root_path,
 552                None,
 553                Default::default(),
 554                cx,
 555            )?;
 556
 557            server
 558                .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
 559                .detach();
 560
 561            let configuration = lsp::DidChangeConfigurationParams {
 562                settings: Default::default(),
 563            };
 564
 565            let editor_info = request::SetEditorInfoParams {
 566                editor_info: request::EditorInfo {
 567                    name: "zed".into(),
 568                    version: env!("CARGO_PKG_VERSION").into(),
 569                },
 570                editor_plugin_info: request::EditorPluginInfo {
 571                    name: "zed-copilot".into(),
 572                    version: "0.0.1".into(),
 573                },
 574            };
 575            let editor_info_json = serde_json::to_value(&editor_info)?;
 576
 577            let server = cx
 578                .update(|cx| {
 579                    let mut params = server.default_initialize_params(false, cx);
 580                    params.initialization_options = Some(editor_info_json);
 581                    server.initialize(params, configuration.into(), cx)
 582                })?
 583                .await?;
 584
 585            let status = server
 586                .request::<request::CheckStatus>(request::CheckStatusParams {
 587                    local_checks_only: false,
 588                })
 589                .await
 590                .into_response()
 591                .context("copilot: check status")?;
 592
 593            anyhow::Ok((server, status))
 594        };
 595
 596        let server = start_language_server.await;
 597        this.update(cx, |this, cx| {
 598            cx.notify();
 599            match server {
 600                Ok((server, status)) => {
 601                    this.server = CopilotServer::Running(RunningCopilotServer {
 602                        lsp: server,
 603                        sign_in_status: SignInStatus::SignedOut {
 604                            awaiting_signing_in: awaiting_sign_in_after_start,
 605                        },
 606                        registered_buffers: Default::default(),
 607                    });
 608                    cx.emit(Event::CopilotLanguageServerStarted);
 609                    this.update_sign_in_status(status, cx);
 610                    // Send configuration now that the LSP is fully started
 611                    this.send_configuration_update(cx);
 612                }
 613                Err(error) => {
 614                    this.server = CopilotServer::Error(error.to_string().into());
 615                    cx.notify()
 616                }
 617            }
 618        })
 619        .ok();
 620    }
 621
 622    pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 623        if let CopilotServer::Running(server) = &mut self.server {
 624            let task = match &server.sign_in_status {
 625                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 626                SignInStatus::SigningIn { task, .. } => {
 627                    cx.notify();
 628                    task.clone()
 629                }
 630                SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
 631                    let lsp = server.lsp.clone();
 632                    let task = cx
 633                        .spawn(async move |this, cx| {
 634                            let sign_in = async {
 635                                let sign_in = lsp
 636                                    .request::<request::SignInInitiate>(
 637                                        request::SignInInitiateParams {},
 638                                    )
 639                                    .await
 640                                    .into_response()
 641                                    .context("copilot sign-in")?;
 642                                match sign_in {
 643                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 644                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 645                                    }
 646                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 647                                        this.update(cx, |this, cx| {
 648                                            if let CopilotServer::Running(RunningCopilotServer {
 649                                                sign_in_status: status,
 650                                                ..
 651                                            }) = &mut this.server
 652                                            {
 653                                                if let SignInStatus::SigningIn {
 654                                                    prompt: prompt_flow,
 655                                                    ..
 656                                                } = status
 657                                                {
 658                                                    *prompt_flow = Some(flow.clone());
 659                                                    cx.notify();
 660                                                }
 661                                            }
 662                                        })?;
 663                                        let response = lsp
 664                                            .request::<request::SignInConfirm>(
 665                                                request::SignInConfirmParams {
 666                                                    user_code: flow.user_code,
 667                                                },
 668                                            )
 669                                            .await
 670                                            .into_response()
 671                                            .context("copilot: sign in confirm")?;
 672                                        Ok(response)
 673                                    }
 674                                }
 675                            };
 676
 677                            let sign_in = sign_in.await;
 678                            this.update(cx, |this, cx| match sign_in {
 679                                Ok(status) => {
 680                                    this.update_sign_in_status(status, cx);
 681                                    Ok(())
 682                                }
 683                                Err(error) => {
 684                                    this.update_sign_in_status(
 685                                        request::SignInStatus::NotSignedIn,
 686                                        cx,
 687                                    );
 688                                    Err(Arc::new(error))
 689                                }
 690                            })?
 691                        })
 692                        .shared();
 693                    server.sign_in_status = SignInStatus::SigningIn {
 694                        prompt: None,
 695                        task: task.clone(),
 696                    };
 697                    cx.notify();
 698                    task
 699                }
 700            };
 701
 702            cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
 703        } else {
 704            // If we're downloading, wait until download is finished
 705            // If we're in a stuck state, display to the user
 706            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 707        }
 708    }
 709
 710    pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 711        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 712        match &self.server {
 713            CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
 714                let server = server.clone();
 715                cx.background_spawn(async move {
 716                    server
 717                        .request::<request::SignOut>(request::SignOutParams {})
 718                        .await
 719                        .into_response()
 720                        .context("copilot: sign in confirm")?;
 721                    anyhow::Ok(())
 722                })
 723            }
 724            CopilotServer::Disabled => cx.background_spawn(async {
 725                clear_copilot_config_dir().await;
 726                anyhow::Ok(())
 727            }),
 728            _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
 729        }
 730    }
 731
 732    pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
 733        let language_settings = all_language_settings(None, cx);
 734        let env = self.build_env(&language_settings.edit_predictions.copilot);
 735        let start_task = cx
 736            .spawn({
 737                let fs = self.fs.clone();
 738                let node_runtime = self.node_runtime.clone();
 739                let server_id = self.server_id;
 740                async move |this, cx| {
 741                    clear_copilot_dir().await;
 742                    Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
 743                        .await
 744                }
 745            })
 746            .shared();
 747
 748        self.server = CopilotServer::Starting {
 749            task: start_task.clone(),
 750        };
 751
 752        cx.notify();
 753
 754        start_task
 755    }
 756
 757    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 758        if let CopilotServer::Running(server) = &self.server {
 759            Some(&server.lsp)
 760        } else {
 761            None
 762        }
 763    }
 764
 765    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 766        let weak_buffer = buffer.downgrade();
 767        self.buffers.insert(weak_buffer.clone());
 768
 769        if let CopilotServer::Running(RunningCopilotServer {
 770            lsp: server,
 771            sign_in_status: status,
 772            registered_buffers,
 773            ..
 774        }) = &mut self.server
 775        {
 776            if !matches!(status, SignInStatus::Authorized { .. }) {
 777                return;
 778            }
 779
 780            let entry = registered_buffers.entry(buffer.entity_id());
 781            if let Entry::Vacant(e) = entry {
 782                let Ok(uri) = uri_for_buffer(buffer, cx) else {
 783                    return;
 784                };
 785                let language_id = id_for_language(buffer.read(cx).language());
 786                let snapshot = buffer.read(cx).snapshot();
 787                server
 788                    .notify::<lsp::notification::DidOpenTextDocument>(
 789                        &lsp::DidOpenTextDocumentParams {
 790                            text_document: lsp::TextDocumentItem {
 791                                uri: uri.clone(),
 792                                language_id: language_id.clone(),
 793                                version: 0,
 794                                text: snapshot.text(),
 795                            },
 796                        },
 797                    )
 798                    .ok();
 799
 800                e.insert(RegisteredBuffer {
 801                    uri,
 802                    language_id,
 803                    snapshot,
 804                    snapshot_version: 0,
 805                    pending_buffer_change: Task::ready(Some(())),
 806                    _subscriptions: [
 807                        cx.subscribe(buffer, |this, buffer, event, cx| {
 808                            this.handle_buffer_event(buffer, event, cx).log_err();
 809                        }),
 810                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 811                            this.buffers.remove(&weak_buffer);
 812                            this.unregister_buffer(&weak_buffer);
 813                        }),
 814                    ],
 815                });
 816            }
 817        }
 818    }
 819
 820    fn handle_buffer_event(
 821        &mut self,
 822        buffer: Entity<Buffer>,
 823        event: &language::BufferEvent,
 824        cx: &mut Context<Self>,
 825    ) -> Result<()> {
 826        if let Ok(server) = self.server.as_running() {
 827            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 828            {
 829                match event {
 830                    language::BufferEvent::Edited => {
 831                        drop(registered_buffer.report_changes(&buffer, cx));
 832                    }
 833                    language::BufferEvent::Saved => {
 834                        server
 835                            .lsp
 836                            .notify::<lsp::notification::DidSaveTextDocument>(
 837                                &lsp::DidSaveTextDocumentParams {
 838                                    text_document: lsp::TextDocumentIdentifier::new(
 839                                        registered_buffer.uri.clone(),
 840                                    ),
 841                                    text: None,
 842                                },
 843                            )?;
 844                    }
 845                    language::BufferEvent::FileHandleChanged
 846                    | language::BufferEvent::LanguageChanged => {
 847                        let new_language_id = id_for_language(buffer.read(cx).language());
 848                        let Ok(new_uri) = uri_for_buffer(&buffer, cx) else {
 849                            return Ok(());
 850                        };
 851                        if new_uri != registered_buffer.uri
 852                            || new_language_id != registered_buffer.language_id
 853                        {
 854                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 855                            registered_buffer.language_id = new_language_id;
 856                            server
 857                                .lsp
 858                                .notify::<lsp::notification::DidCloseTextDocument>(
 859                                    &lsp::DidCloseTextDocumentParams {
 860                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 861                                    },
 862                                )?;
 863                            server
 864                                .lsp
 865                                .notify::<lsp::notification::DidOpenTextDocument>(
 866                                    &lsp::DidOpenTextDocumentParams {
 867                                        text_document: lsp::TextDocumentItem::new(
 868                                            registered_buffer.uri.clone(),
 869                                            registered_buffer.language_id.clone(),
 870                                            registered_buffer.snapshot_version,
 871                                            registered_buffer.snapshot.text(),
 872                                        ),
 873                                    },
 874                                )?;
 875                        }
 876                    }
 877                    _ => {}
 878                }
 879            }
 880        }
 881
 882        Ok(())
 883    }
 884
 885    fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
 886        if let Ok(server) = self.server.as_running() {
 887            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 888                server
 889                    .lsp
 890                    .notify::<lsp::notification::DidCloseTextDocument>(
 891                        &lsp::DidCloseTextDocumentParams {
 892                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 893                        },
 894                    )
 895                    .ok();
 896            }
 897        }
 898    }
 899
 900    pub fn completions<T>(
 901        &mut self,
 902        buffer: &Entity<Buffer>,
 903        position: T,
 904        cx: &mut Context<Self>,
 905    ) -> Task<Result<Vec<Completion>>>
 906    where
 907        T: ToPointUtf16,
 908    {
 909        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 910    }
 911
 912    pub fn completions_cycling<T>(
 913        &mut self,
 914        buffer: &Entity<Buffer>,
 915        position: T,
 916        cx: &mut Context<Self>,
 917    ) -> Task<Result<Vec<Completion>>>
 918    where
 919        T: ToPointUtf16,
 920    {
 921        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 922    }
 923
 924    pub fn accept_completion(
 925        &mut self,
 926        completion: &Completion,
 927        cx: &mut Context<Self>,
 928    ) -> Task<Result<()>> {
 929        let server = match self.server.as_authenticated() {
 930            Ok(server) => server,
 931            Err(error) => return Task::ready(Err(error)),
 932        };
 933        let request =
 934            server
 935                .lsp
 936                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 937                    uuid: completion.uuid.clone(),
 938                });
 939        cx.background_spawn(async move {
 940            request
 941                .await
 942                .into_response()
 943                .context("copilot: notify accepted")?;
 944            Ok(())
 945        })
 946    }
 947
 948    pub fn discard_completions(
 949        &mut self,
 950        completions: &[Completion],
 951        cx: &mut Context<Self>,
 952    ) -> Task<Result<()>> {
 953        let server = match self.server.as_authenticated() {
 954            Ok(server) => server,
 955            Err(_) => return Task::ready(Ok(())),
 956        };
 957        let request =
 958            server
 959                .lsp
 960                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 961                    uuids: completions
 962                        .iter()
 963                        .map(|completion| completion.uuid.clone())
 964                        .collect(),
 965                });
 966        cx.background_spawn(async move {
 967            request
 968                .await
 969                .into_response()
 970                .context("copilot: notify rejected")?;
 971            Ok(())
 972        })
 973    }
 974
 975    fn request_completions<R, T>(
 976        &mut self,
 977        buffer: &Entity<Buffer>,
 978        position: T,
 979        cx: &mut Context<Self>,
 980    ) -> Task<Result<Vec<Completion>>>
 981    where
 982        R: 'static
 983            + lsp::request::Request<
 984                Params = request::GetCompletionsParams,
 985                Result = request::GetCompletionsResult,
 986            >,
 987        T: ToPointUtf16,
 988    {
 989        self.register_buffer(buffer, cx);
 990
 991        let server = match self.server.as_authenticated() {
 992            Ok(server) => server,
 993            Err(error) => return Task::ready(Err(error)),
 994        };
 995        let lsp = server.lsp.clone();
 996        let registered_buffer = server
 997            .registered_buffers
 998            .get_mut(&buffer.entity_id())
 999            .unwrap();
1000        let snapshot = registered_buffer.report_changes(buffer, cx);
1001        let buffer = buffer.read(cx);
1002        let uri = registered_buffer.uri.clone();
1003        let position = position.to_point_utf16(buffer);
1004        let settings = language_settings(
1005            buffer.language_at(position).map(|l| l.name()),
1006            buffer.file(),
1007            cx,
1008        );
1009        let tab_size = settings.tab_size;
1010        let hard_tabs = settings.hard_tabs;
1011        let relative_path = buffer
1012            .file()
1013            .map(|file| file.path().to_path_buf())
1014            .unwrap_or_default();
1015
1016        cx.background_spawn(async move {
1017            let (version, snapshot) = snapshot.await?;
1018            let result = lsp
1019                .request::<R>(request::GetCompletionsParams {
1020                    doc: request::GetCompletionsDocument {
1021                        uri,
1022                        tab_size: tab_size.into(),
1023                        indent_size: 1,
1024                        insert_spaces: !hard_tabs,
1025                        relative_path: relative_path.to_string_lossy().into(),
1026                        position: point_to_lsp(position),
1027                        version: version.try_into().unwrap(),
1028                    },
1029                })
1030                .await
1031                .into_response()
1032                .context("copilot: get completions")?;
1033            let completions = result
1034                .completions
1035                .into_iter()
1036                .map(|completion| {
1037                    let start = snapshot
1038                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
1039                    let end =
1040                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
1041                    Completion {
1042                        uuid: completion.uuid,
1043                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
1044                        text: completion.text,
1045                    }
1046                })
1047                .collect();
1048            anyhow::Ok(completions)
1049        })
1050    }
1051
1052    pub fn status(&self) -> Status {
1053        match &self.server {
1054            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1055            CopilotServer::Disabled => Status::Disabled,
1056            CopilotServer::Error(error) => Status::Error(error.clone()),
1057            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1058                match sign_in_status {
1059                    SignInStatus::Authorized { .. } => Status::Authorized,
1060                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1061                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1062                        prompt: prompt.clone(),
1063                    },
1064                    SignInStatus::SignedOut {
1065                        awaiting_signing_in,
1066                    } => Status::SignedOut {
1067                        awaiting_signing_in: *awaiting_signing_in,
1068                    },
1069                }
1070            }
1071        }
1072    }
1073
1074    fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1075        self.buffers.retain(|buffer| buffer.is_upgradable());
1076
1077        if let Ok(server) = self.server.as_running() {
1078            match lsp_status {
1079                request::SignInStatus::Ok { user: Some(_) }
1080                | request::SignInStatus::MaybeOk { .. }
1081                | request::SignInStatus::AlreadySignedIn { .. } => {
1082                    server.sign_in_status = SignInStatus::Authorized;
1083                    cx.emit(Event::CopilotAuthSignedIn);
1084                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1085                        if let Some(buffer) = buffer.upgrade() {
1086                            self.register_buffer(&buffer, cx);
1087                        }
1088                    }
1089                }
1090                request::SignInStatus::NotAuthorized { .. } => {
1091                    server.sign_in_status = SignInStatus::Unauthorized;
1092                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1093                        self.unregister_buffer(&buffer);
1094                    }
1095                }
1096                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1097                    if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1098                        server.sign_in_status = SignInStatus::SignedOut {
1099                            awaiting_signing_in: false,
1100                        };
1101                    }
1102                    cx.emit(Event::CopilotAuthSignedOut);
1103                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1104                        self.unregister_buffer(&buffer);
1105                    }
1106                }
1107            }
1108
1109            cx.notify();
1110        }
1111    }
1112}
1113
1114fn id_for_language(language: Option<&Arc<Language>>) -> String {
1115    language
1116        .map(|language| language.lsp_id())
1117        .unwrap_or_else(|| "plaintext".to_string())
1118}
1119
1120fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Result<lsp::Url, ()> {
1121    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1122        lsp::Url::from_file_path(file.abs_path(cx))
1123    } else {
1124        format!("buffer://{}", buffer.entity_id())
1125            .parse()
1126            .map_err(|_| ())
1127    }
1128}
1129
1130async fn clear_copilot_dir() {
1131    remove_matching(paths::copilot_dir(), |_| true).await
1132}
1133
1134async fn clear_copilot_config_dir() {
1135    remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1136}
1137
1138async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1139    const PACKAGE_NAME: &str = "@github/copilot-language-server";
1140    const SERVER_PATH: &str =
1141        "node_modules/@github/copilot-language-server/dist/language-server.js";
1142
1143    let latest_version = node_runtime
1144        .npm_package_latest_version(PACKAGE_NAME)
1145        .await?;
1146    let server_path = paths::copilot_dir().join(SERVER_PATH);
1147
1148    fs.create_dir(paths::copilot_dir()).await?;
1149
1150    let should_install = node_runtime
1151        .should_install_npm_package(
1152            PACKAGE_NAME,
1153            &server_path,
1154            paths::copilot_dir(),
1155            &latest_version,
1156        )
1157        .await;
1158    if should_install {
1159        node_runtime
1160            .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1161            .await?;
1162    }
1163
1164    Ok(server_path)
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169    use super::*;
1170    use gpui::TestAppContext;
1171    use util::path;
1172
1173    #[gpui::test(iterations = 10)]
1174    async fn test_buffer_management(cx: &mut TestAppContext) {
1175        let (copilot, mut lsp) = Copilot::fake(cx);
1176
1177        let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1178        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1179            .parse()
1180            .unwrap();
1181        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1182        assert_eq!(
1183            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1184                .await,
1185            lsp::DidOpenTextDocumentParams {
1186                text_document: lsp::TextDocumentItem::new(
1187                    buffer_1_uri.clone(),
1188                    "plaintext".into(),
1189                    0,
1190                    "Hello".into()
1191                ),
1192            }
1193        );
1194
1195        let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1196        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1197            .parse()
1198            .unwrap();
1199        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1200        assert_eq!(
1201            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1202                .await,
1203            lsp::DidOpenTextDocumentParams {
1204                text_document: lsp::TextDocumentItem::new(
1205                    buffer_2_uri.clone(),
1206                    "plaintext".into(),
1207                    0,
1208                    "Goodbye".into()
1209                ),
1210            }
1211        );
1212
1213        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1214        assert_eq!(
1215            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1216                .await,
1217            lsp::DidChangeTextDocumentParams {
1218                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1219                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1220                    range: Some(lsp::Range::new(
1221                        lsp::Position::new(0, 5),
1222                        lsp::Position::new(0, 5)
1223                    )),
1224                    range_length: None,
1225                    text: " world".into(),
1226                }],
1227            }
1228        );
1229
1230        // Ensure updates to the file are reflected in the LSP.
1231        buffer_1.update(cx, |buffer, cx| {
1232            buffer.file_updated(
1233                Arc::new(File {
1234                    abs_path: path!("/root/child/buffer-1").into(),
1235                    path: Path::new("child/buffer-1").into(),
1236                }),
1237                cx,
1238            )
1239        });
1240        assert_eq!(
1241            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1242                .await,
1243            lsp::DidCloseTextDocumentParams {
1244                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1245            }
1246        );
1247        let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1248        assert_eq!(
1249            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1250                .await,
1251            lsp::DidOpenTextDocumentParams {
1252                text_document: lsp::TextDocumentItem::new(
1253                    buffer_1_uri.clone(),
1254                    "plaintext".into(),
1255                    1,
1256                    "Hello world".into()
1257                ),
1258            }
1259        );
1260
1261        // Ensure all previously-registered buffers are closed when signing out.
1262        lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1263            Ok(request::SignOutResult {})
1264        });
1265        copilot
1266            .update(cx, |copilot, cx| copilot.sign_out(cx))
1267            .await
1268            .unwrap();
1269        assert_eq!(
1270            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1271                .await,
1272            lsp::DidCloseTextDocumentParams {
1273                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1274            }
1275        );
1276        assert_eq!(
1277            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1278                .await,
1279            lsp::DidCloseTextDocumentParams {
1280                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1281            }
1282        );
1283
1284        // Ensure all previously-registered buffers are re-opened when signing in.
1285        lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1286            Ok(request::SignInInitiateResult::AlreadySignedIn {
1287                user: "user-1".into(),
1288            })
1289        });
1290        copilot
1291            .update(cx, |copilot, cx| copilot.sign_in(cx))
1292            .await
1293            .unwrap();
1294
1295        assert_eq!(
1296            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1297                .await,
1298            lsp::DidOpenTextDocumentParams {
1299                text_document: lsp::TextDocumentItem::new(
1300                    buffer_1_uri.clone(),
1301                    "plaintext".into(),
1302                    0,
1303                    "Hello world".into()
1304                ),
1305            }
1306        );
1307        assert_eq!(
1308            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1309                .await,
1310            lsp::DidOpenTextDocumentParams {
1311                text_document: lsp::TextDocumentItem::new(
1312                    buffer_2_uri.clone(),
1313                    "plaintext".into(),
1314                    0,
1315                    "Goodbye".into()
1316                ),
1317            }
1318        );
1319        // Dropping a buffer causes it to be closed on the LSP side as well.
1320        cx.update(|_| drop(buffer_2));
1321        assert_eq!(
1322            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1323                .await,
1324            lsp::DidCloseTextDocumentParams {
1325                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1326            }
1327        );
1328    }
1329
1330    struct File {
1331        abs_path: PathBuf,
1332        path: Arc<Path>,
1333    }
1334
1335    impl language::File for File {
1336        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1337            Some(self)
1338        }
1339
1340        fn disk_state(&self) -> language::DiskState {
1341            language::DiskState::Present {
1342                mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1343            }
1344        }
1345
1346        fn path(&self) -> &Arc<Path> {
1347            &self.path
1348        }
1349
1350        fn full_path(&self, _: &App) -> PathBuf {
1351            unimplemented!()
1352        }
1353
1354        fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1355            unimplemented!()
1356        }
1357
1358        fn to_proto(&self, _: &App) -> rpc::proto::File {
1359            unimplemented!()
1360        }
1361
1362        fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1363            settings::WorktreeId::from_usize(0)
1364        }
1365
1366        fn is_private(&self) -> bool {
1367            false
1368        }
1369    }
1370
1371    impl language::LocalFile for File {
1372        fn abs_path(&self, _: &App) -> PathBuf {
1373            self.abs_path.clone()
1374        }
1375
1376        fn load(&self, _: &App) -> Task<Result<String>> {
1377            unimplemented!()
1378        }
1379
1380        fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1381            unimplemented!()
1382        }
1383    }
1384}
1385
1386#[cfg(test)]
1387#[ctor::ctor]
1388fn init_logger() {
1389    zlog::init_test();
1390}