copilot.rs

   1pub mod request;
   2mod sign_in;
   3
   4use anyhow::{anyhow, Context, Result};
   5use async_compression::futures::bufread::GzipDecoder;
   6use async_tar::Archive;
   7use collections::HashMap;
   8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
   9use gpui::{
  10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
  11};
  12use language::{
  13    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
  14    ToPointUtf16,
  15};
  16use log::{debug, error};
  17use lsp::LanguageServer;
  18use node_runtime::NodeRuntime;
  19use request::{LogMessage, StatusNotification};
  20use settings::Settings;
  21use smol::{fs, io::BufReader, stream::StreamExt};
  22use std::{
  23    ffi::OsString,
  24    mem,
  25    ops::Range,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28};
  29use util::{
  30    channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
  31    paths, ResultExt,
  32};
  33
  34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
  35actions!(copilot_auth, [SignIn, SignOut]);
  36
  37const COPILOT_NAMESPACE: &'static str = "copilot";
  38actions!(
  39    copilot,
  40    [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
  41);
  42
  43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
  44    // Disable Copilot for stable releases.
  45    if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
  46        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
  47            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  48            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
  49        });
  50        return;
  51    }
  52
  53    let copilot = cx.add_model({
  54        let node_runtime = node_runtime.clone();
  55        move |cx| Copilot::start(http, node_runtime, cx)
  56    });
  57    cx.set_global(copilot.clone());
  58
  59    cx.observe(&copilot, |handle, cx| {
  60        let status = handle.read(cx).status();
  61        cx.update_global::<collections::CommandPaletteFilter, _, _>(
  62            move |filter, _cx| match status {
  63                Status::Disabled => {
  64                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  65                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
  66                }
  67                Status::Authorized => {
  68                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
  69                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  70                }
  71                _ => {
  72                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
  73                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
  74                }
  75            },
  76        );
  77    })
  78    .detach();
  79
  80    sign_in::init(cx);
  81    cx.add_global_action(|_: &SignIn, cx| {
  82        if let Some(copilot) = Copilot::global(cx) {
  83            copilot
  84                .update(cx, |copilot, cx| copilot.sign_in(cx))
  85                .detach_and_log_err(cx);
  86        }
  87    });
  88    cx.add_global_action(|_: &SignOut, cx| {
  89        if let Some(copilot) = Copilot::global(cx) {
  90            copilot
  91                .update(cx, |copilot, cx| copilot.sign_out(cx))
  92                .detach_and_log_err(cx);
  93        }
  94    });
  95
  96    cx.add_global_action(|_: &Reinstall, cx| {
  97        if let Some(copilot) = Copilot::global(cx) {
  98            copilot
  99                .update(cx, |copilot, cx| copilot.reinstall(cx))
 100                .detach();
 101        }
 102    });
 103}
 104
 105enum CopilotServer {
 106    Disabled,
 107    Starting { task: Shared<Task<()>> },
 108    Error(Arc<str>),
 109    Running(RunningCopilotServer),
 110}
 111
 112impl CopilotServer {
 113    fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
 114        let server = self.as_running()?;
 115        if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
 116            Ok(server)
 117        } else {
 118            Err(anyhow!("must sign in before using copilot"))
 119        }
 120    }
 121
 122    fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
 123        match self {
 124            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
 125            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
 126            CopilotServer::Error(error) => Err(anyhow!(
 127                "copilot was not started because of an error: {}",
 128                error
 129            )),
 130            CopilotServer::Running(server) => Ok(server),
 131        }
 132    }
 133}
 134
 135struct RunningCopilotServer {
 136    lsp: Arc<LanguageServer>,
 137    sign_in_status: SignInStatus,
 138    registered_buffers: HashMap<usize, RegisteredBuffer>,
 139}
 140
 141#[derive(Clone, Debug)]
 142enum SignInStatus {
 143    Authorized,
 144    Unauthorized,
 145    SigningIn {
 146        prompt: Option<request::PromptUserDeviceFlow>,
 147        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 148    },
 149    SignedOut,
 150}
 151
 152#[derive(Debug, Clone)]
 153pub enum Status {
 154    Starting {
 155        task: Shared<Task<()>>,
 156    },
 157    Error(Arc<str>),
 158    Disabled,
 159    SignedOut,
 160    SigningIn {
 161        prompt: Option<request::PromptUserDeviceFlow>,
 162    },
 163    Unauthorized,
 164    Authorized,
 165}
 166
 167impl Status {
 168    pub fn is_authorized(&self) -> bool {
 169        matches!(self, Status::Authorized)
 170    }
 171}
 172
 173struct RegisteredBuffer {
 174    id: usize,
 175    uri: lsp::Url,
 176    language_id: String,
 177    snapshot: BufferSnapshot,
 178    snapshot_version: i32,
 179    _subscriptions: [gpui::Subscription; 2],
 180    pending_buffer_change: Task<Option<()>>,
 181}
 182
 183impl RegisteredBuffer {
 184    fn report_changes(
 185        &mut self,
 186        buffer: &ModelHandle<Buffer>,
 187        cx: &mut ModelContext<Copilot>,
 188    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
 189        let id = self.id;
 190        let (done_tx, done_rx) = oneshot::channel();
 191
 192        if buffer.read(cx).version() == self.snapshot.version {
 193            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
 194        } else {
 195            let buffer = buffer.downgrade();
 196            let prev_pending_change =
 197                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
 198            self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
 199                prev_pending_change.await;
 200
 201                let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 202                    let server = copilot.server.as_authenticated().log_err()?;
 203                    let buffer = server.registered_buffers.get_mut(&id)?;
 204                    Some(buffer.snapshot.version.clone())
 205                })?;
 206                let new_snapshot = buffer
 207                    .upgrade(&cx)?
 208                    .read_with(&cx, |buffer, _| buffer.snapshot());
 209
 210                let content_changes = cx
 211                    .background()
 212                    .spawn({
 213                        let new_snapshot = new_snapshot.clone();
 214                        async move {
 215                            new_snapshot
 216                                .edits_since::<(PointUtf16, usize)>(&old_version)
 217                                .map(|edit| {
 218                                    let edit_start = edit.new.start.0;
 219                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
 220                                    let new_text = new_snapshot
 221                                        .text_for_range(edit.new.start.1..edit.new.end.1)
 222                                        .collect();
 223                                    lsp::TextDocumentContentChangeEvent {
 224                                        range: Some(lsp::Range::new(
 225                                            point_to_lsp(edit_start),
 226                                            point_to_lsp(edit_end),
 227                                        )),
 228                                        range_length: None,
 229                                        text: new_text,
 230                                    }
 231                                })
 232                                .collect::<Vec<_>>()
 233                        }
 234                    })
 235                    .await;
 236
 237                copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
 238                    let server = copilot.server.as_authenticated().log_err()?;
 239                    let buffer = server.registered_buffers.get_mut(&id)?;
 240                    if !content_changes.is_empty() {
 241                        buffer.snapshot_version += 1;
 242                        buffer.snapshot = new_snapshot;
 243                        server
 244                            .lsp
 245                            .notify::<lsp::notification::DidChangeTextDocument>(
 246                                lsp::DidChangeTextDocumentParams {
 247                                    text_document: lsp::VersionedTextDocumentIdentifier::new(
 248                                        buffer.uri.clone(),
 249                                        buffer.snapshot_version,
 250                                    ),
 251                                    content_changes,
 252                                },
 253                            )
 254                            .log_err();
 255                    }
 256                    let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
 257                    Some(())
 258                })?;
 259
 260                Some(())
 261            });
 262        }
 263
 264        done_rx
 265    }
 266}
 267
 268#[derive(Debug)]
 269pub struct Completion {
 270    uuid: String,
 271    pub range: Range<Anchor>,
 272    pub text: String,
 273}
 274
 275pub struct Copilot {
 276    http: Arc<dyn HttpClient>,
 277    node_runtime: Arc<NodeRuntime>,
 278    server: CopilotServer,
 279    buffers: HashMap<usize, WeakModelHandle<Buffer>>,
 280}
 281
 282impl Entity for Copilot {
 283    type Event = ();
 284}
 285
 286impl Copilot {
 287    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 288        if cx.has_global::<ModelHandle<Self>>() {
 289            Some(cx.global::<ModelHandle<Self>>().clone())
 290        } else {
 291            None
 292        }
 293    }
 294
 295    fn start(
 296        http: Arc<dyn HttpClient>,
 297        node_runtime: Arc<NodeRuntime>,
 298        cx: &mut ModelContext<Self>,
 299    ) -> Self {
 300        cx.observe_global::<Settings, _>({
 301            let http = http.clone();
 302            let node_runtime = node_runtime.clone();
 303            move |this, cx| {
 304                if cx.global::<Settings>().features.copilot {
 305                    if matches!(this.server, CopilotServer::Disabled) {
 306                        let start_task = cx
 307                            .spawn({
 308                                let http = http.clone();
 309                                let node_runtime = node_runtime.clone();
 310                                move |this, cx| {
 311                                    Self::start_language_server(http, node_runtime, this, cx)
 312                                }
 313                            })
 314                            .shared();
 315                        this.server = CopilotServer::Starting { task: start_task };
 316                        cx.notify();
 317                    }
 318                } else {
 319                    this.server = CopilotServer::Disabled;
 320                    cx.notify();
 321                }
 322            }
 323        })
 324        .detach();
 325
 326        if cx.global::<Settings>().features.copilot {
 327            let start_task = cx
 328                .spawn({
 329                    let http = http.clone();
 330                    let node_runtime = node_runtime.clone();
 331                    move |this, cx| async {
 332                        Self::start_language_server(http, node_runtime, this, cx).await
 333                    }
 334                })
 335                .shared();
 336
 337            Self {
 338                http,
 339                node_runtime,
 340                server: CopilotServer::Starting { task: start_task },
 341                buffers: Default::default(),
 342            }
 343        } else {
 344            Self {
 345                http,
 346                node_runtime,
 347                server: CopilotServer::Disabled,
 348                buffers: Default::default(),
 349            }
 350        }
 351    }
 352
 353    #[cfg(any(test, feature = "test-support"))]
 354    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
 355        let (server, fake_server) =
 356            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
 357        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
 358        let this = cx.add_model(|cx| Self {
 359            http: http.clone(),
 360            node_runtime: NodeRuntime::new(http, cx.background().clone()),
 361            server: CopilotServer::Running(RunningCopilotServer {
 362                lsp: Arc::new(server),
 363                sign_in_status: SignInStatus::Authorized,
 364                registered_buffers: Default::default(),
 365            }),
 366            buffers: Default::default(),
 367        });
 368        (this, fake_server)
 369    }
 370
 371    fn start_language_server(
 372        http: Arc<dyn HttpClient>,
 373        node_runtime: Arc<NodeRuntime>,
 374        this: ModelHandle<Self>,
 375        mut cx: AsyncAppContext,
 376    ) -> impl Future<Output = ()> {
 377        async move {
 378            let start_language_server = async {
 379                let server_path = get_copilot_lsp(http).await?;
 380                let node_path = node_runtime.binary_path().await?;
 381                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
 382                let server = LanguageServer::new(
 383                    0,
 384                    &node_path,
 385                    arguments,
 386                    Path::new("/"),
 387                    None,
 388                    cx.clone(),
 389                )?;
 390
 391                let server = server.initialize(Default::default()).await?;
 392                let status = server
 393                    .request::<request::CheckStatus>(request::CheckStatusParams {
 394                        local_checks_only: false,
 395                    })
 396                    .await?;
 397
 398                server
 399                    .on_notification::<LogMessage, _>(|params, _cx| {
 400                        match params.level {
 401                            // Copilot is pretty agressive about logging
 402                            0 => debug!("copilot: {}", params.message),
 403                            1 => debug!("copilot: {}", params.message),
 404                            _ => error!("copilot: {}", params.message),
 405                        }
 406
 407                        debug!("copilot metadata: {}", params.metadata_str);
 408                        debug!("copilot extra: {:?}", params.extra);
 409                    })
 410                    .detach();
 411
 412                server
 413                    .on_notification::<StatusNotification, _>(
 414                        |_, _| { /* Silence the notification */ },
 415                    )
 416                    .detach();
 417
 418                server
 419                    .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
 420                        editor_info: request::EditorInfo {
 421                            name: "zed".into(),
 422                            version: env!("CARGO_PKG_VERSION").into(),
 423                        },
 424                        editor_plugin_info: request::EditorPluginInfo {
 425                            name: "zed-copilot".into(),
 426                            version: "0.0.1".into(),
 427                        },
 428                    })
 429                    .await?;
 430
 431                anyhow::Ok((server, status))
 432            };
 433
 434            let server = start_language_server.await;
 435            this.update(&mut cx, |this, cx| {
 436                cx.notify();
 437                match server {
 438                    Ok((server, status)) => {
 439                        this.server = CopilotServer::Running(RunningCopilotServer {
 440                            lsp: server,
 441                            sign_in_status: SignInStatus::SignedOut,
 442                            registered_buffers: Default::default(),
 443                        });
 444                        this.update_sign_in_status(status, cx);
 445                    }
 446                    Err(error) => {
 447                        this.server = CopilotServer::Error(error.to_string().into());
 448                        cx.notify()
 449                    }
 450                }
 451            })
 452        }
 453    }
 454
 455    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 456        if let CopilotServer::Running(server) = &mut self.server {
 457            let task = match &server.sign_in_status {
 458                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
 459                    Task::ready(Ok(())).shared()
 460                }
 461                SignInStatus::SigningIn { task, .. } => {
 462                    cx.notify();
 463                    task.clone()
 464                }
 465                SignInStatus::SignedOut => {
 466                    let lsp = server.lsp.clone();
 467                    let task = cx
 468                        .spawn(|this, mut cx| async move {
 469                            let sign_in = async {
 470                                let sign_in = lsp
 471                                    .request::<request::SignInInitiate>(
 472                                        request::SignInInitiateParams {},
 473                                    )
 474                                    .await?;
 475                                match sign_in {
 476                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
 477                                        Ok(request::SignInStatus::Ok { user })
 478                                    }
 479                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
 480                                        this.update(&mut cx, |this, cx| {
 481                                            if let CopilotServer::Running(RunningCopilotServer {
 482                                                sign_in_status: status,
 483                                                ..
 484                                            }) = &mut this.server
 485                                            {
 486                                                if let SignInStatus::SigningIn {
 487                                                    prompt: prompt_flow,
 488                                                    ..
 489                                                } = status
 490                                                {
 491                                                    *prompt_flow = Some(flow.clone());
 492                                                    cx.notify();
 493                                                }
 494                                            }
 495                                        });
 496                                        let response = lsp
 497                                            .request::<request::SignInConfirm>(
 498                                                request::SignInConfirmParams {
 499                                                    user_code: flow.user_code,
 500                                                },
 501                                            )
 502                                            .await?;
 503                                        Ok(response)
 504                                    }
 505                                }
 506                            };
 507
 508                            let sign_in = sign_in.await;
 509                            this.update(&mut cx, |this, cx| match sign_in {
 510                                Ok(status) => {
 511                                    this.update_sign_in_status(status, cx);
 512                                    Ok(())
 513                                }
 514                                Err(error) => {
 515                                    this.update_sign_in_status(
 516                                        request::SignInStatus::NotSignedIn,
 517                                        cx,
 518                                    );
 519                                    Err(Arc::new(error))
 520                                }
 521                            })
 522                        })
 523                        .shared();
 524                    server.sign_in_status = SignInStatus::SigningIn {
 525                        prompt: None,
 526                        task: task.clone(),
 527                    };
 528                    cx.notify();
 529                    task
 530                }
 531            };
 532
 533            cx.foreground()
 534                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
 535        } else {
 536            // If we're downloading, wait until download is finished
 537            // If we're in a stuck state, display to the user
 538            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 539        }
 540    }
 541
 542    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 543        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
 544        if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
 545            let server = server.clone();
 546            cx.background().spawn(async move {
 547                server
 548                    .request::<request::SignOut>(request::SignOutParams {})
 549                    .await?;
 550                anyhow::Ok(())
 551            })
 552        } else {
 553            Task::ready(Err(anyhow!("copilot hasn't started yet")))
 554        }
 555    }
 556
 557    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 558        let start_task = cx
 559            .spawn({
 560                let http = self.http.clone();
 561                let node_runtime = self.node_runtime.clone();
 562                move |this, cx| async move {
 563                    clear_copilot_dir().await;
 564                    Self::start_language_server(http, node_runtime, this, cx).await
 565                }
 566            })
 567            .shared();
 568
 569        self.server = CopilotServer::Starting {
 570            task: start_task.clone(),
 571        };
 572
 573        cx.notify();
 574
 575        cx.foreground().spawn(start_task)
 576    }
 577
 578    pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
 579        let buffer_id = buffer.id();
 580        self.buffers.insert(buffer_id, buffer.downgrade());
 581
 582        if let CopilotServer::Running(RunningCopilotServer {
 583            lsp: server,
 584            sign_in_status: status,
 585            registered_buffers,
 586            ..
 587        }) = &mut self.server
 588        {
 589            if !matches!(status, SignInStatus::Authorized { .. }) {
 590                return;
 591            }
 592
 593            registered_buffers.entry(buffer.id()).or_insert_with(|| {
 594                let uri: lsp::Url = uri_for_buffer(buffer, cx);
 595                let language_id = id_for_language(buffer.read(cx).language());
 596                let snapshot = buffer.read(cx).snapshot();
 597                server
 598                    .notify::<lsp::notification::DidOpenTextDocument>(
 599                        lsp::DidOpenTextDocumentParams {
 600                            text_document: lsp::TextDocumentItem {
 601                                uri: uri.clone(),
 602                                language_id: language_id.clone(),
 603                                version: 0,
 604                                text: snapshot.text(),
 605                            },
 606                        },
 607                    )
 608                    .log_err();
 609
 610                RegisteredBuffer {
 611                    id: buffer_id,
 612                    uri,
 613                    language_id,
 614                    snapshot,
 615                    snapshot_version: 0,
 616                    pending_buffer_change: Task::ready(Some(())),
 617                    _subscriptions: [
 618                        cx.subscribe(buffer, |this, buffer, event, cx| {
 619                            this.handle_buffer_event(buffer, event, cx).log_err();
 620                        }),
 621                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 622                            this.buffers.remove(&buffer_id);
 623                            this.unregister_buffer(buffer_id);
 624                        }),
 625                    ],
 626                }
 627            });
 628        }
 629    }
 630
 631    fn handle_buffer_event(
 632        &mut self,
 633        buffer: ModelHandle<Buffer>,
 634        event: &language::Event,
 635        cx: &mut ModelContext<Self>,
 636    ) -> Result<()> {
 637        if let Ok(server) = self.server.as_running() {
 638            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
 639                match event {
 640                    language::Event::Edited => {
 641                        let _ = registered_buffer.report_changes(&buffer, cx);
 642                    }
 643                    language::Event::Saved => {
 644                        server
 645                            .lsp
 646                            .notify::<lsp::notification::DidSaveTextDocument>(
 647                                lsp::DidSaveTextDocumentParams {
 648                                    text_document: lsp::TextDocumentIdentifier::new(
 649                                        registered_buffer.uri.clone(),
 650                                    ),
 651                                    text: None,
 652                                },
 653                            )?;
 654                    }
 655                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
 656                        let new_language_id = id_for_language(buffer.read(cx).language());
 657                        let new_uri = uri_for_buffer(&buffer, cx);
 658                        if new_uri != registered_buffer.uri
 659                            || new_language_id != registered_buffer.language_id
 660                        {
 661                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
 662                            registered_buffer.language_id = new_language_id;
 663                            server
 664                                .lsp
 665                                .notify::<lsp::notification::DidCloseTextDocument>(
 666                                    lsp::DidCloseTextDocumentParams {
 667                                        text_document: lsp::TextDocumentIdentifier::new(old_uri),
 668                                    },
 669                                )?;
 670                            server
 671                                .lsp
 672                                .notify::<lsp::notification::DidOpenTextDocument>(
 673                                    lsp::DidOpenTextDocumentParams {
 674                                        text_document: lsp::TextDocumentItem::new(
 675                                            registered_buffer.uri.clone(),
 676                                            registered_buffer.language_id.clone(),
 677                                            registered_buffer.snapshot_version,
 678                                            registered_buffer.snapshot.text(),
 679                                        ),
 680                                    },
 681                                )?;
 682                        }
 683                    }
 684                    _ => {}
 685                }
 686            }
 687        }
 688
 689        Ok(())
 690    }
 691
 692    fn unregister_buffer(&mut self, buffer_id: usize) {
 693        if let Ok(server) = self.server.as_running() {
 694            if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
 695                server
 696                    .lsp
 697                    .notify::<lsp::notification::DidCloseTextDocument>(
 698                        lsp::DidCloseTextDocumentParams {
 699                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
 700                        },
 701                    )
 702                    .log_err();
 703            }
 704        }
 705    }
 706
 707    pub fn completions<T>(
 708        &mut self,
 709        buffer: &ModelHandle<Buffer>,
 710        position: T,
 711        cx: &mut ModelContext<Self>,
 712    ) -> Task<Result<Vec<Completion>>>
 713    where
 714        T: ToPointUtf16,
 715    {
 716        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
 717    }
 718
 719    pub fn completions_cycling<T>(
 720        &mut self,
 721        buffer: &ModelHandle<Buffer>,
 722        position: T,
 723        cx: &mut ModelContext<Self>,
 724    ) -> Task<Result<Vec<Completion>>>
 725    where
 726        T: ToPointUtf16,
 727    {
 728        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
 729    }
 730
 731    pub fn accept_completion(
 732        &mut self,
 733        completion: &Completion,
 734        cx: &mut ModelContext<Self>,
 735    ) -> Task<Result<()>> {
 736        let server = match self.server.as_authenticated() {
 737            Ok(server) => server,
 738            Err(error) => return Task::ready(Err(error)),
 739        };
 740        let request =
 741            server
 742                .lsp
 743                .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
 744                    uuid: completion.uuid.clone(),
 745                });
 746        cx.background().spawn(async move {
 747            request.await?;
 748            Ok(())
 749        })
 750    }
 751
 752    pub fn discard_completions(
 753        &mut self,
 754        completions: &[Completion],
 755        cx: &mut ModelContext<Self>,
 756    ) -> Task<Result<()>> {
 757        let server = match self.server.as_authenticated() {
 758            Ok(server) => server,
 759            Err(error) => return Task::ready(Err(error)),
 760        };
 761        let request =
 762            server
 763                .lsp
 764                .request::<request::NotifyRejected>(request::NotifyRejectedParams {
 765                    uuids: completions
 766                        .iter()
 767                        .map(|completion| completion.uuid.clone())
 768                        .collect(),
 769                });
 770        cx.background().spawn(async move {
 771            request.await?;
 772            Ok(())
 773        })
 774    }
 775
 776    fn request_completions<R, T>(
 777        &mut self,
 778        buffer: &ModelHandle<Buffer>,
 779        position: T,
 780        cx: &mut ModelContext<Self>,
 781    ) -> Task<Result<Vec<Completion>>>
 782    where
 783        R: 'static
 784            + lsp::request::Request<
 785                Params = request::GetCompletionsParams,
 786                Result = request::GetCompletionsResult,
 787            >,
 788        T: ToPointUtf16,
 789    {
 790        self.register_buffer(buffer, cx);
 791
 792        let server = match self.server.as_authenticated() {
 793            Ok(server) => server,
 794            Err(error) => return Task::ready(Err(error)),
 795        };
 796        let lsp = server.lsp.clone();
 797        let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
 798        let snapshot = registered_buffer.report_changes(buffer, cx);
 799        let buffer = buffer.read(cx);
 800        let uri = registered_buffer.uri.clone();
 801        let settings = cx.global::<Settings>();
 802        let position = position.to_point_utf16(buffer);
 803        let language = buffer.language_at(position);
 804        let language_name = language.map(|language| language.name());
 805        let language_name = language_name.as_deref();
 806        let tab_size = settings.tab_size(language_name);
 807        let hard_tabs = settings.hard_tabs(language_name);
 808        let relative_path = buffer
 809            .file()
 810            .map(|file| file.path().to_path_buf())
 811            .unwrap_or_default();
 812
 813        cx.foreground().spawn(async move {
 814            let (version, snapshot) = snapshot.await?;
 815            let result = lsp
 816                .request::<R>(request::GetCompletionsParams {
 817                    doc: request::GetCompletionsDocument {
 818                        uri,
 819                        tab_size: tab_size.into(),
 820                        indent_size: 1,
 821                        insert_spaces: !hard_tabs,
 822                        relative_path: relative_path.to_string_lossy().into(),
 823                        position: point_to_lsp(position),
 824                        version: version.try_into().unwrap(),
 825                    },
 826                })
 827                .await?;
 828            let completions = result
 829                .completions
 830                .into_iter()
 831                .map(|completion| {
 832                    let start = snapshot
 833                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
 834                    let end =
 835                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
 836                    Completion {
 837                        uuid: completion.uuid,
 838                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
 839                        text: completion.text,
 840                    }
 841                })
 842                .collect();
 843            anyhow::Ok(completions)
 844        })
 845    }
 846
 847    pub fn status(&self) -> Status {
 848        match &self.server {
 849            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
 850            CopilotServer::Disabled => Status::Disabled,
 851            CopilotServer::Error(error) => Status::Error(error.clone()),
 852            CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
 853                match sign_in_status {
 854                    SignInStatus::Authorized { .. } => Status::Authorized,
 855                    SignInStatus::Unauthorized { .. } => Status::Unauthorized,
 856                    SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
 857                        prompt: prompt.clone(),
 858                    },
 859                    SignInStatus::SignedOut => Status::SignedOut,
 860                }
 861            }
 862        }
 863    }
 864
 865    fn update_sign_in_status(
 866        &mut self,
 867        lsp_status: request::SignInStatus,
 868        cx: &mut ModelContext<Self>,
 869    ) {
 870        self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
 871
 872        if let Ok(server) = self.server.as_running() {
 873            match lsp_status {
 874                request::SignInStatus::Ok { .. }
 875                | request::SignInStatus::MaybeOk { .. }
 876                | request::SignInStatus::AlreadySignedIn { .. } => {
 877                    server.sign_in_status = SignInStatus::Authorized;
 878                    for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
 879                        if let Some(buffer) = buffer.upgrade(cx) {
 880                            self.register_buffer(&buffer, cx);
 881                        }
 882                    }
 883                }
 884                request::SignInStatus::NotAuthorized { .. } => {
 885                    server.sign_in_status = SignInStatus::Unauthorized;
 886                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 887                        self.unregister_buffer(buffer_id);
 888                    }
 889                }
 890                request::SignInStatus::NotSignedIn => {
 891                    server.sign_in_status = SignInStatus::SignedOut;
 892                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
 893                        self.unregister_buffer(buffer_id);
 894                    }
 895                }
 896            }
 897
 898            cx.notify();
 899        }
 900    }
 901}
 902
 903fn id_for_language(language: Option<&Arc<Language>>) -> String {
 904    let language_name = language.map(|language| language.name());
 905    match language_name.as_deref() {
 906        Some("Plain Text") => "plaintext".to_string(),
 907        Some(language_name) => language_name.to_lowercase(),
 908        None => "plaintext".to_string(),
 909    }
 910}
 911
 912fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
 913    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
 914        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
 915    } else {
 916        format!("buffer://{}", buffer.id()).parse().unwrap()
 917    }
 918}
 919
 920async fn clear_copilot_dir() {
 921    remove_matching(&paths::COPILOT_DIR, |_| true).await
 922}
 923
 924async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 925    const SERVER_PATH: &'static str = "dist/agent.js";
 926
 927    ///Check for the latest copilot language server and download it if we haven't already
 928    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
 929        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
 930
 931        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
 932
 933        fs::create_dir_all(version_dir).await?;
 934        let server_path = version_dir.join(SERVER_PATH);
 935
 936        if fs::metadata(&server_path).await.is_err() {
 937            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
 938            let dist_dir = version_dir.join("dist");
 939            fs::create_dir_all(dist_dir.as_path()).await?;
 940
 941            let url = &release
 942                .assets
 943                .get(0)
 944                .context("Github release for copilot contained no assets")?
 945                .browser_download_url;
 946
 947            let mut response = http
 948                .get(&url, Default::default(), true)
 949                .await
 950                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
 951            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
 952            let archive = Archive::new(decompressed_bytes);
 953            archive.unpack(dist_dir).await?;
 954
 955            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
 956        }
 957
 958        Ok(server_path)
 959    }
 960
 961    match fetch_latest(http).await {
 962        ok @ Result::Ok(..) => ok,
 963        e @ Err(..) => {
 964            e.log_err();
 965            // Fetch a cached binary, if it exists
 966            (|| async move {
 967                let mut last_version_dir = None;
 968                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
 969                while let Some(entry) = entries.next().await {
 970                    let entry = entry?;
 971                    if entry.file_type().await?.is_dir() {
 972                        last_version_dir = Some(entry.path());
 973                    }
 974                }
 975                let last_version_dir =
 976                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
 977                let server_path = last_version_dir.join(SERVER_PATH);
 978                if server_path.exists() {
 979                    Ok(server_path)
 980                } else {
 981                    Err(anyhow!(
 982                        "missing executable in directory {:?}",
 983                        last_version_dir
 984                    ))
 985                }
 986            })()
 987            .await
 988        }
 989    }
 990}
 991
 992#[cfg(test)]
 993mod tests {
 994    use super::*;
 995    use gpui::{executor::Deterministic, TestAppContext};
 996
 997    #[gpui::test(iterations = 10)]
 998    async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
 999        deterministic.forbid_parking();
1000        let (copilot, mut lsp) = Copilot::fake(cx);
1001
1002        let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1003        let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1004        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1005        assert_eq!(
1006            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1007                .await,
1008            lsp::DidOpenTextDocumentParams {
1009                text_document: lsp::TextDocumentItem::new(
1010                    buffer_1_uri.clone(),
1011                    "plaintext".into(),
1012                    0,
1013                    "Hello".into()
1014                ),
1015            }
1016        );
1017
1018        let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1019        let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1020        copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1021        assert_eq!(
1022            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1023                .await,
1024            lsp::DidOpenTextDocumentParams {
1025                text_document: lsp::TextDocumentItem::new(
1026                    buffer_2_uri.clone(),
1027                    "plaintext".into(),
1028                    0,
1029                    "Goodbye".into()
1030                ),
1031            }
1032        );
1033
1034        buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1035        assert_eq!(
1036            lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1037                .await,
1038            lsp::DidChangeTextDocumentParams {
1039                text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1040                content_changes: vec![lsp::TextDocumentContentChangeEvent {
1041                    range: Some(lsp::Range::new(
1042                        lsp::Position::new(0, 5),
1043                        lsp::Position::new(0, 5)
1044                    )),
1045                    range_length: None,
1046                    text: " world".into(),
1047                }],
1048            }
1049        );
1050
1051        // Ensure updates to the file are reflected in the LSP.
1052        buffer_1
1053            .update(cx, |buffer, cx| {
1054                buffer.file_updated(
1055                    Arc::new(File {
1056                        abs_path: "/root/child/buffer-1".into(),
1057                        path: Path::new("child/buffer-1").into(),
1058                    }),
1059                    cx,
1060                )
1061            })
1062            .await;
1063        assert_eq!(
1064            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1065                .await,
1066            lsp::DidCloseTextDocumentParams {
1067                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1068            }
1069        );
1070        let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1071        assert_eq!(
1072            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1073                .await,
1074            lsp::DidOpenTextDocumentParams {
1075                text_document: lsp::TextDocumentItem::new(
1076                    buffer_1_uri.clone(),
1077                    "plaintext".into(),
1078                    1,
1079                    "Hello world".into()
1080                ),
1081            }
1082        );
1083
1084        // Ensure all previously-registered buffers are closed when signing out.
1085        lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1086            Ok(request::SignOutResult {})
1087        });
1088        copilot
1089            .update(cx, |copilot, cx| copilot.sign_out(cx))
1090            .await
1091            .unwrap();
1092        assert_eq!(
1093            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1094                .await,
1095            lsp::DidCloseTextDocumentParams {
1096                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1097            }
1098        );
1099        assert_eq!(
1100            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1101                .await,
1102            lsp::DidCloseTextDocumentParams {
1103                text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1104            }
1105        );
1106
1107        // Ensure all previously-registered buffers are re-opened when signing in.
1108        lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1109            Ok(request::SignInInitiateResult::AlreadySignedIn {
1110                user: "user-1".into(),
1111            })
1112        });
1113        copilot
1114            .update(cx, |copilot, cx| copilot.sign_in(cx))
1115            .await
1116            .unwrap();
1117        assert_eq!(
1118            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1119                .await,
1120            lsp::DidOpenTextDocumentParams {
1121                text_document: lsp::TextDocumentItem::new(
1122                    buffer_2_uri.clone(),
1123                    "plaintext".into(),
1124                    0,
1125                    "Goodbye".into()
1126                ),
1127            }
1128        );
1129        assert_eq!(
1130            lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1131                .await,
1132            lsp::DidOpenTextDocumentParams {
1133                text_document: lsp::TextDocumentItem::new(
1134                    buffer_1_uri.clone(),
1135                    "plaintext".into(),
1136                    0,
1137                    "Hello world".into()
1138                ),
1139            }
1140        );
1141
1142        // Dropping a buffer causes it to be closed on the LSP side as well.
1143        cx.update(|_| drop(buffer_2));
1144        assert_eq!(
1145            lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1146                .await,
1147            lsp::DidCloseTextDocumentParams {
1148                text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1149            }
1150        );
1151    }
1152
1153    struct File {
1154        abs_path: PathBuf,
1155        path: Arc<Path>,
1156    }
1157
1158    impl language::File for File {
1159        fn as_local(&self) -> Option<&dyn language::LocalFile> {
1160            Some(self)
1161        }
1162
1163        fn mtime(&self) -> std::time::SystemTime {
1164            todo!()
1165        }
1166
1167        fn path(&self) -> &Arc<Path> {
1168            &self.path
1169        }
1170
1171        fn full_path(&self, _: &AppContext) -> PathBuf {
1172            todo!()
1173        }
1174
1175        fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1176            todo!()
1177        }
1178
1179        fn is_deleted(&self) -> bool {
1180            todo!()
1181        }
1182
1183        fn as_any(&self) -> &dyn std::any::Any {
1184            todo!()
1185        }
1186
1187        fn to_proto(&self) -> rpc::proto::File {
1188            todo!()
1189        }
1190    }
1191
1192    impl language::LocalFile for File {
1193        fn abs_path(&self, _: &AppContext) -> PathBuf {
1194            self.abs_path.clone()
1195        }
1196
1197        fn load(&self, _: &AppContext) -> Task<Result<String>> {
1198            todo!()
1199        }
1200
1201        fn buffer_reloaded(
1202            &self,
1203            _: u64,
1204            _: &clock::Global,
1205            _: language::RopeFingerprint,
1206            _: ::fs::LineEnding,
1207            _: std::time::SystemTime,
1208            _: &mut AppContext,
1209        ) {
1210            todo!()
1211        }
1212    }
1213}