copilot.rs

   1pub mod request;
   2mod sign_in;
   3
   4use anyhow::{anyhow, Context, Result};
   5use async_compression::futures::bufread::GzipDecoder;
   6use async_tar::Archive;
   7use collections::HashMap;
   8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
   9use gpui::{
  10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
  11};
  12use language::{
  13    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  14    ToPointUtf16,
  15};
  16use log::{debug, error};
  17use lsp::{LanguageServer, LanguageServerId};
  18use node_runtime::NodeRuntime;
  19use request::{LogMessage, StatusNotification};
  20use settings::Settings;
  21use smol::{fs, io::BufReader, stream::StreamExt};
  22use std::{
  23    ffi::OsString,
  24    mem,
  25    ops::Range,
  26    path::{Path, PathBuf},
  27    pin::Pin,
  28    sync::Arc,
  29};
  30use util::{
  31    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
  32};
  33
  34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
  35actions!(copilot_auth, [SignIn, SignOut]);
  36
  37const COPILOT_NAMESPACE: &'static str = "copilot";
  38actions!(
  39    copilot,
  40    [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
  41);
  42
  43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
  44    let copilot = cx.add_model({
  45        let node_runtime = node_runtime.clone();
  46        move |cx| Copilot::start(http, node_runtime, cx)
  47    });
  48    cx.set_global(copilot.clone());
  49
  50    cx.subscribe(&copilot, |_, event, _| {
  51        match event {
  52            Event::CompletionAccepted { uuid, file_type } => {
  53                // Build event object and pass it in
  54                // telemetry.report_clickhouse_event(event, settings.telemetry())
  55            }
  56            Event::CompletionsDiscarded { uuids, file_type } => {
  57                for uuid in uuids {
  58                    // Build event object and pass it in
  59                    // telemetry.report_clickhouse_event(event, settings.telemetry())
  60                }
  61            }
  62        };
  63    })
  64    .detach();
  65
  66    cx.observe(&copilot, |handle, cx| {
  67        let status = handle.read(cx).status();
  68        cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  69            match status {
  70                Status::Disabled => {
  71                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  72                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
  73                }
  74                Status::Authorized => {
  75                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
  76                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  77                }
  78                _ => {
  79                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  80                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  81                }
  82            }
  83        });
  84    })
  85    .detach();
  86
  87    sign_in::init(cx);
  88    cx.add_global_action(|_: &SignIn, cx| {
  89        if let Some(copilot) = Copilot::global(cx) {
  90            copilot
  91                .update(cx, |copilot, cx| copilot.sign_in(cx))
  92                .detach_and_log_err(cx);
  93        }
  94    });
  95    cx.add_global_action(|_: &SignOut, cx| {
  96        if let Some(copilot) = Copilot::global(cx) {
  97            copilot
  98                .update(cx, |copilot, cx| copilot.sign_out(cx))
  99                .detach_and_log_err(cx);
 100        }
 101    });
 102
 103    cx.add_global_action(|_: &Reinstall, cx| {
 104        if let Some(copilot) = Copilot::global(cx) {
 105            copilot
 106                .update(cx, |copilot, cx| copilot.reinstall(cx))
 107                .detach();
 108        }
 109    });
 110}
 111
 112enum CopilotServer {
 113    Disabled,
 114    Starting { task: Shared<Task<()>> },
 115    Error(Arc<str>),
 116    Running(RunningCopilotServer),
 117}
 118
 119impl CopilotServer {
 120    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 121        let server = self.as_running()?;
 122        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 123            Ok(server)
 124        } else {
 125            Err(anyhow!("must sign in before using copilot"))
 126        }
 127    }
 128
 129    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 130        match self {
 131            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 132            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 133            CopilotServer::Error(error) => Err(anyhow!(
 134                "copilot was not started because of an error: {}",
 135                error
 136            )),
 137            CopilotServer::Running(server) => Ok(server),
 138        }
 139    }
 140}
 141
 142struct RunningCopilotServer {
 143    lsp: Arc<LanguageServer>,
 144    sign_in_status: SignInStatus,
 145    registered_buffers: HashMap<u64, RegisteredBuffer>,
 146}
 147
 148#[derive(Clone, Debug)]
 149enum SignInStatus {
 150    Authorized,
 151    Unauthorized,
 152    SigningIn {
 153        prompt: Option<request::PromptUserDeviceFlow>,
 154        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 155    },
 156    SignedOut,
 157}
 158
 159#[derive(Debug, Clone)]
 160pub enum Status {
 161    Starting {
 162        task: Shared<Task<()>>,
 163    },
 164    Error(Arc<str>),
 165    Disabled,
 166    SignedOut,
 167    SigningIn {
 168        prompt: Option<request::PromptUserDeviceFlow>,
 169    },
 170    Unauthorized,
 171    Authorized,
 172}
 173
 174impl Status {
 175    pub fn is_authorized(&self) -> bool {
 176        matches!(self, Status::Authorized)
 177    }
 178}
 179
 180struct RegisteredBuffer {
 181    id: u64,
 182    uri: lsp::Url,
 183    language_id: String,
 184    snapshot: BufferSnapshot,
 185    snapshot_version: i32,
 186    _subscriptions: [gpui::Subscription; 2],
 187    pending_buffer_change: Task<Option<()>>,
 188}
 189
 190impl RegisteredBuffer {
 191    fn report_changes(
 192        &mut self,
 193        buffer: &ModelHandle<Buffer>,
 194        cx: &mut ModelContext<Copilot>,
 195    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 196        let id = self.id;
 197        let (done_tx, done_rx) = oneshot::channel();
 198
 199        if buffer.read(cx).version() == self.snapshot.version {
 200            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 201        } else {
 202            let buffer = buffer.downgrade();
 203            let prev_pending_change =
 204                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 205            self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
 206                prev_pending_change.await;
 207
 208                let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 209                    let server = copilot.server.as_authenticated().log_err()?;
 210                    let buffer = server.registered_buffers.get_mut(&id)?;
 211                    Some(buffer.snapshot.version.clone())
 212                })?;
 213                let new_snapshot = buffer
 214                    .upgrade(&cx)?
 215                    .read_with(&cx, |buffer, _| buffer.snapshot());
 216
 217                let content_changes = cx
 218                    .background()
 219                    .spawn({
 220                        let new_snapshot = new_snapshot.clone();
 221                        async move {
 222                            new_snapshot
 223                                .edits_since::<(PointUtf16, usize)>(&old_version)
 224                                .map(|edit| {
 225                                    let edit_start = edit.new.start.0;
 226                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 227                                    let new_text = new_snapshot
 228                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 229                                        .collect();
 230                                    lsp::TextDocumentContentChangeEvent {
 231                                        range: Some(lsp::Range::new(
 232                                            point_to_lsp(edit_start),
 233                                            point_to_lsp(edit_end),
 234                                        )),
 235                                        range_length: None,
 236                                        text: new_text,
 237                                    }
 238                                })
 239                                .collect::<Vec<_>>()
 240                        }
 241                    })
 242                    .await;
 243
 244                copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 245                    let server = copilot.server.as_authenticated().log_err()?;
 246                    let buffer = server.registered_buffers.get_mut(&id)?;
 247                    if !content_changes.is_empty() {
 248                        buffer.snapshot_version += 1;
 249                        buffer.snapshot = new_snapshot;
 250                        server
 251                            .lsp
 252                            .notify::<lsp::notification::DidChangeTextDocument>(
 253                                lsp::DidChangeTextDocumentParams {
 254                                    text_document: lsp::VersionedTextDocumentIdentifier::new(
 255                                        buffer.uri.clone(),
 256                                        buffer.snapshot_version,
 257                                    ),
 258                                    content_changes,
 259                                },
 260                            )
 261                            .log_err();
 262                    }
 263                    let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 264                    Some(())
 265                })?;
 266
 267                Some(())
 268            });
 269        }
 270
 271        done_rx
 272    }
 273}
 274
 275#[derive(Debug)]
 276pub struct Completion {
 277    uuid: String,
 278    pub range: Range<Anchor>,
 279    pub text: String,
 280}
 281
 282pub struct Copilot {
 283    http: Arc<dyn HttpClient>,
 284    node_runtime: Arc<NodeRuntime>,
 285    server: CopilotServer,
 286    buffers: HashMap<u64, WeakModelHandle<Buffer>>,
 287}
 288
 289pub enum Event {
 290    CompletionAccepted {
 291        uuid: String,
 292        file_type: Option<Arc<str>>,
 293    },
 294    CompletionsDiscarded {
 295        uuids: Vec<String>,
 296        file_type: Option<Arc<str>>,
 297    },
 298}
 299
 300impl Entity for Copilot {
 301    type Event = Event;
 302
 303    fn app_will_quit(
 304        &mut self,
 305        _: &mut AppContext,
 306    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
 307        match mem::replace(&mut self.server, CopilotServer::Disabled) {
 308            CopilotServer::Running(server) => Some(Box::pin(async move {
 309                if let Some(shutdown) = server.lsp.shutdown() {
 310                    shutdown.await;
 311                }
 312            })),
 313            _ => None,
 314        }
 315    }
 316}
 317
 318impl Copilot {
 319    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 320        if cx.has_global::<ModelHandle<Self>>() {
 321            Some(cx.global::<ModelHandle<Self>>().clone())
 322        } else {
 323            None
 324        }
 325    }
 326
 327    fn start(
 328        http: Arc<dyn HttpClient>,
 329        node_runtime: Arc<NodeRuntime>,
 330        cx: &mut ModelContext<Self>,
 331    ) -> Self {
 332        cx.observe_global::<Settings, _>({
 333            let http = http.clone();
 334            let node_runtime = node_runtime.clone();
 335            move |this, cx| {
 336                if cx.global::<Settings>().features.copilot {
 337                    if matches!(this.server, CopilotServer::Disabled) {
 338                        let start_task = cx
 339                            .spawn({
 340                                let http = http.clone();
 341                                let node_runtime = node_runtime.clone();
 342                                move |this, cx| {
 343                                    Self::start_language_server(http, node_runtime, this, cx)
 344                                }
 345                            })
 346                            .shared();
 347                        this.server = CopilotServer::Starting { task: start_task };
 348                        cx.notify();
 349                    }
 350                } else {
 351                    this.server = CopilotServer::Disabled;
 352                    cx.notify();
 353                }
 354            }
 355        })
 356        .detach();
 357
 358        if cx.global::<Settings>().features.copilot {
 359            let start_task = cx
 360                .spawn({
 361                    let http = http.clone();
 362                    let node_runtime = node_runtime.clone();
 363                    move |this, cx| async {
 364                        Self::start_language_server(http, node_runtime, this, cx).await
 365                    }
 366                })
 367                .shared();
 368
 369            Self {
 370                http,
 371                node_runtime,
 372                server: CopilotServer::Starting { task: start_task },
 373                buffers: Default::default(),
 374            }
 375        } else {
 376            Self {
 377                http,
 378                node_runtime,
 379                server: CopilotServer::Disabled,
 380                buffers: Default::default(),
 381            }
 382        }
 383    }
 384
 385    #[cfg(any(test, feature = "test-support"))]
 386    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
 387        let (server, fake_server) =
 388            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
 389        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
 390        let this = cx.add_model(|cx| Self {
 391            http: http.clone(),
 392            node_runtime: NodeRuntime::new(http, cx.background().clone()),
 393            server: CopilotServer::Running(RunningCopilotServer {
 394                lsp: Arc::new(server),
 395                sign_in_status: SignInStatus::Authorized,
 396                registered_buffers: Default::default(),
 397            }),
 398            buffers: Default::default(),
 399        });
 400        (this, fake_server)
 401    }
 402
 403    fn start_language_server(
 404        http: Arc<dyn HttpClient>,
 405        node_runtime: Arc<NodeRuntime>,
 406        this: ModelHandle<Self>,
 407        mut cx: AsyncAppContext,
 408    ) -> impl Future<Output = ()> {
 409        async move {
 410            let start_language_server = async {
 411                let server_path = get_copilot_lsp(http).await?;
 412                let node_path = node_runtime.binary_path().await?;
 413                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
 414                let server = LanguageServer::new(
 415                    LanguageServerId(0),
 416                    &node_path,
 417                    arguments,
 418                    Path::new("/"),
 419                    None,
 420                    cx.clone(),
 421                )?;
 422
 423                server
 424                    .on_notification::<LogMessage, _>(|params, _cx| {
 425                        match params.level {
 426                            // Copilot is pretty agressive about logging
 427                            0 => debug!("copilot: {}", params.message),
 428                            1 => debug!("copilot: {}", params.message),
 429                            _ => error!("copilot: {}", params.message),
 430                        }
 431
 432                        debug!("copilot metadata: {}", params.metadata_str);
 433                        debug!("copilot extra: {:?}", params.extra);
 434                    })
 435                    .detach();
 436
 437                server
 438                    .on_notification::<StatusNotification, _>(
 439                        |_, _| { /* Silence the notification */ },
 440                    )
 441                    .detach();
 442
 443                let server = server.initialize(Default::default()).await?;
 444
 445                let status = server
 446                    .request::<request::CheckStatus>(request::CheckStatusParams {
 447                        local_checks_only: false,
 448                    })
 449                    .await?;
 450
 451                server
 452                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 453                        editor_info: request::EditorInfo {
 454                            name: "zed".into(),
 455                            version: env!("CARGO_PKG_VERSION").into(),
 456                        },
 457                        editor_plugin_info: request::EditorPluginInfo {
 458                            name: "zed-copilot".into(),
 459                            version: "0.0.1".into(),
 460                        },
 461                    })
 462                    .await?;
 463
 464                anyhow::Ok((server, status))
 465            };
 466
 467            let server = start_language_server.await;
 468            this.update(&mut cx, |this, cx| {
 469                cx.notify();
 470                match server {
 471                    Ok((server, status)) => {
 472                        this.server = CopilotServer::Running(RunningCopilotServer {
 473                            lsp: server,
 474                            sign_in_status: SignInStatus::SignedOut,
 475                            registered_buffers: Default::default(),
 476                        });
 477                        this.update_sign_in_status(status, cx);
 478                    }
 479                    Err(error) => {
 480                        this.server = CopilotServer::Error(error.to_string().into());
 481                        cx.notify()
 482                    }
 483                }
 484            })
 485        }
 486    }
 487
 488    pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 489        if let CopilotServer::Running(server) = &mut self.server {
 490            let task = match &server.sign_in_status {
 491                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 492                SignInStatus::SigningIn { task, .. } => {
 493                    cx.notify();
 494                    task.clone()
 495                }
 496                SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
 497                    let lsp = server.lsp.clone();
 498                    let task = cx
 499                        .spawn(|this, mut cx| async move {
 500                            let sign_in = async {
 501                                let sign_in = lsp
 502                                    .request::<request::SignInInitiate>(
 503                                        request::SignInInitiateParams {},
 504                                    )
 505                                    .await?;
 506                                match sign_in {
 507                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 508                                        Ok(request::SignInStatus::Ok { user })
 509                                    }
 510                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 511                                        this.update(&mut cx, |this, cx| {
 512                                            if let CopilotServer::Running(RunningCopilotServer {
 513                                                sign_in_status: status,
 514                                                ..
 515                                            }) = &mut this.server
 516                                            {
 517                                                if let SignInStatus::SigningIn {
 518                                                    prompt: prompt_flow,
 519                                                    ..
 520                                                } = status
 521                                                {
 522                                                    *prompt_flow = Some(flow.clone());
 523                                                    cx.notify();
 524                                                }
 525                                            }
 526                                        });
 527                                        let response = lsp
 528                                            .request::<request::SignInConfirm>(
 529                                                request::SignInConfirmParams {
 530                                                    user_code: flow.user_code,
 531                                                },
 532                                            )
 533                                            .await?;
 534                                        Ok(response)
 535                                    }
 536                                }
 537                            };
 538
 539                            let sign_in = sign_in.await;
 540                            this.update(&mut cx, |this, cx| match sign_in {
 541                                Ok(status) => {
 542                                    this.update_sign_in_status(status, cx);
 543                                    Ok(())
 544                                }
 545                                Err(error) => {
 546                                    this.update_sign_in_status(
 547                                        request::SignInStatus::NotSignedIn,
 548                                        cx,
 549                                    );
 550                                    Err(Arc::new(error))
 551                                }
 552                            })
 553                        })
 554                        .shared();
 555                    server.sign_in_status = SignInStatus::SigningIn {
 556                        prompt: None,
 557                        task: task.clone(),
 558                    };
 559                    cx.notify();
 560                    task
 561                }
 562            };
 563
 564            cx.foreground()
 565                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 566        } else {
 567            // If we're downloading, wait until download is finished
 568            // If we're in a stuck state, display to the user
 569            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 570        }
 571    }
 572
 573    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 574        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 575        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 576            let server = server.clone();
 577            cx.background().spawn(async move {
 578                server
 579                    .request::<request::SignOut>(request::SignOutParams {})
 580                    .await?;
 581                anyhow::Ok(())
 582            })
 583        } else {
 584            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 585        }
 586    }
 587
 588    pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 589        let start_task = cx
 590            .spawn({
 591                let http = self.http.clone();
 592                let node_runtime = self.node_runtime.clone();
 593                move |this, cx| async move {
 594                    clear_copilot_dir().await;
 595                    Self::start_language_server(http, node_runtime, this, cx).await
 596                }
 597            })
 598            .shared();
 599
 600        self.server = CopilotServer::Starting {
 601            task: start_task.clone(),
 602        };
 603
 604        cx.notify();
 605
 606        cx.foreground().spawn(start_task)
 607    }
 608
 609    pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
 610        let buffer_id = buffer.read(cx).remote_id();
 611        self.buffers.insert(buffer_id, buffer.downgrade());
 612
 613        if let CopilotServer::Running(RunningCopilotServer {
 614            lsp: server,
 615            sign_in_status: status,
 616            registered_buffers,
 617            ..
 618        }) = &mut self.server
 619        {
 620            if !matches!(status, SignInStatus::Authorized { .. }) {
 621                return;
 622            }
 623
 624            let buffer_id = buffer.read(cx).remote_id();
 625            registered_buffers.entry(buffer_id).or_insert_with(|| {
 626                let uri: lsp::Url = uri_for_buffer(buffer, cx);
 627                let language_id = id_for_language(buffer.read(cx).language());
 628                let snapshot = buffer.read(cx).snapshot();
 629                server
 630                    .notify::<lsp::notification::DidOpenTextDocument>(
 631                        lsp::DidOpenTextDocumentParams {
 632                            text_document: lsp::TextDocumentItem {
 633                                uri: uri.clone(),
 634                                language_id: language_id.clone(),
 635                                version: 0,
 636                                text: snapshot.text(),
 637                            },
 638                        },
 639                    )
 640                    .log_err();
 641
 642                RegisteredBuffer {
 643                    id: buffer_id,
 644                    uri,
 645                    language_id,
 646                    snapshot,
 647                    snapshot_version: 0,
 648                    pending_buffer_change: Task::ready(Some(())),
 649                    _subscriptions: [
 650                        cx.subscribe(buffer, |this, buffer, event, cx| {
 651                            this.handle_buffer_event(buffer, event, cx).log_err();
 652                        }),
 653                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 654                            this.buffers.remove(&buffer_id);
 655                            this.unregister_buffer(buffer_id);
 656                        }),
 657                    ],
 658                }
 659            });
 660        }
 661    }
 662
 663    fn handle_buffer_event(
 664        &mut self,
 665        buffer: ModelHandle<Buffer>,
 666        event: &language::Event,
 667        cx: &mut ModelContext<Self>,
 668    ) -> Result<()> {
 669        if let Ok(server) = self.server.as_running() {
 670            let buffer_id = buffer.read(cx).remote_id();
 671            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
 672                match event {
 673                    language::Event::Edited => {
 674                        let _ = registered_buffer.report_changes(&buffer, cx);
 675                    }
 676                    language::Event::Saved => {
 677                        server
 678                            .lsp
 679                            .notify::<lsp::notification::DidSaveTextDocument>(
 680                                lsp::DidSaveTextDocumentParams {
 681                                    text_document: lsp::TextDocumentIdentifier::new(
 682                                        registered_buffer.uri.clone(),
 683                                    ),
 684                                    text: None,
 685                                },
 686                            )?;
 687                    }
 688                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
 689                        let new_language_id = id_for_language(buffer.read(cx).language());
 690                        let new_uri = uri_for_buffer(&buffer, cx);
 691                        if new_uri != registered_buffer.uri
 692                            || new_language_id != registered_buffer.language_id
 693                        {
 694                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 695                            registered_buffer.language_id = new_language_id;
 696                            server
 697                                .lsp
 698                                .notify::<lsp::notification::DidCloseTextDocument>(
 699                                    lsp::DidCloseTextDocumentParams {
 700                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 701                                    },
 702                                )?;
 703                            server
 704                                .lsp
 705                                .notify::<lsp::notification::DidOpenTextDocument>(
 706                                    lsp::DidOpenTextDocumentParams {
 707                                        text_document: lsp::TextDocumentItem::new(
 708                                            registered_buffer.uri.clone(),
 709                                            registered_buffer.language_id.clone(),
 710                                            registered_buffer.snapshot_version,
 711                                            registered_buffer.snapshot.text(),
 712                                        ),
 713                                    },
 714                                )?;
 715                        }
 716                    }
 717                    _ => {}
 718                }
 719            }
 720        }
 721
 722        Ok(())
 723    }
 724
 725    fn unregister_buffer(&mut self, buffer_id: u64) {
 726        if let Ok(server) = self.server.as_running() {
 727            if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
 728                server
 729                    .lsp
 730                    .notify::<lsp::notification::DidCloseTextDocument>(
 731                        lsp::DidCloseTextDocumentParams {
 732                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 733                        },
 734                    )
 735                    .log_err();
 736            }
 737        }
 738    }
 739
 740    pub fn completions<T>(
 741        &mut self,
 742        buffer: &ModelHandle<Buffer>,
 743        position: T,
 744        cx: &mut ModelContext<Self>,
 745    ) -> Task<Result<Vec<Completion>>>
 746    where
 747        T: ToPointUtf16,
 748    {
 749        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 750    }
 751
 752    pub fn completions_cycling<T>(
 753        &mut self,
 754        buffer: &ModelHandle<Buffer>,
 755        position: T,
 756        cx: &mut ModelContext<Self>,
 757    ) -> Task<Result<Vec<Completion>>>
 758    where
 759        T: ToPointUtf16,
 760    {
 761        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 762    }
 763
 764    pub fn accept_completion(
 765        &mut self,
 766        completion: &Completion,
 767        file_type: Option<Arc<str>>,
 768        cx: &mut ModelContext<Self>,
 769    ) -> Task<Result<()>> {
 770        let server = match self.server.as_authenticated() {
 771            Ok(server) => server,
 772            Err(error) => return Task::ready(Err(error)),
 773        };
 774
 775        cx.emit(Event::CompletionAccepted {
 776            uuid: completion.uuid.clone(),
 777            file_type,
 778        });
 779
 780        let request =
 781            server
 782                .lsp
 783                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 784                    uuid: completion.uuid.clone(),
 785                });
 786
 787        cx.background().spawn(async move {
 788            request.await?;
 789            Ok(())
 790        })
 791    }
 792
 793    pub fn discard_completions(
 794        &mut self,
 795        completions: &[Completion],
 796        file_type: Option<Arc<str>>,
 797        cx: &mut ModelContext<Self>,
 798    ) -> Task<Result<()>> {
 799        let server = match self.server.as_authenticated() {
 800            Ok(server) => server,
 801            Err(error) => return Task::ready(Err(error)),
 802        };
 803
 804        cx.emit(Event::CompletionsDiscarded {
 805            uuids: completions
 806                .iter()
 807                .map(|completion| completion.uuid.clone())
 808                .collect(),
 809            file_type: file_type.clone(),
 810        });
 811
 812        let request =
 813            server
 814                .lsp
 815                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 816                    uuids: completions
 817                        .iter()
 818                        .map(|completion| completion.uuid.clone())
 819                        .collect(),
 820                });
 821
 822        cx.background().spawn(async move {
 823            request.await?;
 824            Ok(())
 825        })
 826    }
 827
 828    fn request_completions<R, T>(
 829        &mut self,
 830        buffer: &ModelHandle<Buffer>,
 831        position: T,
 832        cx: &mut ModelContext<Self>,
 833    ) -> Task<Result<Vec<Completion>>>
 834    where
 835        R: 'static
 836            + lsp::request::Request<
 837                Params = request::GetCompletionsParams,
 838                Result = request::GetCompletionsResult,
 839            >,
 840        T: ToPointUtf16,
 841    {
 842        self.register_buffer(buffer, cx);
 843
 844        let server = match self.server.as_authenticated() {
 845            Ok(server) => server,
 846            Err(error) => return Task::ready(Err(error)),
 847        };
 848        let lsp = server.lsp.clone();
 849        let buffer_id = buffer.read(cx).remote_id();
 850        let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
 851        let snapshot = registered_buffer.report_changes(buffer, cx);
 852        let buffer = buffer.read(cx);
 853        let uri = registered_buffer.uri.clone();
 854        let settings = cx.global::<Settings>();
 855        let position = position.to_point_utf16(buffer);
 856        let language = buffer.language_at(position);
 857        let language_name = language.map(|language| language.name());
 858        let language_name = language_name.as_deref();
 859        let tab_size = settings.tab_size(language_name);
 860        let hard_tabs = settings.hard_tabs(language_name);
 861        let relative_path = buffer
 862            .file()
 863            .map(|file| file.path().to_path_buf())
 864            .unwrap_or_default();
 865
 866        cx.foreground().spawn(async move {
 867            let (version, snapshot) = snapshot.await?;
 868            let result = lsp
 869                .request::<R>(request::GetCompletionsParams {
 870                    doc: request::GetCompletionsDocument {
 871                        uri,
 872                        tab_size: tab_size.into(),
 873                        indent_size: 1,
 874                        insert_spaces: !hard_tabs,
 875                        relative_path: relative_path.to_string_lossy().into(),
 876                        position: point_to_lsp(position),
 877                        version: version.try_into().unwrap(),
 878                    },
 879                })
 880                .await?;
 881            let completions = result
 882                .completions
 883                .into_iter()
 884                .map(|completion| {
 885                    let start = snapshot
 886                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 887                    let end =
 888                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 889                    Completion {
 890                        uuid: completion.uuid,
 891                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 892                        text: completion.text,
 893                    }
 894                })
 895                .collect();
 896            anyhow::Ok(completions)
 897        })
 898    }
 899
 900    pub fn status(&self) -> Status {
 901        match &self.server {
 902            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 903            CopilotServer::Disabled => Status::Disabled,
 904            CopilotServer::Error(error) => Status::Error(error.clone()),
 905            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 906                match sign_in_status {
 907                    SignInStatus::Authorized { .. } => Status::Authorized,
 908                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 909                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 910                        prompt: prompt.clone(),
 911                    },
 912                    SignInStatus::SignedOut => Status::SignedOut,
 913                }
 914            }
 915        }
 916    }
 917
 918    fn update_sign_in_status(
 919        &mut self,
 920        lsp_status: request::SignInStatus,
 921        cx: &mut ModelContext<Self>,
 922    ) {
 923        self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
 924
 925        if let Ok(server) = self.server.as_running() {
 926            match lsp_status {
 927                request::SignInStatus::Ok { .. }
 928                | request::SignInStatus::MaybeOk { .. }
 929                | request::SignInStatus::AlreadySignedIn { .. } => {
 930                    server.sign_in_status = SignInStatus::Authorized;
 931                    for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
 932                        if let Some(buffer) = buffer.upgrade(cx) {
 933                            self.register_buffer(&buffer, cx);
 934                        }
 935                    }
 936                }
 937                request::SignInStatus::NotAuthorized { .. } => {
 938                    server.sign_in_status = SignInStatus::Unauthorized;
 939                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 940                        self.unregister_buffer(buffer_id);
 941                    }
 942                }
 943                request::SignInStatus::NotSignedIn => {
 944                    server.sign_in_status = SignInStatus::SignedOut;
 945                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 946                        self.unregister_buffer(buffer_id);
 947                    }
 948                }
 949            }
 950
 951            cx.notify();
 952        }
 953    }
 954}
 955
 956fn id_for_language(language: Option<&Arc<Language>>) -> String {
 957    let language_name = language.map(|language| language.name());
 958    match language_name.as_deref() {
 959        Some("Plain Text") => "plaintext".to_string(),
 960        Some(language_name) => language_name.to_lowercase(),
 961        None => "plaintext".to_string(),
 962    }
 963}
 964
 965fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
 966    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 967        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
 968    } else {
 969        format!("buffer://{}", buffer.read(cx).remote_id())
 970            .parse()
 971            .unwrap()
 972    }
 973}
 974
 975async fn clear_copilot_dir() {
 976    remove_matching(&paths::COPILOT_DIR, |_| true).await
 977}
 978
 979async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 980    const SERVER_PATH: &'static str = "dist/agent.js";
 981
 982    ///Check for the latest copilot language server and download it if we haven't already
 983    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 984        let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
 985
 986        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
 987
 988        fs::create_dir_all(version_dir).await?;
 989        let server_path = version_dir.join(SERVER_PATH);
 990
 991        if fs::metadata(&server_path).await.is_err() {
 992            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
 993            let dist_dir = version_dir.join("dist");
 994            fs::create_dir_all(dist_dir.as_path()).await?;
 995
 996            let url = &release
 997                .assets
 998                .get(0)
 999                .context("Github release for copilot contained no assets")?
1000                .browser_download_url;
1001
1002            let mut response = http
1003                .get(&url, Default::default(), true)
1004                .await
1005                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
1006            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
1007            let archive = Archive::new(decompressed_bytes);
1008            archive.unpack(dist_dir).await?;
1009
1010            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1011        }
1012
1013        Ok(server_path)
1014    }
1015
1016    match fetch_latest(http).await {
1017        ok @ Result::Ok(..) => ok,
1018        e @ Err(..) => {
1019            e.log_err();
1020            // Fetch a cached binary, if it exists
1021            (|| async move {
1022                let mut last_version_dir = None;
1023                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1024                while let Some(entry) = entries.next().await {
1025                    let entry = entry?;
1026                    if entry.file_type().await?.is_dir() {
1027                        last_version_dir = Some(entry.path());
1028                    }
1029                }
1030                let last_version_dir =
1031                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1032                let server_path = last_version_dir.join(SERVER_PATH);
1033                if server_path.exists() {
1034                    Ok(server_path)
1035                } else {
1036                    Err(anyhow!(
1037                        "missing executable in directory {:?}",
1038                        last_version_dir
1039                    ))
1040                }
1041            })()
1042            .await
1043        }
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050    use gpui::{executor::Deterministic, TestAppContext};
1051
1052    #[gpui::test(iterations = 10)]
1053    async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1054        deterministic.forbid_parking();
1055        let (copilot, mut lsp) = Copilot::fake(cx);
1056
1057        let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1058        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1059        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1060        assert_eq!(
1061            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1062                .await,
1063            lsp::DidOpenTextDocumentParams {
1064                text_document: lsp::TextDocumentItem::new(
1065                    buffer_1_uri.clone(),
1066                    "plaintext".into(),
1067                    0,
1068                    "Hello".into()
1069                ),
1070            }
1071        );
1072
1073        let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1074        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1075        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1076        assert_eq!(
1077            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1078                .await,
1079            lsp::DidOpenTextDocumentParams {
1080                text_document: lsp::TextDocumentItem::new(
1081                    buffer_2_uri.clone(),
1082                    "plaintext".into(),
1083                    0,
1084                    "Goodbye".into()
1085                ),
1086            }
1087        );
1088
1089        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1090        assert_eq!(
1091            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1092                .await,
1093            lsp::DidChangeTextDocumentParams {
1094                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1095                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1096                    range: Some(lsp::Range::new(
1097                        lsp::Position::new(0, 5),
1098                        lsp::Position::new(0, 5)
1099                    )),
1100                    range_length: None,
1101                    text: " world".into(),
1102                }],
1103            }
1104        );
1105
1106        // Ensure updates to the file are reflected in the LSP.
1107        buffer_1
1108            .update(cx, |buffer, cx| {
1109                buffer.file_updated(
1110                    Arc::new(File {
1111                        abs_path: "/root/child/buffer-1".into(),
1112                        path: Path::new("child/buffer-1").into(),
1113                    }),
1114                    cx,
1115                )
1116            })
1117            .await;
1118        assert_eq!(
1119            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1120                .await,
1121            lsp::DidCloseTextDocumentParams {
1122                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1123            }
1124        );
1125        let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1126        assert_eq!(
1127            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1128                .await,
1129            lsp::DidOpenTextDocumentParams {
1130                text_document: lsp::TextDocumentItem::new(
1131                    buffer_1_uri.clone(),
1132                    "plaintext".into(),
1133                    1,
1134                    "Hello world".into()
1135                ),
1136            }
1137        );
1138
1139        // Ensure all previously-registered buffers are closed when signing out.
1140        lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1141            Ok(request::SignOutResult {})
1142        });
1143        copilot
1144            .update(cx, |copilot, cx| copilot.sign_out(cx))
1145            .await
1146            .unwrap();
1147        assert_eq!(
1148            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1149                .await,
1150            lsp::DidCloseTextDocumentParams {
1151                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1152            }
1153        );
1154        assert_eq!(
1155            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1156                .await,
1157            lsp::DidCloseTextDocumentParams {
1158                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1159            }
1160        );
1161
1162        // Ensure all previously-registered buffers are re-opened when signing in.
1163        lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1164            Ok(request::SignInInitiateResult::AlreadySignedIn {
1165                user: "user-1".into(),
1166            })
1167        });
1168        copilot
1169            .update(cx, |copilot, cx| copilot.sign_in(cx))
1170            .await
1171            .unwrap();
1172        assert_eq!(
1173            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1174                .await,
1175            lsp::DidOpenTextDocumentParams {
1176                text_document: lsp::TextDocumentItem::new(
1177                    buffer_2_uri.clone(),
1178                    "plaintext".into(),
1179                    0,
1180                    "Goodbye".into()
1181                ),
1182            }
1183        );
1184        assert_eq!(
1185            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1186                .await,
1187            lsp::DidOpenTextDocumentParams {
1188                text_document: lsp::TextDocumentItem::new(
1189                    buffer_1_uri.clone(),
1190                    "plaintext".into(),
1191                    0,
1192                    "Hello world".into()
1193                ),
1194            }
1195        );
1196
1197        // Dropping a buffer causes it to be closed on the LSP side as well.
1198        cx.update(|_| drop(buffer_2));
1199        assert_eq!(
1200            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1201                .await,
1202            lsp::DidCloseTextDocumentParams {
1203                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1204            }
1205        );
1206    }
1207
1208    struct File {
1209        abs_path: PathBuf,
1210        path: Arc<Path>,
1211    }
1212
1213    impl language::File for File {
1214        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1215            Some(self)
1216        }
1217
1218        fn mtime(&self) -> std::time::SystemTime {
1219            unimplemented!()
1220        }
1221
1222        fn path(&self) -> &Arc<Path> {
1223            &self.path
1224        }
1225
1226        fn full_path(&self, _: &AppContext) -> PathBuf {
1227            unimplemented!()
1228        }
1229
1230        fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1231            unimplemented!()
1232        }
1233
1234        fn is_deleted(&self) -> bool {
1235            unimplemented!()
1236        }
1237
1238        fn as_any(&self) -> &dyn std::any::Any {
1239            unimplemented!()
1240        }
1241
1242        fn to_proto(&self) -> rpc::proto::File {
1243            unimplemented!()
1244        }
1245    }
1246
1247    impl language::LocalFile for File {
1248        fn abs_path(&self, _: &AppContext) -> PathBuf {
1249            self.abs_path.clone()
1250        }
1251
1252        fn load(&self, _: &AppContext) -> Task<Result<String>> {
1253            unimplemented!()
1254        }
1255
1256        fn buffer_reloaded(
1257            &self,
1258            _: u64,
1259            _: &clock::Global,
1260            _: language::RopeFingerprint,
1261            _: ::fs::LineEnding,
1262            _: std::time::SystemTime,
1263            _: &mut AppContext,
1264        ) {
1265            unimplemented!()
1266        }
1267    }
1268}