copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use collections::HashMap;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 14use log::{debug, error};
 15use lsp::LanguageServer;
 16use node_runtime::NodeRuntime;
 17use request::{LogMessage, StatusNotification};
 18use settings::Settings;
 19use smol::{fs, io::BufReader, stream::StreamExt};
 20use std::{
 21    ffi::OsString,
 22    ops::Range,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25};
 26use util::{
 27    channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
 28    paths, ResultExt,
 29};
 30
 31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 32actions!(copilot_auth, [SignIn, SignOut]);
 33
 34const COPILOT_NAMESPACE: &'static str = "copilot";
 35actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 36
 37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 38    // Disable Copilot for stable releases.
 39    if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
 40        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 41            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 42            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 43        });
 44        return;
 45    }
 46
 47    let copilot = cx.add_model({
 48        let node_runtime = node_runtime.clone();
 49        move |cx| Copilot::start(http, node_runtime, cx)
 50    });
 51    cx.set_global(copilot.clone());
 52
 53    cx.observe(&copilot, |handle, cx| {
 54        let status = handle.read(cx).status();
 55        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 56            move |filter, _cx| match status {
 57                Status::Disabled => {
 58                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 60                }
 61                Status::Authorized => {
 62                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 63                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 64                }
 65                _ => {
 66                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 67                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 68                }
 69            },
 70        );
 71    })
 72    .detach();
 73
 74    sign_in::init(cx);
 75    cx.add_global_action(|_: &SignIn, cx| {
 76        if let Some(copilot) = Copilot::global(cx) {
 77            copilot
 78                .update(cx, |copilot, cx| copilot.sign_in(cx))
 79                .detach_and_log_err(cx);
 80        }
 81    });
 82    cx.add_global_action(|_: &SignOut, cx| {
 83        if let Some(copilot) = Copilot::global(cx) {
 84            copilot
 85                .update(cx, |copilot, cx| copilot.sign_out(cx))
 86                .detach_and_log_err(cx);
 87        }
 88    });
 89
 90    cx.add_global_action(|_: &Reinstall, cx| {
 91        if let Some(copilot) = Copilot::global(cx) {
 92            copilot
 93                .update(cx, |copilot, cx| copilot.reinstall(cx))
 94                .detach();
 95        }
 96    });
 97}
 98
 99enum CopilotServer {
100    Disabled,
101    Starting {
102        task: Shared<Task<()>>,
103    },
104    Error(Arc<str>),
105    Started {
106        server: Arc<LanguageServer>,
107        status: SignInStatus,
108        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
109    },
110}
111
112#[derive(Clone, Debug)]
113enum SignInStatus {
114    Authorized,
115    Unauthorized,
116    SigningIn {
117        prompt: Option<request::PromptUserDeviceFlow>,
118        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
119    },
120    SignedOut,
121}
122
123#[derive(Debug, Clone)]
124pub enum Status {
125    Starting {
126        task: Shared<Task<()>>,
127    },
128    Error(Arc<str>),
129    Disabled,
130    SignedOut,
131    SigningIn {
132        prompt: Option<request::PromptUserDeviceFlow>,
133    },
134    Unauthorized,
135    Authorized,
136}
137
138impl Status {
139    pub fn is_authorized(&self) -> bool {
140        matches!(self, Status::Authorized)
141    }
142}
143
144#[derive(Debug, PartialEq, Eq)]
145pub struct Completion {
146    pub range: Range<Anchor>,
147    pub text: String,
148}
149
150pub struct Copilot {
151    http: Arc<dyn HttpClient>,
152    node_runtime: Arc<NodeRuntime>,
153    server: CopilotServer,
154}
155
156impl Entity for Copilot {
157    type Event = ();
158}
159
160impl Copilot {
161    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
162        if cx.has_global::<ModelHandle<Self>>() {
163            Some(cx.global::<ModelHandle<Self>>().clone())
164        } else {
165            None
166        }
167    }
168
169    fn start(
170        http: Arc<dyn HttpClient>,
171        node_runtime: Arc<NodeRuntime>,
172        cx: &mut ModelContext<Self>,
173    ) -> Self {
174        cx.observe_global::<Settings, _>({
175            let http = http.clone();
176            let node_runtime = node_runtime.clone();
177            move |this, cx| {
178                if cx.global::<Settings>().enable_copilot_integration {
179                    if matches!(this.server, CopilotServer::Disabled) {
180                        let start_task = cx
181                            .spawn({
182                                let http = http.clone();
183                                let node_runtime = node_runtime.clone();
184                                move |this, cx| {
185                                    Self::start_language_server(http, node_runtime, this, cx)
186                                }
187                            })
188                            .shared();
189                        this.server = CopilotServer::Starting { task: start_task };
190                        cx.notify();
191                    }
192                } else {
193                    this.server = CopilotServer::Disabled;
194                    cx.notify();
195                }
196            }
197        })
198        .detach();
199
200        if cx.global::<Settings>().enable_copilot_integration {
201            let start_task = cx
202                .spawn({
203                    let http = http.clone();
204                    let node_runtime = node_runtime.clone();
205                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
206                })
207                .shared();
208
209            Self {
210                http,
211                node_runtime,
212                server: CopilotServer::Starting { task: start_task },
213            }
214        } else {
215            Self {
216                http,
217                node_runtime,
218                server: CopilotServer::Disabled,
219            }
220        }
221    }
222
223    #[cfg(any(test, feature = "test-support"))]
224    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
225        let (server, fake_server) =
226            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
227        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
228        let this = cx.add_model(|cx| Self {
229            http: http.clone(),
230            node_runtime: NodeRuntime::new(http, cx.background().clone()),
231            server: CopilotServer::Started {
232                server: Arc::new(server),
233                status: SignInStatus::Authorized,
234                subscriptions_by_buffer_id: Default::default(),
235            },
236        });
237        (this, fake_server)
238    }
239
240    fn start_language_server(
241        http: Arc<dyn HttpClient>,
242        node_runtime: Arc<NodeRuntime>,
243        this: ModelHandle<Self>,
244        mut cx: AsyncAppContext,
245    ) -> impl Future<Output = ()> {
246        async move {
247            let start_language_server = async {
248                let server_path = get_copilot_lsp(http).await?;
249                let node_path = node_runtime.binary_path().await?;
250                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
251                let server = LanguageServer::new(
252                    0,
253                    &node_path,
254                    arguments,
255                    Path::new("/"),
256                    None,
257                    cx.clone(),
258                )?;
259
260                let server = server.initialize(Default::default()).await?;
261                let status = server
262                    .request::<request::CheckStatus>(request::CheckStatusParams {
263                        local_checks_only: false,
264                    })
265                    .await?;
266
267                server
268                    .on_notification::<LogMessage, _>(|params, _cx| {
269                        match params.level {
270                            // Copilot is pretty agressive about logging
271                            0 => debug!("copilot: {}", params.message),
272                            1 => debug!("copilot: {}", params.message),
273                            _ => error!("copilot: {}", params.message),
274                        }
275
276                        debug!("copilot metadata: {}", params.metadata_str);
277                        debug!("copilot extra: {:?}", params.extra);
278                    })
279                    .detach();
280
281                server
282                    .on_notification::<StatusNotification, _>(
283                        |_, _| { /* Silence the notification */ },
284                    )
285                    .detach();
286
287                anyhow::Ok((server, status))
288            };
289
290            let server = start_language_server.await;
291            this.update(&mut cx, |this, cx| {
292                cx.notify();
293                match server {
294                    Ok((server, status)) => {
295                        this.server = CopilotServer::Started {
296                            server,
297                            status: SignInStatus::SignedOut,
298                            subscriptions_by_buffer_id: Default::default(),
299                        };
300                        this.update_sign_in_status(status, cx);
301                    }
302                    Err(error) => {
303                        this.server = CopilotServer::Error(error.to_string().into());
304                        cx.notify()
305                    }
306                }
307            })
308        }
309    }
310
311    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
312        if let CopilotServer::Started { server, status, .. } = &mut self.server {
313            let task = match status {
314                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
315                    Task::ready(Ok(())).shared()
316                }
317                SignInStatus::SigningIn { task, .. } => {
318                    cx.notify();
319                    task.clone()
320                }
321                SignInStatus::SignedOut => {
322                    let server = server.clone();
323                    let task = cx
324                        .spawn(|this, mut cx| async move {
325                            let sign_in = async {
326                                let sign_in = server
327                                    .request::<request::SignInInitiate>(
328                                        request::SignInInitiateParams {},
329                                    )
330                                    .await?;
331                                match sign_in {
332                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
333                                        Ok(request::SignInStatus::Ok { user })
334                                    }
335                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
336                                        this.update(&mut cx, |this, cx| {
337                                            if let CopilotServer::Started { status, .. } =
338                                                &mut this.server
339                                            {
340                                                if let SignInStatus::SigningIn {
341                                                    prompt: prompt_flow,
342                                                    ..
343                                                } = status
344                                                {
345                                                    *prompt_flow = Some(flow.clone());
346                                                    cx.notify();
347                                                }
348                                            }
349                                        });
350                                        let response = server
351                                            .request::<request::SignInConfirm>(
352                                                request::SignInConfirmParams {
353                                                    user_code: flow.user_code,
354                                                },
355                                            )
356                                            .await?;
357                                        Ok(response)
358                                    }
359                                }
360                            };
361
362                            let sign_in = sign_in.await;
363                            this.update(&mut cx, |this, cx| match sign_in {
364                                Ok(status) => {
365                                    this.update_sign_in_status(status, cx);
366                                    Ok(())
367                                }
368                                Err(error) => {
369                                    this.update_sign_in_status(
370                                        request::SignInStatus::NotSignedIn,
371                                        cx,
372                                    );
373                                    Err(Arc::new(error))
374                                }
375                            })
376                        })
377                        .shared();
378                    *status = SignInStatus::SigningIn {
379                        prompt: None,
380                        task: task.clone(),
381                    };
382                    cx.notify();
383                    task
384                }
385            };
386
387            cx.foreground()
388                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
389        } else {
390            // If we're downloading, wait until download is finished
391            // If we're in a stuck state, display to the user
392            Task::ready(Err(anyhow!("copilot hasn't started yet")))
393        }
394    }
395
396    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
397        if let CopilotServer::Started { server, status, .. } = &mut self.server {
398            *status = SignInStatus::SignedOut;
399            cx.notify();
400
401            let server = server.clone();
402            cx.background().spawn(async move {
403                server
404                    .request::<request::SignOut>(request::SignOutParams {})
405                    .await?;
406                anyhow::Ok(())
407            })
408        } else {
409            Task::ready(Err(anyhow!("copilot hasn't started yet")))
410        }
411    }
412
413    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
414        let start_task = cx
415            .spawn({
416                let http = self.http.clone();
417                let node_runtime = self.node_runtime.clone();
418                move |this, cx| async move {
419                    clear_copilot_dir().await;
420                    Self::start_language_server(http, node_runtime, this, cx).await
421                }
422            })
423            .shared();
424
425        self.server = CopilotServer::Starting {
426            task: start_task.clone(),
427        };
428
429        cx.notify();
430
431        cx.foreground().spawn(start_task)
432    }
433
434    pub fn completions<T>(
435        &mut self,
436        buffer: &ModelHandle<Buffer>,
437        position: T,
438        cx: &mut ModelContext<Self>,
439    ) -> Task<Result<Vec<Completion>>>
440    where
441        T: ToPointUtf16,
442    {
443        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
444    }
445
446    pub fn completions_cycling<T>(
447        &mut self,
448        buffer: &ModelHandle<Buffer>,
449        position: T,
450        cx: &mut ModelContext<Self>,
451    ) -> Task<Result<Vec<Completion>>>
452    where
453        T: ToPointUtf16,
454    {
455        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
456    }
457
458    fn request_completions<R, T>(
459        &mut self,
460        buffer: &ModelHandle<Buffer>,
461        position: T,
462        cx: &mut ModelContext<Self>,
463    ) -> Task<Result<Vec<Completion>>>
464    where
465        R: lsp::request::Request<
466            Params = request::GetCompletionsParams,
467            Result = request::GetCompletionsResult,
468        >,
469        T: ToPointUtf16,
470    {
471        let buffer_id = buffer.id();
472        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
473        let snapshot = buffer.read(cx).snapshot();
474        let server = match &mut self.server {
475            CopilotServer::Starting { .. } => {
476                return Task::ready(Err(anyhow!("copilot is still starting")))
477            }
478            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
479            CopilotServer::Error(error) => {
480                return Task::ready(Err(anyhow!(
481                    "copilot was not started because of an error: {}",
482                    error
483                )))
484            }
485            CopilotServer::Started {
486                server,
487                status,
488                subscriptions_by_buffer_id,
489            } => {
490                if matches!(status, SignInStatus::Authorized { .. }) {
491                    subscriptions_by_buffer_id
492                        .entry(buffer_id)
493                        .or_insert_with(|| {
494                            server
495                                .notify::<lsp::notification::DidOpenTextDocument>(
496                                    lsp::DidOpenTextDocumentParams {
497                                        text_document: lsp::TextDocumentItem {
498                                            uri: uri.clone(),
499                                            language_id: id_for_language(
500                                                buffer.read(cx).language(),
501                                            ),
502                                            version: 0,
503                                            text: snapshot.text(),
504                                        },
505                                    },
506                                )
507                                .log_err();
508
509                            let uri = uri.clone();
510                            cx.observe_release(buffer, move |this, _, _| {
511                                if let CopilotServer::Started {
512                                    server,
513                                    subscriptions_by_buffer_id,
514                                    ..
515                                } = &mut this.server
516                                {
517                                    server
518                                        .notify::<lsp::notification::DidCloseTextDocument>(
519                                            lsp::DidCloseTextDocumentParams {
520                                                text_document: lsp::TextDocumentIdentifier::new(
521                                                    uri.clone(),
522                                                ),
523                                            },
524                                        )
525                                        .log_err();
526                                    subscriptions_by_buffer_id.remove(&buffer_id);
527                                }
528                            })
529                        });
530
531                    server.clone()
532                } else {
533                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
534                }
535            }
536        };
537
538        let settings = cx.global::<Settings>();
539        let position = position.to_point_utf16(&snapshot);
540        let language = snapshot.language_at(position);
541        let language_name = language.map(|language| language.name());
542        let language_name = language_name.as_deref();
543        let tab_size = settings.tab_size(language_name);
544        let hard_tabs = settings.hard_tabs(language_name);
545        let language_id = id_for_language(language);
546
547        let path;
548        let relative_path;
549        if let Some(file) = snapshot.file() {
550            if let Some(file) = file.as_local() {
551                path = file.abs_path(cx);
552            } else {
553                path = file.full_path(cx);
554            }
555            relative_path = file.path().to_path_buf();
556        } else {
557            path = PathBuf::new();
558            relative_path = PathBuf::new();
559        }
560
561        cx.background().spawn(async move {
562            let result = server
563                .request::<R>(request::GetCompletionsParams {
564                    doc: request::GetCompletionsDocument {
565                        source: snapshot.text(),
566                        tab_size: tab_size.into(),
567                        indent_size: 1,
568                        insert_spaces: !hard_tabs,
569                        uri,
570                        path: path.to_string_lossy().into(),
571                        relative_path: relative_path.to_string_lossy().into(),
572                        language_id,
573                        position: point_to_lsp(position),
574                        version: 0,
575                    },
576                })
577                .await?;
578            let completions = result
579                .completions
580                .into_iter()
581                .map(|completion| {
582                    let start = snapshot
583                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
584                    let end =
585                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
586                    Completion {
587                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
588                        text: completion.text,
589                    }
590                })
591                .collect();
592            anyhow::Ok(completions)
593        })
594    }
595
596    pub fn status(&self) -> Status {
597        match &self.server {
598            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
599            CopilotServer::Disabled => Status::Disabled,
600            CopilotServer::Error(error) => Status::Error(error.clone()),
601            CopilotServer::Started { status, .. } => match status {
602                SignInStatus::Authorized { .. } => Status::Authorized,
603                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
604                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
605                    prompt: prompt.clone(),
606                },
607                SignInStatus::SignedOut => Status::SignedOut,
608            },
609        }
610    }
611
612    fn update_sign_in_status(
613        &mut self,
614        lsp_status: request::SignInStatus,
615        cx: &mut ModelContext<Self>,
616    ) {
617        if let CopilotServer::Started { status, .. } = &mut self.server {
618            *status = match lsp_status {
619                request::SignInStatus::Ok { .. }
620                | request::SignInStatus::MaybeOk { .. }
621                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
622                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
623                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
624            };
625            cx.notify();
626        }
627    }
628}
629
630fn id_for_language(language: Option<&Arc<Language>>) -> String {
631    let language_name = language.map(|language| language.name());
632    match language_name.as_deref() {
633        Some("Plain Text") => "plaintext".to_string(),
634        Some(language_name) => language_name.to_lowercase(),
635        None => "plaintext".to_string(),
636    }
637}
638
639async fn clear_copilot_dir() {
640    remove_matching(&paths::COPILOT_DIR, |_| true).await
641}
642
643async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
644    const SERVER_PATH: &'static str = "dist/agent.js";
645
646    ///Check for the latest copilot language server and download it if we haven't already
647    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
648        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
649
650        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
651
652        fs::create_dir_all(version_dir).await?;
653        let server_path = version_dir.join(SERVER_PATH);
654
655        if fs::metadata(&server_path).await.is_err() {
656            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
657            let dist_dir = version_dir.join("dist");
658            fs::create_dir_all(dist_dir.as_path()).await?;
659
660            let url = &release
661                .assets
662                .get(0)
663                .context("Github release for copilot contained no assets")?
664                .browser_download_url;
665
666            let mut response = http
667                .get(&url, Default::default(), true)
668                .await
669                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
670            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
671            let archive = Archive::new(decompressed_bytes);
672            archive.unpack(dist_dir).await?;
673
674            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
675        }
676
677        Ok(server_path)
678    }
679
680    match fetch_latest(http).await {
681        ok @ Result::Ok(..) => ok,
682        e @ Err(..) => {
683            e.log_err();
684            // Fetch a cached binary, if it exists
685            (|| async move {
686                let mut last_version_dir = None;
687                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
688                while let Some(entry) = entries.next().await {
689                    let entry = entry?;
690                    if entry.file_type().await?.is_dir() {
691                        last_version_dir = Some(entry.path());
692                    }
693                }
694                let last_version_dir =
695                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
696                let server_path = last_version_dir.join(SERVER_PATH);
697                if server_path.exists() {
698                    Ok(server_path)
699                } else {
700                    Err(anyhow!(
701                        "missing executable in directory {:?}",
702                        last_version_dir
703                    ))
704                }
705            })()
706            .await
707        }
708    }
709}