copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use collections::HashMap;
  9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 12    Task,
 13};
 14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 15use log::{debug, error};
 16use lsp::LanguageServer;
 17use node_runtime::NodeRuntime;
 18use request::{LogMessage, StatusNotification};
 19use settings::Settings;
 20use smol::{fs, io::BufReader, stream::StreamExt};
 21use staff_mode::{not_staff_mode, staff_mode};
 22
 23use std::{
 24    ffi::OsString,
 25    ops::Range,
 26    path::{Path, PathBuf},
 27    sync::Arc,
 28};
 29use util::{
 30    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 31};
 32
 33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 34actions!(copilot_auth, [SignIn, SignOut]);
 35
 36const COPILOT_NAMESPACE: &'static str = "copilot";
 37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 38
 39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 40    staff_mode(cx, {
 41        move |cx| {
 42            cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 43                filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 44                filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 45            });
 46
 47            let copilot = cx.add_model({
 48                let node_runtime = node_runtime.clone();
 49                let http = client.http_client().clone();
 50                move |cx| Copilot::start(http, node_runtime, cx)
 51            });
 52            cx.set_global(copilot.clone());
 53
 54            observe_namespaces(cx, copilot);
 55
 56            sign_in::init(cx);
 57        }
 58    });
 59    not_staff_mode(cx, |cx| {
 60        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 61            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 62            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 63        });
 64    });
 65
 66    cx.add_global_action(|_: &SignIn, cx| {
 67        if let Some(copilot) = Copilot::global(cx) {
 68            copilot
 69                .update(cx, |copilot, cx| copilot.sign_in(cx))
 70                .detach_and_log_err(cx);
 71        }
 72    });
 73    cx.add_global_action(|_: &SignOut, cx| {
 74        if let Some(copilot) = Copilot::global(cx) {
 75            copilot
 76                .update(cx, |copilot, cx| copilot.sign_out(cx))
 77                .detach_and_log_err(cx);
 78        }
 79    });
 80
 81    cx.add_global_action(|_: &Reinstall, cx| {
 82        if let Some(copilot) = Copilot::global(cx) {
 83            copilot
 84                .update(cx, |copilot, cx| copilot.reinstall(cx))
 85                .detach();
 86        }
 87    });
 88}
 89
 90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
 91    cx.observe(&copilot, |handle, cx| {
 92        let status = handle.read(cx).status();
 93        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 94            move |filter, _cx| match status {
 95                Status::Disabled => {
 96                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 97                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 98                }
 99                Status::Authorized => {
100                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102                }
103                _ => {
104                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106                }
107            },
108        );
109    })
110    .detach();
111}
112
113enum CopilotServer {
114    Disabled,
115    Starting {
116        task: Shared<Task<()>>,
117    },
118    Error(Arc<str>),
119    Started {
120        server: Arc<LanguageServer>,
121        status: SignInStatus,
122        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123    },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128    Authorized,
129    Unauthorized,
130    SigningIn {
131        prompt: Option<request::PromptUserDeviceFlow>,
132        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
133    },
134    SignedOut,
135}
136
137#[derive(Debug, Clone)]
138pub enum Status {
139    Starting {
140        task: Shared<Task<()>>,
141    },
142    Error(Arc<str>),
143    Disabled,
144    SignedOut,
145    SigningIn {
146        prompt: Option<request::PromptUserDeviceFlow>,
147    },
148    Unauthorized,
149    Authorized,
150}
151
152impl Status {
153    pub fn is_authorized(&self) -> bool {
154        matches!(self, Status::Authorized)
155    }
156}
157
158#[derive(Debug, PartialEq, Eq)]
159pub struct Completion {
160    pub range: Range<Anchor>,
161    pub text: String,
162}
163
164pub struct Copilot {
165    http: Arc<dyn HttpClient>,
166    node_runtime: Arc<NodeRuntime>,
167    server: CopilotServer,
168}
169
170impl Entity for Copilot {
171    type Event = ();
172}
173
174impl Copilot {
175    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
176        if cx.has_global::<ModelHandle<Self>>() {
177            Some(cx.global::<ModelHandle<Self>>().clone())
178        } else {
179            None
180        }
181    }
182
183    fn start(
184        http: Arc<dyn HttpClient>,
185        node_runtime: Arc<NodeRuntime>,
186        cx: &mut ModelContext<Self>,
187    ) -> Self {
188        cx.observe_global::<Settings, _>({
189            let http = http.clone();
190            let node_runtime = node_runtime.clone();
191            move |this, cx| {
192                if cx.global::<Settings>().enable_copilot_integration {
193                    if matches!(this.server, CopilotServer::Disabled) {
194                        let start_task = cx
195                            .spawn({
196                                let http = http.clone();
197                                let node_runtime = node_runtime.clone();
198                                move |this, cx| {
199                                    Self::start_language_server(http, node_runtime, this, cx)
200                                }
201                            })
202                            .shared();
203                        this.server = CopilotServer::Starting { task: start_task };
204                        cx.notify();
205                    }
206                } else {
207                    this.server = CopilotServer::Disabled;
208                    cx.notify();
209                }
210            }
211        })
212        .detach();
213
214        if cx.global::<Settings>().enable_copilot_integration {
215            let start_task = cx
216                .spawn({
217                    let http = http.clone();
218                    let node_runtime = node_runtime.clone();
219                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
220                })
221                .shared();
222
223            Self {
224                http,
225                node_runtime,
226                server: CopilotServer::Starting { task: start_task },
227            }
228        } else {
229            Self {
230                http,
231                node_runtime,
232                server: CopilotServer::Disabled,
233            }
234        }
235    }
236
237    #[cfg(any(test, feature = "test-support"))]
238    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
239        let (server, fake_server) =
240            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
241        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
242        let this = cx.add_model(|cx| Self {
243            http: http.clone(),
244            node_runtime: NodeRuntime::new(http, cx.background().clone()),
245            server: CopilotServer::Started {
246                server: Arc::new(server),
247                status: SignInStatus::Authorized,
248                subscriptions_by_buffer_id: Default::default(),
249            },
250        });
251        (this, fake_server)
252    }
253
254    fn start_language_server(
255        http: Arc<dyn HttpClient>,
256        node_runtime: Arc<NodeRuntime>,
257        this: ModelHandle<Self>,
258        mut cx: AsyncAppContext,
259    ) -> impl Future<Output = ()> {
260        async move {
261            let start_language_server = async {
262                let server_path = get_copilot_lsp(http).await?;
263                let node_path = node_runtime.binary_path().await?;
264                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
265                let server = LanguageServer::new(
266                    0,
267                    &node_path,
268                    arguments,
269                    Path::new("/"),
270                    None,
271                    cx.clone(),
272                )?;
273
274                let server = server.initialize(Default::default()).await?;
275                let status = server
276                    .request::<request::CheckStatus>(request::CheckStatusParams {
277                        local_checks_only: false,
278                    })
279                    .await?;
280
281                server
282                    .on_notification::<LogMessage, _>(|params, _cx| {
283                        match params.level {
284                            // Copilot is pretty agressive about logging
285                            0 => debug!("copilot: {}", params.message),
286                            1 => debug!("copilot: {}", params.message),
287                            _ => error!("copilot: {}", params.message),
288                        }
289
290                        debug!("copilot metadata: {}", params.metadata_str);
291                        debug!("copilot extra: {:?}", params.extra);
292                    })
293                    .detach();
294
295                server
296                    .on_notification::<StatusNotification, _>(
297                        |_, _| { /* Silence the notification */ },
298                    )
299                    .detach();
300
301                anyhow::Ok((server, status))
302            };
303
304            let server = start_language_server.await;
305            this.update(&mut cx, |this, cx| {
306                cx.notify();
307                match server {
308                    Ok((server, status)) => {
309                        this.server = CopilotServer::Started {
310                            server,
311                            status: SignInStatus::SignedOut,
312                            subscriptions_by_buffer_id: Default::default(),
313                        };
314                        this.update_sign_in_status(status, cx);
315                    }
316                    Err(error) => {
317                        this.server = CopilotServer::Error(error.to_string().into());
318                        cx.notify()
319                    }
320                }
321            })
322        }
323    }
324
325    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
326        if let CopilotServer::Started { server, status, .. } = &mut self.server {
327            let task = match status {
328                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
329                    Task::ready(Ok(())).shared()
330                }
331                SignInStatus::SigningIn { task, .. } => {
332                    cx.notify();
333                    task.clone()
334                }
335                SignInStatus::SignedOut => {
336                    let server = server.clone();
337                    let task = cx
338                        .spawn(|this, mut cx| async move {
339                            let sign_in = async {
340                                let sign_in = server
341                                    .request::<request::SignInInitiate>(
342                                        request::SignInInitiateParams {},
343                                    )
344                                    .await?;
345                                match sign_in {
346                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
347                                        Ok(request::SignInStatus::Ok { user })
348                                    }
349                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
350                                        this.update(&mut cx, |this, cx| {
351                                            if let CopilotServer::Started { status, .. } =
352                                                &mut this.server
353                                            {
354                                                if let SignInStatus::SigningIn {
355                                                    prompt: prompt_flow,
356                                                    ..
357                                                } = status
358                                                {
359                                                    *prompt_flow = Some(flow.clone());
360                                                    cx.notify();
361                                                }
362                                            }
363                                        });
364                                        let response = server
365                                            .request::<request::SignInConfirm>(
366                                                request::SignInConfirmParams {
367                                                    user_code: flow.user_code,
368                                                },
369                                            )
370                                            .await?;
371                                        Ok(response)
372                                    }
373                                }
374                            };
375
376                            let sign_in = sign_in.await;
377                            this.update(&mut cx, |this, cx| match sign_in {
378                                Ok(status) => {
379                                    this.update_sign_in_status(status, cx);
380                                    Ok(())
381                                }
382                                Err(error) => {
383                                    this.update_sign_in_status(
384                                        request::SignInStatus::NotSignedIn,
385                                        cx,
386                                    );
387                                    Err(Arc::new(error))
388                                }
389                            })
390                        })
391                        .shared();
392                    *status = SignInStatus::SigningIn {
393                        prompt: None,
394                        task: task.clone(),
395                    };
396                    cx.notify();
397                    task
398                }
399            };
400
401            cx.foreground()
402                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
403        } else {
404            // If we're downloading, wait until download is finished
405            // If we're in a stuck state, display to the user
406            Task::ready(Err(anyhow!("copilot hasn't started yet")))
407        }
408    }
409
410    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
411        if let CopilotServer::Started { server, status, .. } = &mut self.server {
412            *status = SignInStatus::SignedOut;
413            cx.notify();
414
415            let server = server.clone();
416            cx.background().spawn(async move {
417                server
418                    .request::<request::SignOut>(request::SignOutParams {})
419                    .await?;
420                anyhow::Ok(())
421            })
422        } else {
423            Task::ready(Err(anyhow!("copilot hasn't started yet")))
424        }
425    }
426
427    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
428        let start_task = cx
429            .spawn({
430                let http = self.http.clone();
431                let node_runtime = self.node_runtime.clone();
432                move |this, cx| async move {
433                    clear_copilot_dir().await;
434                    Self::start_language_server(http, node_runtime, this, cx).await
435                }
436            })
437            .shared();
438
439        self.server = CopilotServer::Starting {
440            task: start_task.clone(),
441        };
442
443        cx.notify();
444
445        cx.foreground().spawn(start_task)
446    }
447
448    pub fn completions<T>(
449        &mut self,
450        buffer: &ModelHandle<Buffer>,
451        position: T,
452        cx: &mut ModelContext<Self>,
453    ) -> Task<Result<Vec<Completion>>>
454    where
455        T: ToPointUtf16,
456    {
457        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
458    }
459
460    pub fn completions_cycling<T>(
461        &mut self,
462        buffer: &ModelHandle<Buffer>,
463        position: T,
464        cx: &mut ModelContext<Self>,
465    ) -> Task<Result<Vec<Completion>>>
466    where
467        T: ToPointUtf16,
468    {
469        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
470    }
471
472    fn request_completions<R, T>(
473        &mut self,
474        buffer: &ModelHandle<Buffer>,
475        position: T,
476        cx: &mut ModelContext<Self>,
477    ) -> Task<Result<Vec<Completion>>>
478    where
479        R: lsp::request::Request<
480            Params = request::GetCompletionsParams,
481            Result = request::GetCompletionsResult,
482        >,
483        T: ToPointUtf16,
484    {
485        let buffer_id = buffer.id();
486        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
487        let snapshot = buffer.read(cx).snapshot();
488        let server = match &mut self.server {
489            CopilotServer::Starting { .. } => {
490                return Task::ready(Err(anyhow!("copilot is still starting")))
491            }
492            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
493            CopilotServer::Error(error) => {
494                return Task::ready(Err(anyhow!(
495                    "copilot was not started because of an error: {}",
496                    error
497                )))
498            }
499            CopilotServer::Started {
500                server,
501                status,
502                subscriptions_by_buffer_id,
503            } => {
504                if matches!(status, SignInStatus::Authorized { .. }) {
505                    subscriptions_by_buffer_id
506                        .entry(buffer_id)
507                        .or_insert_with(|| {
508                            server
509                                .notify::<lsp::notification::DidOpenTextDocument>(
510                                    lsp::DidOpenTextDocumentParams {
511                                        text_document: lsp::TextDocumentItem {
512                                            uri: uri.clone(),
513                                            language_id: id_for_language(
514                                                buffer.read(cx).language(),
515                                            ),
516                                            version: 0,
517                                            text: snapshot.text(),
518                                        },
519                                    },
520                                )
521                                .log_err();
522
523                            let uri = uri.clone();
524                            cx.observe_release(buffer, move |this, _, _| {
525                                if let CopilotServer::Started {
526                                    server,
527                                    subscriptions_by_buffer_id,
528                                    ..
529                                } = &mut this.server
530                                {
531                                    server
532                                        .notify::<lsp::notification::DidCloseTextDocument>(
533                                            lsp::DidCloseTextDocumentParams {
534                                                text_document: lsp::TextDocumentIdentifier::new(
535                                                    uri.clone(),
536                                                ),
537                                            },
538                                        )
539                                        .log_err();
540                                    subscriptions_by_buffer_id.remove(&buffer_id);
541                                }
542                            })
543                        });
544
545                    server.clone()
546                } else {
547                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
548                }
549            }
550        };
551
552        let settings = cx.global::<Settings>();
553        let position = position.to_point_utf16(&snapshot);
554        let language = snapshot.language_at(position);
555        let language_name = language.map(|language| language.name());
556        let language_name = language_name.as_deref();
557        let tab_size = settings.tab_size(language_name);
558        let hard_tabs = settings.hard_tabs(language_name);
559        let language_id = id_for_language(language);
560
561        let path;
562        let relative_path;
563        if let Some(file) = snapshot.file() {
564            if let Some(file) = file.as_local() {
565                path = file.abs_path(cx);
566            } else {
567                path = file.full_path(cx);
568            }
569            relative_path = file.path().to_path_buf();
570        } else {
571            path = PathBuf::new();
572            relative_path = PathBuf::new();
573        }
574
575        cx.background().spawn(async move {
576            let result = server
577                .request::<R>(request::GetCompletionsParams {
578                    doc: request::GetCompletionsDocument {
579                        source: snapshot.text(),
580                        tab_size: tab_size.into(),
581                        indent_size: 1,
582                        insert_spaces: !hard_tabs,
583                        uri,
584                        path: path.to_string_lossy().into(),
585                        relative_path: relative_path.to_string_lossy().into(),
586                        language_id,
587                        position: point_to_lsp(position),
588                        version: 0,
589                    },
590                })
591                .await?;
592            let completions = result
593                .completions
594                .into_iter()
595                .map(|completion| {
596                    let start = snapshot
597                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
598                    let end =
599                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
600                    Completion {
601                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
602                        text: completion.text,
603                    }
604                })
605                .collect();
606            anyhow::Ok(completions)
607        })
608    }
609
610    pub fn status(&self) -> Status {
611        match &self.server {
612            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
613            CopilotServer::Disabled => Status::Disabled,
614            CopilotServer::Error(error) => Status::Error(error.clone()),
615            CopilotServer::Started { status, .. } => match status {
616                SignInStatus::Authorized { .. } => Status::Authorized,
617                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
618                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
619                    prompt: prompt.clone(),
620                },
621                SignInStatus::SignedOut => Status::SignedOut,
622            },
623        }
624    }
625
626    fn update_sign_in_status(
627        &mut self,
628        lsp_status: request::SignInStatus,
629        cx: &mut ModelContext<Self>,
630    ) {
631        if let CopilotServer::Started { status, .. } = &mut self.server {
632            *status = match lsp_status {
633                request::SignInStatus::Ok { .. }
634                | request::SignInStatus::MaybeOk { .. }
635                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
636                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
637                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
638            };
639            cx.notify();
640        }
641    }
642}
643
644fn id_for_language(language: Option<&Arc<Language>>) -> String {
645    let language_name = language.map(|language| language.name());
646    match language_name.as_deref() {
647        Some("Plain Text") => "plaintext".to_string(),
648        Some(language_name) => language_name.to_lowercase(),
649        None => "plaintext".to_string(),
650    }
651}
652
653async fn clear_copilot_dir() {
654    remove_matching(&paths::COPILOT_DIR, |_| true).await
655}
656
657async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
658    const SERVER_PATH: &'static str = "dist/agent.js";
659
660    ///Check for the latest copilot language server and download it if we haven't already
661    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
662        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
663
664        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
665
666        fs::create_dir_all(version_dir).await?;
667        let server_path = version_dir.join(SERVER_PATH);
668
669        if fs::metadata(&server_path).await.is_err() {
670            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
671            let dist_dir = version_dir.join("dist");
672            fs::create_dir_all(dist_dir.as_path()).await?;
673
674            let url = &release
675                .assets
676                .get(0)
677                .context("Github release for copilot contained no assets")?
678                .browser_download_url;
679
680            let mut response = http
681                .get(&url, Default::default(), true)
682                .await
683                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
684            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
685            let archive = Archive::new(decompressed_bytes);
686            archive.unpack(dist_dir).await?;
687
688            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
689        }
690
691        Ok(server_path)
692    }
693
694    match fetch_latest(http).await {
695        ok @ Result::Ok(..) => ok,
696        e @ Err(..) => {
697            e.log_err();
698            // Fetch a cached binary, if it exists
699            (|| async move {
700                let mut last_version_dir = None;
701                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
702                while let Some(entry) = entries.next().await {
703                    let entry = entry?;
704                    if entry.file_type().await?.is_dir() {
705                        last_version_dir = Some(entry.path());
706                    }
707                }
708                let last_version_dir =
709                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
710                let server_path = last_version_dir.join(SERVER_PATH);
711                if server_path.exists() {
712                    Ok(server_path)
713                } else {
714                    Err(anyhow!(
715                        "missing executable in directory {:?}",
716                        last_version_dir
717                    ))
718                }
719            })()
720            .await
721        }
722    }
723}