copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 14use lsp::LanguageServer;
 15use node_runtime::NodeRuntime;
 16use settings::Settings;
 17use smol::{fs, io::BufReader, stream::StreamExt};
 18use std::{
 19    ffi::OsString,
 20    ops::Range,
 21    path::{Path, PathBuf},
 22    sync::Arc,
 23};
 24use util::{
 25    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 26};
 27
 28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 29actions!(copilot_auth, [SignIn, SignOut]);
 30
 31const COPILOT_NAMESPACE: &'static str = "copilot";
 32actions!(
 33    copilot,
 34    [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
 35);
 36
 37pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 38    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 39    cx.set_global(copilot.clone());
 40    cx.add_global_action(|_: &SignIn, cx| {
 41        let copilot = Copilot::global(cx).unwrap();
 42        copilot
 43            .update(cx, |copilot, cx| copilot.sign_in(cx))
 44            .detach_and_log_err(cx);
 45    });
 46    cx.add_global_action(|_: &SignOut, cx| {
 47        let copilot = Copilot::global(cx).unwrap();
 48        copilot
 49            .update(cx, |copilot, cx| copilot.sign_out(cx))
 50            .detach_and_log_err(cx);
 51    });
 52
 53    cx.add_global_action(|_: &Reinstall, cx| {
 54        let copilot = Copilot::global(cx).unwrap();
 55        copilot
 56            .update(cx, |copilot, cx| copilot.reinstall(cx))
 57            .detach();
 58    });
 59
 60    cx.observe(&copilot, |handle, cx| {
 61        let status = handle.read(cx).status();
 62        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 63            move |filter, _cx| match status {
 64                Status::Disabled => {
 65                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 66                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 67                }
 68                Status::Authorized => {
 69                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 70                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 71                }
 72                _ => {
 73                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 74                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 75                }
 76            },
 77        );
 78    })
 79    .detach();
 80
 81    sign_in::init(cx);
 82}
 83
 84enum CopilotServer {
 85    Disabled,
 86    Starting {
 87        task: Shared<Task<()>>,
 88    },
 89    Error(Arc<str>),
 90    Started {
 91        server: Arc<LanguageServer>,
 92        status: SignInStatus,
 93    },
 94}
 95
 96#[derive(Clone, Debug)]
 97enum SignInStatus {
 98    Authorized {
 99        _user: String,
100    },
101    Unauthorized {
102        _user: String,
103    },
104    SigningIn {
105        prompt: Option<request::PromptUserDeviceFlow>,
106        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
107    },
108    SignedOut,
109}
110
111#[derive(Debug, Clone)]
112pub enum Status {
113    Starting {
114        task: Shared<Task<()>>,
115    },
116    Error(Arc<str>),
117    Disabled,
118    SignedOut,
119    SigningIn {
120        prompt: Option<request::PromptUserDeviceFlow>,
121    },
122    Unauthorized,
123    Authorized,
124}
125
126impl Status {
127    pub fn is_authorized(&self) -> bool {
128        matches!(self, Status::Authorized)
129    }
130}
131
132#[derive(Debug, PartialEq, Eq)]
133pub struct Completion {
134    pub range: Range<Anchor>,
135    pub text: String,
136}
137
138pub struct Copilot {
139    http: Arc<dyn HttpClient>,
140    node_runtime: Arc<NodeRuntime>,
141    server: CopilotServer,
142}
143
144impl Entity for Copilot {
145    type Event = ();
146}
147
148impl Copilot {
149    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
150        match self.server {
151            CopilotServer::Starting { ref task } => Some(task.clone()),
152            _ => None,
153        }
154    }
155
156    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
157        if cx.has_global::<ModelHandle<Self>>() {
158            Some(cx.global::<ModelHandle<Self>>().clone())
159        } else {
160            None
161        }
162    }
163
164    fn start(
165        http: Arc<dyn HttpClient>,
166        node_runtime: Arc<NodeRuntime>,
167        cx: &mut ModelContext<Self>,
168    ) -> Self {
169        cx.observe_global::<Settings, _>({
170            let http = http.clone();
171            let node_runtime = node_runtime.clone();
172            move |this, cx| {
173                if cx.global::<Settings>().enable_copilot_integration {
174                    if matches!(this.server, CopilotServer::Disabled) {
175                        let start_task = cx
176                            .spawn({
177                                let http = http.clone();
178                                let node_runtime = node_runtime.clone();
179                                move |this, cx| {
180                                    Self::start_language_server(http, node_runtime, this, cx)
181                                }
182                            })
183                            .shared();
184                        this.server = CopilotServer::Starting { task: start_task };
185                        cx.notify();
186                    }
187                } else {
188                    this.server = CopilotServer::Disabled;
189                    cx.notify();
190                }
191            }
192        })
193        .detach();
194
195        if cx.global::<Settings>().enable_copilot_integration {
196            let start_task = cx
197                .spawn({
198                    let http = http.clone();
199                    let node_runtime = node_runtime.clone();
200                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
201                })
202                .shared();
203
204            Self {
205                http,
206                node_runtime,
207                server: CopilotServer::Starting { task: start_task },
208            }
209        } else {
210            Self {
211                http,
212                node_runtime,
213                server: CopilotServer::Disabled,
214            }
215        }
216    }
217
218    fn start_language_server(
219        http: Arc<dyn HttpClient>,
220        node_runtime: Arc<NodeRuntime>,
221        this: ModelHandle<Self>,
222        mut cx: AsyncAppContext,
223    ) -> impl Future<Output = ()> {
224        async move {
225            let start_language_server = async {
226                let server_path = get_copilot_lsp(http).await?;
227                let node_path = node_runtime.binary_path().await?;
228                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
229                let server = LanguageServer::new(
230                    0,
231                    &node_path,
232                    arguments,
233                    Path::new("/"),
234                    None,
235                    cx.clone(),
236                )?;
237
238                let server = server.initialize(Default::default()).await?;
239                let status = server
240                    .request::<request::CheckStatus>(request::CheckStatusParams {
241                        local_checks_only: false,
242                    })
243                    .await?;
244                anyhow::Ok((server, status))
245            };
246
247            let server = start_language_server.await;
248            this.update(&mut cx, |this, cx| {
249                cx.notify();
250                match server {
251                    Ok((server, status)) => {
252                        this.server = CopilotServer::Started {
253                            server,
254                            status: SignInStatus::SignedOut,
255                        };
256                        this.update_sign_in_status(status, cx);
257                    }
258                    Err(error) => {
259                        this.server = CopilotServer::Error(error.to_string().into());
260                        cx.notify()
261                    }
262                }
263            })
264        }
265    }
266
267    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
268        if let CopilotServer::Started { server, status } = &mut self.server {
269            let task = match status {
270                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
271                    Task::ready(Ok(())).shared()
272                }
273                SignInStatus::SigningIn { task, .. } => {
274                    cx.notify();
275                    task.clone()
276                }
277                SignInStatus::SignedOut => {
278                    let server = server.clone();
279                    let task = cx
280                        .spawn(|this, mut cx| async move {
281                            let sign_in = async {
282                                let sign_in = server
283                                    .request::<request::SignInInitiate>(
284                                        request::SignInInitiateParams {},
285                                    )
286                                    .await?;
287                                match sign_in {
288                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
289                                        Ok(request::SignInStatus::Ok { user })
290                                    }
291                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
292                                        this.update(&mut cx, |this, cx| {
293                                            if let CopilotServer::Started { status, .. } =
294                                                &mut this.server
295                                            {
296                                                if let SignInStatus::SigningIn {
297                                                    prompt: prompt_flow,
298                                                    ..
299                                                } = status
300                                                {
301                                                    *prompt_flow = Some(flow.clone());
302                                                    cx.notify();
303                                                }
304                                            }
305                                        });
306                                        let response = server
307                                            .request::<request::SignInConfirm>(
308                                                request::SignInConfirmParams {
309                                                    user_code: flow.user_code,
310                                                },
311                                            )
312                                            .await?;
313                                        Ok(response)
314                                    }
315                                }
316                            };
317
318                            let sign_in = sign_in.await;
319                            this.update(&mut cx, |this, cx| match sign_in {
320                                Ok(status) => {
321                                    this.update_sign_in_status(status, cx);
322                                    Ok(())
323                                }
324                                Err(error) => {
325                                    this.update_sign_in_status(
326                                        request::SignInStatus::NotSignedIn,
327                                        cx,
328                                    );
329                                    Err(Arc::new(error))
330                                }
331                            })
332                        })
333                        .shared();
334                    *status = SignInStatus::SigningIn {
335                        prompt: None,
336                        task: task.clone(),
337                    };
338                    cx.notify();
339                    task
340                }
341            };
342
343            cx.foreground()
344                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
345        } else {
346            // If we're downloading, wait until download is finished
347            // If we're in a stuck state, display to the user
348            Task::ready(Err(anyhow!("copilot hasn't started yet")))
349        }
350    }
351
352    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
353        if let CopilotServer::Started { server, status } = &mut self.server {
354            *status = SignInStatus::SignedOut;
355            cx.notify();
356
357            let server = server.clone();
358            cx.background().spawn(async move {
359                server
360                    .request::<request::SignOut>(request::SignOutParams {})
361                    .await?;
362                anyhow::Ok(())
363            })
364        } else {
365            Task::ready(Err(anyhow!("copilot hasn't started yet")))
366        }
367    }
368
369    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
370        let start_task = cx
371            .spawn({
372                let http = self.http.clone();
373                let node_runtime = self.node_runtime.clone();
374                move |this, cx| async move {
375                    clear_copilot_dir().await;
376                    Self::start_language_server(http, node_runtime, this, cx).await
377                }
378            })
379            .shared();
380
381        self.server = CopilotServer::Starting {
382            task: start_task.clone(),
383        };
384
385        cx.notify();
386
387        cx.foreground().spawn(start_task)
388    }
389
390    pub fn completion<T>(
391        &self,
392        buffer: &ModelHandle<Buffer>,
393        position: T,
394        cx: &mut ModelContext<Self>,
395    ) -> Task<Result<Option<Completion>>>
396    where
397        T: ToPointUtf16,
398    {
399        let server = match self.authorized_server() {
400            Ok(server) => server,
401            Err(error) => return Task::ready(Err(error)),
402        };
403
404        let buffer = buffer.read(cx).snapshot();
405        let request = server
406            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
407        cx.background().spawn(async move {
408            let result = request.await?;
409            let completion = result
410                .completions
411                .into_iter()
412                .next()
413                .map(|completion| completion_from_lsp(completion, &buffer));
414            anyhow::Ok(completion)
415        })
416    }
417
418    pub fn completions_cycling<T>(
419        &self,
420        buffer: &ModelHandle<Buffer>,
421        position: T,
422        cx: &mut ModelContext<Self>,
423    ) -> Task<Result<Vec<Completion>>>
424    where
425        T: ToPointUtf16,
426    {
427        let server = match self.authorized_server() {
428            Ok(server) => server,
429            Err(error) => return Task::ready(Err(error)),
430        };
431
432        let buffer = buffer.read(cx).snapshot();
433        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
434            &buffer, position, cx,
435        ));
436        cx.background().spawn(async move {
437            let result = request.await?;
438            let completions = result
439                .completions
440                .into_iter()
441                .map(|completion| completion_from_lsp(completion, &buffer))
442                .collect();
443            anyhow::Ok(completions)
444        })
445    }
446
447    pub fn status(&self) -> Status {
448        match &self.server {
449            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
450            CopilotServer::Disabled => Status::Disabled,
451            CopilotServer::Error(error) => Status::Error(error.clone()),
452            CopilotServer::Started { status, .. } => match status {
453                SignInStatus::Authorized { .. } => Status::Authorized,
454                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
455                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
456                    prompt: prompt.clone(),
457                },
458                SignInStatus::SignedOut => Status::SignedOut,
459            },
460        }
461    }
462
463    fn update_sign_in_status(
464        &mut self,
465        lsp_status: request::SignInStatus,
466        cx: &mut ModelContext<Self>,
467    ) {
468        if let CopilotServer::Started { status, .. } = &mut self.server {
469            *status = match lsp_status {
470                request::SignInStatus::Ok { user }
471                | request::SignInStatus::MaybeOk { user }
472                | request::SignInStatus::AlreadySignedIn { user } => {
473                    SignInStatus::Authorized { _user: user }
474                }
475                request::SignInStatus::NotAuthorized { user } => {
476                    SignInStatus::Unauthorized { _user: user }
477                }
478                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
479            };
480            cx.notify();
481        }
482    }
483
484    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
485        match &self.server {
486            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
487            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
488            CopilotServer::Error(error) => Err(anyhow!(
489                "copilot was not started because of an error: {}",
490                error
491            )),
492            CopilotServer::Started { server, status } => {
493                if matches!(status, SignInStatus::Authorized { .. }) {
494                    Ok(server.clone())
495                } else {
496                    Err(anyhow!("must sign in before using copilot"))
497                }
498            }
499        }
500    }
501}
502
503fn build_completion_params<T>(
504    buffer: &BufferSnapshot,
505    position: T,
506    cx: &AppContext,
507) -> request::GetCompletionsParams
508where
509    T: ToPointUtf16,
510{
511    let position = position.to_point_utf16(&buffer);
512    let language_name = buffer.language_at(position).map(|language| language.name());
513    let language_name = language_name.as_deref();
514
515    let path;
516    let relative_path;
517    if let Some(file) = buffer.file() {
518        if let Some(file) = file.as_local() {
519            path = file.abs_path(cx);
520        } else {
521            path = file.full_path(cx);
522        }
523        relative_path = file.path().to_path_buf();
524    } else {
525        path = PathBuf::from("/untitled");
526        relative_path = PathBuf::from("untitled");
527    }
528
529    let settings = cx.global::<Settings>();
530    let language_id = match language_name {
531        Some("Plain Text") => "plaintext".to_string(),
532        Some(language_name) => language_name.to_lowercase(),
533        None => "plaintext".to_string(),
534    };
535    request::GetCompletionsParams {
536        doc: request::GetCompletionsDocument {
537            source: buffer.text(),
538            tab_size: settings.tab_size(language_name).into(),
539            indent_size: 1,
540            insert_spaces: !settings.hard_tabs(language_name),
541            uri: lsp::Url::from_file_path(&path).unwrap(),
542            path: path.to_string_lossy().into(),
543            relative_path: relative_path.to_string_lossy().into(),
544            language_id,
545            position: point_to_lsp(position),
546            version: 0,
547        },
548    }
549}
550
551fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
552    let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
553    let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
554    Completion {
555        range: buffer.anchor_before(start)..buffer.anchor_after(end),
556        text: completion.text,
557    }
558}
559
560async fn clear_copilot_dir() {
561    remove_matching(&paths::COPILOT_DIR, |_| true).await
562}
563
564async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
565    const SERVER_PATH: &'static str = "dist/agent.js";
566
567    ///Check for the latest copilot language server and download it if we haven't already
568    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
569        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
570
571        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
572
573        fs::create_dir_all(version_dir).await?;
574        let server_path = version_dir.join(SERVER_PATH);
575
576        if fs::metadata(&server_path).await.is_err() {
577            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
578            let dist_dir = version_dir.join("dist");
579            fs::create_dir_all(dist_dir.as_path()).await?;
580
581            let url = &release
582                .assets
583                .get(0)
584                .context("Github release for copilot contained no assets")?
585                .browser_download_url;
586
587            let mut response = http
588                .get(&url, Default::default(), true)
589                .await
590                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
591            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
592            let archive = Archive::new(decompressed_bytes);
593            archive.unpack(dist_dir).await?;
594
595            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
596        }
597
598        Ok(server_path)
599    }
600
601    match fetch_latest(http).await {
602        ok @ Result::Ok(..) => ok,
603        e @ Err(..) => {
604            e.log_err();
605            // Fetch a cached binary, if it exists
606            (|| async move {
607                let mut last_version_dir = None;
608                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
609                while let Some(entry) = entries.next().await {
610                    let entry = entry?;
611                    if entry.file_type().await?.is_dir() {
612                        last_version_dir = Some(entry.path());
613                    }
614                }
615                let last_version_dir =
616                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
617                let server_path = last_version_dir.join(SERVER_PATH);
618                if server_path.exists() {
619                    Ok(server_path)
620                } else {
621                    Err(anyhow!(
622                        "missing executable in directory {:?}",
623                        last_version_dir
624                    ))
625                }
626            })()
627            .await
628        }
629    }
630}