copilot.rs

   1pub mod request;
   2mod sign_in;
   3
   4use anyhow::{anyhow, Context, Result};
   5use async_compression::futures::bufread::GzipDecoder;
   6use async_tar::Archive;
   7use collections::HashMap;
   8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
   9use gpui::{
  10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
  11};
  12use language::{
  13    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  14    ToPointUtf16,
  15};
  16use log::{debug, error};
  17use lsp::{LanguageServer, LanguageServerId};
  18use node_runtime::NodeRuntime;
  19use request::{LogMessage, StatusNotification};
  20use settings::Settings;
  21use smol::{fs, io::BufReader, stream::StreamExt};
  22use std::{
  23    ffi::OsString,
  24    mem,
  25    ops::Range,
  26    path::{Path, PathBuf},
  27    pin::Pin,
  28    sync::Arc,
  29};
  30use util::{
  31    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
  32};
  33
  34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
  35actions!(copilot_auth, [SignIn, SignOut]);
  36
  37const COPILOT_NAMESPACE: &'static str = "copilot";
  38actions!(
  39    copilot,
  40    [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
  41);
  42
  43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
  44    let copilot = cx.add_model({
  45        let node_runtime = node_runtime.clone();
  46        move |cx| Copilot::start(http, node_runtime, cx)
  47    });
  48    cx.set_global(copilot.clone());
  49
  50    //////////////////////////////////////
  51    // SUBSCRIBE TO COPILOT EVENTS HERE //
  52    //////////////////////////////////////
  53
  54    cx.observe(&copilot, |handle, cx| {
  55        let status = handle.read(cx).status();
  56        cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  57            match status {
  58                Status::Disabled => {
  59                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  60                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
  61                }
  62                Status::Authorized => {
  63                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
  64                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  65                }
  66                _ => {
  67                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  68                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  69                }
  70            }
  71        });
  72    })
  73    .detach();
  74
  75    sign_in::init(cx);
  76    cx.add_global_action(|_: &SignIn, cx| {
  77        if let Some(copilot) = Copilot::global(cx) {
  78            copilot
  79                .update(cx, |copilot, cx| copilot.sign_in(cx))
  80                .detach_and_log_err(cx);
  81        }
  82    });
  83    cx.add_global_action(|_: &SignOut, cx| {
  84        if let Some(copilot) = Copilot::global(cx) {
  85            copilot
  86                .update(cx, |copilot, cx| copilot.sign_out(cx))
  87                .detach_and_log_err(cx);
  88        }
  89    });
  90
  91    cx.add_global_action(|_: &Reinstall, cx| {
  92        if let Some(copilot) = Copilot::global(cx) {
  93            copilot
  94                .update(cx, |copilot, cx| copilot.reinstall(cx))
  95                .detach();
  96        }
  97    });
  98}
  99
 100enum CopilotServer {
 101    Disabled,
 102    Starting { task: Shared<Task<()>> },
 103    Error(Arc<str>),
 104    Running(RunningCopilotServer),
 105}
 106
 107impl CopilotServer {
 108    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 109        let server = self.as_running()?;
 110        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 111            Ok(server)
 112        } else {
 113            Err(anyhow!("must sign in before using copilot"))
 114        }
 115    }
 116
 117    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 118        match self {
 119            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 120            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 121            CopilotServer::Error(error) => Err(anyhow!(
 122                "copilot was not started because of an error: {}",
 123                error
 124            )),
 125            CopilotServer::Running(server) => Ok(server),
 126        }
 127    }
 128}
 129
 130struct RunningCopilotServer {
 131    lsp: Arc<LanguageServer>,
 132    sign_in_status: SignInStatus,
 133    registered_buffers: HashMap<u64, RegisteredBuffer>,
 134}
 135
 136#[derive(Clone, Debug)]
 137enum SignInStatus {
 138    Authorized,
 139    Unauthorized,
 140    SigningIn {
 141        prompt: Option<request::PromptUserDeviceFlow>,
 142        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 143    },
 144    SignedOut,
 145}
 146
 147#[derive(Debug, Clone)]
 148pub enum Status {
 149    Starting {
 150        task: Shared<Task<()>>,
 151    },
 152    Error(Arc<str>),
 153    Disabled,
 154    SignedOut,
 155    SigningIn {
 156        prompt: Option<request::PromptUserDeviceFlow>,
 157    },
 158    Unauthorized,
 159    Authorized,
 160}
 161
 162impl Status {
 163    pub fn is_authorized(&self) -> bool {
 164        matches!(self, Status::Authorized)
 165    }
 166}
 167
 168struct RegisteredBuffer {
 169    id: u64,
 170    uri: lsp::Url,
 171    language_id: String,
 172    snapshot: BufferSnapshot,
 173    snapshot_version: i32,
 174    _subscriptions: [gpui::Subscription; 2],
 175    pending_buffer_change: Task<Option<()>>,
 176}
 177
 178impl RegisteredBuffer {
 179    fn report_changes(
 180        &mut self,
 181        buffer: &ModelHandle<Buffer>,
 182        cx: &mut ModelContext<Copilot>,
 183    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 184        let id = self.id;
 185        let (done_tx, done_rx) = oneshot::channel();
 186
 187        if buffer.read(cx).version() == self.snapshot.version {
 188            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 189        } else {
 190            let buffer = buffer.downgrade();
 191            let prev_pending_change =
 192                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 193            self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
 194                prev_pending_change.await;
 195
 196                let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 197                    let server = copilot.server.as_authenticated().log_err()?;
 198                    let buffer = server.registered_buffers.get_mut(&id)?;
 199                    Some(buffer.snapshot.version.clone())
 200                })?;
 201                let new_snapshot = buffer
 202                    .upgrade(&cx)?
 203                    .read_with(&cx, |buffer, _| buffer.snapshot());
 204
 205                let content_changes = cx
 206                    .background()
 207                    .spawn({
 208                        let new_snapshot = new_snapshot.clone();
 209                        async move {
 210                            new_snapshot
 211                                .edits_since::<(PointUtf16, usize)>(&old_version)
 212                                .map(|edit| {
 213                                    let edit_start = edit.new.start.0;
 214                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 215                                    let new_text = new_snapshot
 216                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 217                                        .collect();
 218                                    lsp::TextDocumentContentChangeEvent {
 219                                        range: Some(lsp::Range::new(
 220                                            point_to_lsp(edit_start),
 221                                            point_to_lsp(edit_end),
 222                                        )),
 223                                        range_length: None,
 224                                        text: new_text,
 225                                    }
 226                                })
 227                                .collect::<Vec<_>>()
 228                        }
 229                    })
 230                    .await;
 231
 232                copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 233                    let server = copilot.server.as_authenticated().log_err()?;
 234                    let buffer = server.registered_buffers.get_mut(&id)?;
 235                    if !content_changes.is_empty() {
 236                        buffer.snapshot_version += 1;
 237                        buffer.snapshot = new_snapshot;
 238                        server
 239                            .lsp
 240                            .notify::<lsp::notification::DidChangeTextDocument>(
 241                                lsp::DidChangeTextDocumentParams {
 242                                    text_document: lsp::VersionedTextDocumentIdentifier::new(
 243                                        buffer.uri.clone(),
 244                                        buffer.snapshot_version,
 245                                    ),
 246                                    content_changes,
 247                                },
 248                            )
 249                            .log_err();
 250                    }
 251                    let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 252                    Some(())
 253                })?;
 254
 255                Some(())
 256            });
 257        }
 258
 259        done_rx
 260    }
 261}
 262
 263#[derive(Debug)]
 264pub struct Completion {
 265    uuid: String,
 266    pub range: Range<Anchor>,
 267    pub text: String,
 268}
 269
 270pub struct Copilot {
 271    http: Arc<dyn HttpClient>,
 272    node_runtime: Arc<NodeRuntime>,
 273    server: CopilotServer,
 274    buffers: HashMap<u64, WeakModelHandle<Buffer>>,
 275}
 276
 277pub enum Event {
 278    CompletionAccepted {
 279        uuid: String,
 280        file_type: Option<Arc<str>>,
 281    },
 282    CompletionsDiscarded {
 283        uuids: Vec<String>,
 284        file_type: Option<Arc<str>>,
 285    },
 286}
 287
 288impl Entity for Copilot {
 289    type Event = Event;
 290
 291    fn app_will_quit(
 292        &mut self,
 293        _: &mut AppContext,
 294    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
 295        match mem::replace(&mut self.server, CopilotServer::Disabled) {
 296            CopilotServer::Running(server) => Some(Box::pin(async move {
 297                if let Some(shutdown) = server.lsp.shutdown() {
 298                    shutdown.await;
 299                }
 300            })),
 301            _ => None,
 302        }
 303    }
 304}
 305
 306impl Copilot {
 307    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 308        if cx.has_global::<ModelHandle<Self>>() {
 309            Some(cx.global::<ModelHandle<Self>>().clone())
 310        } else {
 311            None
 312        }
 313    }
 314
 315    fn start(
 316        http: Arc<dyn HttpClient>,
 317        node_runtime: Arc<NodeRuntime>,
 318        cx: &mut ModelContext<Self>,
 319    ) -> Self {
 320        cx.observe_global::<Settings, _>({
 321            let http = http.clone();
 322            let node_runtime = node_runtime.clone();
 323            move |this, cx| {
 324                if cx.global::<Settings>().features.copilot {
 325                    if matches!(this.server, CopilotServer::Disabled) {
 326                        let start_task = cx
 327                            .spawn({
 328                                let http = http.clone();
 329                                let node_runtime = node_runtime.clone();
 330                                move |this, cx| {
 331                                    Self::start_language_server(http, node_runtime, this, cx)
 332                                }
 333                            })
 334                            .shared();
 335                        this.server = CopilotServer::Starting { task: start_task };
 336                        cx.notify();
 337                    }
 338                } else {
 339                    this.server = CopilotServer::Disabled;
 340                    cx.notify();
 341                }
 342            }
 343        })
 344        .detach();
 345
 346        if cx.global::<Settings>().features.copilot {
 347            let start_task = cx
 348                .spawn({
 349                    let http = http.clone();
 350                    let node_runtime = node_runtime.clone();
 351                    move |this, cx| async {
 352                        Self::start_language_server(http, node_runtime, this, cx).await
 353                    }
 354                })
 355                .shared();
 356
 357            Self {
 358                http,
 359                node_runtime,
 360                server: CopilotServer::Starting { task: start_task },
 361                buffers: Default::default(),
 362            }
 363        } else {
 364            Self {
 365                http,
 366                node_runtime,
 367                server: CopilotServer::Disabled,
 368                buffers: Default::default(),
 369            }
 370        }
 371    }
 372
 373    #[cfg(any(test, feature = "test-support"))]
 374    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
 375        let (server, fake_server) =
 376            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
 377        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
 378        let this = cx.add_model(|cx| Self {
 379            http: http.clone(),
 380            node_runtime: NodeRuntime::new(http, cx.background().clone()),
 381            server: CopilotServer::Running(RunningCopilotServer {
 382                lsp: Arc::new(server),
 383                sign_in_status: SignInStatus::Authorized,
 384                registered_buffers: Default::default(),
 385            }),
 386            buffers: Default::default(),
 387        });
 388        (this, fake_server)
 389    }
 390
 391    fn start_language_server(
 392        http: Arc<dyn HttpClient>,
 393        node_runtime: Arc<NodeRuntime>,
 394        this: ModelHandle<Self>,
 395        mut cx: AsyncAppContext,
 396    ) -> impl Future<Output = ()> {
 397        async move {
 398            let start_language_server = async {
 399                let server_path = get_copilot_lsp(http).await?;
 400                let node_path = node_runtime.binary_path().await?;
 401                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
 402                let server = LanguageServer::new(
 403                    LanguageServerId(0),
 404                    &node_path,
 405                    arguments,
 406                    Path::new("/"),
 407                    None,
 408                    cx.clone(),
 409                )?;
 410
 411                server
 412                    .on_notification::<LogMessage, _>(|params, _cx| {
 413                        match params.level {
 414                            // Copilot is pretty agressive about logging
 415                            0 => debug!("copilot: {}", params.message),
 416                            1 => debug!("copilot: {}", params.message),
 417                            _ => error!("copilot: {}", params.message),
 418                        }
 419
 420                        debug!("copilot metadata: {}", params.metadata_str);
 421                        debug!("copilot extra: {:?}", params.extra);
 422                    })
 423                    .detach();
 424
 425                server
 426                    .on_notification::<StatusNotification, _>(
 427                        |_, _| { /* Silence the notification */ },
 428                    )
 429                    .detach();
 430
 431                let server = server.initialize(Default::default()).await?;
 432
 433                let status = server
 434                    .request::<request::CheckStatus>(request::CheckStatusParams {
 435                        local_checks_only: false,
 436                    })
 437                    .await?;
 438
 439                server
 440                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 441                        editor_info: request::EditorInfo {
 442                            name: "zed".into(),
 443                            version: env!("CARGO_PKG_VERSION").into(),
 444                        },
 445                        editor_plugin_info: request::EditorPluginInfo {
 446                            name: "zed-copilot".into(),
 447                            version: "0.0.1".into(),
 448                        },
 449                    })
 450                    .await?;
 451
 452                anyhow::Ok((server, status))
 453            };
 454
 455            let server = start_language_server.await;
 456            this.update(&mut cx, |this, cx| {
 457                cx.notify();
 458                match server {
 459                    Ok((server, status)) => {
 460                        this.server = CopilotServer::Running(RunningCopilotServer {
 461                            lsp: server,
 462                            sign_in_status: SignInStatus::SignedOut,
 463                            registered_buffers: Default::default(),
 464                        });
 465                        this.update_sign_in_status(status, cx);
 466                    }
 467                    Err(error) => {
 468                        this.server = CopilotServer::Error(error.to_string().into());
 469                        cx.notify()
 470                    }
 471                }
 472            })
 473        }
 474    }
 475
 476    pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 477        if let CopilotServer::Running(server) = &mut self.server {
 478            let task = match &server.sign_in_status {
 479                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 480                SignInStatus::SigningIn { task, .. } => {
 481                    cx.notify();
 482                    task.clone()
 483                }
 484                SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
 485                    let lsp = server.lsp.clone();
 486                    let task = cx
 487                        .spawn(|this, mut cx| async move {
 488                            let sign_in = async {
 489                                let sign_in = lsp
 490                                    .request::<request::SignInInitiate>(
 491                                        request::SignInInitiateParams {},
 492                                    )
 493                                    .await?;
 494                                match sign_in {
 495                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 496                                        Ok(request::SignInStatus::Ok { user })
 497                                    }
 498                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 499                                        this.update(&mut cx, |this, cx| {
 500                                            if let CopilotServer::Running(RunningCopilotServer {
 501                                                sign_in_status: status,
 502                                                ..
 503                                            }) = &mut this.server
 504                                            {
 505                                                if let SignInStatus::SigningIn {
 506                                                    prompt: prompt_flow,
 507                                                    ..
 508                                                } = status
 509                                                {
 510                                                    *prompt_flow = Some(flow.clone());
 511                                                    cx.notify();
 512                                                }
 513                                            }
 514                                        });
 515                                        let response = lsp
 516                                            .request::<request::SignInConfirm>(
 517                                                request::SignInConfirmParams {
 518                                                    user_code: flow.user_code,
 519                                                },
 520                                            )
 521                                            .await?;
 522                                        Ok(response)
 523                                    }
 524                                }
 525                            };
 526
 527                            let sign_in = sign_in.await;
 528                            this.update(&mut cx, |this, cx| match sign_in {
 529                                Ok(status) => {
 530                                    this.update_sign_in_status(status, cx);
 531                                    Ok(())
 532                                }
 533                                Err(error) => {
 534                                    this.update_sign_in_status(
 535                                        request::SignInStatus::NotSignedIn,
 536                                        cx,
 537                                    );
 538                                    Err(Arc::new(error))
 539                                }
 540                            })
 541                        })
 542                        .shared();
 543                    server.sign_in_status = SignInStatus::SigningIn {
 544                        prompt: None,
 545                        task: task.clone(),
 546                    };
 547                    cx.notify();
 548                    task
 549                }
 550            };
 551
 552            cx.foreground()
 553                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 554        } else {
 555            // If we're downloading, wait until download is finished
 556            // If we're in a stuck state, display to the user
 557            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 558        }
 559    }
 560
 561    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 562        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 563        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 564            let server = server.clone();
 565            cx.background().spawn(async move {
 566                server
 567                    .request::<request::SignOut>(request::SignOutParams {})
 568                    .await?;
 569                anyhow::Ok(())
 570            })
 571        } else {
 572            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 573        }
 574    }
 575
 576    pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 577        let start_task = cx
 578            .spawn({
 579                let http = self.http.clone();
 580                let node_runtime = self.node_runtime.clone();
 581                move |this, cx| async move {
 582                    clear_copilot_dir().await;
 583                    Self::start_language_server(http, node_runtime, this, cx).await
 584                }
 585            })
 586            .shared();
 587
 588        self.server = CopilotServer::Starting {
 589            task: start_task.clone(),
 590        };
 591
 592        cx.notify();
 593
 594        cx.foreground().spawn(start_task)
 595    }
 596
 597    pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
 598        let buffer_id = buffer.read(cx).remote_id();
 599        self.buffers.insert(buffer_id, buffer.downgrade());
 600
 601        if let CopilotServer::Running(RunningCopilotServer {
 602            lsp: server,
 603            sign_in_status: status,
 604            registered_buffers,
 605            ..
 606        }) = &mut self.server
 607        {
 608            if !matches!(status, SignInStatus::Authorized { .. }) {
 609                return;
 610            }
 611
 612            let buffer_id = buffer.read(cx).remote_id();
 613            registered_buffers.entry(buffer_id).or_insert_with(|| {
 614                let uri: lsp::Url = uri_for_buffer(buffer, cx);
 615                let language_id = id_for_language(buffer.read(cx).language());
 616                let snapshot = buffer.read(cx).snapshot();
 617                server
 618                    .notify::<lsp::notification::DidOpenTextDocument>(
 619                        lsp::DidOpenTextDocumentParams {
 620                            text_document: lsp::TextDocumentItem {
 621                                uri: uri.clone(),
 622                                language_id: language_id.clone(),
 623                                version: 0,
 624                                text: snapshot.text(),
 625                            },
 626                        },
 627                    )
 628                    .log_err();
 629
 630                RegisteredBuffer {
 631                    id: buffer_id,
 632                    uri,
 633                    language_id,
 634                    snapshot,
 635                    snapshot_version: 0,
 636                    pending_buffer_change: Task::ready(Some(())),
 637                    _subscriptions: [
 638                        cx.subscribe(buffer, |this, buffer, event, cx| {
 639                            this.handle_buffer_event(buffer, event, cx).log_err();
 640                        }),
 641                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 642                            this.buffers.remove(&buffer_id);
 643                            this.unregister_buffer(buffer_id);
 644                        }),
 645                    ],
 646                }
 647            });
 648        }
 649    }
 650
 651    fn handle_buffer_event(
 652        &mut self,
 653        buffer: ModelHandle<Buffer>,
 654        event: &language::Event,
 655        cx: &mut ModelContext<Self>,
 656    ) -> Result<()> {
 657        if let Ok(server) = self.server.as_running() {
 658            let buffer_id = buffer.read(cx).remote_id();
 659            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
 660                match event {
 661                    language::Event::Edited => {
 662                        let _ = registered_buffer.report_changes(&buffer, cx);
 663                    }
 664                    language::Event::Saved => {
 665                        server
 666                            .lsp
 667                            .notify::<lsp::notification::DidSaveTextDocument>(
 668                                lsp::DidSaveTextDocumentParams {
 669                                    text_document: lsp::TextDocumentIdentifier::new(
 670                                        registered_buffer.uri.clone(),
 671                                    ),
 672                                    text: None,
 673                                },
 674                            )?;
 675                    }
 676                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
 677                        let new_language_id = id_for_language(buffer.read(cx).language());
 678                        let new_uri = uri_for_buffer(&buffer, cx);
 679                        if new_uri != registered_buffer.uri
 680                            || new_language_id != registered_buffer.language_id
 681                        {
 682                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 683                            registered_buffer.language_id = new_language_id;
 684                            server
 685                                .lsp
 686                                .notify::<lsp::notification::DidCloseTextDocument>(
 687                                    lsp::DidCloseTextDocumentParams {
 688                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 689                                    },
 690                                )?;
 691                            server
 692                                .lsp
 693                                .notify::<lsp::notification::DidOpenTextDocument>(
 694                                    lsp::DidOpenTextDocumentParams {
 695                                        text_document: lsp::TextDocumentItem::new(
 696                                            registered_buffer.uri.clone(),
 697                                            registered_buffer.language_id.clone(),
 698                                            registered_buffer.snapshot_version,
 699                                            registered_buffer.snapshot.text(),
 700                                        ),
 701                                    },
 702                                )?;
 703                        }
 704                    }
 705                    _ => {}
 706                }
 707            }
 708        }
 709
 710        Ok(())
 711    }
 712
 713    fn unregister_buffer(&mut self, buffer_id: u64) {
 714        if let Ok(server) = self.server.as_running() {
 715            if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
 716                server
 717                    .lsp
 718                    .notify::<lsp::notification::DidCloseTextDocument>(
 719                        lsp::DidCloseTextDocumentParams {
 720                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 721                        },
 722                    )
 723                    .log_err();
 724            }
 725        }
 726    }
 727
 728    pub fn completions<T>(
 729        &mut self,
 730        buffer: &ModelHandle<Buffer>,
 731        position: T,
 732        cx: &mut ModelContext<Self>,
 733    ) -> Task<Result<Vec<Completion>>>
 734    where
 735        T: ToPointUtf16,
 736    {
 737        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 738    }
 739
 740    pub fn completions_cycling<T>(
 741        &mut self,
 742        buffer: &ModelHandle<Buffer>,
 743        position: T,
 744        cx: &mut ModelContext<Self>,
 745    ) -> Task<Result<Vec<Completion>>>
 746    where
 747        T: ToPointUtf16,
 748    {
 749        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 750    }
 751
 752    pub fn accept_completion(
 753        &mut self,
 754        completion: &Completion,
 755        file_type: Option<Arc<str>>,
 756        cx: &mut ModelContext<Self>,
 757    ) -> Task<Result<()>> {
 758        let server = match self.server.as_authenticated() {
 759            Ok(server) => server,
 760            Err(error) => return Task::ready(Err(error)),
 761        };
 762
 763        cx.emit(Event::CompletionAccepted {
 764            uuid: completion.uuid.clone(),
 765            file_type,
 766        });
 767
 768        let request =
 769            server
 770                .lsp
 771                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 772                    uuid: completion.uuid.clone(),
 773                });
 774
 775        cx.background().spawn(async move {
 776            request.await?;
 777            Ok(())
 778        })
 779    }
 780
 781    pub fn discard_completions(
 782        &mut self,
 783        completions: &[Completion],
 784        file_type: Option<Arc<str>>,
 785        cx: &mut ModelContext<Self>,
 786    ) -> Task<Result<()>> {
 787        let server = match self.server.as_authenticated() {
 788            Ok(server) => server,
 789            Err(error) => return Task::ready(Err(error)),
 790        };
 791
 792        cx.emit(Event::CompletionsDiscarded {
 793            uuids: completions
 794                .iter()
 795                .map(|completion| completion.uuid.clone())
 796                .collect(),
 797            file_type: file_type.clone(),
 798        });
 799
 800        let request =
 801            server
 802                .lsp
 803                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 804                    uuids: completions
 805                        .iter()
 806                        .map(|completion| completion.uuid.clone())
 807                        .collect(),
 808                });
 809
 810        cx.background().spawn(async move {
 811            request.await?;
 812            Ok(())
 813        })
 814    }
 815
 816    fn request_completions<R, T>(
 817        &mut self,
 818        buffer: &ModelHandle<Buffer>,
 819        position: T,
 820        cx: &mut ModelContext<Self>,
 821    ) -> Task<Result<Vec<Completion>>>
 822    where
 823        R: 'static
 824            + lsp::request::Request<
 825                Params = request::GetCompletionsParams,
 826                Result = request::GetCompletionsResult,
 827            >,
 828        T: ToPointUtf16,
 829    {
 830        self.register_buffer(buffer, cx);
 831
 832        let server = match self.server.as_authenticated() {
 833            Ok(server) => server,
 834            Err(error) => return Task::ready(Err(error)),
 835        };
 836        let lsp = server.lsp.clone();
 837        let buffer_id = buffer.read(cx).remote_id();
 838        let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
 839        let snapshot = registered_buffer.report_changes(buffer, cx);
 840        let buffer = buffer.read(cx);
 841        let uri = registered_buffer.uri.clone();
 842        let settings = cx.global::<Settings>();
 843        let position = position.to_point_utf16(buffer);
 844        let language = buffer.language_at(position);
 845        let language_name = language.map(|language| language.name());
 846        let language_name = language_name.as_deref();
 847        let tab_size = settings.tab_size(language_name);
 848        let hard_tabs = settings.hard_tabs(language_name);
 849        let relative_path = buffer
 850            .file()
 851            .map(|file| file.path().to_path_buf())
 852            .unwrap_or_default();
 853
 854        cx.foreground().spawn(async move {
 855            let (version, snapshot) = snapshot.await?;
 856            let result = lsp
 857                .request::<R>(request::GetCompletionsParams {
 858                    doc: request::GetCompletionsDocument {
 859                        uri,
 860                        tab_size: tab_size.into(),
 861                        indent_size: 1,
 862                        insert_spaces: !hard_tabs,
 863                        relative_path: relative_path.to_string_lossy().into(),
 864                        position: point_to_lsp(position),
 865                        version: version.try_into().unwrap(),
 866                    },
 867                })
 868                .await?;
 869            let completions = result
 870                .completions
 871                .into_iter()
 872                .map(|completion| {
 873                    let start = snapshot
 874                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 875                    let end =
 876                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 877                    Completion {
 878                        uuid: completion.uuid,
 879                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 880                        text: completion.text,
 881                    }
 882                })
 883                .collect();
 884            anyhow::Ok(completions)
 885        })
 886    }
 887
 888    pub fn status(&self) -> Status {
 889        match &self.server {
 890            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 891            CopilotServer::Disabled => Status::Disabled,
 892            CopilotServer::Error(error) => Status::Error(error.clone()),
 893            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 894                match sign_in_status {
 895                    SignInStatus::Authorized { .. } => Status::Authorized,
 896                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 897                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 898                        prompt: prompt.clone(),
 899                    },
 900                    SignInStatus::SignedOut => Status::SignedOut,
 901                }
 902            }
 903        }
 904    }
 905
 906    fn update_sign_in_status(
 907        &mut self,
 908        lsp_status: request::SignInStatus,
 909        cx: &mut ModelContext<Self>,
 910    ) {
 911        self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
 912
 913        if let Ok(server) = self.server.as_running() {
 914            match lsp_status {
 915                request::SignInStatus::Ok { .. }
 916                | request::SignInStatus::MaybeOk { .. }
 917                | request::SignInStatus::AlreadySignedIn { .. } => {
 918                    server.sign_in_status = SignInStatus::Authorized;
 919                    for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
 920                        if let Some(buffer) = buffer.upgrade(cx) {
 921                            self.register_buffer(&buffer, cx);
 922                        }
 923                    }
 924                }
 925                request::SignInStatus::NotAuthorized { .. } => {
 926                    server.sign_in_status = SignInStatus::Unauthorized;
 927                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 928                        self.unregister_buffer(buffer_id);
 929                    }
 930                }
 931                request::SignInStatus::NotSignedIn => {
 932                    server.sign_in_status = SignInStatus::SignedOut;
 933                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 934                        self.unregister_buffer(buffer_id);
 935                    }
 936                }
 937            }
 938
 939            cx.notify();
 940        }
 941    }
 942}
 943
 944fn id_for_language(language: Option<&Arc<Language>>) -> String {
 945    let language_name = language.map(|language| language.name());
 946    match language_name.as_deref() {
 947        Some("Plain Text") => "plaintext".to_string(),
 948        Some(language_name) => language_name.to_lowercase(),
 949        None => "plaintext".to_string(),
 950    }
 951}
 952
 953fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
 954    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 955        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
 956    } else {
 957        format!("buffer://{}", buffer.read(cx).remote_id())
 958            .parse()
 959            .unwrap()
 960    }
 961}
 962
 963async fn clear_copilot_dir() {
 964    remove_matching(&paths::COPILOT_DIR, |_| true).await
 965}
 966
 967async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 968    const SERVER_PATH: &'static str = "dist/agent.js";
 969
 970    ///Check for the latest copilot language server and download it if we haven't already
 971    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 972        let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
 973
 974        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
 975
 976        fs::create_dir_all(version_dir).await?;
 977        let server_path = version_dir.join(SERVER_PATH);
 978
 979        if fs::metadata(&server_path).await.is_err() {
 980            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
 981            let dist_dir = version_dir.join("dist");
 982            fs::create_dir_all(dist_dir.as_path()).await?;
 983
 984            let url = &release
 985                .assets
 986                .get(0)
 987                .context("Github release for copilot contained no assets")?
 988                .browser_download_url;
 989
 990            let mut response = http
 991                .get(&url, Default::default(), true)
 992                .await
 993                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
 994            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
 995            let archive = Archive::new(decompressed_bytes);
 996            archive.unpack(dist_dir).await?;
 997
 998            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
 999        }
1000
1001        Ok(server_path)
1002    }
1003
1004    match fetch_latest(http).await {
1005        ok @ Result::Ok(..) => ok,
1006        e @ Err(..) => {
1007            e.log_err();
1008            // Fetch a cached binary, if it exists
1009            (|| async move {
1010                let mut last_version_dir = None;
1011                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1012                while let Some(entry) = entries.next().await {
1013                    let entry = entry?;
1014                    if entry.file_type().await?.is_dir() {
1015                        last_version_dir = Some(entry.path());
1016                    }
1017                }
1018                let last_version_dir =
1019                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1020                let server_path = last_version_dir.join(SERVER_PATH);
1021                if server_path.exists() {
1022                    Ok(server_path)
1023                } else {
1024                    Err(anyhow!(
1025                        "missing executable in directory {:?}",
1026                        last_version_dir
1027                    ))
1028                }
1029            })()
1030            .await
1031        }
1032    }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037    use super::*;
1038    use gpui::{executor::Deterministic, TestAppContext};
1039
1040    #[gpui::test(iterations = 10)]
1041    async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1042        deterministic.forbid_parking();
1043        let (copilot, mut lsp) = Copilot::fake(cx);
1044
1045        let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1046        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1047        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1048        assert_eq!(
1049            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1050                .await,
1051            lsp::DidOpenTextDocumentParams {
1052                text_document: lsp::TextDocumentItem::new(
1053                    buffer_1_uri.clone(),
1054                    "plaintext".into(),
1055                    0,
1056                    "Hello".into()
1057                ),
1058            }
1059        );
1060
1061        let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1062        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1063        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1064        assert_eq!(
1065            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1066                .await,
1067            lsp::DidOpenTextDocumentParams {
1068                text_document: lsp::TextDocumentItem::new(
1069                    buffer_2_uri.clone(),
1070                    "plaintext".into(),
1071                    0,
1072                    "Goodbye".into()
1073                ),
1074            }
1075        );
1076
1077        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1078        assert_eq!(
1079            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1080                .await,
1081            lsp::DidChangeTextDocumentParams {
1082                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1083                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1084                    range: Some(lsp::Range::new(
1085                        lsp::Position::new(0, 5),
1086                        lsp::Position::new(0, 5)
1087                    )),
1088                    range_length: None,
1089                    text: " world".into(),
1090                }],
1091            }
1092        );
1093
1094        // Ensure updates to the file are reflected in the LSP.
1095        buffer_1
1096            .update(cx, |buffer, cx| {
1097                buffer.file_updated(
1098                    Arc::new(File {
1099                        abs_path: "/root/child/buffer-1".into(),
1100                        path: Path::new("child/buffer-1").into(),
1101                    }),
1102                    cx,
1103                )
1104            })
1105            .await;
1106        assert_eq!(
1107            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1108                .await,
1109            lsp::DidCloseTextDocumentParams {
1110                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1111            }
1112        );
1113        let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1114        assert_eq!(
1115            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1116                .await,
1117            lsp::DidOpenTextDocumentParams {
1118                text_document: lsp::TextDocumentItem::new(
1119                    buffer_1_uri.clone(),
1120                    "plaintext".into(),
1121                    1,
1122                    "Hello world".into()
1123                ),
1124            }
1125        );
1126
1127        // Ensure all previously-registered buffers are closed when signing out.
1128        lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1129            Ok(request::SignOutResult {})
1130        });
1131        copilot
1132            .update(cx, |copilot, cx| copilot.sign_out(cx))
1133            .await
1134            .unwrap();
1135        assert_eq!(
1136            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1137                .await,
1138            lsp::DidCloseTextDocumentParams {
1139                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1140            }
1141        );
1142        assert_eq!(
1143            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1144                .await,
1145            lsp::DidCloseTextDocumentParams {
1146                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1147            }
1148        );
1149
1150        // Ensure all previously-registered buffers are re-opened when signing in.
1151        lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1152            Ok(request::SignInInitiateResult::AlreadySignedIn {
1153                user: "user-1".into(),
1154            })
1155        });
1156        copilot
1157            .update(cx, |copilot, cx| copilot.sign_in(cx))
1158            .await
1159            .unwrap();
1160        assert_eq!(
1161            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1162                .await,
1163            lsp::DidOpenTextDocumentParams {
1164                text_document: lsp::TextDocumentItem::new(
1165                    buffer_2_uri.clone(),
1166                    "plaintext".into(),
1167                    0,
1168                    "Goodbye".into()
1169                ),
1170            }
1171        );
1172        assert_eq!(
1173            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1174                .await,
1175            lsp::DidOpenTextDocumentParams {
1176                text_document: lsp::TextDocumentItem::new(
1177                    buffer_1_uri.clone(),
1178                    "plaintext".into(),
1179                    0,
1180                    "Hello world".into()
1181                ),
1182            }
1183        );
1184
1185        // Dropping a buffer causes it to be closed on the LSP side as well.
1186        cx.update(|_| drop(buffer_2));
1187        assert_eq!(
1188            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1189                .await,
1190            lsp::DidCloseTextDocumentParams {
1191                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1192            }
1193        );
1194    }
1195
1196    struct File {
1197        abs_path: PathBuf,
1198        path: Arc<Path>,
1199    }
1200
1201    impl language::File for File {
1202        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1203            Some(self)
1204        }
1205
1206        fn mtime(&self) -> std::time::SystemTime {
1207            unimplemented!()
1208        }
1209
1210        fn path(&self) -> &Arc<Path> {
1211            &self.path
1212        }
1213
1214        fn full_path(&self, _: &AppContext) -> PathBuf {
1215            unimplemented!()
1216        }
1217
1218        fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1219            unimplemented!()
1220        }
1221
1222        fn is_deleted(&self) -> bool {
1223            unimplemented!()
1224        }
1225
1226        fn as_any(&self) -> &dyn std::any::Any {
1227            unimplemented!()
1228        }
1229
1230        fn to_proto(&self) -> rpc::proto::File {
1231            unimplemented!()
1232        }
1233    }
1234
1235    impl language::LocalFile for File {
1236        fn abs_path(&self, _: &AppContext) -> PathBuf {
1237            self.abs_path.clone()
1238        }
1239
1240        fn load(&self, _: &AppContext) -> Task<Result<String>> {
1241            unimplemented!()
1242        }
1243
1244        fn buffer_reloaded(
1245            &self,
1246            _: u64,
1247            _: &clock::Global,
1248            _: language::RopeFingerprint,
1249            _: ::fs::LineEnding,
1250            _: std::time::SystemTime,
1251            _: &mut AppContext,
1252        ) {
1253            unimplemented!()
1254        }
1255    }
1256}