copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use collections::HashMap;
  9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 12    Task,
 13};
 14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 15use log::{debug, error};
 16use lsp::LanguageServer;
 17use node_runtime::NodeRuntime;
 18use request::{LogMessage, StatusNotification};
 19use settings::Settings;
 20use smol::{fs, io::BufReader, stream::StreamExt};
 21use std::{
 22    ffi::OsString,
 23    ops::Range,
 24    path::{Path, PathBuf},
 25    sync::Arc,
 26};
 27use util::{
 28    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 29};
 30
 31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 32actions!(copilot_auth, [SignIn, SignOut]);
 33
 34const COPILOT_NAMESPACE: &'static str = "copilot";
 35actions!(
 36    copilot,
 37    [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
 38);
 39
 40pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 41    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 42    cx.set_global(copilot.clone());
 43    cx.add_global_action(|_: &SignIn, cx| {
 44        let copilot = Copilot::global(cx).unwrap();
 45        copilot
 46            .update(cx, |copilot, cx| copilot.sign_in(cx))
 47            .detach_and_log_err(cx);
 48    });
 49    cx.add_global_action(|_: &SignOut, cx| {
 50        let copilot = Copilot::global(cx).unwrap();
 51        copilot
 52            .update(cx, |copilot, cx| copilot.sign_out(cx))
 53            .detach_and_log_err(cx);
 54    });
 55
 56    cx.add_global_action(|_: &Reinstall, cx| {
 57        let copilot = Copilot::global(cx).unwrap();
 58        copilot
 59            .update(cx, |copilot, cx| copilot.reinstall(cx))
 60            .detach();
 61    });
 62
 63    cx.observe(&copilot, |handle, cx| {
 64        let status = handle.read(cx).status();
 65        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 66            move |filter, _cx| match status {
 67                Status::Disabled => {
 68                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 69                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 70                }
 71                Status::Authorized => {
 72                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 73                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 74                }
 75                _ => {
 76                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 77                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 78                }
 79            },
 80        );
 81    })
 82    .detach();
 83
 84    sign_in::init(cx);
 85}
 86
 87enum CopilotServer {
 88    Disabled,
 89    Starting {
 90        task: Shared<Task<()>>,
 91    },
 92    Error(Arc<str>),
 93    Started {
 94        server: Arc<LanguageServer>,
 95        status: SignInStatus,
 96        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
 97    },
 98}
 99
100#[derive(Clone, Debug)]
101enum SignInStatus {
102    Authorized {
103        _user: String,
104    },
105    Unauthorized {
106        _user: String,
107    },
108    SigningIn {
109        prompt: Option<request::PromptUserDeviceFlow>,
110        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
111    },
112    SignedOut,
113}
114
115#[derive(Debug, Clone)]
116pub enum Status {
117    Starting {
118        task: Shared<Task<()>>,
119    },
120    Error(Arc<str>),
121    Disabled,
122    SignedOut,
123    SigningIn {
124        prompt: Option<request::PromptUserDeviceFlow>,
125    },
126    Unauthorized,
127    Authorized,
128}
129
130impl Status {
131    pub fn is_authorized(&self) -> bool {
132        matches!(self, Status::Authorized)
133    }
134}
135
136#[derive(Debug, PartialEq, Eq)]
137pub struct Completion {
138    pub range: Range<Anchor>,
139    pub text: String,
140}
141
142pub struct Copilot {
143    http: Arc<dyn HttpClient>,
144    node_runtime: Arc<NodeRuntime>,
145    server: CopilotServer,
146}
147
148impl Entity for Copilot {
149    type Event = ();
150}
151
152impl Copilot {
153    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
154        match self.server {
155            CopilotServer::Starting { ref task } => Some(task.clone()),
156            _ => None,
157        }
158    }
159
160    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
161        if cx.has_global::<ModelHandle<Self>>() {
162            Some(cx.global::<ModelHandle<Self>>().clone())
163        } else {
164            None
165        }
166    }
167
168    fn start(
169        http: Arc<dyn HttpClient>,
170        node_runtime: Arc<NodeRuntime>,
171        cx: &mut ModelContext<Self>,
172    ) -> Self {
173        cx.observe_global::<Settings, _>({
174            let http = http.clone();
175            let node_runtime = node_runtime.clone();
176            move |this, cx| {
177                if cx.global::<Settings>().enable_copilot_integration {
178                    if matches!(this.server, CopilotServer::Disabled) {
179                        let start_task = cx
180                            .spawn({
181                                let http = http.clone();
182                                let node_runtime = node_runtime.clone();
183                                move |this, cx| {
184                                    Self::start_language_server(http, node_runtime, this, cx)
185                                }
186                            })
187                            .shared();
188                        this.server = CopilotServer::Starting { task: start_task };
189                        cx.notify();
190                    }
191                } else {
192                    this.server = CopilotServer::Disabled;
193                    cx.notify();
194                }
195            }
196        })
197        .detach();
198
199        if cx.global::<Settings>().enable_copilot_integration {
200            let start_task = cx
201                .spawn({
202                    let http = http.clone();
203                    let node_runtime = node_runtime.clone();
204                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
205                })
206                .shared();
207
208            Self {
209                http,
210                node_runtime,
211                server: CopilotServer::Starting { task: start_task },
212            }
213        } else {
214            Self {
215                http,
216                node_runtime,
217                server: CopilotServer::Disabled,
218            }
219        }
220    }
221
222    fn start_language_server(
223        http: Arc<dyn HttpClient>,
224        node_runtime: Arc<NodeRuntime>,
225        this: ModelHandle<Self>,
226        mut cx: AsyncAppContext,
227    ) -> impl Future<Output = ()> {
228        async move {
229            let start_language_server = async {
230                let server_path = get_copilot_lsp(http).await?;
231                let node_path = node_runtime.binary_path().await?;
232                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
233                let server = LanguageServer::new(
234                    0,
235                    &node_path,
236                    arguments,
237                    Path::new("/"),
238                    None,
239                    cx.clone(),
240                )?;
241
242                let server = server.initialize(Default::default()).await?;
243                let status = server
244                    .request::<request::CheckStatus>(request::CheckStatusParams {
245                        local_checks_only: false,
246                    })
247                    .await?;
248
249                server
250                    .on_notification::<LogMessage, _>(|params, _cx| {
251                        match params.level {
252                            // Copilot is pretty agressive about logging
253                            0 => debug!("copilot: {}", params.message),
254                            1 => debug!("copilot: {}", params.message),
255                            _ => error!("copilot: {}", params.message),
256                        }
257
258                        debug!("copilot metadata: {}", params.metadata_str);
259                        debug!("copilot extra: {:?}", params.extra);
260                    })
261                    .detach();
262
263                server
264                    .on_notification::<StatusNotification, _>(
265                        |_, _| { /* Silence the notification */ },
266                    )
267                    .detach();
268
269                anyhow::Ok((server, status))
270            };
271
272            let server = start_language_server.await;
273            this.update(&mut cx, |this, cx| {
274                cx.notify();
275                match server {
276                    Ok((server, status)) => {
277                        this.server = CopilotServer::Started {
278                            server,
279                            status: SignInStatus::SignedOut,
280                            subscriptions_by_buffer_id: Default::default(),
281                        };
282                        this.update_sign_in_status(status, cx);
283                    }
284                    Err(error) => {
285                        this.server = CopilotServer::Error(error.to_string().into());
286                        cx.notify()
287                    }
288                }
289            })
290        }
291    }
292
293    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
294        if let CopilotServer::Started { server, status, .. } = &mut self.server {
295            let task = match status {
296                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
297                    Task::ready(Ok(())).shared()
298                }
299                SignInStatus::SigningIn { task, .. } => {
300                    cx.notify();
301                    task.clone()
302                }
303                SignInStatus::SignedOut => {
304                    let server = server.clone();
305                    let task = cx
306                        .spawn(|this, mut cx| async move {
307                            let sign_in = async {
308                                let sign_in = server
309                                    .request::<request::SignInInitiate>(
310                                        request::SignInInitiateParams {},
311                                    )
312                                    .await?;
313                                match sign_in {
314                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
315                                        Ok(request::SignInStatus::Ok { user })
316                                    }
317                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
318                                        this.update(&mut cx, |this, cx| {
319                                            if let CopilotServer::Started { status, .. } =
320                                                &mut this.server
321                                            {
322                                                if let SignInStatus::SigningIn {
323                                                    prompt: prompt_flow,
324                                                    ..
325                                                } = status
326                                                {
327                                                    *prompt_flow = Some(flow.clone());
328                                                    cx.notify();
329                                                }
330                                            }
331                                        });
332                                        let response = server
333                                            .request::<request::SignInConfirm>(
334                                                request::SignInConfirmParams {
335                                                    user_code: flow.user_code,
336                                                },
337                                            )
338                                            .await?;
339                                        Ok(response)
340                                    }
341                                }
342                            };
343
344                            let sign_in = sign_in.await;
345                            this.update(&mut cx, |this, cx| match sign_in {
346                                Ok(status) => {
347                                    this.update_sign_in_status(status, cx);
348                                    Ok(())
349                                }
350                                Err(error) => {
351                                    this.update_sign_in_status(
352                                        request::SignInStatus::NotSignedIn,
353                                        cx,
354                                    );
355                                    Err(Arc::new(error))
356                                }
357                            })
358                        })
359                        .shared();
360                    *status = SignInStatus::SigningIn {
361                        prompt: None,
362                        task: task.clone(),
363                    };
364                    cx.notify();
365                    task
366                }
367            };
368
369            cx.foreground()
370                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
371        } else {
372            // If we're downloading, wait until download is finished
373            // If we're in a stuck state, display to the user
374            Task::ready(Err(anyhow!("copilot hasn't started yet")))
375        }
376    }
377
378    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
379        if let CopilotServer::Started { server, status, .. } = &mut self.server {
380            *status = SignInStatus::SignedOut;
381            cx.notify();
382
383            let server = server.clone();
384            cx.background().spawn(async move {
385                server
386                    .request::<request::SignOut>(request::SignOutParams {})
387                    .await?;
388                anyhow::Ok(())
389            })
390        } else {
391            Task::ready(Err(anyhow!("copilot hasn't started yet")))
392        }
393    }
394
395    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
396        let start_task = cx
397            .spawn({
398                let http = self.http.clone();
399                let node_runtime = self.node_runtime.clone();
400                move |this, cx| async move {
401                    clear_copilot_dir().await;
402                    Self::start_language_server(http, node_runtime, this, cx).await
403                }
404            })
405            .shared();
406
407        self.server = CopilotServer::Starting {
408            task: start_task.clone(),
409        };
410
411        cx.notify();
412
413        cx.foreground().spawn(start_task)
414    }
415
416    pub fn completions<T>(
417        &mut self,
418        buffer: &ModelHandle<Buffer>,
419        position: T,
420        cx: &mut ModelContext<Self>,
421    ) -> Task<Result<Vec<Completion>>>
422    where
423        T: ToPointUtf16,
424    {
425        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
426    }
427
428    pub fn completions_cycling<T>(
429        &mut self,
430        buffer: &ModelHandle<Buffer>,
431        position: T,
432        cx: &mut ModelContext<Self>,
433    ) -> Task<Result<Vec<Completion>>>
434    where
435        T: ToPointUtf16,
436    {
437        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
438    }
439
440    fn request_completions<R, T>(
441        &mut self,
442        buffer: &ModelHandle<Buffer>,
443        position: T,
444        cx: &mut ModelContext<Self>,
445    ) -> Task<Result<Vec<Completion>>>
446    where
447        R: lsp::request::Request<
448            Params = request::GetCompletionsParams,
449            Result = request::GetCompletionsResult,
450        >,
451        T: ToPointUtf16,
452    {
453        let buffer_id = buffer.id();
454        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
455        let snapshot = buffer.read(cx).snapshot();
456        let server = match &mut self.server {
457            CopilotServer::Starting { .. } => {
458                return Task::ready(Err(anyhow!("copilot is still starting")))
459            }
460            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
461            CopilotServer::Error(error) => {
462                return Task::ready(Err(anyhow!(
463                    "copilot was not started because of an error: {}",
464                    error
465                )))
466            }
467            CopilotServer::Started {
468                server,
469                status,
470                subscriptions_by_buffer_id,
471            } => {
472                if matches!(status, SignInStatus::Authorized { .. }) {
473                    subscriptions_by_buffer_id
474                        .entry(buffer_id)
475                        .or_insert_with(|| {
476                            server
477                                .notify::<lsp::notification::DidOpenTextDocument>(
478                                    lsp::DidOpenTextDocumentParams {
479                                        text_document: lsp::TextDocumentItem {
480                                            uri: uri.clone(),
481                                            language_id: id_for_language(
482                                                buffer.read(cx).language(),
483                                            ),
484                                            version: 0,
485                                            text: snapshot.text(),
486                                        },
487                                    },
488                                )
489                                .log_err();
490
491                            let uri = uri.clone();
492                            cx.observe_release(buffer, move |this, _, _| {
493                                if let CopilotServer::Started {
494                                    server,
495                                    subscriptions_by_buffer_id,
496                                    ..
497                                } = &mut this.server
498                                {
499                                    server
500                                        .notify::<lsp::notification::DidCloseTextDocument>(
501                                            lsp::DidCloseTextDocumentParams {
502                                                text_document: lsp::TextDocumentIdentifier::new(
503                                                    uri.clone(),
504                                                ),
505                                            },
506                                        )
507                                        .log_err();
508                                    subscriptions_by_buffer_id.remove(&buffer_id);
509                                }
510                            })
511                        });
512
513                    server.clone()
514                } else {
515                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
516                }
517            }
518        };
519
520        let settings = cx.global::<Settings>();
521        let position = position.to_point_utf16(&snapshot);
522        let language = snapshot.language_at(position);
523        let language_name = language.map(|language| language.name());
524        let language_name = language_name.as_deref();
525        let tab_size = settings.tab_size(language_name);
526        let hard_tabs = settings.hard_tabs(language_name);
527        let language_id = id_for_language(language);
528
529        let path;
530        let relative_path;
531        if let Some(file) = snapshot.file() {
532            if let Some(file) = file.as_local() {
533                path = file.abs_path(cx);
534            } else {
535                path = file.full_path(cx);
536            }
537            relative_path = file.path().to_path_buf();
538        } else {
539            path = PathBuf::new();
540            relative_path = PathBuf::new();
541        }
542
543        cx.background().spawn(async move {
544            let result = server
545                .request::<R>(request::GetCompletionsParams {
546                    doc: request::GetCompletionsDocument {
547                        source: snapshot.text(),
548                        tab_size: tab_size.into(),
549                        indent_size: 1,
550                        insert_spaces: !hard_tabs,
551                        uri,
552                        path: path.to_string_lossy().into(),
553                        relative_path: relative_path.to_string_lossy().into(),
554                        language_id,
555                        position: point_to_lsp(position),
556                        version: 0,
557                    },
558                })
559                .await?;
560            let completions = result
561                .completions
562                .into_iter()
563                .map(|completion| {
564                    let start = snapshot
565                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
566                    let end =
567                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
568                    Completion {
569                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
570                        text: completion.text,
571                    }
572                })
573                .collect();
574            anyhow::Ok(completions)
575        })
576    }
577
578    pub fn status(&self) -> Status {
579        match &self.server {
580            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
581            CopilotServer::Disabled => Status::Disabled,
582            CopilotServer::Error(error) => Status::Error(error.clone()),
583            CopilotServer::Started { status, .. } => match status {
584                SignInStatus::Authorized { .. } => Status::Authorized,
585                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
586                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
587                    prompt: prompt.clone(),
588                },
589                SignInStatus::SignedOut => Status::SignedOut,
590            },
591        }
592    }
593
594    fn update_sign_in_status(
595        &mut self,
596        lsp_status: request::SignInStatus,
597        cx: &mut ModelContext<Self>,
598    ) {
599        if let CopilotServer::Started { status, .. } = &mut self.server {
600            *status = match lsp_status {
601                request::SignInStatus::Ok { user }
602                | request::SignInStatus::MaybeOk { user }
603                | request::SignInStatus::AlreadySignedIn { user } => {
604                    SignInStatus::Authorized { _user: user }
605                }
606                request::SignInStatus::NotAuthorized { user } => {
607                    SignInStatus::Unauthorized { _user: user }
608                }
609                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
610            };
611            cx.notify();
612        }
613    }
614}
615
616fn id_for_language(language: Option<&Arc<Language>>) -> String {
617    let language_name = language.map(|language| language.name());
618    match language_name.as_deref() {
619        Some("Plain Text") => "plaintext".to_string(),
620        Some(language_name) => language_name.to_lowercase(),
621        None => "plaintext".to_string(),
622    }
623}
624
625async fn clear_copilot_dir() {
626    remove_matching(&paths::COPILOT_DIR, |_| true).await
627}
628
629async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
630    const SERVER_PATH: &'static str = "dist/agent.js";
631
632    ///Check for the latest copilot language server and download it if we haven't already
633    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
634        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
635
636        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
637
638        fs::create_dir_all(version_dir).await?;
639        let server_path = version_dir.join(SERVER_PATH);
640
641        if fs::metadata(&server_path).await.is_err() {
642            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
643            let dist_dir = version_dir.join("dist");
644            fs::create_dir_all(dist_dir.as_path()).await?;
645
646            let url = &release
647                .assets
648                .get(0)
649                .context("Github release for copilot contained no assets")?
650                .browser_download_url;
651
652            let mut response = http
653                .get(&url, Default::default(), true)
654                .await
655                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
656            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
657            let archive = Archive::new(decompressed_bytes);
658            archive.unpack(dist_dir).await?;
659
660            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
661        }
662
663        Ok(server_path)
664    }
665
666    match fetch_latest(http).await {
667        ok @ Result::Ok(..) => ok,
668        e @ Err(..) => {
669            e.log_err();
670            // Fetch a cached binary, if it exists
671            (|| async move {
672                let mut last_version_dir = None;
673                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
674                while let Some(entry) = entries.next().await {
675                    let entry = entry?;
676                    if entry.file_type().await?.is_dir() {
677                        last_version_dir = Some(entry.path());
678                    }
679                }
680                let last_version_dir =
681                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
682                let server_path = last_version_dir.join(SERVER_PATH);
683                if server_path.exists() {
684                    Ok(server_path)
685                } else {
686                    Err(anyhow!(
687                        "missing executable in directory {:?}",
688                        last_version_dir
689                    ))
690                }
691            })()
692            .await
693        }
694    }
695}