copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use collections::HashMap;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 11use log::{debug, error};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use request::{LogMessage, StatusNotification};
 15use settings::Settings;
 16use smol::{fs, io::BufReader, stream::StreamExt};
 17use std::{
 18    ffi::OsString,
 19    ops::Range,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
 25    paths, ResultExt,
 26};
 27
 28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 29actions!(copilot_auth, [SignIn, SignOut]);
 30
 31const COPILOT_NAMESPACE: &'static str = "copilot";
 32actions!(
 33    copilot,
 34    [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
 35);
 36
 37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
 38    // Disable Copilot for stable releases.
 39    if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
 40        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 41            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 42            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 43        });
 44        return;
 45    }
 46
 47    let copilot = cx.add_model({
 48        let node_runtime = node_runtime.clone();
 49        move |cx| Copilot::start(http, node_runtime, cx)
 50    });
 51    cx.set_global(copilot.clone());
 52
 53    cx.observe(&copilot, |handle, cx| {
 54        let status = handle.read(cx).status();
 55        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 56            move |filter, _cx| match status {
 57                Status::Disabled => {
 58                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 60                }
 61                Status::Authorized => {
 62                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 63                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 64                }
 65                _ => {
 66                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 67                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 68                }
 69            },
 70        );
 71    })
 72    .detach();
 73
 74    sign_in::init(cx);
 75    cx.add_global_action(|_: &SignIn, cx| {
 76        if let Some(copilot) = Copilot::global(cx) {
 77            copilot
 78                .update(cx, |copilot, cx| copilot.sign_in(cx))
 79                .detach_and_log_err(cx);
 80        }
 81    });
 82    cx.add_global_action(|_: &SignOut, cx| {
 83        if let Some(copilot) = Copilot::global(cx) {
 84            copilot
 85                .update(cx, |copilot, cx| copilot.sign_out(cx))
 86                .detach_and_log_err(cx);
 87        }
 88    });
 89
 90    cx.add_global_action(|_: &Reinstall, cx| {
 91        if let Some(copilot) = Copilot::global(cx) {
 92            copilot
 93                .update(cx, |copilot, cx| copilot.reinstall(cx))
 94                .detach();
 95        }
 96    });
 97}
 98
 99enum CopilotServer {
100    Disabled,
101    Starting {
102        task: Shared<Task<()>>,
103    },
104    Error(Arc<str>),
105    Started {
106        server: Arc<LanguageServer>,
107        status: SignInStatus,
108        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
109    },
110}
111
112#[derive(Clone, Debug)]
113enum SignInStatus {
114    Authorized,
115    Unauthorized,
116    SigningIn {
117        prompt: Option<request::PromptUserDeviceFlow>,
118        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
119    },
120    SignedOut,
121}
122
123#[derive(Debug, Clone)]
124pub enum Status {
125    Starting {
126        task: Shared<Task<()>>,
127    },
128    Error(Arc<str>),
129    Disabled,
130    SignedOut,
131    SigningIn {
132        prompt: Option<request::PromptUserDeviceFlow>,
133    },
134    Unauthorized,
135    Authorized,
136}
137
138impl Status {
139    pub fn is_authorized(&self) -> bool {
140        matches!(self, Status::Authorized)
141    }
142}
143
144#[derive(Debug, PartialEq, Eq)]
145pub struct Completion {
146    pub range: Range<Anchor>,
147    pub text: String,
148}
149
150pub struct Copilot {
151    http: Arc<dyn HttpClient>,
152    node_runtime: Arc<NodeRuntime>,
153    server: CopilotServer,
154}
155
156impl Entity for Copilot {
157    type Event = ();
158}
159
160impl Copilot {
161    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
162        if cx.has_global::<ModelHandle<Self>>() {
163            Some(cx.global::<ModelHandle<Self>>().clone())
164        } else {
165            None
166        }
167    }
168
169    fn start(
170        http: Arc<dyn HttpClient>,
171        node_runtime: Arc<NodeRuntime>,
172        cx: &mut ModelContext<Self>,
173    ) -> Self {
174        cx.observe_global::<Settings, _>({
175            let http = http.clone();
176            let node_runtime = node_runtime.clone();
177            move |this, cx| {
178                if cx.global::<Settings>().features.copilot {
179                    if matches!(this.server, CopilotServer::Disabled) {
180                        let start_task = cx
181                            .spawn({
182                                let http = http.clone();
183                                let node_runtime = node_runtime.clone();
184                                move |this, cx| {
185                                    Self::start_language_server(http, node_runtime, this, cx)
186                                }
187                            })
188                            .shared();
189                        this.server = CopilotServer::Starting { task: start_task };
190                        cx.notify();
191                    }
192                } else {
193                    this.server = CopilotServer::Disabled;
194                    cx.notify();
195                }
196            }
197        })
198        .detach();
199
200        if cx.global::<Settings>().features.copilot {
201            let start_task = cx
202                .spawn({
203                    let http = http.clone();
204                    let node_runtime = node_runtime.clone();
205                    move |this, cx| async {
206                        Self::start_language_server(http, node_runtime, this, cx).await
207                    }
208                })
209                .shared();
210
211            Self {
212                http,
213                node_runtime,
214                server: CopilotServer::Starting { task: start_task },
215            }
216        } else {
217            Self {
218                http,
219                node_runtime,
220                server: CopilotServer::Disabled,
221            }
222        }
223    }
224
225    #[cfg(any(test, feature = "test-support"))]
226    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
227        let (server, fake_server) =
228            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
229        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
230        let this = cx.add_model(|cx| Self {
231            http: http.clone(),
232            node_runtime: NodeRuntime::new(http, cx.background().clone()),
233            server: CopilotServer::Started {
234                server: Arc::new(server),
235                status: SignInStatus::Authorized,
236                subscriptions_by_buffer_id: Default::default(),
237            },
238        });
239        (this, fake_server)
240    }
241
242    fn start_language_server(
243        http: Arc<dyn HttpClient>,
244        node_runtime: Arc<NodeRuntime>,
245        this: ModelHandle<Self>,
246        mut cx: AsyncAppContext,
247    ) -> impl Future<Output = ()> {
248        async move {
249            let start_language_server = async {
250                let server_path = get_copilot_lsp(http).await?;
251                let node_path = node_runtime.binary_path().await?;
252                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
253                let server = LanguageServer::new(
254                    0,
255                    &node_path,
256                    arguments,
257                    Path::new("/"),
258                    None,
259                    cx.clone(),
260                )?;
261
262                let server = server.initialize(Default::default()).await?;
263                let status = server
264                    .request::<request::CheckStatus>(request::CheckStatusParams {
265                        local_checks_only: false,
266                    })
267                    .await?;
268
269                server
270                    .on_notification::<LogMessage, _>(|params, _cx| {
271                        match params.level {
272                            // Copilot is pretty agressive about logging
273                            0 => debug!("copilot: {}", params.message),
274                            1 => debug!("copilot: {}", params.message),
275                            _ => error!("copilot: {}", params.message),
276                        }
277
278                        debug!("copilot metadata: {}", params.metadata_str);
279                        debug!("copilot extra: {:?}", params.extra);
280                    })
281                    .detach();
282
283                server
284                    .on_notification::<StatusNotification, _>(
285                        |_, _| { /* Silence the notification */ },
286                    )
287                    .detach();
288
289                anyhow::Ok((server, status))
290            };
291
292            let server = start_language_server.await;
293            this.update(&mut cx, |this, cx| {
294                cx.notify();
295                match server {
296                    Ok((server, status)) => {
297                        this.server = CopilotServer::Started {
298                            server,
299                            status: SignInStatus::SignedOut,
300                            subscriptions_by_buffer_id: Default::default(),
301                        };
302                        this.update_sign_in_status(status, cx);
303                    }
304                    Err(error) => {
305                        this.server = CopilotServer::Error(error.to_string().into());
306                        cx.notify()
307                    }
308                }
309            })
310        }
311    }
312
313    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
314        if let CopilotServer::Started { server, status, .. } = &mut self.server {
315            let task = match status {
316                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
317                    Task::ready(Ok(())).shared()
318                }
319                SignInStatus::SigningIn { task, .. } => {
320                    cx.notify();
321                    task.clone()
322                }
323                SignInStatus::SignedOut => {
324                    let server = server.clone();
325                    let task = cx
326                        .spawn(|this, mut cx| async move {
327                            let sign_in = async {
328                                let sign_in = server
329                                    .request::<request::SignInInitiate>(
330                                        request::SignInInitiateParams {},
331                                    )
332                                    .await?;
333                                match sign_in {
334                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
335                                        Ok(request::SignInStatus::Ok { user })
336                                    }
337                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
338                                        this.update(&mut cx, |this, cx| {
339                                            if let CopilotServer::Started { status, .. } =
340                                                &mut this.server
341                                            {
342                                                if let SignInStatus::SigningIn {
343                                                    prompt: prompt_flow,
344                                                    ..
345                                                } = status
346                                                {
347                                                    *prompt_flow = Some(flow.clone());
348                                                    cx.notify();
349                                                }
350                                            }
351                                        });
352                                        let response = server
353                                            .request::<request::SignInConfirm>(
354                                                request::SignInConfirmParams {
355                                                    user_code: flow.user_code,
356                                                },
357                                            )
358                                            .await?;
359                                        Ok(response)
360                                    }
361                                }
362                            };
363
364                            let sign_in = sign_in.await;
365                            this.update(&mut cx, |this, cx| match sign_in {
366                                Ok(status) => {
367                                    this.update_sign_in_status(status, cx);
368                                    Ok(())
369                                }
370                                Err(error) => {
371                                    this.update_sign_in_status(
372                                        request::SignInStatus::NotSignedIn,
373                                        cx,
374                                    );
375                                    Err(Arc::new(error))
376                                }
377                            })
378                        })
379                        .shared();
380                    *status = SignInStatus::SigningIn {
381                        prompt: None,
382                        task: task.clone(),
383                    };
384                    cx.notify();
385                    task
386                }
387            };
388
389            cx.foreground()
390                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
391        } else {
392            // If we're downloading, wait until download is finished
393            // If we're in a stuck state, display to the user
394            Task::ready(Err(anyhow!("copilot hasn't started yet")))
395        }
396    }
397
398    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
399        if let CopilotServer::Started { server, status, .. } = &mut self.server {
400            *status = SignInStatus::SignedOut;
401            cx.notify();
402
403            let server = server.clone();
404            cx.background().spawn(async move {
405                server
406                    .request::<request::SignOut>(request::SignOutParams {})
407                    .await?;
408                anyhow::Ok(())
409            })
410        } else {
411            Task::ready(Err(anyhow!("copilot hasn't started yet")))
412        }
413    }
414
415    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
416        let start_task = cx
417            .spawn({
418                let http = self.http.clone();
419                let node_runtime = self.node_runtime.clone();
420                move |this, cx| async move {
421                    clear_copilot_dir().await;
422                    Self::start_language_server(http, node_runtime, this, cx).await
423                }
424            })
425            .shared();
426
427        self.server = CopilotServer::Starting {
428            task: start_task.clone(),
429        };
430
431        cx.notify();
432
433        cx.foreground().spawn(start_task)
434    }
435
436    pub fn completions<T>(
437        &mut self,
438        buffer: &ModelHandle<Buffer>,
439        position: T,
440        cx: &mut ModelContext<Self>,
441    ) -> Task<Result<Vec<Completion>>>
442    where
443        T: ToPointUtf16,
444    {
445        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
446    }
447
448    pub fn completions_cycling<T>(
449        &mut self,
450        buffer: &ModelHandle<Buffer>,
451        position: T,
452        cx: &mut ModelContext<Self>,
453    ) -> Task<Result<Vec<Completion>>>
454    where
455        T: ToPointUtf16,
456    {
457        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
458    }
459
460    fn request_completions<R, T>(
461        &mut self,
462        buffer: &ModelHandle<Buffer>,
463        position: T,
464        cx: &mut ModelContext<Self>,
465    ) -> Task<Result<Vec<Completion>>>
466    where
467        R: lsp::request::Request<
468            Params = request::GetCompletionsParams,
469            Result = request::GetCompletionsResult,
470        >,
471        T: ToPointUtf16,
472    {
473        let buffer_id = buffer.id();
474        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
475        let snapshot = buffer.read(cx).snapshot();
476        let server = match &mut self.server {
477            CopilotServer::Starting { .. } => {
478                return Task::ready(Err(anyhow!("copilot is still starting")))
479            }
480            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
481            CopilotServer::Error(error) => {
482                return Task::ready(Err(anyhow!(
483                    "copilot was not started because of an error: {}",
484                    error
485                )))
486            }
487            CopilotServer::Started {
488                server,
489                status,
490                subscriptions_by_buffer_id,
491            } => {
492                if matches!(status, SignInStatus::Authorized { .. }) {
493                    subscriptions_by_buffer_id
494                        .entry(buffer_id)
495                        .or_insert_with(|| {
496                            server
497                                .notify::<lsp::notification::DidOpenTextDocument>(
498                                    lsp::DidOpenTextDocumentParams {
499                                        text_document: lsp::TextDocumentItem {
500                                            uri: uri.clone(),
501                                            language_id: id_for_language(
502                                                buffer.read(cx).language(),
503                                            ),
504                                            version: 0,
505                                            text: snapshot.text(),
506                                        },
507                                    },
508                                )
509                                .log_err();
510
511                            let uri = uri.clone();
512                            cx.observe_release(buffer, move |this, _, _| {
513                                if let CopilotServer::Started {
514                                    server,
515                                    subscriptions_by_buffer_id,
516                                    ..
517                                } = &mut this.server
518                                {
519                                    server
520                                        .notify::<lsp::notification::DidCloseTextDocument>(
521                                            lsp::DidCloseTextDocumentParams {
522                                                text_document: lsp::TextDocumentIdentifier::new(
523                                                    uri.clone(),
524                                                ),
525                                            },
526                                        )
527                                        .log_err();
528                                    subscriptions_by_buffer_id.remove(&buffer_id);
529                                }
530                            })
531                        });
532
533                    server.clone()
534                } else {
535                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
536                }
537            }
538        };
539
540        let settings = cx.global::<Settings>();
541        let position = position.to_point_utf16(&snapshot);
542        let language = snapshot.language_at(position);
543        let language_name = language.map(|language| language.name());
544        let language_name = language_name.as_deref();
545        let tab_size = settings.tab_size(language_name);
546        let hard_tabs = settings.hard_tabs(language_name);
547        let language_id = id_for_language(language);
548
549        let path;
550        let relative_path;
551        if let Some(file) = snapshot.file() {
552            if let Some(file) = file.as_local() {
553                path = file.abs_path(cx);
554            } else {
555                path = file.full_path(cx);
556            }
557            relative_path = file.path().to_path_buf();
558        } else {
559            path = PathBuf::new();
560            relative_path = PathBuf::new();
561        }
562
563        cx.background().spawn(async move {
564            let result = server
565                .request::<R>(request::GetCompletionsParams {
566                    doc: request::GetCompletionsDocument {
567                        source: snapshot.text(),
568                        tab_size: tab_size.into(),
569                        indent_size: 1,
570                        insert_spaces: !hard_tabs,
571                        uri,
572                        path: path.to_string_lossy().into(),
573                        relative_path: relative_path.to_string_lossy().into(),
574                        language_id,
575                        position: point_to_lsp(position),
576                        version: 0,
577                    },
578                })
579                .await?;
580            let completions = result
581                .completions
582                .into_iter()
583                .map(|completion| {
584                    let start = snapshot
585                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
586                    let end =
587                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
588                    Completion {
589                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
590                        text: completion.text,
591                    }
592                })
593                .collect();
594            anyhow::Ok(completions)
595        })
596    }
597
598    pub fn status(&self) -> Status {
599        match &self.server {
600            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
601            CopilotServer::Disabled => Status::Disabled,
602            CopilotServer::Error(error) => Status::Error(error.clone()),
603            CopilotServer::Started { status, .. } => match status {
604                SignInStatus::Authorized { .. } => Status::Authorized,
605                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
606                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
607                    prompt: prompt.clone(),
608                },
609                SignInStatus::SignedOut => Status::SignedOut,
610            },
611        }
612    }
613
614    fn update_sign_in_status(
615        &mut self,
616        lsp_status: request::SignInStatus,
617        cx: &mut ModelContext<Self>,
618    ) {
619        if let CopilotServer::Started { status, .. } = &mut self.server {
620            *status = match lsp_status {
621                request::SignInStatus::Ok { .. }
622                | request::SignInStatus::MaybeOk { .. }
623                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
624                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
625                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
626            };
627            cx.notify();
628        }
629    }
630}
631
632fn id_for_language(language: Option<&Arc<Language>>) -> String {
633    let language_name = language.map(|language| language.name());
634    match language_name.as_deref() {
635        Some("Plain Text") => "plaintext".to_string(),
636        Some(language_name) => language_name.to_lowercase(),
637        None => "plaintext".to_string(),
638    }
639}
640
641async fn clear_copilot_dir() {
642    remove_matching(&paths::COPILOT_DIR, |_| true).await
643}
644
645async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
646    const SERVER_PATH: &'static str = "dist/agent.js";
647
648    ///Check for the latest copilot language server and download it if we haven't already
649    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
650        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
651
652        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
653
654        fs::create_dir_all(version_dir).await?;
655        let server_path = version_dir.join(SERVER_PATH);
656
657        if fs::metadata(&server_path).await.is_err() {
658            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
659            let dist_dir = version_dir.join("dist");
660            fs::create_dir_all(dist_dir.as_path()).await?;
661
662            let url = &release
663                .assets
664                .get(0)
665                .context("Github release for copilot contained no assets")?
666                .browser_download_url;
667
668            let mut response = http
669                .get(&url, Default::default(), true)
670                .await
671                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
672            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
673            let archive = Archive::new(decompressed_bytes);
674            archive.unpack(dist_dir).await?;
675
676            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
677        }
678
679        Ok(server_path)
680    }
681
682    match fetch_latest(http).await {
683        ok @ Result::Ok(..) => ok,
684        e @ Err(..) => {
685            e.log_err();
686            // Fetch a cached binary, if it exists
687            (|| async move {
688                let mut last_version_dir = None;
689                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
690                while let Some(entry) = entries.next().await {
691                    let entry = entry?;
692                    if entry.file_type().await?.is_dir() {
693                        last_version_dir = Some(entry.path());
694                    }
695                }
696                let last_version_dir =
697                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
698                let server_path = last_version_dir.join(SERVER_PATH);
699                if server_path.exists() {
700                    Ok(server_path)
701                } else {
702                    Err(anyhow!(
703                        "missing executable in directory {:?}",
704                        last_version_dir
705                    ))
706                }
707            })()
708            .await
709        }
710    }
711}