copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use collections::HashMap;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 11use log::{debug, error};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use request::{LogMessage, StatusNotification};
 15use settings::Settings;
 16use smol::{fs, io::BufReader, stream::StreamExt};
 17use std::{
 18    ffi::OsString,
 19    ops::Range,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
 25    paths, ResultExt,
 26};
 27
 28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 29actions!(copilot_auth, [SignIn, SignOut]);
 30
 31const COPILOT_NAMESPACE: &'static str = "copilot";
 32actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 33
 34pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
 35    // Disable Copilot for stable releases.
 36    if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
 37        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 38            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 39            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 40        });
 41        return;
 42    }
 43
 44    let copilot = cx.add_model({
 45        let node_runtime = node_runtime.clone();
 46        move |cx| Copilot::start(http, node_runtime, cx)
 47    });
 48    cx.set_global(copilot.clone());
 49
 50    cx.observe(&copilot, |handle, cx| {
 51        let status = handle.read(cx).status();
 52        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 53            move |filter, _cx| match status {
 54                Status::Disabled => {
 55                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 56                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 57                }
 58                Status::Authorized => {
 59                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 60                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 61                }
 62                _ => {
 63                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 64                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 65                }
 66            },
 67        );
 68    })
 69    .detach();
 70
 71    sign_in::init(cx);
 72    cx.add_global_action(|_: &SignIn, cx| {
 73        if let Some(copilot) = Copilot::global(cx) {
 74            copilot
 75                .update(cx, |copilot, cx| copilot.sign_in(cx))
 76                .detach_and_log_err(cx);
 77        }
 78    });
 79    cx.add_global_action(|_: &SignOut, cx| {
 80        if let Some(copilot) = Copilot::global(cx) {
 81            copilot
 82                .update(cx, |copilot, cx| copilot.sign_out(cx))
 83                .detach_and_log_err(cx);
 84        }
 85    });
 86
 87    cx.add_global_action(|_: &Reinstall, cx| {
 88        if let Some(copilot) = Copilot::global(cx) {
 89            copilot
 90                .update(cx, |copilot, cx| copilot.reinstall(cx))
 91                .detach();
 92        }
 93    });
 94}
 95
 96enum CopilotServer {
 97    Disabled,
 98    Starting {
 99        task: Shared<Task<()>>,
100    },
101    Error(Arc<str>),
102    Started {
103        server: Arc<LanguageServer>,
104        status: SignInStatus,
105        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
106    },
107}
108
109#[derive(Clone, Debug)]
110enum SignInStatus {
111    Authorized,
112    Unauthorized,
113    SigningIn {
114        prompt: Option<request::PromptUserDeviceFlow>,
115        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
116    },
117    SignedOut,
118}
119
120#[derive(Debug, Clone)]
121pub enum Status {
122    Starting {
123        task: Shared<Task<()>>,
124    },
125    Error(Arc<str>),
126    Disabled,
127    SignedOut,
128    SigningIn {
129        prompt: Option<request::PromptUserDeviceFlow>,
130    },
131    Unauthorized,
132    Authorized,
133}
134
135impl Status {
136    pub fn is_authorized(&self) -> bool {
137        matches!(self, Status::Authorized)
138    }
139}
140
141#[derive(Debug, PartialEq, Eq)]
142pub struct Completion {
143    pub range: Range<Anchor>,
144    pub text: String,
145}
146
147pub struct Copilot {
148    http: Arc<dyn HttpClient>,
149    node_runtime: Arc<NodeRuntime>,
150    server: CopilotServer,
151}
152
153impl Entity for Copilot {
154    type Event = ();
155}
156
157impl Copilot {
158    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
159        if cx.has_global::<ModelHandle<Self>>() {
160            Some(cx.global::<ModelHandle<Self>>().clone())
161        } else {
162            None
163        }
164    }
165
166    fn start(
167        http: Arc<dyn HttpClient>,
168        node_runtime: Arc<NodeRuntime>,
169        cx: &mut ModelContext<Self>,
170    ) -> Self {
171        cx.observe_global::<Settings, _>({
172            let http = http.clone();
173            let node_runtime = node_runtime.clone();
174            move |this, cx| {
175                if cx.global::<Settings>().enable_copilot_integration {
176                    if matches!(this.server, CopilotServer::Disabled) {
177                        let start_task = cx
178                            .spawn({
179                                let http = http.clone();
180                                let node_runtime = node_runtime.clone();
181                                move |this, cx| {
182                                    Self::start_language_server(http, node_runtime, this, cx)
183                                }
184                            })
185                            .shared();
186                        this.server = CopilotServer::Starting { task: start_task };
187                        cx.notify();
188                    }
189                } else {
190                    this.server = CopilotServer::Disabled;
191                    cx.notify();
192                }
193            }
194        })
195        .detach();
196
197        if cx.global::<Settings>().enable_copilot_integration {
198            let start_task = cx
199                .spawn({
200                    let http = http.clone();
201                    let node_runtime = node_runtime.clone();
202                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
203                })
204                .shared();
205
206            Self {
207                http,
208                node_runtime,
209                server: CopilotServer::Starting { task: start_task },
210            }
211        } else {
212            Self {
213                http,
214                node_runtime,
215                server: CopilotServer::Disabled,
216            }
217        }
218    }
219
220    #[cfg(any(test, feature = "test-support"))]
221    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
222        let (server, fake_server) =
223            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
224        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
225        let this = cx.add_model(|cx| Self {
226            http: http.clone(),
227            node_runtime: NodeRuntime::new(http, cx.background().clone()),
228            server: CopilotServer::Started {
229                server: Arc::new(server),
230                status: SignInStatus::Authorized,
231                subscriptions_by_buffer_id: Default::default(),
232            },
233        });
234        (this, fake_server)
235    }
236
237    fn start_language_server(
238        http: Arc<dyn HttpClient>,
239        node_runtime: Arc<NodeRuntime>,
240        this: ModelHandle<Self>,
241        mut cx: AsyncAppContext,
242    ) -> impl Future<Output = ()> {
243        async move {
244            let start_language_server = async {
245                let server_path = get_copilot_lsp(http).await?;
246                let node_path = node_runtime.binary_path().await?;
247                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
248                let server = LanguageServer::new(
249                    0,
250                    &node_path,
251                    arguments,
252                    Path::new("/"),
253                    None,
254                    cx.clone(),
255                )?;
256
257                let server = server.initialize(Default::default()).await?;
258                let status = server
259                    .request::<request::CheckStatus>(request::CheckStatusParams {
260                        local_checks_only: false,
261                    })
262                    .await?;
263
264                server
265                    .on_notification::<LogMessage, _>(|params, _cx| {
266                        match params.level {
267                            // Copilot is pretty agressive about logging
268                            0 => debug!("copilot: {}", params.message),
269                            1 => debug!("copilot: {}", params.message),
270                            _ => error!("copilot: {}", params.message),
271                        }
272
273                        debug!("copilot metadata: {}", params.metadata_str);
274                        debug!("copilot extra: {:?}", params.extra);
275                    })
276                    .detach();
277
278                server
279                    .on_notification::<StatusNotification, _>(
280                        |_, _| { /* Silence the notification */ },
281                    )
282                    .detach();
283
284                anyhow::Ok((server, status))
285            };
286
287            let server = start_language_server.await;
288            this.update(&mut cx, |this, cx| {
289                cx.notify();
290                match server {
291                    Ok((server, status)) => {
292                        this.server = CopilotServer::Started {
293                            server,
294                            status: SignInStatus::SignedOut,
295                            subscriptions_by_buffer_id: Default::default(),
296                        };
297                        this.update_sign_in_status(status, cx);
298                    }
299                    Err(error) => {
300                        this.server = CopilotServer::Error(error.to_string().into());
301                        cx.notify()
302                    }
303                }
304            })
305        }
306    }
307
308    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
309        if let CopilotServer::Started { server, status, .. } = &mut self.server {
310            let task = match status {
311                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
312                    Task::ready(Ok(())).shared()
313                }
314                SignInStatus::SigningIn { task, .. } => {
315                    cx.notify();
316                    task.clone()
317                }
318                SignInStatus::SignedOut => {
319                    let server = server.clone();
320                    let task = cx
321                        .spawn(|this, mut cx| async move {
322                            let sign_in = async {
323                                let sign_in = server
324                                    .request::<request::SignInInitiate>(
325                                        request::SignInInitiateParams {},
326                                    )
327                                    .await?;
328                                match sign_in {
329                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
330                                        Ok(request::SignInStatus::Ok { user })
331                                    }
332                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
333                                        this.update(&mut cx, |this, cx| {
334                                            if let CopilotServer::Started { status, .. } =
335                                                &mut this.server
336                                            {
337                                                if let SignInStatus::SigningIn {
338                                                    prompt: prompt_flow,
339                                                    ..
340                                                } = status
341                                                {
342                                                    *prompt_flow = Some(flow.clone());
343                                                    cx.notify();
344                                                }
345                                            }
346                                        });
347                                        let response = server
348                                            .request::<request::SignInConfirm>(
349                                                request::SignInConfirmParams {
350                                                    user_code: flow.user_code,
351                                                },
352                                            )
353                                            .await?;
354                                        Ok(response)
355                                    }
356                                }
357                            };
358
359                            let sign_in = sign_in.await;
360                            this.update(&mut cx, |this, cx| match sign_in {
361                                Ok(status) => {
362                                    this.update_sign_in_status(status, cx);
363                                    Ok(())
364                                }
365                                Err(error) => {
366                                    this.update_sign_in_status(
367                                        request::SignInStatus::NotSignedIn,
368                                        cx,
369                                    );
370                                    Err(Arc::new(error))
371                                }
372                            })
373                        })
374                        .shared();
375                    *status = SignInStatus::SigningIn {
376                        prompt: None,
377                        task: task.clone(),
378                    };
379                    cx.notify();
380                    task
381                }
382            };
383
384            cx.foreground()
385                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
386        } else {
387            // If we're downloading, wait until download is finished
388            // If we're in a stuck state, display to the user
389            Task::ready(Err(anyhow!("copilot hasn't started yet")))
390        }
391    }
392
393    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
394        if let CopilotServer::Started { server, status, .. } = &mut self.server {
395            *status = SignInStatus::SignedOut;
396            cx.notify();
397
398            let server = server.clone();
399            cx.background().spawn(async move {
400                server
401                    .request::<request::SignOut>(request::SignOutParams {})
402                    .await?;
403                anyhow::Ok(())
404            })
405        } else {
406            Task::ready(Err(anyhow!("copilot hasn't started yet")))
407        }
408    }
409
410    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
411        let start_task = cx
412            .spawn({
413                let http = self.http.clone();
414                let node_runtime = self.node_runtime.clone();
415                move |this, cx| async move {
416                    clear_copilot_dir().await;
417                    Self::start_language_server(http, node_runtime, this, cx).await
418                }
419            })
420            .shared();
421
422        self.server = CopilotServer::Starting {
423            task: start_task.clone(),
424        };
425
426        cx.notify();
427
428        cx.foreground().spawn(start_task)
429    }
430
431    pub fn completions<T>(
432        &mut self,
433        buffer: &ModelHandle<Buffer>,
434        position: T,
435        cx: &mut ModelContext<Self>,
436    ) -> Task<Result<Vec<Completion>>>
437    where
438        T: ToPointUtf16,
439    {
440        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
441    }
442
443    pub fn completions_cycling<T>(
444        &mut self,
445        buffer: &ModelHandle<Buffer>,
446        position: T,
447        cx: &mut ModelContext<Self>,
448    ) -> Task<Result<Vec<Completion>>>
449    where
450        T: ToPointUtf16,
451    {
452        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
453    }
454
455    fn request_completions<R, T>(
456        &mut self,
457        buffer: &ModelHandle<Buffer>,
458        position: T,
459        cx: &mut ModelContext<Self>,
460    ) -> Task<Result<Vec<Completion>>>
461    where
462        R: lsp::request::Request<
463            Params = request::GetCompletionsParams,
464            Result = request::GetCompletionsResult,
465        >,
466        T: ToPointUtf16,
467    {
468        let buffer_id = buffer.id();
469        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
470        let snapshot = buffer.read(cx).snapshot();
471        let server = match &mut self.server {
472            CopilotServer::Starting { .. } => {
473                return Task::ready(Err(anyhow!("copilot is still starting")))
474            }
475            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
476            CopilotServer::Error(error) => {
477                return Task::ready(Err(anyhow!(
478                    "copilot was not started because of an error: {}",
479                    error
480                )))
481            }
482            CopilotServer::Started {
483                server,
484                status,
485                subscriptions_by_buffer_id,
486            } => {
487                if matches!(status, SignInStatus::Authorized { .. }) {
488                    subscriptions_by_buffer_id
489                        .entry(buffer_id)
490                        .or_insert_with(|| {
491                            server
492                                .notify::<lsp::notification::DidOpenTextDocument>(
493                                    lsp::DidOpenTextDocumentParams {
494                                        text_document: lsp::TextDocumentItem {
495                                            uri: uri.clone(),
496                                            language_id: id_for_language(
497                                                buffer.read(cx).language(),
498                                            ),
499                                            version: 0,
500                                            text: snapshot.text(),
501                                        },
502                                    },
503                                )
504                                .log_err();
505
506                            let uri = uri.clone();
507                            cx.observe_release(buffer, move |this, _, _| {
508                                if let CopilotServer::Started {
509                                    server,
510                                    subscriptions_by_buffer_id,
511                                    ..
512                                } = &mut this.server
513                                {
514                                    server
515                                        .notify::<lsp::notification::DidCloseTextDocument>(
516                                            lsp::DidCloseTextDocumentParams {
517                                                text_document: lsp::TextDocumentIdentifier::new(
518                                                    uri.clone(),
519                                                ),
520                                            },
521                                        )
522                                        .log_err();
523                                    subscriptions_by_buffer_id.remove(&buffer_id);
524                                }
525                            })
526                        });
527
528                    server.clone()
529                } else {
530                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
531                }
532            }
533        };
534
535        let settings = cx.global::<Settings>();
536        let position = position.to_point_utf16(&snapshot);
537        let language = snapshot.language_at(position);
538        let language_name = language.map(|language| language.name());
539        let language_name = language_name.as_deref();
540        let tab_size = settings.tab_size(language_name);
541        let hard_tabs = settings.hard_tabs(language_name);
542        let language_id = id_for_language(language);
543
544        let path;
545        let relative_path;
546        if let Some(file) = snapshot.file() {
547            if let Some(file) = file.as_local() {
548                path = file.abs_path(cx);
549            } else {
550                path = file.full_path(cx);
551            }
552            relative_path = file.path().to_path_buf();
553        } else {
554            path = PathBuf::new();
555            relative_path = PathBuf::new();
556        }
557
558        cx.background().spawn(async move {
559            let result = server
560                .request::<R>(request::GetCompletionsParams {
561                    doc: request::GetCompletionsDocument {
562                        source: snapshot.text(),
563                        tab_size: tab_size.into(),
564                        indent_size: 1,
565                        insert_spaces: !hard_tabs,
566                        uri,
567                        path: path.to_string_lossy().into(),
568                        relative_path: relative_path.to_string_lossy().into(),
569                        language_id,
570                        position: point_to_lsp(position),
571                        version: 0,
572                    },
573                })
574                .await?;
575            let completions = result
576                .completions
577                .into_iter()
578                .map(|completion| {
579                    let start = snapshot
580                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
581                    let end =
582                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
583                    Completion {
584                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
585                        text: completion.text,
586                    }
587                })
588                .collect();
589            anyhow::Ok(completions)
590        })
591    }
592
593    pub fn status(&self) -> Status {
594        match &self.server {
595            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
596            CopilotServer::Disabled => Status::Disabled,
597            CopilotServer::Error(error) => Status::Error(error.clone()),
598            CopilotServer::Started { status, .. } => match status {
599                SignInStatus::Authorized { .. } => Status::Authorized,
600                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
601                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
602                    prompt: prompt.clone(),
603                },
604                SignInStatus::SignedOut => Status::SignedOut,
605            },
606        }
607    }
608
609    fn update_sign_in_status(
610        &mut self,
611        lsp_status: request::SignInStatus,
612        cx: &mut ModelContext<Self>,
613    ) {
614        if let CopilotServer::Started { status, .. } = &mut self.server {
615            *status = match lsp_status {
616                request::SignInStatus::Ok { .. }
617                | request::SignInStatus::MaybeOk { .. }
618                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
619                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
620                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
621            };
622            cx.notify();
623        }
624    }
625}
626
627fn id_for_language(language: Option<&Arc<Language>>) -> String {
628    let language_name = language.map(|language| language.name());
629    match language_name.as_deref() {
630        Some("Plain Text") => "plaintext".to_string(),
631        Some(language_name) => language_name.to_lowercase(),
632        None => "plaintext".to_string(),
633    }
634}
635
636async fn clear_copilot_dir() {
637    remove_matching(&paths::COPILOT_DIR, |_| true).await
638}
639
640async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
641    const SERVER_PATH: &'static str = "dist/agent.js";
642
643    ///Check for the latest copilot language server and download it if we haven't already
644    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
645        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
646
647        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
648
649        fs::create_dir_all(version_dir).await?;
650        let server_path = version_dir.join(SERVER_PATH);
651
652        if fs::metadata(&server_path).await.is_err() {
653            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
654            let dist_dir = version_dir.join("dist");
655            fs::create_dir_all(dist_dir.as_path()).await?;
656
657            let url = &release
658                .assets
659                .get(0)
660                .context("Github release for copilot contained no assets")?
661                .browser_download_url;
662
663            let mut response = http
664                .get(&url, Default::default(), true)
665                .await
666                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
667            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
668            let archive = Archive::new(decompressed_bytes);
669            archive.unpack(dist_dir).await?;
670
671            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
672        }
673
674        Ok(server_path)
675    }
676
677    match fetch_latest(http).await {
678        ok @ Result::Ok(..) => ok,
679        e @ Err(..) => {
680            e.log_err();
681            // Fetch a cached binary, if it exists
682            (|| async move {
683                let mut last_version_dir = None;
684                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
685                while let Some(entry) = entries.next().await {
686                    let entry = entry?;
687                    if entry.file_type().await?.is_dir() {
688                        last_version_dir = Some(entry.path());
689                    }
690                }
691                let last_version_dir =
692                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
693                let server_path = last_version_dir.join(SERVER_PATH);
694                if server_path.exists() {
695                    Ok(server_path)
696                } else {
697                    Err(anyhow!(
698                        "missing executable in directory {:?}",
699                        last_version_dir
700                    ))
701                }
702            })()
703            .await
704        }
705    }
706}