copilot.rs

   1mod copilot_completion_provider;
   2pub mod request;
   3mod sign_in;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use async_compression::futures::bufread::GzipDecoder;
   7use async_tar::Archive;
   8use collections::{HashMap, HashSet};
   9use command_palette_hooks::CommandPaletteFilter;
  10use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
  11use gpui::{
  12    actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
  13    ModelContext, Task, WeakModel,
  14};
  15use language::{
  16    language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
  17    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  18    ToPointUtf16,
  19};
  20use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
  21use node_runtime::NodeRuntime;
  22use parking_lot::Mutex;
  23use request::StatusNotification;
  24use settings::SettingsStore;
  25use smol::{fs, io::BufReader, stream::StreamExt};
  26use std::{
  27    any::TypeId,
  28    ffi::OsString,
  29    mem,
  30    ops::Range,
  31    path::{Path, PathBuf},
  32    sync::Arc,
  33};
  34use util::{
  35    fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt,
  36};
  37
  38pub use copilot_completion_provider::CopilotCompletionProvider;
  39pub use sign_in::CopilotCodeVerification;
  40
  41actions!(
  42    copilot,
  43    [
  44        Suggest,
  45        NextSuggestion,
  46        PreviousSuggestion,
  47        Reinstall,
  48        SignIn,
  49        SignOut
  50    ]
  51);
  52
  53pub fn init(
  54    new_server_id: LanguageServerId,
  55    http: Arc<dyn HttpClient>,
  56    node_runtime: Arc<dyn NodeRuntime>,
  57    cx: &mut AppContext,
  58) {
  59    let copilot = cx.new_model({
  60        let node_runtime = node_runtime.clone();
  61        move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
  62    });
  63    Copilot::set_global(copilot.clone(), cx);
  64    cx.observe(&copilot, |handle, cx| {
  65        let copilot_action_types = [
  66            TypeId::of::<Suggest>(),
  67            TypeId::of::<NextSuggestion>(),
  68            TypeId::of::<PreviousSuggestion>(),
  69            TypeId::of::<Reinstall>(),
  70        ];
  71        let copilot_auth_action_types = [TypeId::of::<SignOut>()];
  72        let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
  73        let status = handle.read(cx).status();
  74        let filter = CommandPaletteFilter::global_mut(cx);
  75
  76        match status {
  77            Status::Disabled => {
  78                filter.hide_action_types(&copilot_action_types);
  79                filter.hide_action_types(&copilot_auth_action_types);
  80                filter.hide_action_types(&copilot_no_auth_action_types);
  81            }
  82            Status::Authorized => {
  83                filter.hide_action_types(&copilot_no_auth_action_types);
  84                filter.show_action_types(
  85                    copilot_action_types
  86                        .iter()
  87                        .chain(&copilot_auth_action_types),
  88                );
  89            }
  90            _ => {
  91                filter.hide_action_types(&copilot_action_types);
  92                filter.hide_action_types(&copilot_auth_action_types);
  93                filter.show_action_types(copilot_no_auth_action_types.iter());
  94            }
  95        }
  96    })
  97    .detach();
  98
  99    cx.on_action(|_: &SignIn, cx| {
 100        if let Some(copilot) = Copilot::global(cx) {
 101            copilot
 102                .update(cx, |copilot, cx| copilot.sign_in(cx))
 103                .detach_and_log_err(cx);
 104        }
 105    });
 106    cx.on_action(|_: &SignOut, cx| {
 107        if let Some(copilot) = Copilot::global(cx) {
 108            copilot
 109                .update(cx, |copilot, cx| copilot.sign_out(cx))
 110                .detach_and_log_err(cx);
 111        }
 112    });
 113    cx.on_action(|_: &Reinstall, cx| {
 114        if let Some(copilot) = Copilot::global(cx) {
 115            copilot
 116                .update(cx, |copilot, cx| copilot.reinstall(cx))
 117                .detach();
 118        }
 119    });
 120}
 121
 122enum CopilotServer {
 123    Disabled,
 124    Starting { task: Shared<Task<()>> },
 125    Error(Arc<str>),
 126    Running(RunningCopilotServer),
 127}
 128
 129impl CopilotServer {
 130    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 131        let server = self.as_running()?;
 132        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 133            Ok(server)
 134        } else {
 135            Err(anyhow!("must sign in before using copilot"))
 136        }
 137    }
 138
 139    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 140        match self {
 141            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 142            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 143            CopilotServer::Error(error) => Err(anyhow!(
 144                "copilot was not started because of an error: {}",
 145                error
 146            )),
 147            CopilotServer::Running(server) => Ok(server),
 148        }
 149    }
 150}
 151
 152struct RunningCopilotServer {
 153    lsp: Arc<LanguageServer>,
 154    sign_in_status: SignInStatus,
 155    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 156}
 157
 158#[derive(Clone, Debug)]
 159enum SignInStatus {
 160    Authorized,
 161    Unauthorized,
 162    SigningIn {
 163        prompt: Option<request::PromptUserDeviceFlow>,
 164        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 165    },
 166    SignedOut,
 167}
 168
 169#[derive(Debug, Clone)]
 170pub enum Status {
 171    Starting {
 172        task: Shared<Task<()>>,
 173    },
 174    Error(Arc<str>),
 175    Disabled,
 176    SignedOut,
 177    SigningIn {
 178        prompt: Option<request::PromptUserDeviceFlow>,
 179    },
 180    Unauthorized,
 181    Authorized,
 182}
 183
 184impl Status {
 185    pub fn is_authorized(&self) -> bool {
 186        matches!(self, Status::Authorized)
 187    }
 188}
 189
 190struct RegisteredBuffer {
 191    uri: lsp::Url,
 192    language_id: String,
 193    snapshot: BufferSnapshot,
 194    snapshot_version: i32,
 195    _subscriptions: [gpui::Subscription; 2],
 196    pending_buffer_change: Task<Option<()>>,
 197}
 198
 199impl RegisteredBuffer {
 200    fn report_changes(
 201        &mut self,
 202        buffer: &Model<Buffer>,
 203        cx: &mut ModelContext<Copilot>,
 204    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 205        let (done_tx, done_rx) = oneshot::channel();
 206
 207        if buffer.read(cx).version() == self.snapshot.version {
 208            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 209        } else {
 210            let buffer = buffer.downgrade();
 211            let id = buffer.entity_id();
 212            let prev_pending_change =
 213                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 214            self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
 215                prev_pending_change.await;
 216
 217                let old_version = copilot
 218                    .update(&mut cx, |copilot, _| {
 219                        let server = copilot.server.as_authenticated().log_err()?;
 220                        let buffer = server.registered_buffers.get_mut(&id)?;
 221                        Some(buffer.snapshot.version.clone())
 222                    })
 223                    .ok()??;
 224                let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
 225
 226                let content_changes = cx
 227                    .background_executor()
 228                    .spawn({
 229                        let new_snapshot = new_snapshot.clone();
 230                        async move {
 231                            new_snapshot
 232                                .edits_since::<(PointUtf16, usize)>(&old_version)
 233                                .map(|edit| {
 234                                    let edit_start = edit.new.start.0;
 235                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 236                                    let new_text = new_snapshot
 237                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 238                                        .collect();
 239                                    lsp::TextDocumentContentChangeEvent {
 240                                        range: Some(lsp::Range::new(
 241                                            point_to_lsp(edit_start),
 242                                            point_to_lsp(edit_end),
 243                                        )),
 244                                        range_length: None,
 245                                        text: new_text,
 246                                    }
 247                                })
 248                                .collect::<Vec<_>>()
 249                        }
 250                    })
 251                    .await;
 252
 253                copilot
 254                    .update(&mut cx, |copilot, _| {
 255                        let server = copilot.server.as_authenticated().log_err()?;
 256                        let buffer = server.registered_buffers.get_mut(&id)?;
 257                        if !content_changes.is_empty() {
 258                            buffer.snapshot_version += 1;
 259                            buffer.snapshot = new_snapshot;
 260                            server
 261                                .lsp
 262                                .notify::<lsp::notification::DidChangeTextDocument>(
 263                                    lsp::DidChangeTextDocumentParams {
 264                                        text_document: lsp::VersionedTextDocumentIdentifier::new(
 265                                            buffer.uri.clone(),
 266                                            buffer.snapshot_version,
 267                                        ),
 268                                        content_changes,
 269                                    },
 270                                )
 271                                .log_err();
 272                        }
 273                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 274                        Some(())
 275                    })
 276                    .ok()?;
 277
 278                Some(())
 279            });
 280        }
 281
 282        done_rx
 283    }
 284}
 285
 286#[derive(Debug)]
 287pub struct Completion {
 288    pub uuid: String,
 289    pub range: Range<Anchor>,
 290    pub text: String,
 291}
 292
 293pub struct Copilot {
 294    http: Arc<dyn HttpClient>,
 295    node_runtime: Arc<dyn NodeRuntime>,
 296    server: CopilotServer,
 297    buffers: HashSet<WeakModel<Buffer>>,
 298    server_id: LanguageServerId,
 299    _subscription: gpui::Subscription,
 300}
 301
 302pub enum Event {
 303    CopilotLanguageServerStarted,
 304}
 305
 306impl EventEmitter<Event> for Copilot {}
 307
 308struct GlobalCopilot(Model<Copilot>);
 309
 310impl Global for GlobalCopilot {}
 311
 312impl Copilot {
 313    pub fn global(cx: &AppContext) -> Option<Model<Self>> {
 314        cx.try_global::<GlobalCopilot>()
 315            .map(|model| model.0.clone())
 316    }
 317
 318    pub fn set_global(copilot: Model<Self>, cx: &mut AppContext) {
 319        cx.set_global(GlobalCopilot(copilot));
 320    }
 321
 322    fn start(
 323        new_server_id: LanguageServerId,
 324        http: Arc<dyn HttpClient>,
 325        node_runtime: Arc<dyn NodeRuntime>,
 326        cx: &mut ModelContext<Self>,
 327    ) -> Self {
 328        let mut this = Self {
 329            server_id: new_server_id,
 330            http,
 331            node_runtime,
 332            server: CopilotServer::Disabled,
 333            buffers: Default::default(),
 334            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 335        };
 336        this.enable_or_disable_copilot(cx);
 337        cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
 338            .detach();
 339        this
 340    }
 341
 342    fn shutdown_language_server(
 343        &mut self,
 344        _cx: &mut ModelContext<Self>,
 345    ) -> impl Future<Output = ()> {
 346        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 347            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 348            _ => None,
 349        };
 350
 351        async move {
 352            if let Some(shutdown) = shutdown {
 353                shutdown.await;
 354            }
 355        }
 356    }
 357
 358    fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
 359        let server_id = self.server_id;
 360        let http = self.http.clone();
 361        let node_runtime = self.node_runtime.clone();
 362        if all_language_settings(None, cx).inline_completions.provider
 363            == InlineCompletionProvider::Copilot
 364        {
 365            if matches!(self.server, CopilotServer::Disabled) {
 366                let start_task = cx
 367                    .spawn(move |this, cx| {
 368                        Self::start_language_server(server_id, http, node_runtime, this, cx)
 369                    })
 370                    .shared();
 371                self.server = CopilotServer::Starting { task: start_task };
 372                cx.notify();
 373            }
 374        } else {
 375            self.server = CopilotServer::Disabled;
 376            cx.notify();
 377        }
 378    }
 379
 380    #[cfg(any(test, feature = "test-support"))]
 381    pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
 382        use lsp::FakeLanguageServer;
 383        use node_runtime::FakeNodeRuntime;
 384
 385        let (server, fake_server) = FakeLanguageServer::new(
 386            LanguageServerId(0),
 387            LanguageServerBinary {
 388                path: "path/to/copilot".into(),
 389                arguments: vec![],
 390                env: None,
 391            },
 392            "copilot".into(),
 393            Default::default(),
 394            cx.to_async(),
 395        );
 396        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
 397        let node_runtime = FakeNodeRuntime::new();
 398        let this = cx.new_model(|cx| Self {
 399            server_id: LanguageServerId(0),
 400            http: http.clone(),
 401            node_runtime,
 402            server: CopilotServer::Running(RunningCopilotServer {
 403                lsp: Arc::new(server),
 404                sign_in_status: SignInStatus::Authorized,
 405                registered_buffers: Default::default(),
 406            }),
 407            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 408            buffers: Default::default(),
 409        });
 410        (this, fake_server)
 411    }
 412
 413    fn start_language_server(
 414        new_server_id: LanguageServerId,
 415        http: Arc<dyn HttpClient>,
 416        node_runtime: Arc<dyn NodeRuntime>,
 417        this: WeakModel<Self>,
 418        mut cx: AsyncAppContext,
 419    ) -> impl Future<Output = ()> {
 420        async move {
 421            let start_language_server = async {
 422                let server_path = get_copilot_lsp(http).await?;
 423                let node_path = node_runtime.binary_path().await?;
 424                let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 425                let binary = LanguageServerBinary {
 426                    path: node_path,
 427                    arguments,
 428                    // TODO: We could set HTTP_PROXY etc here and fix the copilot issue.
 429                    env: None,
 430                };
 431
 432                let server = LanguageServer::new(
 433                    Arc::new(Mutex::new(None)),
 434                    new_server_id,
 435                    binary,
 436                    Path::new("/"),
 437                    None,
 438                    cx.clone(),
 439                )?;
 440
 441                server
 442                    .on_notification::<StatusNotification, _>(
 443                        |_, _| { /* Silence the notification */ },
 444                    )
 445                    .detach();
 446                let server = cx.update(|cx| server.initialize(None, cx))?.await?;
 447
 448                let status = server
 449                    .request::<request::CheckStatus>(request::CheckStatusParams {
 450                        local_checks_only: false,
 451                    })
 452                    .await?;
 453
 454                server
 455                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 456                        editor_info: request::EditorInfo {
 457                            name: "zed".into(),
 458                            version: env!("CARGO_PKG_VERSION").into(),
 459                        },
 460                        editor_plugin_info: request::EditorPluginInfo {
 461                            name: "zed-copilot".into(),
 462                            version: "0.0.1".into(),
 463                        },
 464                    })
 465                    .await?;
 466
 467                anyhow::Ok((server, status))
 468            };
 469
 470            let server = start_language_server.await;
 471            this.update(&mut cx, |this, cx| {
 472                cx.notify();
 473                match server {
 474                    Ok((server, status)) => {
 475                        this.server = CopilotServer::Running(RunningCopilotServer {
 476                            lsp: server,
 477                            sign_in_status: SignInStatus::SignedOut,
 478                            registered_buffers: Default::default(),
 479                        });
 480                        cx.emit(Event::CopilotLanguageServerStarted);
 481                        this.update_sign_in_status(status, cx);
 482                    }
 483                    Err(error) => {
 484                        this.server = CopilotServer::Error(error.to_string().into());
 485                        cx.notify()
 486                    }
 487                }
 488            })
 489            .ok();
 490        }
 491    }
 492
 493    pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 494        if let CopilotServer::Running(server) = &mut self.server {
 495            let task = match &server.sign_in_status {
 496                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 497                SignInStatus::SigningIn { task, .. } => {
 498                    cx.notify();
 499                    task.clone()
 500                }
 501                SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
 502                    let lsp = server.lsp.clone();
 503                    let task = cx
 504                        .spawn(|this, mut cx| async move {
 505                            let sign_in = async {
 506                                let sign_in = lsp
 507                                    .request::<request::SignInInitiate>(
 508                                        request::SignInInitiateParams {},
 509                                    )
 510                                    .await?;
 511                                match sign_in {
 512                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 513                                        Ok(request::SignInStatus::Ok { user: Some(user) })
 514                                    }
 515                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 516                                        this.update(&mut cx, |this, cx| {
 517                                            if let CopilotServer::Running(RunningCopilotServer {
 518                                                sign_in_status: status,
 519                                                ..
 520                                            }) = &mut this.server
 521                                            {
 522                                                if let SignInStatus::SigningIn {
 523                                                    prompt: prompt_flow,
 524                                                    ..
 525                                                } = status
 526                                                {
 527                                                    *prompt_flow = Some(flow.clone());
 528                                                    cx.notify();
 529                                                }
 530                                            }
 531                                        })?;
 532                                        let response = lsp
 533                                            .request::<request::SignInConfirm>(
 534                                                request::SignInConfirmParams {
 535                                                    user_code: flow.user_code,
 536                                                },
 537                                            )
 538                                            .await?;
 539                                        Ok(response)
 540                                    }
 541                                }
 542                            };
 543
 544                            let sign_in = sign_in.await;
 545                            this.update(&mut cx, |this, cx| match sign_in {
 546                                Ok(status) => {
 547                                    this.update_sign_in_status(status, cx);
 548                                    Ok(())
 549                                }
 550                                Err(error) => {
 551                                    this.update_sign_in_status(
 552                                        request::SignInStatus::NotSignedIn,
 553                                        cx,
 554                                    );
 555                                    Err(Arc::new(error))
 556                                }
 557                            })?
 558                        })
 559                        .shared();
 560                    server.sign_in_status = SignInStatus::SigningIn {
 561                        prompt: None,
 562                        task: task.clone(),
 563                    };
 564                    cx.notify();
 565                    task
 566                }
 567            };
 568
 569            cx.background_executor()
 570                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 571        } else {
 572            // If we're downloading, wait until download is finished
 573            // If we're in a stuck state, display to the user
 574            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 575        }
 576    }
 577
 578    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 579        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 580        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 581            let server = server.clone();
 582            cx.background_executor().spawn(async move {
 583                server
 584                    .request::<request::SignOut>(request::SignOutParams {})
 585                    .await?;
 586                anyhow::Ok(())
 587            })
 588        } else {
 589            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 590        }
 591    }
 592
 593    pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 594        let start_task = cx
 595            .spawn({
 596                let http = self.http.clone();
 597                let node_runtime = self.node_runtime.clone();
 598                let server_id = self.server_id;
 599                move |this, cx| async move {
 600                    clear_copilot_dir().await;
 601                    Self::start_language_server(server_id, http, node_runtime, this, cx).await
 602                }
 603            })
 604            .shared();
 605
 606        self.server = CopilotServer::Starting {
 607            task: start_task.clone(),
 608        };
 609
 610        cx.notify();
 611
 612        cx.background_executor().spawn(start_task)
 613    }
 614
 615    pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
 616        if let CopilotServer::Running(server) = &self.server {
 617            Some(&server.lsp)
 618        } else {
 619            None
 620        }
 621    }
 622
 623    pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
 624        let weak_buffer = buffer.downgrade();
 625        self.buffers.insert(weak_buffer.clone());
 626
 627        if let CopilotServer::Running(RunningCopilotServer {
 628            lsp: server,
 629            sign_in_status: status,
 630            registered_buffers,
 631            ..
 632        }) = &mut self.server
 633        {
 634            if !matches!(status, SignInStatus::Authorized { .. }) {
 635                return;
 636            }
 637
 638            registered_buffers
 639                .entry(buffer.entity_id())
 640                .or_insert_with(|| {
 641                    let uri: lsp::Url = uri_for_buffer(buffer, cx);
 642                    let language_id = id_for_language(buffer.read(cx).language());
 643                    let snapshot = buffer.read(cx).snapshot();
 644                    server
 645                        .notify::<lsp::notification::DidOpenTextDocument>(
 646                            lsp::DidOpenTextDocumentParams {
 647                                text_document: lsp::TextDocumentItem {
 648                                    uri: uri.clone(),
 649                                    language_id: language_id.clone(),
 650                                    version: 0,
 651                                    text: snapshot.text(),
 652                                },
 653                            },
 654                        )
 655                        .log_err();
 656
 657                    RegisteredBuffer {
 658                        uri,
 659                        language_id,
 660                        snapshot,
 661                        snapshot_version: 0,
 662                        pending_buffer_change: Task::ready(Some(())),
 663                        _subscriptions: [
 664                            cx.subscribe(buffer, |this, buffer, event, cx| {
 665                                this.handle_buffer_event(buffer, event, cx).log_err();
 666                            }),
 667                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 668                                this.buffers.remove(&weak_buffer);
 669                                this.unregister_buffer(&weak_buffer);
 670                            }),
 671                        ],
 672                    }
 673                });
 674        }
 675    }
 676
 677    fn handle_buffer_event(
 678        &mut self,
 679        buffer: Model<Buffer>,
 680        event: &language::Event,
 681        cx: &mut ModelContext<Self>,
 682    ) -> Result<()> {
 683        if let Ok(server) = self.server.as_running() {
 684            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 685            {
 686                match event {
 687                    language::Event::Edited => {
 688                        let _ = registered_buffer.report_changes(&buffer, cx);
 689                    }
 690                    language::Event::Saved => {
 691                        server
 692                            .lsp
 693                            .notify::<lsp::notification::DidSaveTextDocument>(
 694                                lsp::DidSaveTextDocumentParams {
 695                                    text_document: lsp::TextDocumentIdentifier::new(
 696                                        registered_buffer.uri.clone(),
 697                                    ),
 698                                    text: None,
 699                                },
 700                            )?;
 701                    }
 702                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
 703                        let new_language_id = id_for_language(buffer.read(cx).language());
 704                        let new_uri = uri_for_buffer(&buffer, cx);
 705                        if new_uri != registered_buffer.uri
 706                            || new_language_id != registered_buffer.language_id
 707                        {
 708                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 709                            registered_buffer.language_id = new_language_id;
 710                            server
 711                                .lsp
 712                                .notify::<lsp::notification::DidCloseTextDocument>(
 713                                    lsp::DidCloseTextDocumentParams {
 714                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 715                                    },
 716                                )?;
 717                            server
 718                                .lsp
 719                                .notify::<lsp::notification::DidOpenTextDocument>(
 720                                    lsp::DidOpenTextDocumentParams {
 721                                        text_document: lsp::TextDocumentItem::new(
 722                                            registered_buffer.uri.clone(),
 723                                            registered_buffer.language_id.clone(),
 724                                            registered_buffer.snapshot_version,
 725                                            registered_buffer.snapshot.text(),
 726                                        ),
 727                                    },
 728                                )?;
 729                        }
 730                    }
 731                    _ => {}
 732                }
 733            }
 734        }
 735
 736        Ok(())
 737    }
 738
 739    fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
 740        if let Ok(server) = self.server.as_running() {
 741            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 742                server
 743                    .lsp
 744                    .notify::<lsp::notification::DidCloseTextDocument>(
 745                        lsp::DidCloseTextDocumentParams {
 746                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 747                        },
 748                    )
 749                    .log_err();
 750            }
 751        }
 752    }
 753
 754    pub fn completions<T>(
 755        &mut self,
 756        buffer: &Model<Buffer>,
 757        position: T,
 758        cx: &mut ModelContext<Self>,
 759    ) -> Task<Result<Vec<Completion>>>
 760    where
 761        T: ToPointUtf16,
 762    {
 763        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 764    }
 765
 766    pub fn completions_cycling<T>(
 767        &mut self,
 768        buffer: &Model<Buffer>,
 769        position: T,
 770        cx: &mut ModelContext<Self>,
 771    ) -> Task<Result<Vec<Completion>>>
 772    where
 773        T: ToPointUtf16,
 774    {
 775        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 776    }
 777
 778    pub fn accept_completion(
 779        &mut self,
 780        completion: &Completion,
 781        cx: &mut ModelContext<Self>,
 782    ) -> Task<Result<()>> {
 783        let server = match self.server.as_authenticated() {
 784            Ok(server) => server,
 785            Err(error) => return Task::ready(Err(error)),
 786        };
 787        let request =
 788            server
 789                .lsp
 790                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 791                    uuid: completion.uuid.clone(),
 792                });
 793        cx.background_executor().spawn(async move {
 794            request.await?;
 795            Ok(())
 796        })
 797    }
 798
 799    pub fn discard_completions(
 800        &mut self,
 801        completions: &[Completion],
 802        cx: &mut ModelContext<Self>,
 803    ) -> Task<Result<()>> {
 804        let server = match self.server.as_authenticated() {
 805            Ok(server) => server,
 806            Err(_) => return Task::ready(Ok(())),
 807        };
 808        let request =
 809            server
 810                .lsp
 811                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 812                    uuids: completions
 813                        .iter()
 814                        .map(|completion| completion.uuid.clone())
 815                        .collect(),
 816                });
 817        cx.background_executor().spawn(async move {
 818            request.await?;
 819            Ok(())
 820        })
 821    }
 822
 823    fn request_completions<R, T>(
 824        &mut self,
 825        buffer: &Model<Buffer>,
 826        position: T,
 827        cx: &mut ModelContext<Self>,
 828    ) -> Task<Result<Vec<Completion>>>
 829    where
 830        R: 'static
 831            + lsp::request::Request<
 832                Params = request::GetCompletionsParams,
 833                Result = request::GetCompletionsResult,
 834            >,
 835        T: ToPointUtf16,
 836    {
 837        self.register_buffer(buffer, cx);
 838
 839        let server = match self.server.as_authenticated() {
 840            Ok(server) => server,
 841            Err(error) => return Task::ready(Err(error)),
 842        };
 843        let lsp = server.lsp.clone();
 844        let registered_buffer = server
 845            .registered_buffers
 846            .get_mut(&buffer.entity_id())
 847            .unwrap();
 848        let snapshot = registered_buffer.report_changes(buffer, cx);
 849        let buffer = buffer.read(cx);
 850        let uri = registered_buffer.uri.clone();
 851        let position = position.to_point_utf16(buffer);
 852        let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
 853        let tab_size = settings.tab_size;
 854        let hard_tabs = settings.hard_tabs;
 855        let relative_path = buffer
 856            .file()
 857            .map(|file| file.path().to_path_buf())
 858            .unwrap_or_default();
 859
 860        cx.background_executor().spawn(async move {
 861            let (version, snapshot) = snapshot.await?;
 862            let result = lsp
 863                .request::<R>(request::GetCompletionsParams {
 864                    doc: request::GetCompletionsDocument {
 865                        uri,
 866                        tab_size: tab_size.into(),
 867                        indent_size: 1,
 868                        insert_spaces: !hard_tabs,
 869                        relative_path: relative_path.to_string_lossy().into(),
 870                        position: point_to_lsp(position),
 871                        version: version.try_into().unwrap(),
 872                    },
 873                })
 874                .await?;
 875            let completions = result
 876                .completions
 877                .into_iter()
 878                .map(|completion| {
 879                    let start = snapshot
 880                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 881                    let end =
 882                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 883                    Completion {
 884                        uuid: completion.uuid,
 885                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 886                        text: completion.text,
 887                    }
 888                })
 889                .collect();
 890            anyhow::Ok(completions)
 891        })
 892    }
 893
 894    pub fn status(&self) -> Status {
 895        match &self.server {
 896            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 897            CopilotServer::Disabled => Status::Disabled,
 898            CopilotServer::Error(error) => Status::Error(error.clone()),
 899            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 900                match sign_in_status {
 901                    SignInStatus::Authorized { .. } => Status::Authorized,
 902                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 903                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 904                        prompt: prompt.clone(),
 905                    },
 906                    SignInStatus::SignedOut => Status::SignedOut,
 907                }
 908            }
 909        }
 910    }
 911
 912    fn update_sign_in_status(
 913        &mut self,
 914        lsp_status: request::SignInStatus,
 915        cx: &mut ModelContext<Self>,
 916    ) {
 917        self.buffers.retain(|buffer| buffer.is_upgradable());
 918
 919        if let Ok(server) = self.server.as_running() {
 920            match lsp_status {
 921                request::SignInStatus::Ok { user: Some(_) }
 922                | request::SignInStatus::MaybeOk { .. }
 923                | request::SignInStatus::AlreadySignedIn { .. } => {
 924                    server.sign_in_status = SignInStatus::Authorized;
 925                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 926                        if let Some(buffer) = buffer.upgrade() {
 927                            self.register_buffer(&buffer, cx);
 928                        }
 929                    }
 930                }
 931                request::SignInStatus::NotAuthorized { .. } => {
 932                    server.sign_in_status = SignInStatus::Unauthorized;
 933                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 934                        self.unregister_buffer(&buffer);
 935                    }
 936                }
 937                request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
 938                    server.sign_in_status = SignInStatus::SignedOut;
 939                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 940                        self.unregister_buffer(&buffer);
 941                    }
 942                }
 943            }
 944
 945            cx.notify();
 946        }
 947    }
 948}
 949
 950fn id_for_language(language: Option<&Arc<Language>>) -> String {
 951    language
 952        .map(|language| language.lsp_id())
 953        .unwrap_or_else(|| "plaintext".to_string())
 954}
 955
 956fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
 957    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 958        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
 959    } else {
 960        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
 961    }
 962}
 963
 964async fn clear_copilot_dir() {
 965    remove_matching(&paths::COPILOT_DIR, |_| true).await
 966}
 967
 968async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 969    const SERVER_PATH: &str = "dist/agent.js";
 970
 971    ///Check for the latest copilot language server and download it if we haven't already
 972    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 973        let release =
 974            latest_github_release("zed-industries/copilot", true, false, http.clone()).await?;
 975
 976        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.tag_name));
 977
 978        fs::create_dir_all(version_dir).await?;
 979        let server_path = version_dir.join(SERVER_PATH);
 980
 981        if fs::metadata(&server_path).await.is_err() {
 982            // Copilot LSP looks for this dist dir specifically, so lets add it in.
 983            let dist_dir = version_dir.join("dist");
 984            fs::create_dir_all(dist_dir.as_path()).await?;
 985
 986            let url = &release
 987                .assets
 988                .get(0)
 989                .context("Github release for copilot contained no assets")?
 990                .browser_download_url;
 991
 992            let mut response = http
 993                .get(url, Default::default(), true)
 994                .await
 995                .context("error downloading copilot release")?;
 996            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
 997            let archive = Archive::new(decompressed_bytes);
 998            archive.unpack(dist_dir).await?;
 999
1000            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1001        }
1002
1003        Ok(server_path)
1004    }
1005
1006    match fetch_latest(http).await {
1007        ok @ Result::Ok(..) => ok,
1008        e @ Err(..) => {
1009            e.log_err();
1010            // Fetch a cached binary, if it exists
1011            maybe!(async {
1012                let mut last_version_dir = None;
1013                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1014                while let Some(entry) = entries.next().await {
1015                    let entry = entry?;
1016                    if entry.file_type().await?.is_dir() {
1017                        last_version_dir = Some(entry.path());
1018                    }
1019                }
1020                let last_version_dir =
1021                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1022                let server_path = last_version_dir.join(SERVER_PATH);
1023                if server_path.exists() {
1024                    Ok(server_path)
1025                } else {
1026                    Err(anyhow!(
1027                        "missing executable in directory {:?}",
1028                        last_version_dir
1029                    ))
1030                }
1031            })
1032            .await
1033        }
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040    use gpui::TestAppContext;
1041    use language::BufferId;
1042
1043    #[gpui::test(iterations = 10)]
1044    async fn test_buffer_management(cx: &mut TestAppContext) {
1045        let (copilot, mut lsp) = Copilot::fake(cx);
1046
1047        let buffer_1 = cx.new_model(|cx| Buffer::local("Hello", cx));
1048        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1049            .parse()
1050            .unwrap();
1051        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1052        assert_eq!(
1053            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1054                .await,
1055            lsp::DidOpenTextDocumentParams {
1056                text_document: lsp::TextDocumentItem::new(
1057                    buffer_1_uri.clone(),
1058                    "plaintext".into(),
1059                    0,
1060                    "Hello".into()
1061                ),
1062            }
1063        );
1064
1065        let buffer_2 = cx.new_model(|cx| Buffer::local("Goodbye", cx));
1066        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1067            .parse()
1068            .unwrap();
1069        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1070        assert_eq!(
1071            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1072                .await,
1073            lsp::DidOpenTextDocumentParams {
1074                text_document: lsp::TextDocumentItem::new(
1075                    buffer_2_uri.clone(),
1076                    "plaintext".into(),
1077                    0,
1078                    "Goodbye".into()
1079                ),
1080            }
1081        );
1082
1083        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1084        assert_eq!(
1085            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1086                .await,
1087            lsp::DidChangeTextDocumentParams {
1088                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1089                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1090                    range: Some(lsp::Range::new(
1091                        lsp::Position::new(0, 5),
1092                        lsp::Position::new(0, 5)
1093                    )),
1094                    range_length: None,
1095                    text: " world".into(),
1096                }],
1097            }
1098        );
1099
1100        // Ensure updates to the file are reflected in the LSP.
1101        buffer_1.update(cx, |buffer, cx| {
1102            buffer.file_updated(
1103                Arc::new(File {
1104                    abs_path: "/root/child/buffer-1".into(),
1105                    path: Path::new("child/buffer-1").into(),
1106                }),
1107                cx,
1108            )
1109        });
1110        assert_eq!(
1111            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1112                .await,
1113            lsp::DidCloseTextDocumentParams {
1114                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1115            }
1116        );
1117        let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1118        assert_eq!(
1119            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1120                .await,
1121            lsp::DidOpenTextDocumentParams {
1122                text_document: lsp::TextDocumentItem::new(
1123                    buffer_1_uri.clone(),
1124                    "plaintext".into(),
1125                    1,
1126                    "Hello world".into()
1127                ),
1128            }
1129        );
1130
1131        // Ensure all previously-registered buffers are closed when signing out.
1132        lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1133            Ok(request::SignOutResult {})
1134        });
1135        copilot
1136            .update(cx, |copilot, cx| copilot.sign_out(cx))
1137            .await
1138            .unwrap();
1139        assert_eq!(
1140            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1141                .await,
1142            lsp::DidCloseTextDocumentParams {
1143                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1144            }
1145        );
1146        assert_eq!(
1147            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1148                .await,
1149            lsp::DidCloseTextDocumentParams {
1150                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1151            }
1152        );
1153
1154        // Ensure all previously-registered buffers are re-opened when signing in.
1155        lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1156            Ok(request::SignInInitiateResult::AlreadySignedIn {
1157                user: "user-1".into(),
1158            })
1159        });
1160        copilot
1161            .update(cx, |copilot, cx| copilot.sign_in(cx))
1162            .await
1163            .unwrap();
1164
1165        assert_eq!(
1166            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1167                .await,
1168            lsp::DidOpenTextDocumentParams {
1169                text_document: lsp::TextDocumentItem::new(
1170                    buffer_1_uri.clone(),
1171                    "plaintext".into(),
1172                    0,
1173                    "Hello world".into()
1174                ),
1175            }
1176        );
1177        assert_eq!(
1178            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1179                .await,
1180            lsp::DidOpenTextDocumentParams {
1181                text_document: lsp::TextDocumentItem::new(
1182                    buffer_2_uri.clone(),
1183                    "plaintext".into(),
1184                    0,
1185                    "Goodbye".into()
1186                ),
1187            }
1188        );
1189        // Dropping a buffer causes it to be closed on the LSP side as well.
1190        cx.update(|_| drop(buffer_2));
1191        assert_eq!(
1192            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1193                .await,
1194            lsp::DidCloseTextDocumentParams {
1195                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1196            }
1197        );
1198    }
1199
1200    struct File {
1201        abs_path: PathBuf,
1202        path: Arc<Path>,
1203    }
1204
1205    impl language::File for File {
1206        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1207            Some(self)
1208        }
1209
1210        fn mtime(&self) -> Option<std::time::SystemTime> {
1211            unimplemented!()
1212        }
1213
1214        fn path(&self) -> &Arc<Path> {
1215            &self.path
1216        }
1217
1218        fn full_path(&self, _: &AppContext) -> PathBuf {
1219            unimplemented!()
1220        }
1221
1222        fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1223            unimplemented!()
1224        }
1225
1226        fn is_deleted(&self) -> bool {
1227            unimplemented!()
1228        }
1229
1230        fn as_any(&self) -> &dyn std::any::Any {
1231            unimplemented!()
1232        }
1233
1234        fn to_proto(&self) -> rpc::proto::File {
1235            unimplemented!()
1236        }
1237
1238        fn worktree_id(&self) -> usize {
1239            0
1240        }
1241
1242        fn is_private(&self) -> bool {
1243            false
1244        }
1245    }
1246
1247    impl language::LocalFile for File {
1248        fn abs_path(&self, _: &AppContext) -> PathBuf {
1249            self.abs_path.clone()
1250        }
1251
1252        fn load(&self, _: &AppContext) -> Task<Result<String>> {
1253            unimplemented!()
1254        }
1255
1256        fn buffer_reloaded(
1257            &self,
1258            _: BufferId,
1259            _: &clock::Global,
1260            _: language::LineEnding,
1261            _: Option<std::time::SystemTime>,
1262            _: &mut AppContext,
1263        ) {
1264            unimplemented!()
1265        }
1266    }
1267}