copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use collections::HashMap;
  9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 12    Task,
 13};
 14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 15use log::{debug, error};
 16use lsp::LanguageServer;
 17use node_runtime::NodeRuntime;
 18use request::{LogMessage, StatusNotification};
 19use settings::Settings;
 20use smol::{fs, io::BufReader, stream::StreamExt};
 21use staff_mode::{not_staff_mode, staff_mode};
 22
 23use std::{
 24    ffi::OsString,
 25    ops::Range,
 26    path::{Path, PathBuf},
 27    sync::Arc,
 28};
 29use util::{
 30    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 31};
 32
 33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 34actions!(copilot_auth, [SignIn, SignOut]);
 35
 36const COPILOT_NAMESPACE: &'static str = "copilot";
 37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 38
 39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 40    staff_mode(cx, {
 41        move |cx| {
 42            cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 43                filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 44                filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 45            });
 46
 47            let copilot = cx.add_model({
 48                let node_runtime = node_runtime.clone();
 49                let http = client.http_client().clone();
 50                move |cx| Copilot::start(http, node_runtime, cx)
 51            });
 52            cx.set_global(copilot.clone());
 53
 54            observe_namespaces(cx, copilot);
 55
 56            sign_in::init(cx);
 57        }
 58    });
 59    not_staff_mode(cx, |cx| {
 60        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 61            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 62            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 63        });
 64    });
 65
 66    cx.add_global_action(|_: &SignIn, cx| {
 67        if let Some(copilot) = Copilot::global(cx) {
 68            copilot
 69                .update(cx, |copilot, cx| copilot.sign_in(cx))
 70                .detach_and_log_err(cx);
 71        }
 72    });
 73    cx.add_global_action(|_: &SignOut, cx| {
 74        if let Some(copilot) = Copilot::global(cx) {
 75            copilot
 76                .update(cx, |copilot, cx| copilot.sign_out(cx))
 77                .detach_and_log_err(cx);
 78        }
 79    });
 80
 81    cx.add_global_action(|_: &Reinstall, cx| {
 82        if let Some(copilot) = Copilot::global(cx) {
 83            copilot
 84                .update(cx, |copilot, cx| copilot.reinstall(cx))
 85                .detach();
 86        }
 87    });
 88}
 89
 90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
 91    cx.observe(&copilot, |handle, cx| {
 92        let status = handle.read(cx).status();
 93        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 94            move |filter, _cx| match status {
 95                Status::Disabled => {
 96                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 97                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 98                }
 99                Status::Authorized => {
100                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102                }
103                _ => {
104                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106                }
107            },
108        );
109    })
110    .detach();
111}
112
113enum CopilotServer {
114    Disabled,
115    Starting {
116        task: Shared<Task<()>>,
117    },
118    Error(Arc<str>),
119    Started {
120        server: Arc<LanguageServer>,
121        status: SignInStatus,
122        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123    },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128    Authorized {
129        _user: String,
130    },
131    Unauthorized {
132        _user: String,
133    },
134    SigningIn {
135        prompt: Option<request::PromptUserDeviceFlow>,
136        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
137    },
138    SignedOut,
139}
140
141#[derive(Debug, Clone)]
142pub enum Status {
143    Starting {
144        task: Shared<Task<()>>,
145    },
146    Error(Arc<str>),
147    Disabled,
148    SignedOut,
149    SigningIn {
150        prompt: Option<request::PromptUserDeviceFlow>,
151    },
152    Unauthorized,
153    Authorized,
154}
155
156impl Status {
157    pub fn is_authorized(&self) -> bool {
158        matches!(self, Status::Authorized)
159    }
160}
161
162#[derive(Debug, PartialEq, Eq)]
163pub struct Completion {
164    pub range: Range<Anchor>,
165    pub text: String,
166}
167
168pub struct Copilot {
169    http: Arc<dyn HttpClient>,
170    node_runtime: Arc<NodeRuntime>,
171    server: CopilotServer,
172}
173
174impl Entity for Copilot {
175    type Event = ();
176}
177
178impl Copilot {
179    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
180        match self.server {
181            CopilotServer::Starting { ref task } => Some(task.clone()),
182            _ => None,
183        }
184    }
185
186    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
187        if cx.has_global::<ModelHandle<Self>>() {
188            Some(cx.global::<ModelHandle<Self>>().clone())
189        } else {
190            None
191        }
192    }
193
194    fn start(
195        http: Arc<dyn HttpClient>,
196        node_runtime: Arc<NodeRuntime>,
197        cx: &mut ModelContext<Self>,
198    ) -> Self {
199        cx.observe_global::<Settings, _>({
200            let http = http.clone();
201            let node_runtime = node_runtime.clone();
202            move |this, cx| {
203                if cx.global::<Settings>().enable_copilot_integration {
204                    if matches!(this.server, CopilotServer::Disabled) {
205                        let start_task = cx
206                            .spawn({
207                                let http = http.clone();
208                                let node_runtime = node_runtime.clone();
209                                move |this, cx| {
210                                    Self::start_language_server(http, node_runtime, this, cx)
211                                }
212                            })
213                            .shared();
214                        this.server = CopilotServer::Starting { task: start_task };
215                        cx.notify();
216                    }
217                } else {
218                    this.server = CopilotServer::Disabled;
219                    cx.notify();
220                }
221            }
222        })
223        .detach();
224
225        if cx.global::<Settings>().enable_copilot_integration {
226            let start_task = cx
227                .spawn({
228                    let http = http.clone();
229                    let node_runtime = node_runtime.clone();
230                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
231                })
232                .shared();
233
234            Self {
235                http,
236                node_runtime,
237                server: CopilotServer::Starting { task: start_task },
238            }
239        } else {
240            Self {
241                http,
242                node_runtime,
243                server: CopilotServer::Disabled,
244            }
245        }
246    }
247
248    fn start_language_server(
249        http: Arc<dyn HttpClient>,
250        node_runtime: Arc<NodeRuntime>,
251        this: ModelHandle<Self>,
252        mut cx: AsyncAppContext,
253    ) -> impl Future<Output = ()> {
254        async move {
255            let start_language_server = async {
256                let server_path = get_copilot_lsp(http).await?;
257                let node_path = node_runtime.binary_path().await?;
258                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
259                let server = LanguageServer::new(
260                    0,
261                    &node_path,
262                    arguments,
263                    Path::new("/"),
264                    None,
265                    cx.clone(),
266                )?;
267
268                let server = server.initialize(Default::default()).await?;
269                let status = server
270                    .request::<request::CheckStatus>(request::CheckStatusParams {
271                        local_checks_only: false,
272                    })
273                    .await?;
274
275                server
276                    .on_notification::<LogMessage, _>(|params, _cx| {
277                        match params.level {
278                            // Copilot is pretty agressive about logging
279                            0 => debug!("copilot: {}", params.message),
280                            1 => debug!("copilot: {}", params.message),
281                            _ => error!("copilot: {}", params.message),
282                        }
283
284                        debug!("copilot metadata: {}", params.metadata_str);
285                        debug!("copilot extra: {:?}", params.extra);
286                    })
287                    .detach();
288
289                server
290                    .on_notification::<StatusNotification, _>(
291                        |_, _| { /* Silence the notification */ },
292                    )
293                    .detach();
294
295                anyhow::Ok((server, status))
296            };
297
298            let server = start_language_server.await;
299            this.update(&mut cx, |this, cx| {
300                cx.notify();
301                match server {
302                    Ok((server, status)) => {
303                        this.server = CopilotServer::Started {
304                            server,
305                            status: SignInStatus::SignedOut,
306                            subscriptions_by_buffer_id: Default::default(),
307                        };
308                        this.update_sign_in_status(status, cx);
309                    }
310                    Err(error) => {
311                        this.server = CopilotServer::Error(error.to_string().into());
312                        cx.notify()
313                    }
314                }
315            })
316        }
317    }
318
319    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
320        if let CopilotServer::Started { server, status, .. } = &mut self.server {
321            let task = match status {
322                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
323                    Task::ready(Ok(())).shared()
324                }
325                SignInStatus::SigningIn { task, .. } => {
326                    cx.notify();
327                    task.clone()
328                }
329                SignInStatus::SignedOut => {
330                    let server = server.clone();
331                    let task = cx
332                        .spawn(|this, mut cx| async move {
333                            let sign_in = async {
334                                let sign_in = server
335                                    .request::<request::SignInInitiate>(
336                                        request::SignInInitiateParams {},
337                                    )
338                                    .await?;
339                                match sign_in {
340                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
341                                        Ok(request::SignInStatus::Ok { user })
342                                    }
343                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
344                                        this.update(&mut cx, |this, cx| {
345                                            if let CopilotServer::Started { status, .. } =
346                                                &mut this.server
347                                            {
348                                                if let SignInStatus::SigningIn {
349                                                    prompt: prompt_flow,
350                                                    ..
351                                                } = status
352                                                {
353                                                    *prompt_flow = Some(flow.clone());
354                                                    cx.notify();
355                                                }
356                                            }
357                                        });
358                                        let response = server
359                                            .request::<request::SignInConfirm>(
360                                                request::SignInConfirmParams {
361                                                    user_code: flow.user_code,
362                                                },
363                                            )
364                                            .await?;
365                                        Ok(response)
366                                    }
367                                }
368                            };
369
370                            let sign_in = sign_in.await;
371                            this.update(&mut cx, |this, cx| match sign_in {
372                                Ok(status) => {
373                                    this.update_sign_in_status(status, cx);
374                                    Ok(())
375                                }
376                                Err(error) => {
377                                    this.update_sign_in_status(
378                                        request::SignInStatus::NotSignedIn,
379                                        cx,
380                                    );
381                                    Err(Arc::new(error))
382                                }
383                            })
384                        })
385                        .shared();
386                    *status = SignInStatus::SigningIn {
387                        prompt: None,
388                        task: task.clone(),
389                    };
390                    cx.notify();
391                    task
392                }
393            };
394
395            cx.foreground()
396                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
397        } else {
398            // If we're downloading, wait until download is finished
399            // If we're in a stuck state, display to the user
400            Task::ready(Err(anyhow!("copilot hasn't started yet")))
401        }
402    }
403
404    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
405        if let CopilotServer::Started { server, status, .. } = &mut self.server {
406            *status = SignInStatus::SignedOut;
407            cx.notify();
408
409            let server = server.clone();
410            cx.background().spawn(async move {
411                server
412                    .request::<request::SignOut>(request::SignOutParams {})
413                    .await?;
414                anyhow::Ok(())
415            })
416        } else {
417            Task::ready(Err(anyhow!("copilot hasn't started yet")))
418        }
419    }
420
421    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
422        let start_task = cx
423            .spawn({
424                let http = self.http.clone();
425                let node_runtime = self.node_runtime.clone();
426                move |this, cx| async move {
427                    clear_copilot_dir().await;
428                    Self::start_language_server(http, node_runtime, this, cx).await
429                }
430            })
431            .shared();
432
433        self.server = CopilotServer::Starting {
434            task: start_task.clone(),
435        };
436
437        cx.notify();
438
439        cx.foreground().spawn(start_task)
440    }
441
442    pub fn completions<T>(
443        &mut self,
444        buffer: &ModelHandle<Buffer>,
445        position: T,
446        cx: &mut ModelContext<Self>,
447    ) -> Task<Result<Vec<Completion>>>
448    where
449        T: ToPointUtf16,
450    {
451        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
452    }
453
454    pub fn completions_cycling<T>(
455        &mut self,
456        buffer: &ModelHandle<Buffer>,
457        position: T,
458        cx: &mut ModelContext<Self>,
459    ) -> Task<Result<Vec<Completion>>>
460    where
461        T: ToPointUtf16,
462    {
463        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
464    }
465
466    fn request_completions<R, T>(
467        &mut self,
468        buffer: &ModelHandle<Buffer>,
469        position: T,
470        cx: &mut ModelContext<Self>,
471    ) -> Task<Result<Vec<Completion>>>
472    where
473        R: lsp::request::Request<
474            Params = request::GetCompletionsParams,
475            Result = request::GetCompletionsResult,
476        >,
477        T: ToPointUtf16,
478    {
479        let buffer_id = buffer.id();
480        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
481        let snapshot = buffer.read(cx).snapshot();
482        let server = match &mut self.server {
483            CopilotServer::Starting { .. } => {
484                return Task::ready(Err(anyhow!("copilot is still starting")))
485            }
486            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
487            CopilotServer::Error(error) => {
488                return Task::ready(Err(anyhow!(
489                    "copilot was not started because of an error: {}",
490                    error
491                )))
492            }
493            CopilotServer::Started {
494                server,
495                status,
496                subscriptions_by_buffer_id,
497            } => {
498                if matches!(status, SignInStatus::Authorized { .. }) {
499                    subscriptions_by_buffer_id
500                        .entry(buffer_id)
501                        .or_insert_with(|| {
502                            server
503                                .notify::<lsp::notification::DidOpenTextDocument>(
504                                    lsp::DidOpenTextDocumentParams {
505                                        text_document: lsp::TextDocumentItem {
506                                            uri: uri.clone(),
507                                            language_id: id_for_language(
508                                                buffer.read(cx).language(),
509                                            ),
510                                            version: 0,
511                                            text: snapshot.text(),
512                                        },
513                                    },
514                                )
515                                .log_err();
516
517                            let uri = uri.clone();
518                            cx.observe_release(buffer, move |this, _, _| {
519                                if let CopilotServer::Started {
520                                    server,
521                                    subscriptions_by_buffer_id,
522                                    ..
523                                } = &mut this.server
524                                {
525                                    server
526                                        .notify::<lsp::notification::DidCloseTextDocument>(
527                                            lsp::DidCloseTextDocumentParams {
528                                                text_document: lsp::TextDocumentIdentifier::new(
529                                                    uri.clone(),
530                                                ),
531                                            },
532                                        )
533                                        .log_err();
534                                    subscriptions_by_buffer_id.remove(&buffer_id);
535                                }
536                            })
537                        });
538
539                    server.clone()
540                } else {
541                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
542                }
543            }
544        };
545
546        let settings = cx.global::<Settings>();
547        let position = position.to_point_utf16(&snapshot);
548        let language = snapshot.language_at(position);
549        let language_name = language.map(|language| language.name());
550        let language_name = language_name.as_deref();
551        let tab_size = settings.tab_size(language_name);
552        let hard_tabs = settings.hard_tabs(language_name);
553        let language_id = id_for_language(language);
554
555        let path;
556        let relative_path;
557        if let Some(file) = snapshot.file() {
558            if let Some(file) = file.as_local() {
559                path = file.abs_path(cx);
560            } else {
561                path = file.full_path(cx);
562            }
563            relative_path = file.path().to_path_buf();
564        } else {
565            path = PathBuf::new();
566            relative_path = PathBuf::new();
567        }
568
569        cx.background().spawn(async move {
570            let result = server
571                .request::<R>(request::GetCompletionsParams {
572                    doc: request::GetCompletionsDocument {
573                        source: snapshot.text(),
574                        tab_size: tab_size.into(),
575                        indent_size: 1,
576                        insert_spaces: !hard_tabs,
577                        uri,
578                        path: path.to_string_lossy().into(),
579                        relative_path: relative_path.to_string_lossy().into(),
580                        language_id,
581                        position: point_to_lsp(position),
582                        version: 0,
583                    },
584                })
585                .await?;
586            let completions = result
587                .completions
588                .into_iter()
589                .map(|completion| {
590                    let start = snapshot
591                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
592                    let end =
593                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
594                    Completion {
595                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
596                        text: completion.text,
597                    }
598                })
599                .collect();
600            anyhow::Ok(completions)
601        })
602    }
603
604    pub fn status(&self) -> Status {
605        match &self.server {
606            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
607            CopilotServer::Disabled => Status::Disabled,
608            CopilotServer::Error(error) => Status::Error(error.clone()),
609            CopilotServer::Started { status, .. } => match status {
610                SignInStatus::Authorized { .. } => Status::Authorized,
611                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
612                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
613                    prompt: prompt.clone(),
614                },
615                SignInStatus::SignedOut => Status::SignedOut,
616            },
617        }
618    }
619
620    fn update_sign_in_status(
621        &mut self,
622        lsp_status: request::SignInStatus,
623        cx: &mut ModelContext<Self>,
624    ) {
625        if let CopilotServer::Started { status, .. } = &mut self.server {
626            *status = match lsp_status {
627                request::SignInStatus::Ok { user }
628                | request::SignInStatus::MaybeOk { user }
629                | request::SignInStatus::AlreadySignedIn { user } => {
630                    SignInStatus::Authorized { _user: user }
631                }
632                request::SignInStatus::NotAuthorized { user } => {
633                    SignInStatus::Unauthorized { _user: user }
634                }
635                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
636            };
637            cx.notify();
638        }
639    }
640}
641
642fn id_for_language(language: Option<&Arc<Language>>) -> String {
643    let language_name = language.map(|language| language.name());
644    match language_name.as_deref() {
645        Some("Plain Text") => "plaintext".to_string(),
646        Some(language_name) => language_name.to_lowercase(),
647        None => "plaintext".to_string(),
648    }
649}
650
651async fn clear_copilot_dir() {
652    remove_matching(&paths::COPILOT_DIR, |_| true).await
653}
654
655async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
656    const SERVER_PATH: &'static str = "dist/agent.js";
657
658    ///Check for the latest copilot language server and download it if we haven't already
659    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
660        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
661
662        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
663
664        fs::create_dir_all(version_dir).await?;
665        let server_path = version_dir.join(SERVER_PATH);
666
667        if fs::metadata(&server_path).await.is_err() {
668            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
669            let dist_dir = version_dir.join("dist");
670            fs::create_dir_all(dist_dir.as_path()).await?;
671
672            let url = &release
673                .assets
674                .get(0)
675                .context("Github release for copilot contained no assets")?
676                .browser_download_url;
677
678            let mut response = http
679                .get(&url, Default::default(), true)
680                .await
681                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
682            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
683            let archive = Archive::new(decompressed_bytes);
684            archive.unpack(dist_dir).await?;
685
686            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
687        }
688
689        Ok(server_path)
690    }
691
692    match fetch_latest(http).await {
693        ok @ Result::Ok(..) => ok,
694        e @ Err(..) => {
695            e.log_err();
696            // Fetch a cached binary, if it exists
697            (|| async move {
698                let mut last_version_dir = None;
699                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
700                while let Some(entry) = entries.next().await {
701                    let entry = entry?;
702                    if entry.file_type().await?.is_dir() {
703                        last_version_dir = Some(entry.path());
704                    }
705                }
706                let last_version_dir =
707                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
708                let server_path = last_version_dir.join(SERVER_PATH);
709                if server_path.exists() {
710                    Ok(server_path)
711                } else {
712                    Err(anyhow!(
713                        "missing executable in directory {:?}",
714                        last_version_dir
715                    ))
716                }
717            })()
718            .await
719        }
720    }
721}