copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use collections::HashMap;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 11use log::{debug, error};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use request::{LogMessage, StatusNotification};
 15use settings::Settings;
 16use smol::{fs, io::BufReader, stream::StreamExt};
 17use std::{
 18    ffi::OsString,
 19    ops::Range,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 25};
 26
 27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 28actions!(copilot_auth, [SignIn, SignOut]);
 29
 30const COPILOT_NAMESPACE: &'static str = "copilot";
 31actions!(
 32    copilot,
 33    [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
 34);
 35
 36pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
 37    let copilot = cx.add_model({
 38        let node_runtime = node_runtime.clone();
 39        move |cx| Copilot::start(http, node_runtime, cx)
 40    });
 41    cx.set_global(copilot.clone());
 42
 43    cx.observe(&copilot, |handle, cx| {
 44        let status = handle.read(cx).status();
 45        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 46            move |filter, _cx| match status {
 47                Status::Disabled => {
 48                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 49                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 50                }
 51                Status::Authorized => {
 52                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 53                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 54                }
 55                _ => {
 56                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 57                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 58                }
 59            },
 60        );
 61    })
 62    .detach();
 63
 64    sign_in::init(cx);
 65    cx.add_global_action(|_: &SignIn, cx| {
 66        if let Some(copilot) = Copilot::global(cx) {
 67            copilot
 68                .update(cx, |copilot, cx| copilot.sign_in(cx))
 69                .detach_and_log_err(cx);
 70        }
 71    });
 72    cx.add_global_action(|_: &SignOut, cx| {
 73        if let Some(copilot) = Copilot::global(cx) {
 74            copilot
 75                .update(cx, |copilot, cx| copilot.sign_out(cx))
 76                .detach_and_log_err(cx);
 77        }
 78    });
 79
 80    cx.add_global_action(|_: &Reinstall, cx| {
 81        if let Some(copilot) = Copilot::global(cx) {
 82            copilot
 83                .update(cx, |copilot, cx| copilot.reinstall(cx))
 84                .detach();
 85        }
 86    });
 87}
 88
 89enum CopilotServer {
 90    Disabled,
 91    Starting {
 92        task: Shared<Task<()>>,
 93    },
 94    Error(Arc<str>),
 95    Started {
 96        server: Arc<LanguageServer>,
 97        status: SignInStatus,
 98        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
 99    },
100}
101
102#[derive(Clone, Debug)]
103enum SignInStatus {
104    Authorized,
105    Unauthorized,
106    SigningIn {
107        prompt: Option<request::PromptUserDeviceFlow>,
108        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
109    },
110    SignedOut,
111}
112
113#[derive(Debug, Clone)]
114pub enum Status {
115    Starting {
116        task: Shared<Task<()>>,
117    },
118    Error(Arc<str>),
119    Disabled,
120    SignedOut,
121    SigningIn {
122        prompt: Option<request::PromptUserDeviceFlow>,
123    },
124    Unauthorized,
125    Authorized,
126}
127
128impl Status {
129    pub fn is_authorized(&self) -> bool {
130        matches!(self, Status::Authorized)
131    }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub struct Completion {
136    pub range: Range<Anchor>,
137    pub text: String,
138}
139
140pub struct Copilot {
141    http: Arc<dyn HttpClient>,
142    node_runtime: Arc<NodeRuntime>,
143    server: CopilotServer,
144}
145
146impl Entity for Copilot {
147    type Event = ();
148}
149
150impl Copilot {
151    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
152        if cx.has_global::<ModelHandle<Self>>() {
153            Some(cx.global::<ModelHandle<Self>>().clone())
154        } else {
155            None
156        }
157    }
158
159    fn start(
160        http: Arc<dyn HttpClient>,
161        node_runtime: Arc<NodeRuntime>,
162        cx: &mut ModelContext<Self>,
163    ) -> Self {
164        cx.observe_global::<Settings, _>({
165            let http = http.clone();
166            let node_runtime = node_runtime.clone();
167            move |this, cx| {
168                if cx.global::<Settings>().features.copilot {
169                    if matches!(this.server, CopilotServer::Disabled) {
170                        let start_task = cx
171                            .spawn({
172                                let http = http.clone();
173                                let node_runtime = node_runtime.clone();
174                                move |this, cx| {
175                                    Self::start_language_server(http, node_runtime, this, cx)
176                                }
177                            })
178                            .shared();
179                        this.server = CopilotServer::Starting { task: start_task };
180                        cx.notify();
181                    }
182                } else {
183                    this.server = CopilotServer::Disabled;
184                    cx.notify();
185                }
186            }
187        })
188        .detach();
189
190        if cx.global::<Settings>().features.copilot {
191            let start_task = cx
192                .spawn({
193                    let http = http.clone();
194                    let node_runtime = node_runtime.clone();
195                    move |this, cx| async {
196                        Self::start_language_server(http, node_runtime, this, cx).await
197                    }
198                })
199                .shared();
200
201            Self {
202                http,
203                node_runtime,
204                server: CopilotServer::Starting { task: start_task },
205            }
206        } else {
207            Self {
208                http,
209                node_runtime,
210                server: CopilotServer::Disabled,
211            }
212        }
213    }
214
215    #[cfg(any(test, feature = "test-support"))]
216    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
217        let (server, fake_server) =
218            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
219        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
220        let this = cx.add_model(|cx| Self {
221            http: http.clone(),
222            node_runtime: NodeRuntime::new(http, cx.background().clone()),
223            server: CopilotServer::Started {
224                server: Arc::new(server),
225                status: SignInStatus::Authorized,
226                subscriptions_by_buffer_id: Default::default(),
227            },
228        });
229        (this, fake_server)
230    }
231
232    fn start_language_server(
233        http: Arc<dyn HttpClient>,
234        node_runtime: Arc<NodeRuntime>,
235        this: ModelHandle<Self>,
236        mut cx: AsyncAppContext,
237    ) -> impl Future<Output = ()> {
238        async move {
239            let start_language_server = async {
240                let server_path = get_copilot_lsp(http).await?;
241                let node_path = node_runtime.binary_path().await?;
242                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
243                let server = LanguageServer::new(
244                    0,
245                    &node_path,
246                    arguments,
247                    Path::new("/"),
248                    None,
249                    cx.clone(),
250                )?;
251
252                let server = server.initialize(Default::default()).await?;
253                let status = server
254                    .request::<request::CheckStatus>(request::CheckStatusParams {
255                        local_checks_only: false,
256                    })
257                    .await?;
258
259                server
260                    .on_notification::<LogMessage, _>(|params, _cx| {
261                        match params.level {
262                            // Copilot is pretty agressive about logging
263                            0 => debug!("copilot: {}", params.message),
264                            1 => debug!("copilot: {}", params.message),
265                            _ => error!("copilot: {}", params.message),
266                        }
267
268                        debug!("copilot metadata: {}", params.metadata_str);
269                        debug!("copilot extra: {:?}", params.extra);
270                    })
271                    .detach();
272
273                server
274                    .on_notification::<StatusNotification, _>(
275                        |_, _| { /* Silence the notification */ },
276                    )
277                    .detach();
278
279                anyhow::Ok((server, status))
280            };
281
282            let server = start_language_server.await;
283            this.update(&mut cx, |this, cx| {
284                cx.notify();
285                match server {
286                    Ok((server, status)) => {
287                        this.server = CopilotServer::Started {
288                            server,
289                            status: SignInStatus::SignedOut,
290                            subscriptions_by_buffer_id: Default::default(),
291                        };
292                        this.update_sign_in_status(status, cx);
293                    }
294                    Err(error) => {
295                        this.server = CopilotServer::Error(error.to_string().into());
296                        cx.notify()
297                    }
298                }
299            })
300        }
301    }
302
303    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
304        if let CopilotServer::Started { server, status, .. } = &mut self.server {
305            let task = match status {
306                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
307                    Task::ready(Ok(())).shared()
308                }
309                SignInStatus::SigningIn { task, .. } => {
310                    cx.notify();
311                    task.clone()
312                }
313                SignInStatus::SignedOut => {
314                    let server = server.clone();
315                    let task = cx
316                        .spawn(|this, mut cx| async move {
317                            let sign_in = async {
318                                let sign_in = server
319                                    .request::<request::SignInInitiate>(
320                                        request::SignInInitiateParams {},
321                                    )
322                                    .await?;
323                                match sign_in {
324                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
325                                        Ok(request::SignInStatus::Ok { user })
326                                    }
327                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
328                                        this.update(&mut cx, |this, cx| {
329                                            if let CopilotServer::Started { status, .. } =
330                                                &mut this.server
331                                            {
332                                                if let SignInStatus::SigningIn {
333                                                    prompt: prompt_flow,
334                                                    ..
335                                                } = status
336                                                {
337                                                    *prompt_flow = Some(flow.clone());
338                                                    cx.notify();
339                                                }
340                                            }
341                                        });
342                                        let response = server
343                                            .request::<request::SignInConfirm>(
344                                                request::SignInConfirmParams {
345                                                    user_code: flow.user_code,
346                                                },
347                                            )
348                                            .await?;
349                                        Ok(response)
350                                    }
351                                }
352                            };
353
354                            let sign_in = sign_in.await;
355                            this.update(&mut cx, |this, cx| match sign_in {
356                                Ok(status) => {
357                                    this.update_sign_in_status(status, cx);
358                                    Ok(())
359                                }
360                                Err(error) => {
361                                    this.update_sign_in_status(
362                                        request::SignInStatus::NotSignedIn,
363                                        cx,
364                                    );
365                                    Err(Arc::new(error))
366                                }
367                            })
368                        })
369                        .shared();
370                    *status = SignInStatus::SigningIn {
371                        prompt: None,
372                        task: task.clone(),
373                    };
374                    cx.notify();
375                    task
376                }
377            };
378
379            cx.foreground()
380                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
381        } else {
382            // If we're downloading, wait until download is finished
383            // If we're in a stuck state, display to the user
384            Task::ready(Err(anyhow!("copilot hasn't started yet")))
385        }
386    }
387
388    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
389        if let CopilotServer::Started { server, status, .. } = &mut self.server {
390            *status = SignInStatus::SignedOut;
391            cx.notify();
392
393            let server = server.clone();
394            cx.background().spawn(async move {
395                server
396                    .request::<request::SignOut>(request::SignOutParams {})
397                    .await?;
398                anyhow::Ok(())
399            })
400        } else {
401            Task::ready(Err(anyhow!("copilot hasn't started yet")))
402        }
403    }
404
405    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
406        let start_task = cx
407            .spawn({
408                let http = self.http.clone();
409                let node_runtime = self.node_runtime.clone();
410                move |this, cx| async move {
411                    clear_copilot_dir().await;
412                    Self::start_language_server(http, node_runtime, this, cx).await
413                }
414            })
415            .shared();
416
417        self.server = CopilotServer::Starting {
418            task: start_task.clone(),
419        };
420
421        cx.notify();
422
423        cx.foreground().spawn(start_task)
424    }
425
426    pub fn completions<T>(
427        &mut self,
428        buffer: &ModelHandle<Buffer>,
429        position: T,
430        cx: &mut ModelContext<Self>,
431    ) -> Task<Result<Vec<Completion>>>
432    where
433        T: ToPointUtf16,
434    {
435        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
436    }
437
438    pub fn completions_cycling<T>(
439        &mut self,
440        buffer: &ModelHandle<Buffer>,
441        position: T,
442        cx: &mut ModelContext<Self>,
443    ) -> Task<Result<Vec<Completion>>>
444    where
445        T: ToPointUtf16,
446    {
447        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
448    }
449
450    fn request_completions<R, T>(
451        &mut self,
452        buffer: &ModelHandle<Buffer>,
453        position: T,
454        cx: &mut ModelContext<Self>,
455    ) -> Task<Result<Vec<Completion>>>
456    where
457        R: lsp::request::Request<
458            Params = request::GetCompletionsParams,
459            Result = request::GetCompletionsResult,
460        >,
461        T: ToPointUtf16,
462    {
463        let buffer_id = buffer.id();
464        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
465        let snapshot = buffer.read(cx).snapshot();
466        let server = match &mut self.server {
467            CopilotServer::Starting { .. } => {
468                return Task::ready(Err(anyhow!("copilot is still starting")))
469            }
470            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
471            CopilotServer::Error(error) => {
472                return Task::ready(Err(anyhow!(
473                    "copilot was not started because of an error: {}",
474                    error
475                )))
476            }
477            CopilotServer::Started {
478                server,
479                status,
480                subscriptions_by_buffer_id,
481            } => {
482                if matches!(status, SignInStatus::Authorized { .. }) {
483                    subscriptions_by_buffer_id
484                        .entry(buffer_id)
485                        .or_insert_with(|| {
486                            server
487                                .notify::<lsp::notification::DidOpenTextDocument>(
488                                    lsp::DidOpenTextDocumentParams {
489                                        text_document: lsp::TextDocumentItem {
490                                            uri: uri.clone(),
491                                            language_id: id_for_language(
492                                                buffer.read(cx).language(),
493                                            ),
494                                            version: 0,
495                                            text: snapshot.text(),
496                                        },
497                                    },
498                                )
499                                .log_err();
500
501                            let uri = uri.clone();
502                            cx.observe_release(buffer, move |this, _, _| {
503                                if let CopilotServer::Started {
504                                    server,
505                                    subscriptions_by_buffer_id,
506                                    ..
507                                } = &mut this.server
508                                {
509                                    server
510                                        .notify::<lsp::notification::DidCloseTextDocument>(
511                                            lsp::DidCloseTextDocumentParams {
512                                                text_document: lsp::TextDocumentIdentifier::new(
513                                                    uri.clone(),
514                                                ),
515                                            },
516                                        )
517                                        .log_err();
518                                    subscriptions_by_buffer_id.remove(&buffer_id);
519                                }
520                            })
521                        });
522
523                    server.clone()
524                } else {
525                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
526                }
527            }
528        };
529
530        let settings = cx.global::<Settings>();
531        let position = position.to_point_utf16(&snapshot);
532        let language = snapshot.language_at(position);
533        let language_name = language.map(|language| language.name());
534        let language_name = language_name.as_deref();
535        let tab_size = settings.tab_size(language_name);
536        let hard_tabs = settings.hard_tabs(language_name);
537        let language_id = id_for_language(language);
538
539        let path;
540        let relative_path;
541        if let Some(file) = snapshot.file() {
542            if let Some(file) = file.as_local() {
543                path = file.abs_path(cx);
544            } else {
545                path = file.full_path(cx);
546            }
547            relative_path = file.path().to_path_buf();
548        } else {
549            path = PathBuf::new();
550            relative_path = PathBuf::new();
551        }
552
553        cx.background().spawn(async move {
554            let result = server
555                .request::<R>(request::GetCompletionsParams {
556                    doc: request::GetCompletionsDocument {
557                        source: snapshot.text(),
558                        tab_size: tab_size.into(),
559                        indent_size: 1,
560                        insert_spaces: !hard_tabs,
561                        uri,
562                        path: path.to_string_lossy().into(),
563                        relative_path: relative_path.to_string_lossy().into(),
564                        language_id,
565                        position: point_to_lsp(position),
566                        version: 0,
567                    },
568                })
569                .await?;
570            let completions = result
571                .completions
572                .into_iter()
573                .map(|completion| {
574                    let start = snapshot
575                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
576                    let end =
577                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
578                    Completion {
579                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
580                        text: completion.text,
581                    }
582                })
583                .collect();
584            anyhow::Ok(completions)
585        })
586    }
587
588    pub fn status(&self) -> Status {
589        match &self.server {
590            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
591            CopilotServer::Disabled => Status::Disabled,
592            CopilotServer::Error(error) => Status::Error(error.clone()),
593            CopilotServer::Started { status, .. } => match status {
594                SignInStatus::Authorized { .. } => Status::Authorized,
595                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
596                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
597                    prompt: prompt.clone(),
598                },
599                SignInStatus::SignedOut => Status::SignedOut,
600            },
601        }
602    }
603
604    fn update_sign_in_status(
605        &mut self,
606        lsp_status: request::SignInStatus,
607        cx: &mut ModelContext<Self>,
608    ) {
609        if let CopilotServer::Started { status, .. } = &mut self.server {
610            *status = match lsp_status {
611                request::SignInStatus::Ok { .. }
612                | request::SignInStatus::MaybeOk { .. }
613                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
614                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
615                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
616            };
617            cx.notify();
618        }
619    }
620}
621
622fn id_for_language(language: Option<&Arc<Language>>) -> String {
623    let language_name = language.map(|language| language.name());
624    match language_name.as_deref() {
625        Some("Plain Text") => "plaintext".to_string(),
626        Some(language_name) => language_name.to_lowercase(),
627        None => "plaintext".to_string(),
628    }
629}
630
631async fn clear_copilot_dir() {
632    remove_matching(&paths::COPILOT_DIR, |_| true).await
633}
634
635async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
636    const SERVER_PATH: &'static str = "dist/agent.js";
637
638    ///Check for the latest copilot language server and download it if we haven't already
639    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
640        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
641
642        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
643
644        fs::create_dir_all(version_dir).await?;
645        let server_path = version_dir.join(SERVER_PATH);
646
647        if fs::metadata(&server_path).await.is_err() {
648            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
649            let dist_dir = version_dir.join("dist");
650            fs::create_dir_all(dist_dir.as_path()).await?;
651
652            let url = &release
653                .assets
654                .get(0)
655                .context("Github release for copilot contained no assets")?
656                .browser_download_url;
657
658            let mut response = http
659                .get(&url, Default::default(), true)
660                .await
661                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
662            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
663            let archive = Archive::new(decompressed_bytes);
664            archive.unpack(dist_dir).await?;
665
666            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
667        }
668
669        Ok(server_path)
670    }
671
672    match fetch_latest(http).await {
673        ok @ Result::Ok(..) => ok,
674        e @ Err(..) => {
675            e.log_err();
676            // Fetch a cached binary, if it exists
677            (|| async move {
678                let mut last_version_dir = None;
679                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
680                while let Some(entry) = entries.next().await {
681                    let entry = entry?;
682                    if entry.file_type().await?.is_dir() {
683                        last_version_dir = Some(entry.path());
684                    }
685                }
686                let last_version_dir =
687                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
688                let server_path = last_version_dir.join(SERVER_PATH);
689                if server_path.exists() {
690                    Ok(server_path)
691                } else {
692                    Err(anyhow!(
693                        "missing executable in directory {:?}",
694                        last_version_dir
695                    ))
696                }
697            })()
698            .await
699        }
700    }
701}