copilot.rs

   1mod copilot_completion_provider;
   2pub mod request;
   3mod sign_in;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use async_compression::futures::bufread::GzipDecoder;
   7use async_tar::Archive;
   8use collections::{HashMap, HashSet};
   9use command_palette_hooks::CommandPaletteFilter;
  10use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
  11use gpui::{
  12    actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
  13    ModelContext, Task, WeakModel,
  14};
  15use http::github::latest_github_release;
  16use http::HttpClient;
  17use language::{
  18    language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
  19    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  20    ToPointUtf16,
  21};
  22use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
  23use node_runtime::NodeRuntime;
  24use parking_lot::Mutex;
  25use request::StatusNotification;
  26use settings::SettingsStore;
  27use smol::{fs, io::BufReader, stream::StreamExt};
  28use std::{
  29    any::TypeId,
  30    ffi::OsString,
  31    mem,
  32    ops::Range,
  33    path::{Path, PathBuf},
  34    sync::Arc,
  35};
  36use util::{fs::remove_matching, maybe, paths, ResultExt};
  37
  38pub use copilot_completion_provider::CopilotCompletionProvider;
  39pub use sign_in::CopilotCodeVerification;
  40
  41actions!(
  42    copilot,
  43    [
  44        Suggest,
  45        NextSuggestion,
  46        PreviousSuggestion,
  47        Reinstall,
  48        SignIn,
  49        SignOut
  50    ]
  51);
  52
  53pub fn init(
  54    new_server_id: LanguageServerId,
  55    http: Arc<dyn HttpClient>,
  56    node_runtime: Arc<dyn NodeRuntime>,
  57    cx: &mut AppContext,
  58) {
  59    let copilot = cx.new_model({
  60        let node_runtime = node_runtime.clone();
  61        move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
  62    });
  63    Copilot::set_global(copilot.clone(), cx);
  64    cx.observe(&copilot, |handle, cx| {
  65        let copilot_action_types = [
  66            TypeId::of::<Suggest>(),
  67            TypeId::of::<NextSuggestion>(),
  68            TypeId::of::<PreviousSuggestion>(),
  69            TypeId::of::<Reinstall>(),
  70        ];
  71        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  72        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  73        let status = handle.read(cx).status();
  74        let filter = CommandPaletteFilter::global_mut(cx);
  75
  76        match status {
  77            Status::Disabled => {
  78                filter.hide_action_types(&copilot_action_types);
  79                filter.hide_action_types(&copilot_auth_action_types);
  80                filter.hide_action_types(&copilot_no_auth_action_types);
  81            }
  82            Status::Authorized => {
  83                filter.hide_action_types(&copilot_no_auth_action_types);
  84                filter.show_action_types(
  85                    copilot_action_types
  86                        .iter()
  87                        .chain(&copilot_auth_action_types),
  88                );
  89            }
  90            _ => {
  91                filter.hide_action_types(&copilot_action_types);
  92                filter.hide_action_types(&copilot_auth_action_types);
  93                filter.show_action_types(copilot_no_auth_action_types.iter());
  94            }
  95        }
  96    })
  97    .detach();
  98
  99    cx.on_action(|_: &SignIn, cx| {
 100        if let Some(copilot) = Copilot::global(cx) {
 101            copilot
 102                .update(cx, |copilot, cx| copilot.sign_in(cx))
 103                .detach_and_log_err(cx);
 104        }
 105    });
 106    cx.on_action(|_: &SignOut, cx| {
 107        if let Some(copilot) = Copilot::global(cx) {
 108            copilot
 109                .update(cx, |copilot, cx| copilot.sign_out(cx))
 110                .detach_and_log_err(cx);
 111        }
 112    });
 113    cx.on_action(|_: &Reinstall, cx| {
 114        if let Some(copilot) = Copilot::global(cx) {
 115            copilot
 116                .update(cx, |copilot, cx| copilot.reinstall(cx))
 117                .detach();
 118        }
 119    });
 120}
 121
 122enum CopilotServer {
 123    Disabled,
 124    Starting { task: Shared<Task<()>> },
 125    Error(Arc<str>),
 126    Running(RunningCopilotServer),
 127}
 128
 129impl CopilotServer {
 130    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 131        let server = self.as_running()?;
 132        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 133            Ok(server)
 134        } else {
 135            Err(anyhow!("must sign in before using copilot"))
 136        }
 137    }
 138
 139    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 140        match self {
 141            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 142            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 143            CopilotServer::Error(error) => Err(anyhow!(
 144                "copilot was not started because of an error: {}",
 145                error
 146            )),
 147            CopilotServer::Running(server) => Ok(server),
 148        }
 149    }
 150}
 151
 152struct RunningCopilotServer {
 153    lsp: Arc<LanguageServer>,
 154    sign_in_status: SignInStatus,
 155    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 156}
 157
 158#[derive(Clone, Debug)]
 159enum SignInStatus {
 160    Authorized,
 161    Unauthorized,
 162    SigningIn {
 163        prompt: Option<request::PromptUserDeviceFlow>,
 164        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 165    },
 166    SignedOut,
 167}
 168
 169#[derive(Debug, Clone)]
 170pub enum Status {
 171    Starting {
 172        task: Shared<Task<()>>,
 173    },
 174    Error(Arc<str>),
 175    Disabled,
 176    SignedOut,
 177    SigningIn {
 178        prompt: Option<request::PromptUserDeviceFlow>,
 179    },
 180    Unauthorized,
 181    Authorized,
 182}
 183
 184impl Status {
 185    pub fn is_authorized(&self) -> bool {
 186        matches!(self, Status::Authorized)
 187    }
 188}
 189
 190struct RegisteredBuffer {
 191    uri: lsp::Url,
 192    language_id: String,
 193    snapshot: BufferSnapshot,
 194    snapshot_version: i32,
 195    _subscriptions: [gpui::Subscription; 2],
 196    pending_buffer_change: Task<Option<()>>,
 197}
 198
 199impl RegisteredBuffer {
 200    fn report_changes(
 201        &mut self,
 202        buffer: &Model<Buffer>,
 203        cx: &mut ModelContext<Copilot>,
 204    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 205        let (done_tx, done_rx) = oneshot::channel();
 206
 207        if buffer.read(cx).version() == self.snapshot.version {
 208            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 209        } else {
 210            let buffer = buffer.downgrade();
 211            let id = buffer.entity_id();
 212            let prev_pending_change =
 213                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 214            self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
 215                prev_pending_change.await;
 216
 217                let old_version = copilot
 218                    .update(&mut cx, |copilot, _| {
 219                        let server = copilot.server.as_authenticated().log_err()?;
 220                        let buffer = server.registered_buffers.get_mut(&id)?;
 221                        Some(buffer.snapshot.version.clone())
 222                    })
 223                    .ok()??;
 224                let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
 225
 226                let content_changes = cx
 227                    .background_executor()
 228                    .spawn({
 229                        let new_snapshot = new_snapshot.clone();
 230                        async move {
 231                            new_snapshot
 232                                .edits_since::<(PointUtf16, usize)>(&old_version)
 233                                .map(|edit| {
 234                                    let edit_start = edit.new.start.0;
 235                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 236                                    let new_text = new_snapshot
 237                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 238                                        .collect();
 239                                    lsp::TextDocumentContentChangeEvent {
 240                                        range: Some(lsp::Range::new(
 241                                            point_to_lsp(edit_start),
 242                                            point_to_lsp(edit_end),
 243                                        )),
 244                                        range_length: None,
 245                                        text: new_text,
 246                                    }
 247                                })
 248                                .collect::<Vec<_>>()
 249                        }
 250                    })
 251                    .await;
 252
 253                copilot
 254                    .update(&mut cx, |copilot, _| {
 255                        let server = copilot.server.as_authenticated().log_err()?;
 256                        let buffer = server.registered_buffers.get_mut(&id)?;
 257                        if !content_changes.is_empty() {
 258                            buffer.snapshot_version += 1;
 259                            buffer.snapshot = new_snapshot;
 260                            server
 261                                .lsp
 262                                .notify::<lsp::notification::DidChangeTextDocument>(
 263                                    lsp::DidChangeTextDocumentParams {
 264                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 265                                            buffer.uri.clone(),
 266                                            buffer.snapshot_version,
 267                                        ),
 268                                        content_changes,
 269                                    },
 270                                )
 271                                .log_err();
 272                        }
 273                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 274                        Some(())
 275                    })
 276                    .ok()?;
 277
 278                Some(())
 279            });
 280        }
 281
 282        done_rx
 283    }
 284}
 285
 286#[derive(Debug)]
 287pub struct Completion {
 288    pub uuid: String,
 289    pub range: Range<Anchor>,
 290    pub text: String,
 291}
 292
 293pub struct Copilot {
 294    http: Arc<dyn HttpClient>,
 295    node_runtime: Arc<dyn NodeRuntime>,
 296    server: CopilotServer,
 297    buffers: HashSet<WeakModel<Buffer>>,
 298    server_id: LanguageServerId,
 299    _subscription: gpui::Subscription,
 300}
 301
 302pub enum Event {
 303    CopilotLanguageServerStarted,
 304}
 305
 306impl EventEmitter<Event> for Copilot {}
 307
 308struct GlobalCopilot(Model<Copilot>);
 309
 310impl Global for GlobalCopilot {}
 311
 312impl Copilot {
 313    pub fn global(cx: &AppContext) -> Option<Model<Self>> {
 314        cx.try_global::<GlobalCopilot>()
 315            .map(|model| model.0.clone())
 316    }
 317
 318    pub fn set_global(copilot: Model<Self>, cx: &mut AppContext) {
 319        cx.set_global(GlobalCopilot(copilot));
 320    }
 321
 322    fn start(
 323        new_server_id: LanguageServerId,
 324        http: Arc<dyn HttpClient>,
 325        node_runtime: Arc<dyn NodeRuntime>,
 326        cx: &mut ModelContext<Self>,
 327    ) -> Self {
 328        let mut this = Self {
 329            server_id: new_server_id,
 330            http,
 331            node_runtime,
 332            server: CopilotServer::Disabled,
 333            buffers: Default::default(),
 334            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 335        };
 336        this.enable_or_disable_copilot(cx);
 337        cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
 338            .detach();
 339        this
 340    }
 341
 342    fn shutdown_language_server(
 343        &mut self,
 344        _cx: &mut ModelContext<Self>,
 345    ) -> impl Future<Output = ()> {
 346        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 347            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 348            _ => None,
 349        };
 350
 351        async move {
 352            if let Some(shutdown) = shutdown {
 353                shutdown.await;
 354            }
 355        }
 356    }
 357
 358    fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
 359        let server_id = self.server_id;
 360        let http = self.http.clone();
 361        let node_runtime = self.node_runtime.clone();
 362        if all_language_settings(None, cx).inline_completions.provider
 363            == InlineCompletionProvider::Copilot
 364        {
 365            if matches!(self.server, CopilotServer::Disabled) {
 366                let start_task = cx
 367                    .spawn(move |this, cx| {
 368                        Self::start_language_server(server_id, http, node_runtime, this, cx)
 369                    })
 370                    .shared();
 371                self.server = CopilotServer::Starting { task: start_task };
 372                cx.notify();
 373            }
 374        } else {
 375            self.server = CopilotServer::Disabled;
 376            cx.notify();
 377        }
 378    }
 379
 380    #[cfg(any(test, feature = "test-support"))]
 381    pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
 382        use lsp::FakeLanguageServer;
 383        use node_runtime::FakeNodeRuntime;
 384
 385        let (server, fake_server) = FakeLanguageServer::new(
 386            LanguageServerId(0),
 387            LanguageServerBinary {
 388                path: "path/to/copilot".into(),
 389                arguments: vec![],
 390                env: None,
 391            },
 392            "copilot".into(),
 393            Default::default(),
 394            cx.to_async(),
 395        );
 396        let http = http::FakeHttpClient::create(|_| async { unreachable!() });
 397        let node_runtime = FakeNodeRuntime::new();
 398        let this = cx.new_model(|cx| Self {
 399            server_id: LanguageServerId(0),
 400            http: http.clone(),
 401            node_runtime,
 402            server: CopilotServer::Running(RunningCopilotServer {
 403                lsp: Arc::new(server),
 404                sign_in_status: SignInStatus::Authorized,
 405                registered_buffers: Default::default(),
 406            }),
 407            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 408            buffers: Default::default(),
 409        });
 410        (this, fake_server)
 411    }
 412
 413    fn start_language_server(
 414        new_server_id: LanguageServerId,
 415        http: Arc<dyn HttpClient>,
 416        node_runtime: Arc<dyn NodeRuntime>,
 417        this: WeakModel<Self>,
 418        mut cx: AsyncAppContext,
 419    ) -> impl Future<Output = ()> {
 420        async move {
 421            let start_language_server = async {
 422                let server_path = get_copilot_lsp(http).await?;
 423                let node_path = node_runtime.binary_path().await?;
 424                let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 425                let binary = LanguageServerBinary {
 426                    path: node_path,
 427                    arguments,
 428                    // TODO: We could set HTTP_PROXY etc here and fix the copilot issue.
 429                    env: None,
 430                };
 431
 432                let root_path = if cfg!(target_os = "windows") {
 433                    Path::new("C:/")
 434                } else {
 435                    Path::new("/")
 436                };
 437
 438                let server = LanguageServer::new(
 439                    Arc::new(Mutex::new(None)),
 440                    new_server_id,
 441                    binary,
 442                    root_path,
 443                    None,
 444                    cx.clone(),
 445                )?;
 446
 447                server
 448                    .on_notification::<StatusNotification, _>(
 449                        |_, _| { /* Silence the notification */ },
 450                    )
 451                    .detach();
 452                let server = cx.update(|cx| server.initialize(None, cx))?.await?;
 453
 454                let status = server
 455                    .request::<request::CheckStatus>(request::CheckStatusParams {
 456                        local_checks_only: false,
 457                    })
 458                    .await?;
 459
 460                server
 461                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 462                        editor_info: request::EditorInfo {
 463                            name: "zed".into(),
 464                            version: env!("CARGO_PKG_VERSION").into(),
 465                        },
 466                        editor_plugin_info: request::EditorPluginInfo {
 467                            name: "zed-copilot".into(),
 468                            version: "0.0.1".into(),
 469                        },
 470                    })
 471                    .await?;
 472
 473                anyhow::Ok((server, status))
 474            };
 475
 476            let server = start_language_server.await;
 477            this.update(&mut cx, |this, cx| {
 478                cx.notify();
 479                match server {
 480                    Ok((server, status)) => {
 481                        this.server = CopilotServer::Running(RunningCopilotServer {
 482                            lsp: server,
 483                            sign_in_status: SignInStatus::SignedOut,
 484                            registered_buffers: Default::default(),
 485                        });
 486                        cx.emit(Event::CopilotLanguageServerStarted);
 487                        this.update_sign_in_status(status, cx);
 488                    }
 489                    Err(error) => {
 490                        this.server = CopilotServer::Error(error.to_string().into());
 491                        cx.notify()
 492                    }
 493                }
 494            })
 495            .ok();
 496        }
 497    }
 498
 499    pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 500        if let CopilotServer::Running(server) = &mut self.server {
 501            let task = match &server.sign_in_status {
 502                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 503                SignInStatus::SigningIn { task, .. } => {
 504                    cx.notify();
 505                    task.clone()
 506                }
 507                SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
 508                    let lsp = server.lsp.clone();
 509                    let task = cx
 510                        .spawn(|this, mut cx| async move {
 511                            let sign_in = async {
 512                                let sign_in = lsp
 513                                    .request::<request::SignInInitiate>(
 514                                        request::SignInInitiateParams {},
 515                                    )
 516                                    .await?;
 517                                match sign_in {
 518                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 519                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 520                                    }
 521                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 522                                        this.update(&mut cx, |this, cx| {
 523                                            if let CopilotServer::Running(RunningCopilotServer {
 524                                                sign_in_status: status,
 525                                                ..
 526                                            }) = &mut this.server
 527                                            {
 528                                                if let SignInStatus::SigningIn {
 529                                                    prompt: prompt_flow,
 530                                                    ..
 531                                                } = status
 532                                                {
 533                                                    *prompt_flow = Some(flow.clone());
 534                                                    cx.notify();
 535                                                }
 536                                            }
 537                                        })?;
 538                                        let response = lsp
 539                                            .request::<request::SignInConfirm>(
 540                                                request::SignInConfirmParams {
 541                                                    user_code: flow.user_code,
 542                                                },
 543                                            )
 544                                            .await?;
 545                                        Ok(response)
 546                                    }
 547                                }
 548                            };
 549
 550                            let sign_in = sign_in.await;
 551                            this.update(&mut cx, |this, cx| match sign_in {
 552                                Ok(status) => {
 553                                    this.update_sign_in_status(status, cx);
 554                                    Ok(())
 555                                }
 556                                Err(error) => {
 557                                    this.update_sign_in_status(
 558                                        request::SignInStatus::NotSignedIn,
 559                                        cx,
 560                                    );
 561                                    Err(Arc::new(error))
 562                                }
 563                            })?
 564                        })
 565                        .shared();
 566                    server.sign_in_status = SignInStatus::SigningIn {
 567                        prompt: None,
 568                        task: task.clone(),
 569                    };
 570                    cx.notify();
 571                    task
 572                }
 573            };
 574
 575            cx.background_executor()
 576                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 577        } else {
 578            // If we're downloading, wait until download is finished
 579            // If we're in a stuck state, display to the user
 580            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 581        }
 582    }
 583
 584    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 585        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 586        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 587            let server = server.clone();
 588            cx.background_executor().spawn(async move {
 589                server
 590                    .request::<request::SignOut>(request::SignOutParams {})
 591                    .await?;
 592                anyhow::Ok(())
 593            })
 594        } else {
 595            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 596        }
 597    }
 598
 599    pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 600        let start_task = cx
 601            .spawn({
 602                let http = self.http.clone();
 603                let node_runtime = self.node_runtime.clone();
 604                let server_id = self.server_id;
 605                move |this, cx| async move {
 606                    clear_copilot_dir().await;
 607                    Self::start_language_server(server_id, http, node_runtime, this, cx).await
 608                }
 609            })
 610            .shared();
 611
 612        self.server = CopilotServer::Starting {
 613            task: start_task.clone(),
 614        };
 615
 616        cx.notify();
 617
 618        cx.background_executor().spawn(start_task)
 619    }
 620
 621    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 622        if let CopilotServer::Running(server) = &self.server {
 623            Some(&server.lsp)
 624        } else {
 625            None
 626        }
 627    }
 628
 629    pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
 630        let weak_buffer = buffer.downgrade();
 631        self.buffers.insert(weak_buffer.clone());
 632
 633        if let CopilotServer::Running(RunningCopilotServer {
 634            lsp: server,
 635            sign_in_status: status,
 636            registered_buffers,
 637            ..
 638        }) = &mut self.server
 639        {
 640            if !matches!(status, SignInStatus::Authorized { .. }) {
 641                return;
 642            }
 643
 644            registered_buffers
 645                .entry(buffer.entity_id())
 646                .or_insert_with(|| {
 647                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 648                    let language_id = id_for_language(buffer.read(cx).language());
 649                    let snapshot = buffer.read(cx).snapshot();
 650                    server
 651                        .notify::<lsp::notification::DidOpenTextDocument>(
 652                            lsp::DidOpenTextDocumentParams {
 653                                text_document: lsp::TextDocumentItem {
 654                                    uri: uri.clone(),
 655                                    language_id: language_id.clone(),
 656                                    version: 0,
 657                                    text: snapshot.text(),
 658                                },
 659                            },
 660                        )
 661                        .log_err();
 662
 663                    RegisteredBuffer {
 664                        uri,
 665                        language_id,
 666                        snapshot,
 667                        snapshot_version: 0,
 668                        pending_buffer_change: Task::ready(Some(())),
 669                        _subscriptions: [
 670                            cx.subscribe(buffer, |this, buffer, event, cx| {
 671                                this.handle_buffer_event(buffer, event, cx).log_err();
 672                            }),
 673                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 674                                this.buffers.remove(&weak_buffer);
 675                                this.unregister_buffer(&weak_buffer);
 676                            }),
 677                        ],
 678                    }
 679                });
 680        }
 681    }
 682
 683    fn handle_buffer_event(
 684        &mut self,
 685        buffer: Model<Buffer>,
 686        event: &language::Event,
 687        cx: &mut ModelContext<Self>,
 688    ) -> Result<()> {
 689        if let Ok(server) = self.server.as_running() {
 690            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 691            {
 692                match event {
 693                    language::Event::Edited => {
 694                        let _ = registered_buffer.report_changes(&buffer, cx);
 695                    }
 696                    language::Event::Saved => {
 697                        server
 698                            .lsp
 699                            .notify::<lsp::notification::DidSaveTextDocument>(
 700                                lsp::DidSaveTextDocumentParams {
 701                                    text_document: lsp::TextDocumentIdentifier::new(
 702                                        registered_buffer.uri.clone(),
 703                                    ),
 704                                    text: None,
 705                                },
 706                            )?;
 707                    }
 708                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
 709                        let new_language_id = id_for_language(buffer.read(cx).language());
 710                        let new_uri = uri_for_buffer(&buffer, cx);
 711                        if new_uri != registered_buffer.uri
 712                            || new_language_id != registered_buffer.language_id
 713                        {
 714                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 715                            registered_buffer.language_id = new_language_id;
 716                            server
 717                                .lsp
 718                                .notify::<lsp::notification::DidCloseTextDocument>(
 719                                    lsp::DidCloseTextDocumentParams {
 720                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 721                                    },
 722                                )?;
 723                            server
 724                                .lsp
 725                                .notify::<lsp::notification::DidOpenTextDocument>(
 726                                    lsp::DidOpenTextDocumentParams {
 727                                        text_document: lsp::TextDocumentItem::new(
 728                                            registered_buffer.uri.clone(),
 729                                            registered_buffer.language_id.clone(),
 730                                            registered_buffer.snapshot_version,
 731                                            registered_buffer.snapshot.text(),
 732                                        ),
 733                                    },
 734                                )?;
 735                        }
 736                    }
 737                    _ => {}
 738                }
 739            }
 740        }
 741
 742        Ok(())
 743    }
 744
 745    fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
 746        if let Ok(server) = self.server.as_running() {
 747            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 748                server
 749                    .lsp
 750                    .notify::<lsp::notification::DidCloseTextDocument>(
 751                        lsp::DidCloseTextDocumentParams {
 752                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 753                        },
 754                    )
 755                    .log_err();
 756            }
 757        }
 758    }
 759
 760    pub fn completions<T>(
 761        &mut self,
 762        buffer: &Model<Buffer>,
 763        position: T,
 764        cx: &mut ModelContext<Self>,
 765    ) -> Task<Result<Vec<Completion>>>
 766    where
 767        T: ToPointUtf16,
 768    {
 769        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 770    }
 771
 772    pub fn completions_cycling<T>(
 773        &mut self,
 774        buffer: &Model<Buffer>,
 775        position: T,
 776        cx: &mut ModelContext<Self>,
 777    ) -> Task<Result<Vec<Completion>>>
 778    where
 779        T: ToPointUtf16,
 780    {
 781        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 782    }
 783
 784    pub fn accept_completion(
 785        &mut self,
 786        completion: &Completion,
 787        cx: &mut ModelContext<Self>,
 788    ) -> Task<Result<()>> {
 789        let server = match self.server.as_authenticated() {
 790            Ok(server) => server,
 791            Err(error) => return Task::ready(Err(error)),
 792        };
 793        let request =
 794            server
 795                .lsp
 796                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 797                    uuid: completion.uuid.clone(),
 798                });
 799        cx.background_executor().spawn(async move {
 800            request.await?;
 801            Ok(())
 802        })
 803    }
 804
 805    pub fn discard_completions(
 806        &mut self,
 807        completions: &[Completion],
 808        cx: &mut ModelContext<Self>,
 809    ) -> Task<Result<()>> {
 810        let server = match self.server.as_authenticated() {
 811            Ok(server) => server,
 812            Err(_) => return Task::ready(Ok(())),
 813        };
 814        let request =
 815            server
 816                .lsp
 817                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 818                    uuids: completions
 819                        .iter()
 820                        .map(|completion| completion.uuid.clone())
 821                        .collect(),
 822                });
 823        cx.background_executor().spawn(async move {
 824            request.await?;
 825            Ok(())
 826        })
 827    }
 828
 829    fn request_completions<R, T>(
 830        &mut self,
 831        buffer: &Model<Buffer>,
 832        position: T,
 833        cx: &mut ModelContext<Self>,
 834    ) -> Task<Result<Vec<Completion>>>
 835    where
 836        R: 'static
 837            + lsp::request::Request<
 838                Params = request::GetCompletionsParams,
 839                Result = request::GetCompletionsResult,
 840            >,
 841        T: ToPointUtf16,
 842    {
 843        self.register_buffer(buffer, cx);
 844
 845        let server = match self.server.as_authenticated() {
 846            Ok(server) => server,
 847            Err(error) => return Task::ready(Err(error)),
 848        };
 849        let lsp = server.lsp.clone();
 850        let registered_buffer = server
 851            .registered_buffers
 852            .get_mut(&buffer.entity_id())
 853            .unwrap();
 854        let snapshot = registered_buffer.report_changes(buffer, cx);
 855        let buffer = buffer.read(cx);
 856        let uri = registered_buffer.uri.clone();
 857        let position = position.to_point_utf16(buffer);
 858        let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
 859        let tab_size = settings.tab_size;
 860        let hard_tabs = settings.hard_tabs;
 861        let relative_path = buffer
 862            .file()
 863            .map(|file| file.path().to_path_buf())
 864            .unwrap_or_default();
 865
 866        cx.background_executor().spawn(async move {
 867            let (version, snapshot) = snapshot.await?;
 868            let result = lsp
 869                .request::<R>(request::GetCompletionsParams {
 870                    doc: request::GetCompletionsDocument {
 871                        uri,
 872                        tab_size: tab_size.into(),
 873                        indent_size: 1,
 874                        insert_spaces: !hard_tabs,
 875                        relative_path: relative_path.to_string_lossy().into(),
 876                        position: point_to_lsp(position),
 877                        version: version.try_into().unwrap(),
 878                    },
 879                })
 880                .await?;
 881            let completions = result
 882                .completions
 883                .into_iter()
 884                .map(|completion| {
 885                    let start = snapshot
 886                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 887                    let end =
 888                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 889                    Completion {
 890                        uuid: completion.uuid,
 891                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 892                        text: completion.text,
 893                    }
 894                })
 895                .collect();
 896            anyhow::Ok(completions)
 897        })
 898    }
 899
 900    pub fn status(&self) -> Status {
 901        match &self.server {
 902            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 903            CopilotServer::Disabled => Status::Disabled,
 904            CopilotServer::Error(error) => Status::Error(error.clone()),
 905            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 906                match sign_in_status {
 907                    SignInStatus::Authorized { .. } => Status::Authorized,
 908                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 909                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 910                        prompt: prompt.clone(),
 911                    },
 912                    SignInStatus::SignedOut => Status::SignedOut,
 913                }
 914            }
 915        }
 916    }
 917
 918    fn update_sign_in_status(
 919        &mut self,
 920        lsp_status: request::SignInStatus,
 921        cx: &mut ModelContext<Self>,
 922    ) {
 923        self.buffers.retain(|buffer| buffer.is_upgradable());
 924
 925        if let Ok(server) = self.server.as_running() {
 926            match lsp_status {
 927                request::SignInStatus::Ok { user: Some(_) }
 928                | request::SignInStatus::MaybeOk { .. }
 929                | request::SignInStatus::AlreadySignedIn { .. } => {
 930                    server.sign_in_status = SignInStatus::Authorized;
 931                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 932                        if let Some(buffer) = buffer.upgrade() {
 933                            self.register_buffer(&buffer, cx);
 934                        }
 935                    }
 936                }
 937                request::SignInStatus::NotAuthorized { .. } => {
 938                    server.sign_in_status = SignInStatus::Unauthorized;
 939                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 940                        self.unregister_buffer(&buffer);
 941                    }
 942                }
 943                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
 944                    server.sign_in_status = SignInStatus::SignedOut;
 945                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 946                        self.unregister_buffer(&buffer);
 947                    }
 948                }
 949            }
 950
 951            cx.notify();
 952        }
 953    }
 954}
 955
 956fn id_for_language(language: Option<&Arc<Language>>) -> String {
 957    language
 958        .map(|language| language.lsp_id())
 959        .unwrap_or_else(|| "plaintext".to_string())
 960}
 961
 962fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
 963    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 964        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
 965    } else {
 966        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
 967    }
 968}
 969
 970async fn clear_copilot_dir() {
 971    remove_matching(&paths::COPILOT_DIR, |_| true).await
 972}
 973
 974async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 975    const SERVER_PATH: &str = "dist/agent.js";
 976
 977    ///Check for the latest copilot language server and download it if we haven't already
 978    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 979        let release =
 980            latest_github_release("zed-industries/copilot", true, false, http.clone()).await?;
 981
 982        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.tag_name));
 983
 984        fs::create_dir_all(version_dir).await?;
 985        let server_path = version_dir.join(SERVER_PATH);
 986
 987        if fs::metadata(&server_path).await.is_err() {
 988            // Copilot LSP looks for this dist dir specifically, so lets add it in.
 989            let dist_dir = version_dir.join("dist");
 990            fs::create_dir_all(dist_dir.as_path()).await?;
 991
 992            let url = &release
 993                .assets
 994                .get(0)
 995                .context("Github release for copilot contained no assets")?
 996                .browser_download_url;
 997
 998            let mut response = http
 999                .get(url, Default::default(), true)
1000                .await
1001                .context("error downloading copilot release")?;
1002            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
1003            let archive = Archive::new(decompressed_bytes);
1004            archive.unpack(dist_dir).await?;
1005
1006            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1007        }
1008
1009        Ok(server_path)
1010    }
1011
1012    match fetch_latest(http).await {
1013        ok @ Result::Ok(..) => ok,
1014        e @ Err(..) => {
1015            e.log_err();
1016            // Fetch a cached binary, if it exists
1017            maybe!(async {
1018                let mut last_version_dir = None;
1019                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1020                while let Some(entry) = entries.next().await {
1021                    let entry = entry?;
1022                    if entry.file_type().await?.is_dir() {
1023                        last_version_dir = Some(entry.path());
1024                    }
1025                }
1026                let last_version_dir =
1027                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1028                let server_path = last_version_dir.join(SERVER_PATH);
1029                if server_path.exists() {
1030                    Ok(server_path)
1031                } else {
1032                    Err(anyhow!(
1033                        "missing executable in directory {:?}",
1034                        last_version_dir
1035                    ))
1036                }
1037            })
1038            .await
1039        }
1040    }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046    use gpui::TestAppContext;
1047    use language::BufferId;
1048
1049    #[gpui::test(iterations = 10)]
1050    async fn test_buffer_management(cx: &mut TestAppContext) {
1051        let (copilot, mut lsp) = Copilot::fake(cx);
1052
1053        let buffer_1 = cx.new_model(|cx| Buffer::local("Hello", cx));
1054        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1055            .parse()
1056            .unwrap();
1057        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1058        assert_eq!(
1059            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1060                .await,
1061            lsp::DidOpenTextDocumentParams {
1062                text_document: lsp::TextDocumentItem::new(
1063                    buffer_1_uri.clone(),
1064                    "plaintext".into(),
1065                    0,
1066                    "Hello".into()
1067                ),
1068            }
1069        );
1070
1071        let buffer_2 = cx.new_model(|cx| Buffer::local("Goodbye", cx));
1072        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1073            .parse()
1074            .unwrap();
1075        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1076        assert_eq!(
1077            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1078                .await,
1079            lsp::DidOpenTextDocumentParams {
1080                text_document: lsp::TextDocumentItem::new(
1081                    buffer_2_uri.clone(),
1082                    "plaintext".into(),
1083                    0,
1084                    "Goodbye".into()
1085                ),
1086            }
1087        );
1088
1089        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1090        assert_eq!(
1091            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1092                .await,
1093            lsp::DidChangeTextDocumentParams {
1094                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1095                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1096                    range: Some(lsp::Range::new(
1097                        lsp::Position::new(0, 5),
1098                        lsp::Position::new(0, 5)
1099                    )),
1100                    range_length: None,
1101                    text: " world".into(),
1102                }],
1103            }
1104        );
1105
1106        // Ensure updates to the file are reflected in the LSP.
1107        buffer_1.update(cx, |buffer, cx| {
1108            buffer.file_updated(
1109                Arc::new(File {
1110                    abs_path: "/root/child/buffer-1".into(),
1111                    path: Path::new("child/buffer-1").into(),
1112                }),
1113                cx,
1114            )
1115        });
1116        assert_eq!(
1117            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1118                .await,
1119            lsp::DidCloseTextDocumentParams {
1120                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1121            }
1122        );
1123        let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1124        assert_eq!(
1125            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1126                .await,
1127            lsp::DidOpenTextDocumentParams {
1128                text_document: lsp::TextDocumentItem::new(
1129                    buffer_1_uri.clone(),
1130                    "plaintext".into(),
1131                    1,
1132                    "Hello world".into()
1133                ),
1134            }
1135        );
1136
1137        // Ensure all previously-registered buffers are closed when signing out.
1138        lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1139            Ok(request::SignOutResult {})
1140        });
1141        copilot
1142            .update(cx, |copilot, cx| copilot.sign_out(cx))
1143            .await
1144            .unwrap();
1145        assert_eq!(
1146            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1147                .await,
1148            lsp::DidCloseTextDocumentParams {
1149                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1150            }
1151        );
1152        assert_eq!(
1153            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1154                .await,
1155            lsp::DidCloseTextDocumentParams {
1156                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1157            }
1158        );
1159
1160        // Ensure all previously-registered buffers are re-opened when signing in.
1161        lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1162            Ok(request::SignInInitiateResult::AlreadySignedIn {
1163                user: "user-1".into(),
1164            })
1165        });
1166        copilot
1167            .update(cx, |copilot, cx| copilot.sign_in(cx))
1168            .await
1169            .unwrap();
1170
1171        assert_eq!(
1172            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1173                .await,
1174            lsp::DidOpenTextDocumentParams {
1175                text_document: lsp::TextDocumentItem::new(
1176                    buffer_1_uri.clone(),
1177                    "plaintext".into(),
1178                    0,
1179                    "Hello world".into()
1180                ),
1181            }
1182        );
1183        assert_eq!(
1184            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1185                .await,
1186            lsp::DidOpenTextDocumentParams {
1187                text_document: lsp::TextDocumentItem::new(
1188                    buffer_2_uri.clone(),
1189                    "plaintext".into(),
1190                    0,
1191                    "Goodbye".into()
1192                ),
1193            }
1194        );
1195        // Dropping a buffer causes it to be closed on the LSP side as well.
1196        cx.update(|_| drop(buffer_2));
1197        assert_eq!(
1198            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1199                .await,
1200            lsp::DidCloseTextDocumentParams {
1201                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1202            }
1203        );
1204    }
1205
1206    struct File {
1207        abs_path: PathBuf,
1208        path: Arc<Path>,
1209    }
1210
1211    impl language::File for File {
1212        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1213            Some(self)
1214        }
1215
1216        fn mtime(&self) -> Option<std::time::SystemTime> {
1217            unimplemented!()
1218        }
1219
1220        fn path(&self) -> &Arc<Path> {
1221            &self.path
1222        }
1223
1224        fn full_path(&self, _: &AppContext) -> PathBuf {
1225            unimplemented!()
1226        }
1227
1228        fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1229            unimplemented!()
1230        }
1231
1232        fn is_deleted(&self) -> bool {
1233            unimplemented!()
1234        }
1235
1236        fn as_any(&self) -> &dyn std::any::Any {
1237            unimplemented!()
1238        }
1239
1240        fn to_proto(&self) -> rpc::proto::File {
1241            unimplemented!()
1242        }
1243
1244        fn worktree_id(&self) -> usize {
1245            0
1246        }
1247
1248        fn is_private(&self) -> bool {
1249            false
1250        }
1251    }
1252
1253    impl language::LocalFile for File {
1254        fn abs_path(&self, _: &AppContext) -> PathBuf {
1255            self.abs_path.clone()
1256        }
1257
1258        fn load(&self, _: &AppContext) -> Task<Result<String>> {
1259            unimplemented!()
1260        }
1261
1262        fn buffer_reloaded(
1263            &self,
1264            _: BufferId,
1265            _: &clock::Global,
1266            _: language::LineEnding,
1267            _: Option<std::time::SystemTime>,
1268            _: &mut AppContext,
1269        ) {
1270            unimplemented!()
1271        }
1272    }
1273}