copilot2.rs

   1pub mod request;
   2mod sign_in;
   3
   4use anyhow::{anyhow, Context as _, Result};
   5use async_compression::futures::bufread::GzipDecoder;
   6use async_tar::Archive;
   7use collections::{HashMap, HashSet};
   8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
   9use gpui2::{
  10    AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task,
  11    WeakHandle,
  12};
  13use language2::{
  14    language_settings::{all_language_settings, language_settings},
  15    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language,
  16    LanguageServerName, PointUtf16, ToPointUtf16,
  17};
  18use lsp2::{LanguageServer, LanguageServerBinary, LanguageServerId};
  19use node_runtime::NodeRuntime;
  20use request::StatusNotification;
  21use settings2::SettingsStore;
  22use smol::{fs, io::BufReader, stream::StreamExt};
  23use std::{
  24    ffi::OsString,
  25    mem,
  26    ops::Range,
  27    path::{Path, PathBuf},
  28    sync::Arc,
  29};
  30use util::{
  31    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
  32};
  33
  34// todo!()
  35// const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
  36// actions!(copilot_auth, [SignIn, SignOut]);
  37
  38// todo!()
  39// const COPILOT_NAMESPACE: &'static str = "copilot";
  40// actions!(
  41//     copilot,
  42//     [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
  43// );
  44
  45pub fn init(
  46    new_server_id: LanguageServerId,
  47    http: Arc<dyn HttpClient>,
  48    node_runtime: Arc<dyn NodeRuntime>,
  49    cx: &mut AppContext,
  50) {
  51    let copilot = cx.entity({
  52        let node_runtime = node_runtime.clone();
  53        move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
  54    });
  55    cx.set_global(copilot.clone());
  56
  57    // TODO
  58    // cx.observe(&copilot, |handle, cx| {
  59    //     let status = handle.read(cx).status();
  60    //     cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  61    //         match status {
  62    //             Status::Disabled => {
  63    //                 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  64    //                 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
  65    //             }
  66    //             Status::Authorized => {
  67    //                 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
  68    //                 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  69    //             }
  70    //             _ => {
  71    //                 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  72    //                 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  73    //             }
  74    //         }
  75    //     });
  76    // })
  77    // .detach();
  78
  79    // sign_in::init(cx);
  80    // cx.add_global_action(|_: &SignIn, cx| {
  81    //     if let Some(copilot) = Copilot::global(cx) {
  82    //         copilot
  83    //             .update(cx, |copilot, cx| copilot.sign_in(cx))
  84    //             .detach_and_log_err(cx);
  85    //     }
  86    // });
  87    // cx.add_global_action(|_: &SignOut, cx| {
  88    //     if let Some(copilot) = Copilot::global(cx) {
  89    //         copilot
  90    //             .update(cx, |copilot, cx| copilot.sign_out(cx))
  91    //             .detach_and_log_err(cx);
  92    //     }
  93    // });
  94
  95    // cx.add_global_action(|_: &Reinstall, cx| {
  96    //     if let Some(copilot) = Copilot::global(cx) {
  97    //         copilot
  98    //             .update(cx, |copilot, cx| copilot.reinstall(cx))
  99    //             .detach();
 100    //     }
 101    // });
 102}
 103
 104enum CopilotServer {
 105    Disabled,
 106    Starting { task: Shared<Task<()>> },
 107    Error(Arc<str>),
 108    Running(RunningCopilotServer),
 109}
 110
 111impl CopilotServer {
 112    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 113        let server = self.as_running()?;
 114        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 115            Ok(server)
 116        } else {
 117            Err(anyhow!("must sign in before using copilot"))
 118        }
 119    }
 120
 121    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 122        match self {
 123            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 124            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 125            CopilotServer::Error(error) => Err(anyhow!(
 126                "copilot was not started because of an error: {}",
 127                error
 128            )),
 129            CopilotServer::Running(server) => Ok(server),
 130        }
 131    }
 132}
 133
 134struct RunningCopilotServer {
 135    name: LanguageServerName,
 136    lsp: Arc<LanguageServer>,
 137    sign_in_status: SignInStatus,
 138    registered_buffers: HashMap<EntityId, RegisteredBuffer>,
 139}
 140
 141#[derive(Clone, Debug)]
 142enum SignInStatus {
 143    Authorized,
 144    Unauthorized,
 145    SigningIn {
 146        prompt: Option<request::PromptUserDeviceFlow>,
 147        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 148    },
 149    SignedOut,
 150}
 151
 152#[derive(Debug, Clone)]
 153pub enum Status {
 154    Starting {
 155        task: Shared<Task<()>>,
 156    },
 157    Error(Arc<str>),
 158    Disabled,
 159    SignedOut,
 160    SigningIn {
 161        prompt: Option<request::PromptUserDeviceFlow>,
 162    },
 163    Unauthorized,
 164    Authorized,
 165}
 166
 167impl Status {
 168    pub fn is_authorized(&self) -> bool {
 169        matches!(self, Status::Authorized)
 170    }
 171}
 172
 173struct RegisteredBuffer {
 174    uri: lsp2::Url,
 175    language_id: String,
 176    snapshot: BufferSnapshot,
 177    snapshot_version: i32,
 178    _subscriptions: [gpui2::Subscription; 2],
 179    pending_buffer_change: Task<Option<()>>,
 180}
 181
 182impl RegisteredBuffer {
 183    fn report_changes(
 184        &mut self,
 185        buffer: &Handle<Buffer>,
 186        cx: &mut ModelContext<Copilot>,
 187    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 188        let (done_tx, done_rx) = oneshot::channel();
 189
 190        if buffer.read(cx).version() == self.snapshot.version {
 191            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 192        } else {
 193            let buffer = buffer.downgrade();
 194            let id = buffer.entity_id();
 195            let prev_pending_change =
 196                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 197            self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
 198                prev_pending_change.await;
 199
 200                let old_version = copilot
 201                    .update(&mut cx, |copilot, _| {
 202                        let server = copilot.server.as_authenticated().log_err()?;
 203                        let buffer = server.registered_buffers.get_mut(&id)?;
 204                        Some(buffer.snapshot.version.clone())
 205                    })
 206                    .ok()??;
 207                let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
 208
 209                let content_changes = cx
 210                    .executor()
 211                    .spawn({
 212                        let new_snapshot = new_snapshot.clone();
 213                        async move {
 214                            new_snapshot
 215                                .edits_since::<(PointUtf16, usize)>(&old_version)
 216                                .map(|edit| {
 217                                    let edit_start = edit.new.start.0;
 218                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 219                                    let new_text = new_snapshot
 220                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 221                                        .collect();
 222                                    lsp2::TextDocumentContentChangeEvent {
 223                                        range: Some(lsp2::Range::new(
 224                                            point_to_lsp(edit_start),
 225                                            point_to_lsp(edit_end),
 226                                        )),
 227                                        range_length: None,
 228                                        text: new_text,
 229                                    }
 230                                })
 231                                .collect::<Vec<_>>()
 232                        }
 233                    })
 234                    .await;
 235
 236                copilot
 237                    .update(&mut cx, |copilot, _| {
 238                        let server = copilot.server.as_authenticated().log_err()?;
 239                        let buffer = server.registered_buffers.get_mut(&id)?;
 240                        if !content_changes.is_empty() {
 241                            buffer.snapshot_version += 1;
 242                            buffer.snapshot = new_snapshot;
 243                            server
 244                                .lsp
 245                                .notify::<lsp2::notification::DidChangeTextDocument>(
 246                                    lsp2::DidChangeTextDocumentParams {
 247                                        text_document: lsp2::VersionedTextDocumentIdentifier::new(
 248                                            buffer.uri.clone(),
 249                                            buffer.snapshot_version,
 250                                        ),
 251                                        content_changes,
 252                                    },
 253                                )
 254                                .log_err();
 255                        }
 256                        let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 257                        Some(())
 258                    })
 259                    .ok()?;
 260
 261                Some(())
 262            });
 263        }
 264
 265        done_rx
 266    }
 267}
 268
 269#[derive(Debug)]
 270pub struct Completion {
 271    pub uuid: String,
 272    pub range: Range<Anchor>,
 273    pub text: String,
 274}
 275
 276pub struct Copilot {
 277    http: Arc<dyn HttpClient>,
 278    node_runtime: Arc<dyn NodeRuntime>,
 279    server: CopilotServer,
 280    buffers: HashSet<WeakHandle<Buffer>>,
 281    server_id: LanguageServerId,
 282    _subscription: gpui2::Subscription,
 283}
 284
 285pub enum Event {
 286    CopilotLanguageServerStarted,
 287}
 288
 289impl EventEmitter for Copilot {
 290    type Event = Event;
 291}
 292
 293impl Copilot {
 294    pub fn global(cx: &AppContext) -> Option<Handle<Self>> {
 295        if cx.has_global::<Handle<Self>>() {
 296            Some(cx.global::<Handle<Self>>().clone())
 297        } else {
 298            None
 299        }
 300    }
 301
 302    fn start(
 303        new_server_id: LanguageServerId,
 304        http: Arc<dyn HttpClient>,
 305        node_runtime: Arc<dyn NodeRuntime>,
 306        cx: &mut ModelContext<Self>,
 307    ) -> Self {
 308        let mut this = Self {
 309            server_id: new_server_id,
 310            http,
 311            node_runtime,
 312            server: CopilotServer::Disabled,
 313            buffers: Default::default(),
 314            _subscription: cx.on_app_quit(Self::shutdown_language_server),
 315        };
 316        this.enable_or_disable_copilot(cx);
 317        cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
 318            .detach();
 319        this
 320    }
 321
 322    fn shutdown_language_server(
 323        &mut self,
 324        _cx: &mut ModelContext<Self>,
 325    ) -> impl Future<Output = ()> {
 326        let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
 327            CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
 328            _ => None,
 329        };
 330
 331        async move {
 332            if let Some(shutdown) = shutdown {
 333                shutdown.await;
 334            }
 335        }
 336    }
 337
 338    fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
 339        let server_id = self.server_id;
 340        let http = self.http.clone();
 341        let node_runtime = self.node_runtime.clone();
 342        if all_language_settings(None, cx).copilot_enabled(None, None) {
 343            if matches!(self.server, CopilotServer::Disabled) {
 344                let start_task = cx
 345                    .spawn(move |this, cx| {
 346                        Self::start_language_server(server_id, http, node_runtime, this, cx)
 347                    })
 348                    .shared();
 349                self.server = CopilotServer::Starting { task: start_task };
 350                cx.notify();
 351            }
 352        } else {
 353            self.server = CopilotServer::Disabled;
 354            cx.notify();
 355        }
 356    }
 357
 358    // #[cfg(any(test, feature = "test-support"))]
 359    // pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
 360    //     use node_runtime::FakeNodeRuntime;
 361
 362    //     let (server, fake_server) =
 363    //         LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
 364    //     let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
 365    //     let node_runtime = FakeNodeRuntime::new();
 366    //     let this = cx.add_model(|_| Self {
 367    //         server_id: LanguageServerId(0),
 368    //         http: http.clone(),
 369    //         node_runtime,
 370    //         server: CopilotServer::Running(RunningCopilotServer {
 371    //             name: LanguageServerName(Arc::from("copilot")),
 372    //             lsp: Arc::new(server),
 373    //             sign_in_status: SignInStatus::Authorized,
 374    //             registered_buffers: Default::default(),
 375    //         }),
 376    //         buffers: Default::default(),
 377    //     });
 378    //     (this, fake_server)
 379    // }
 380
 381    fn start_language_server(
 382        new_server_id: LanguageServerId,
 383        http: Arc<dyn HttpClient>,
 384        node_runtime: Arc<dyn NodeRuntime>,
 385        this: WeakHandle<Self>,
 386        mut cx: AsyncAppContext,
 387    ) -> impl Future<Output = ()> {
 388        async move {
 389            let start_language_server = async {
 390                let server_path = get_copilot_lsp(http).await?;
 391                let node_path = node_runtime.binary_path().await?;
 392                let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
 393                let binary = LanguageServerBinary {
 394                    path: node_path,
 395                    arguments,
 396                };
 397                let server =
 398                    LanguageServer::new(new_server_id, binary, Path::new("/"), None, cx.clone())?;
 399
 400                server
 401                    .on_notification::<StatusNotification, _>(
 402                        |_, _| { /* Silence the notification */ },
 403                    )
 404                    .detach();
 405
 406                let server = server.initialize(Default::default()).await?;
 407
 408                let status = server
 409                    .request::<request::CheckStatus>(request::CheckStatusParams {
 410                        local_checks_only: false,
 411                    })
 412                    .await?;
 413
 414                server
 415                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 416                        editor_info: request::EditorInfo {
 417                            name: "zed".into(),
 418                            version: env!("CARGO_PKG_VERSION").into(),
 419                        },
 420                        editor_plugin_info: request::EditorPluginInfo {
 421                            name: "zed-copilot".into(),
 422                            version: "0.0.1".into(),
 423                        },
 424                    })
 425                    .await?;
 426
 427                anyhow::Ok((server, status))
 428            };
 429
 430            let server = start_language_server.await;
 431            this.update(&mut cx, |this, cx| {
 432                cx.notify();
 433                match server {
 434                    Ok((server, status)) => {
 435                        this.server = CopilotServer::Running(RunningCopilotServer {
 436                            name: LanguageServerName(Arc::from("copilot")),
 437                            lsp: server,
 438                            sign_in_status: SignInStatus::SignedOut,
 439                            registered_buffers: Default::default(),
 440                        });
 441                        cx.emit(Event::CopilotLanguageServerStarted);
 442                        this.update_sign_in_status(status, cx);
 443                    }
 444                    Err(error) => {
 445                        this.server = CopilotServer::Error(error.to_string().into());
 446                        cx.notify()
 447                    }
 448                }
 449            })
 450            .ok();
 451        }
 452    }
 453
 454    pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 455        if let CopilotServer::Running(server) = &mut self.server {
 456            let task = match &server.sign_in_status {
 457                SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
 458                SignInStatus::SigningIn { task, .. } => {
 459                    cx.notify();
 460                    task.clone()
 461                }
 462                SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
 463                    let lsp = server.lsp.clone();
 464                    let task = cx
 465                        .spawn(|this, mut cx| async move {
 466                            let sign_in = async {
 467                                let sign_in = lsp
 468                                    .request::<request::SignInInitiate>(
 469                                        request::SignInInitiateParams {},
 470                                    )
 471                                    .await?;
 472                                match sign_in {
 473                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 474                                        Ok(request::SignInStatus::Ok { user })
 475                                    }
 476                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 477                                        this.update(&mut cx, |this, cx| {
 478                                            if let CopilotServer::Running(RunningCopilotServer {
 479                                                sign_in_status: status,
 480                                                ..
 481                                            }) = &mut this.server
 482                                            {
 483                                                if let SignInStatus::SigningIn {
 484                                                    prompt: prompt_flow,
 485                                                    ..
 486                                                } = status
 487                                                {
 488                                                    *prompt_flow = Some(flow.clone());
 489                                                    cx.notify();
 490                                                }
 491                                            }
 492                                        })?;
 493                                        let response = lsp
 494                                            .request::<request::SignInConfirm>(
 495                                                request::SignInConfirmParams {
 496                                                    user_code: flow.user_code,
 497                                                },
 498                                            )
 499                                            .await?;
 500                                        Ok(response)
 501                                    }
 502                                }
 503                            };
 504
 505                            let sign_in = sign_in.await;
 506                            this.update(&mut cx, |this, cx| match sign_in {
 507                                Ok(status) => {
 508                                    this.update_sign_in_status(status, cx);
 509                                    Ok(())
 510                                }
 511                                Err(error) => {
 512                                    this.update_sign_in_status(
 513                                        request::SignInStatus::NotSignedIn,
 514                                        cx,
 515                                    );
 516                                    Err(Arc::new(error))
 517                                }
 518                            })?
 519                        })
 520                        .shared();
 521                    server.sign_in_status = SignInStatus::SigningIn {
 522                        prompt: None,
 523                        task: task.clone(),
 524                    };
 525                    cx.notify();
 526                    task
 527                }
 528            };
 529
 530            cx.executor()
 531                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 532        } else {
 533            // If we're downloading, wait until download is finished
 534            // If we're in a stuck state, display to the user
 535            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 536        }
 537    }
 538
 539    #[allow(dead_code)] // todo!()
 540    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 541        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 542        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 543            let server = server.clone();
 544            cx.executor().spawn(async move {
 545                server
 546                    .request::<request::SignOut>(request::SignOutParams {})
 547                    .await?;
 548                anyhow::Ok(())
 549            })
 550        } else {
 551            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 552        }
 553    }
 554
 555    pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 556        let start_task = cx
 557            .spawn({
 558                let http = self.http.clone();
 559                let node_runtime = self.node_runtime.clone();
 560                let server_id = self.server_id;
 561                move |this, cx| async move {
 562                    clear_copilot_dir().await;
 563                    Self::start_language_server(server_id, http, node_runtime, this, cx).await
 564                }
 565            })
 566            .shared();
 567
 568        self.server = CopilotServer::Starting {
 569            task: start_task.clone(),
 570        };
 571
 572        cx.notify();
 573
 574        cx.executor().spawn(start_task)
 575    }
 576
 577    pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc<LanguageServer>)> {
 578        if let CopilotServer::Running(server) = &self.server {
 579            Some((&server.name, &server.lsp))
 580        } else {
 581            None
 582        }
 583    }
 584
 585    pub fn register_buffer(&mut self, buffer: &Handle<Buffer>, cx: &mut ModelContext<Self>) {
 586        let weak_buffer = buffer.downgrade();
 587        self.buffers.insert(weak_buffer.clone());
 588
 589        if let CopilotServer::Running(RunningCopilotServer {
 590            lsp: server,
 591            sign_in_status: status,
 592            registered_buffers,
 593            ..
 594        }) = &mut self.server
 595        {
 596            if !matches!(status, SignInStatus::Authorized { .. }) {
 597                return;
 598            }
 599
 600            registered_buffers
 601                .entry(buffer.entity_id())
 602                .or_insert_with(|| {
 603                    let uri: lsp2::Url = uri_for_buffer(buffer, cx);
 604                    let language_id = id_for_language(buffer.read(cx).language());
 605                    let snapshot = buffer.read(cx).snapshot();
 606                    server
 607                        .notify::<lsp2::notification::DidOpenTextDocument>(
 608                            lsp2::DidOpenTextDocumentParams {
 609                                text_document: lsp2::TextDocumentItem {
 610                                    uri: uri.clone(),
 611                                    language_id: language_id.clone(),
 612                                    version: 0,
 613                                    text: snapshot.text(),
 614                                },
 615                            },
 616                        )
 617                        .log_err();
 618
 619                    RegisteredBuffer {
 620                        uri,
 621                        language_id,
 622                        snapshot,
 623                        snapshot_version: 0,
 624                        pending_buffer_change: Task::ready(Some(())),
 625                        _subscriptions: [
 626                            cx.subscribe(buffer, |this, buffer, event, cx| {
 627                                this.handle_buffer_event(buffer, event, cx).log_err();
 628                            }),
 629                            cx.observe_release(buffer, move |this, _buffer, _cx| {
 630                                this.buffers.remove(&weak_buffer);
 631                                this.unregister_buffer(&weak_buffer);
 632                            }),
 633                        ],
 634                    }
 635                });
 636        }
 637    }
 638
 639    fn handle_buffer_event(
 640        &mut self,
 641        buffer: Handle<Buffer>,
 642        event: &language2::Event,
 643        cx: &mut ModelContext<Self>,
 644    ) -> Result<()> {
 645        if let Ok(server) = self.server.as_running() {
 646            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
 647            {
 648                match event {
 649                    language2::Event::Edited => {
 650                        let _ = registered_buffer.report_changes(&buffer, cx);
 651                    }
 652                    language2::Event::Saved => {
 653                        server
 654                            .lsp
 655                            .notify::<lsp2::notification::DidSaveTextDocument>(
 656                                lsp2::DidSaveTextDocumentParams {
 657                                    text_document: lsp2::TextDocumentIdentifier::new(
 658                                        registered_buffer.uri.clone(),
 659                                    ),
 660                                    text: None,
 661                                },
 662                            )?;
 663                    }
 664                    language2::Event::FileHandleChanged | language2::Event::LanguageChanged => {
 665                        let new_language_id = id_for_language(buffer.read(cx).language());
 666                        let new_uri = uri_for_buffer(&buffer, cx);
 667                        if new_uri != registered_buffer.uri
 668                            || new_language_id != registered_buffer.language_id
 669                        {
 670                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 671                            registered_buffer.language_id = new_language_id;
 672                            server
 673                                .lsp
 674                                .notify::<lsp2::notification::DidCloseTextDocument>(
 675                                    lsp2::DidCloseTextDocumentParams {
 676                                        text_document: lsp2::TextDocumentIdentifier::new(old_uri),
 677                                    },
 678                                )?;
 679                            server
 680                                .lsp
 681                                .notify::<lsp2::notification::DidOpenTextDocument>(
 682                                    lsp2::DidOpenTextDocumentParams {
 683                                        text_document: lsp2::TextDocumentItem::new(
 684                                            registered_buffer.uri.clone(),
 685                                            registered_buffer.language_id.clone(),
 686                                            registered_buffer.snapshot_version,
 687                                            registered_buffer.snapshot.text(),
 688                                        ),
 689                                    },
 690                                )?;
 691                        }
 692                    }
 693                    _ => {}
 694                }
 695            }
 696        }
 697
 698        Ok(())
 699    }
 700
 701    fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) {
 702        if let Ok(server) = self.server.as_running() {
 703            if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
 704                server
 705                    .lsp
 706                    .notify::<lsp2::notification::DidCloseTextDocument>(
 707                        lsp2::DidCloseTextDocumentParams {
 708                            text_document: lsp2::TextDocumentIdentifier::new(buffer.uri),
 709                        },
 710                    )
 711                    .log_err();
 712            }
 713        }
 714    }
 715
 716    pub fn completions<T>(
 717        &mut self,
 718        buffer: &Handle<Buffer>,
 719        position: T,
 720        cx: &mut ModelContext<Self>,
 721    ) -> Task<Result<Vec<Completion>>>
 722    where
 723        T: ToPointUtf16,
 724    {
 725        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 726    }
 727
 728    pub fn completions_cycling<T>(
 729        &mut self,
 730        buffer: &Handle<Buffer>,
 731        position: T,
 732        cx: &mut ModelContext<Self>,
 733    ) -> Task<Result<Vec<Completion>>>
 734    where
 735        T: ToPointUtf16,
 736    {
 737        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 738    }
 739
 740    pub fn accept_completion(
 741        &mut self,
 742        completion: &Completion,
 743        cx: &mut ModelContext<Self>,
 744    ) -> Task<Result<()>> {
 745        let server = match self.server.as_authenticated() {
 746            Ok(server) => server,
 747            Err(error) => return Task::ready(Err(error)),
 748        };
 749        let request =
 750            server
 751                .lsp
 752                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 753                    uuid: completion.uuid.clone(),
 754                });
 755        cx.executor().spawn(async move {
 756            request.await?;
 757            Ok(())
 758        })
 759    }
 760
 761    pub fn discard_completions(
 762        &mut self,
 763        completions: &[Completion],
 764        cx: &mut ModelContext<Self>,
 765    ) -> Task<Result<()>> {
 766        let server = match self.server.as_authenticated() {
 767            Ok(server) => server,
 768            Err(error) => return Task::ready(Err(error)),
 769        };
 770        let request =
 771            server
 772                .lsp
 773                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 774                    uuids: completions
 775                        .iter()
 776                        .map(|completion| completion.uuid.clone())
 777                        .collect(),
 778                });
 779        cx.executor().spawn(async move {
 780            request.await?;
 781            Ok(())
 782        })
 783    }
 784
 785    fn request_completions<R, T>(
 786        &mut self,
 787        buffer: &Handle<Buffer>,
 788        position: T,
 789        cx: &mut ModelContext<Self>,
 790    ) -> Task<Result<Vec<Completion>>>
 791    where
 792        R: 'static
 793            + lsp2::request::Request<
 794                Params = request::GetCompletionsParams,
 795                Result = request::GetCompletionsResult,
 796            >,
 797        T: ToPointUtf16,
 798    {
 799        self.register_buffer(buffer, cx);
 800
 801        let server = match self.server.as_authenticated() {
 802            Ok(server) => server,
 803            Err(error) => return Task::ready(Err(error)),
 804        };
 805        let lsp = server.lsp.clone();
 806        let registered_buffer = server
 807            .registered_buffers
 808            .get_mut(&buffer.entity_id())
 809            .unwrap();
 810        let snapshot = registered_buffer.report_changes(buffer, cx);
 811        let buffer = buffer.read(cx);
 812        let uri = registered_buffer.uri.clone();
 813        let position = position.to_point_utf16(buffer);
 814        let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
 815        let tab_size = settings.tab_size;
 816        let hard_tabs = settings.hard_tabs;
 817        let relative_path = buffer
 818            .file()
 819            .map(|file| file.path().to_path_buf())
 820            .unwrap_or_default();
 821
 822        cx.executor().spawn(async move {
 823            let (version, snapshot) = snapshot.await?;
 824            let result = lsp
 825                .request::<R>(request::GetCompletionsParams {
 826                    doc: request::GetCompletionsDocument {
 827                        uri,
 828                        tab_size: tab_size.into(),
 829                        indent_size: 1,
 830                        insert_spaces: !hard_tabs,
 831                        relative_path: relative_path.to_string_lossy().into(),
 832                        position: point_to_lsp(position),
 833                        version: version.try_into().unwrap(),
 834                    },
 835                })
 836                .await?;
 837            let completions = result
 838                .completions
 839                .into_iter()
 840                .map(|completion| {
 841                    let start = snapshot
 842                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 843                    let end =
 844                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 845                    Completion {
 846                        uuid: completion.uuid,
 847                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 848                        text: completion.text,
 849                    }
 850                })
 851                .collect();
 852            anyhow::Ok(completions)
 853        })
 854    }
 855
 856    pub fn status(&self) -> Status {
 857        match &self.server {
 858            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 859            CopilotServer::Disabled => Status::Disabled,
 860            CopilotServer::Error(error) => Status::Error(error.clone()),
 861            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 862                match sign_in_status {
 863                    SignInStatus::Authorized { .. } => Status::Authorized,
 864                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 865                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 866                        prompt: prompt.clone(),
 867                    },
 868                    SignInStatus::SignedOut => Status::SignedOut,
 869                }
 870            }
 871        }
 872    }
 873
 874    fn update_sign_in_status(
 875        &mut self,
 876        lsp_status: request::SignInStatus,
 877        cx: &mut ModelContext<Self>,
 878    ) {
 879        self.buffers.retain(|buffer| buffer.is_upgradable());
 880
 881        if let Ok(server) = self.server.as_running() {
 882            match lsp_status {
 883                request::SignInStatus::Ok { .. }
 884                | request::SignInStatus::MaybeOk { .. }
 885                | request::SignInStatus::AlreadySignedIn { .. } => {
 886                    server.sign_in_status = SignInStatus::Authorized;
 887                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 888                        if let Some(buffer) = buffer.upgrade() {
 889                            self.register_buffer(&buffer, cx);
 890                        }
 891                    }
 892                }
 893                request::SignInStatus::NotAuthorized { .. } => {
 894                    server.sign_in_status = SignInStatus::Unauthorized;
 895                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 896                        self.unregister_buffer(&buffer);
 897                    }
 898                }
 899                request::SignInStatus::NotSignedIn => {
 900                    server.sign_in_status = SignInStatus::SignedOut;
 901                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
 902                        self.unregister_buffer(&buffer);
 903                    }
 904                }
 905            }
 906
 907            cx.notify();
 908        }
 909    }
 910}
 911
 912fn id_for_language(language: Option<&Arc<Language>>) -> String {
 913    let language_name = language.map(|language| language.name());
 914    match language_name.as_deref() {
 915        Some("Plain Text") => "plaintext".to_string(),
 916        Some(language_name) => language_name.to_lowercase(),
 917        None => "plaintext".to_string(),
 918    }
 919}
 920
 921fn uri_for_buffer(buffer: &Handle<Buffer>, cx: &AppContext) -> lsp2::Url {
 922    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 923        lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
 924    } else {
 925        format!("buffer://{}", buffer.entity_id()).parse().unwrap()
 926    }
 927}
 928
 929async fn clear_copilot_dir() {
 930    remove_matching(&paths::COPILOT_DIR, |_| true).await
 931}
 932
 933async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 934    const SERVER_PATH: &'static str = "dist/agent.js";
 935
 936    ///Check for the latest copilot language server and download it if we haven't already
 937    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 938        let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
 939
 940        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
 941
 942        fs::create_dir_all(version_dir).await?;
 943        let server_path = version_dir.join(SERVER_PATH);
 944
 945        if fs::metadata(&server_path).await.is_err() {
 946            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
 947            let dist_dir = version_dir.join("dist");
 948            fs::create_dir_all(dist_dir.as_path()).await?;
 949
 950            let url = &release
 951                .assets
 952                .get(0)
 953                .context("Github release for copilot contained no assets")?
 954                .browser_download_url;
 955
 956            let mut response = http
 957                .get(&url, Default::default(), true)
 958                .await
 959                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
 960            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
 961            let archive = Archive::new(decompressed_bytes);
 962            archive.unpack(dist_dir).await?;
 963
 964            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
 965        }
 966
 967        Ok(server_path)
 968    }
 969
 970    match fetch_latest(http).await {
 971        ok @ Result::Ok(..) => ok,
 972        e @ Err(..) => {
 973            e.log_err();
 974            // Fetch a cached binary, if it exists
 975            (|| async move {
 976                let mut last_version_dir = None;
 977                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
 978                while let Some(entry) = entries.next().await {
 979                    let entry = entry?;
 980                    if entry.file_type().await?.is_dir() {
 981                        last_version_dir = Some(entry.path());
 982                    }
 983                }
 984                let last_version_dir =
 985                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
 986                let server_path = last_version_dir.join(SERVER_PATH);
 987                if server_path.exists() {
 988                    Ok(server_path)
 989                } else {
 990                    Err(anyhow!(
 991                        "missing executable in directory {:?}",
 992                        last_version_dir
 993                    ))
 994                }
 995            })()
 996            .await
 997        }
 998    }
 999}
1000
1001// #[cfg(test)]
1002// mod tests {
1003//     use super::*;
1004//     use gpui::{executor::Deterministic, TestAppContext};
1005
1006//     #[gpui::test(iterations = 10)]
1007//     async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1008//         deterministic.forbid_parking();
1009//         let (copilot, mut lsp) = Copilot::fake(cx);
1010
1011//         let buffer_1 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Hello"));
1012//         let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1013//         copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1014//         assert_eq!(
1015//             lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1016//                 .await,
1017//             lsp::DidOpenTextDocumentParams {
1018//                 text_document: lsp::TextDocumentItem::new(
1019//                     buffer_1_uri.clone(),
1020//                     "plaintext".into(),
1021//                     0,
1022//                     "Hello".into()
1023//                 ),
1024//             }
1025//         );
1026
1027//         let buffer_2 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Goodbye"));
1028//         let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1029//         copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1030//         assert_eq!(
1031//             lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1032//                 .await,
1033//             lsp::DidOpenTextDocumentParams {
1034//                 text_document: lsp::TextDocumentItem::new(
1035//                     buffer_2_uri.clone(),
1036//                     "plaintext".into(),
1037//                     0,
1038//                     "Goodbye".into()
1039//                 ),
1040//             }
1041//         );
1042
1043//         buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1044//         assert_eq!(
1045//             lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1046//                 .await,
1047//             lsp::DidChangeTextDocumentParams {
1048//                 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1049//                 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1050//                     range: Some(lsp::Range::new(
1051//                         lsp::Position::new(0, 5),
1052//                         lsp::Position::new(0, 5)
1053//                     )),
1054//                     range_length: None,
1055//                     text: " world".into(),
1056//                 }],
1057//             }
1058//         );
1059
1060//         // Ensure updates to the file are reflected in the LSP.
1061//         buffer_1
1062//             .update(cx, |buffer, cx| {
1063//                 buffer.file_updated(
1064//                     Arc::new(File {
1065//                         abs_path: "/root/child/buffer-1".into(),
1066//                         path: Path::new("child/buffer-1").into(),
1067//                     }),
1068//                     cx,
1069//                 )
1070//             })
1071//             .await;
1072//         assert_eq!(
1073//             lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1074//                 .await,
1075//             lsp::DidCloseTextDocumentParams {
1076//                 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1077//             }
1078//         );
1079//         let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1080//         assert_eq!(
1081//             lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1082//                 .await,
1083//             lsp::DidOpenTextDocumentParams {
1084//                 text_document: lsp::TextDocumentItem::new(
1085//                     buffer_1_uri.clone(),
1086//                     "plaintext".into(),
1087//                     1,
1088//                     "Hello world".into()
1089//                 ),
1090//             }
1091//         );
1092
1093//         // Ensure all previously-registered buffers are closed when signing out.
1094//         lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1095//             Ok(request::SignOutResult {})
1096//         });
1097//         copilot
1098//             .update(cx, |copilot, cx| copilot.sign_out(cx))
1099//             .await
1100//             .unwrap();
1101//         assert_eq!(
1102//             lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1103//                 .await,
1104//             lsp::DidCloseTextDocumentParams {
1105//                 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1106//             }
1107//         );
1108//         assert_eq!(
1109//             lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1110//                 .await,
1111//             lsp::DidCloseTextDocumentParams {
1112//                 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1113//             }
1114//         );
1115
1116//         // Ensure all previously-registered buffers are re-opened when signing in.
1117//         lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1118//             Ok(request::SignInInitiateResult::AlreadySignedIn {
1119//                 user: "user-1".into(),
1120//             })
1121//         });
1122//         copilot
1123//             .update(cx, |copilot, cx| copilot.sign_in(cx))
1124//             .await
1125//             .unwrap();
1126//         assert_eq!(
1127//             lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1128//                 .await,
1129//             lsp::DidOpenTextDocumentParams {
1130//                 text_document: lsp::TextDocumentItem::new(
1131//                     buffer_2_uri.clone(),
1132//                     "plaintext".into(),
1133//                     0,
1134//                     "Goodbye".into()
1135//                 ),
1136//             }
1137//         );
1138//         assert_eq!(
1139//             lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1140//                 .await,
1141//             lsp::DidOpenTextDocumentParams {
1142//                 text_document: lsp::TextDocumentItem::new(
1143//                     buffer_1_uri.clone(),
1144//                     "plaintext".into(),
1145//                     0,
1146//                     "Hello world".into()
1147//                 ),
1148//             }
1149//         );
1150
1151//         // Dropping a buffer causes it to be closed on the LSP side as well.
1152//         cx.update(|_| drop(buffer_2));
1153//         assert_eq!(
1154//             lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1155//                 .await,
1156//             lsp::DidCloseTextDocumentParams {
1157//                 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1158//             }
1159//         );
1160//     }
1161
1162//     struct File {
1163//         abs_path: PathBuf,
1164//         path: Arc<Path>,
1165//     }
1166
1167//     impl language2::File for File {
1168//         fn as_local(&self) -> Option<&dyn language2::LocalFile> {
1169//             Some(self)
1170//         }
1171
1172//         fn mtime(&self) -> std::time::SystemTime {
1173//             unimplemented!()
1174//         }
1175
1176//         fn path(&self) -> &Arc<Path> {
1177//             &self.path
1178//         }
1179
1180//         fn full_path(&self, _: &AppContext) -> PathBuf {
1181//             unimplemented!()
1182//         }
1183
1184//         fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1185//             unimplemented!()
1186//         }
1187
1188//         fn is_deleted(&self) -> bool {
1189//             unimplemented!()
1190//         }
1191
1192//         fn as_any(&self) -> &dyn std::any::Any {
1193//             unimplemented!()
1194//         }
1195
1196//         fn to_proto(&self) -> rpc::proto::File {
1197//             unimplemented!()
1198//         }
1199
1200//         fn worktree_id(&self) -> usize {
1201//             0
1202//         }
1203//     }
1204
1205//     impl language::LocalFile for File {
1206//         fn abs_path(&self, _: &AppContext) -> PathBuf {
1207//             self.abs_path.clone()
1208//         }
1209
1210//         fn load(&self, _: &AppContext) -> Task<Result<String>> {
1211//             unimplemented!()
1212//         }
1213
1214//         fn buffer_reloaded(
1215//             &self,
1216//             _: u64,
1217//             _: &clock::Global,
1218//             _: language::RopeFingerprint,
1219//             _: language::LineEnding,
1220//             _: std::time::SystemTime,
1221//             _: &mut AppContext,
1222//         ) {
1223//             unimplemented!()
1224//         }
1225//     }
1226// }