copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, bail, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 14use log::{debug, error};
 15use lsp::LanguageServer;
 16use node_runtime::NodeRuntime;
 17use request::{LogMessage, StatusNotification};
 18use settings::Settings;
 19use smol::{fs, io::BufReader, stream::StreamExt};
 20use std::{
 21    ffi::OsString,
 22    ops::Range,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25};
 26use util::{
 27    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 28};
 29
 30const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 31actions!(copilot_auth, [SignIn, SignOut]);
 32
 33const COPILOT_NAMESPACE: &'static str = "copilot";
 34actions!(
 35    copilot,
 36    [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
 37);
 38
 39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 40    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 41    cx.set_global(copilot.clone());
 42    cx.add_global_action(|_: &SignIn, cx| {
 43        let copilot = Copilot::global(cx).unwrap();
 44        copilot
 45            .update(cx, |copilot, cx| copilot.sign_in(cx))
 46            .detach_and_log_err(cx);
 47    });
 48    cx.add_global_action(|_: &SignOut, cx| {
 49        let copilot = Copilot::global(cx).unwrap();
 50        copilot
 51            .update(cx, |copilot, cx| copilot.sign_out(cx))
 52            .detach_and_log_err(cx);
 53    });
 54
 55    cx.add_global_action(|_: &Reinstall, cx| {
 56        let copilot = Copilot::global(cx).unwrap();
 57        copilot
 58            .update(cx, |copilot, cx| copilot.reinstall(cx))
 59            .detach();
 60    });
 61
 62    cx.observe(&copilot, |handle, cx| {
 63        let status = handle.read(cx).status();
 64        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 65            move |filter, _cx| match status {
 66                Status::Disabled => {
 67                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 68                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 69                }
 70                Status::Authorized => {
 71                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 72                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 73                }
 74                _ => {
 75                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 76                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 77                }
 78            },
 79        );
 80    })
 81    .detach();
 82
 83    sign_in::init(cx);
 84}
 85
 86enum CopilotServer {
 87    Disabled,
 88    Starting {
 89        task: Shared<Task<()>>,
 90    },
 91    Error(Arc<str>),
 92    Started {
 93        server: Arc<LanguageServer>,
 94        status: SignInStatus,
 95    },
 96}
 97
 98#[derive(Clone, Debug)]
 99enum SignInStatus {
100    Authorized {
101        _user: String,
102    },
103    Unauthorized {
104        _user: String,
105    },
106    SigningIn {
107        prompt: Option<request::PromptUserDeviceFlow>,
108        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
109    },
110    SignedOut,
111}
112
113#[derive(Debug, Clone)]
114pub enum Status {
115    Starting {
116        task: Shared<Task<()>>,
117    },
118    Error(Arc<str>),
119    Disabled,
120    SignedOut,
121    SigningIn {
122        prompt: Option<request::PromptUserDeviceFlow>,
123    },
124    Unauthorized,
125    Authorized,
126}
127
128impl Status {
129    pub fn is_authorized(&self) -> bool {
130        matches!(self, Status::Authorized)
131    }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub struct Completion {
136    pub range: Range<Anchor>,
137    pub text: String,
138}
139
140pub struct Copilot {
141    http: Arc<dyn HttpClient>,
142    node_runtime: Arc<NodeRuntime>,
143    server: CopilotServer,
144}
145
146impl Entity for Copilot {
147    type Event = ();
148}
149
150impl Copilot {
151    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
152        match self.server {
153            CopilotServer::Starting { ref task } => Some(task.clone()),
154            _ => None,
155        }
156    }
157
158    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
159        if cx.has_global::<ModelHandle<Self>>() {
160            Some(cx.global::<ModelHandle<Self>>().clone())
161        } else {
162            None
163        }
164    }
165
166    fn start(
167        http: Arc<dyn HttpClient>,
168        node_runtime: Arc<NodeRuntime>,
169        cx: &mut ModelContext<Self>,
170    ) -> Self {
171        cx.observe_global::<Settings, _>({
172            let http = http.clone();
173            let node_runtime = node_runtime.clone();
174            move |this, cx| {
175                if cx.global::<Settings>().enable_copilot_integration {
176                    if matches!(this.server, CopilotServer::Disabled) {
177                        let start_task = cx
178                            .spawn({
179                                let http = http.clone();
180                                let node_runtime = node_runtime.clone();
181                                move |this, cx| {
182                                    Self::start_language_server(http, node_runtime, this, cx)
183                                }
184                            })
185                            .shared();
186                        this.server = CopilotServer::Starting { task: start_task };
187                        cx.notify();
188                    }
189                } else {
190                    this.server = CopilotServer::Disabled;
191                    cx.notify();
192                }
193            }
194        })
195        .detach();
196
197        if cx.global::<Settings>().enable_copilot_integration {
198            let start_task = cx
199                .spawn({
200                    let http = http.clone();
201                    let node_runtime = node_runtime.clone();
202                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
203                })
204                .shared();
205
206            Self {
207                http,
208                node_runtime,
209                server: CopilotServer::Starting { task: start_task },
210            }
211        } else {
212            Self {
213                http,
214                node_runtime,
215                server: CopilotServer::Disabled,
216            }
217        }
218    }
219
220    fn start_language_server(
221        http: Arc<dyn HttpClient>,
222        node_runtime: Arc<NodeRuntime>,
223        this: ModelHandle<Self>,
224        mut cx: AsyncAppContext,
225    ) -> impl Future<Output = ()> {
226        async move {
227            let start_language_server = async {
228                let server_path = get_copilot_lsp(http).await?;
229                let node_path = node_runtime.binary_path().await?;
230                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
231                let server = LanguageServer::new(
232                    0,
233                    &node_path,
234                    arguments,
235                    Path::new("/"),
236                    None,
237                    cx.clone(),
238                )?;
239
240                let server = server.initialize(Default::default()).await?;
241                let status = server
242                    .request::<request::CheckStatus>(request::CheckStatusParams {
243                        local_checks_only: false,
244                    })
245                    .await?;
246
247                server
248                    .on_notification::<LogMessage, _>(|params, _cx| {
249                        match params.level {
250                            // Copilot is pretty agressive about logging
251                            0 => debug!("copilot: {}", params.message),
252                            1 => debug!("copilot: {}", params.message),
253                            _ => error!("copilot: {}", params.message),
254                        }
255
256                        debug!("copilot metadata: {}", params.metadata_str);
257                        debug!("copilot extra: {:?}", params.extra);
258                    })
259                    .detach();
260
261                server
262                    .on_notification::<StatusNotification, _>(
263                        |_, _| { /* Silence the notification */ },
264                    )
265                    .detach();
266
267                anyhow::Ok((server, status))
268            };
269
270            let server = start_language_server.await;
271            this.update(&mut cx, |this, cx| {
272                cx.notify();
273                match server {
274                    Ok((server, status)) => {
275                        this.server = CopilotServer::Started {
276                            server,
277                            status: SignInStatus::SignedOut,
278                        };
279                        this.update_sign_in_status(status, cx);
280                    }
281                    Err(error) => {
282                        this.server = CopilotServer::Error(error.to_string().into());
283                        cx.notify()
284                    }
285                }
286            })
287        }
288    }
289
290    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
291        if let CopilotServer::Started { server, status } = &mut self.server {
292            let task = match status {
293                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
294                    Task::ready(Ok(())).shared()
295                }
296                SignInStatus::SigningIn { task, .. } => {
297                    cx.notify();
298                    task.clone()
299                }
300                SignInStatus::SignedOut => {
301                    let server = server.clone();
302                    let task = cx
303                        .spawn(|this, mut cx| async move {
304                            let sign_in = async {
305                                let sign_in = server
306                                    .request::<request::SignInInitiate>(
307                                        request::SignInInitiateParams {},
308                                    )
309                                    .await?;
310                                match sign_in {
311                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
312                                        Ok(request::SignInStatus::Ok { user })
313                                    }
314                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
315                                        this.update(&mut cx, |this, cx| {
316                                            if let CopilotServer::Started { status, .. } =
317                                                &mut this.server
318                                            {
319                                                if let SignInStatus::SigningIn {
320                                                    prompt: prompt_flow,
321                                                    ..
322                                                } = status
323                                                {
324                                                    *prompt_flow = Some(flow.clone());
325                                                    cx.notify();
326                                                }
327                                            }
328                                        });
329                                        let response = server
330                                            .request::<request::SignInConfirm>(
331                                                request::SignInConfirmParams {
332                                                    user_code: flow.user_code,
333                                                },
334                                            )
335                                            .await?;
336                                        Ok(response)
337                                    }
338                                }
339                            };
340
341                            let sign_in = sign_in.await;
342                            this.update(&mut cx, |this, cx| match sign_in {
343                                Ok(status) => {
344                                    this.update_sign_in_status(status, cx);
345                                    Ok(())
346                                }
347                                Err(error) => {
348                                    this.update_sign_in_status(
349                                        request::SignInStatus::NotSignedIn,
350                                        cx,
351                                    );
352                                    Err(Arc::new(error))
353                                }
354                            })
355                        })
356                        .shared();
357                    *status = SignInStatus::SigningIn {
358                        prompt: None,
359                        task: task.clone(),
360                    };
361                    cx.notify();
362                    task
363                }
364            };
365
366            cx.foreground()
367                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
368        } else {
369            // If we're downloading, wait until download is finished
370            // If we're in a stuck state, display to the user
371            Task::ready(Err(anyhow!("copilot hasn't started yet")))
372        }
373    }
374
375    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
376        if let CopilotServer::Started { server, status } = &mut self.server {
377            *status = SignInStatus::SignedOut;
378            cx.notify();
379
380            let server = server.clone();
381            cx.background().spawn(async move {
382                server
383                    .request::<request::SignOut>(request::SignOutParams {})
384                    .await?;
385                anyhow::Ok(())
386            })
387        } else {
388            Task::ready(Err(anyhow!("copilot hasn't started yet")))
389        }
390    }
391
392    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
393        let start_task = cx
394            .spawn({
395                let http = self.http.clone();
396                let node_runtime = self.node_runtime.clone();
397                move |this, cx| async move {
398                    clear_copilot_dir().await;
399                    Self::start_language_server(http, node_runtime, this, cx).await
400                }
401            })
402            .shared();
403
404        self.server = CopilotServer::Starting {
405            task: start_task.clone(),
406        };
407
408        cx.notify();
409
410        cx.foreground().spawn(start_task)
411    }
412
413    pub fn completion<T>(
414        &self,
415        buffer: &ModelHandle<Buffer>,
416        position: T,
417        cx: &mut ModelContext<Self>,
418    ) -> Task<Result<Option<Completion>>>
419    where
420        T: ToPointUtf16,
421    {
422        let server = match self.authorized_server() {
423            Ok(server) => server,
424            Err(error) => return Task::ready(Err(error)),
425        };
426
427        let buffer = buffer.read(cx);
428
429        if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
430            return Task::ready(Err(anyhow!("Copilot only works locally")));
431        }
432
433        let buffer = buffer.snapshot();
434        let request = server.request::<request::GetCompletions>(
435            build_completion_params(&buffer, position, cx).unwrap(),
436        );
437        cx.background().spawn(async move {
438            let result = request.await?;
439            let completion = result
440                .completions
441                .into_iter()
442                .next()
443                .map(|completion| completion_from_lsp(completion, &buffer));
444            anyhow::Ok(completion)
445        })
446    }
447
448    pub fn completions_cycling<T>(
449        &self,
450        buffer: &ModelHandle<Buffer>,
451        position: T,
452        cx: &mut ModelContext<Self>,
453    ) -> Task<Result<Vec<Completion>>>
454    where
455        T: ToPointUtf16,
456    {
457        let server = match self.authorized_server() {
458            Ok(server) => server,
459            Err(error) => return Task::ready(Err(error)),
460        };
461
462        let buffer = buffer.read(cx);
463
464        if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
465            return Task::ready(Err(anyhow!("Copilot only works locally")));
466        }
467
468        let buffer = buffer.snapshot();
469        let request = server.request::<request::GetCompletionsCycling>(
470            build_completion_params(&buffer, position, cx).unwrap(),
471        );
472        cx.background().spawn(async move {
473            let result = request.await?;
474            let completions = result
475                .completions
476                .into_iter()
477                .map(|completion| completion_from_lsp(completion, &buffer))
478                .collect();
479            anyhow::Ok(completions)
480        })
481    }
482
483    pub fn status(&self) -> Status {
484        match &self.server {
485            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
486            CopilotServer::Disabled => Status::Disabled,
487            CopilotServer::Error(error) => Status::Error(error.clone()),
488            CopilotServer::Started { status, .. } => match status {
489                SignInStatus::Authorized { .. } => Status::Authorized,
490                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
491                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
492                    prompt: prompt.clone(),
493                },
494                SignInStatus::SignedOut => Status::SignedOut,
495            },
496        }
497    }
498
499    fn update_sign_in_status(
500        &mut self,
501        lsp_status: request::SignInStatus,
502        cx: &mut ModelContext<Self>,
503    ) {
504        if let CopilotServer::Started { status, .. } = &mut self.server {
505            *status = match lsp_status {
506                request::SignInStatus::Ok { user }
507                | request::SignInStatus::MaybeOk { user }
508                | request::SignInStatus::AlreadySignedIn { user } => {
509                    SignInStatus::Authorized { _user: user }
510                }
511                request::SignInStatus::NotAuthorized { user } => {
512                    SignInStatus::Unauthorized { _user: user }
513                }
514                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
515            };
516            cx.notify();
517        }
518    }
519
520    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
521        match &self.server {
522            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
523            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
524            CopilotServer::Error(error) => Err(anyhow!(
525                "copilot was not started because of an error: {}",
526                error
527            )),
528            CopilotServer::Started { server, status } => {
529                if matches!(status, SignInStatus::Authorized { .. }) {
530                    Ok(server.clone())
531                } else {
532                    Err(anyhow!("must sign in before using copilot"))
533                }
534            }
535        }
536    }
537}
538
539fn build_completion_params<T>(
540    buffer: &BufferSnapshot,
541    position: T,
542    cx: &AppContext,
543) -> anyhow::Result<request::GetCompletionsParams>
544where
545    T: ToPointUtf16,
546{
547    let position = position.to_point_utf16(&buffer);
548    let language_name = buffer.language_at(position).map(|language| language.name());
549    let language_name = language_name.as_deref();
550
551    let path;
552    let relative_path;
553    if let Some(file) = buffer.file() {
554        if let Some(file) = file.as_local() {
555            path = file.abs_path(cx);
556        } else {
557            path = file.full_path(cx);
558        }
559        relative_path = file.path().to_path_buf();
560    } else {
561        path = PathBuf::from("/untitled");
562        relative_path = PathBuf::from("untitled");
563    }
564
565    let settings = cx.global::<Settings>();
566    let language_id = match language_name {
567        Some("Plain Text") => "plaintext".to_string(),
568        Some(language_name) => language_name.to_lowercase(),
569        None => "plaintext".to_string(),
570    };
571
572    let Ok(uri) = lsp::Url::from_file_path(&path) else {
573        bail!("Failed convert file path")
574    };
575
576    Ok(request::GetCompletionsParams {
577        doc: request::GetCompletionsDocument {
578            source: buffer.text(),
579            tab_size: settings.tab_size(language_name).into(),
580            indent_size: 1,
581            insert_spaces: !settings.hard_tabs(language_name),
582            uri,
583            path: path.to_string_lossy().into(),
584            relative_path: relative_path.to_string_lossy().into(),
585            language_id,
586            position: point_to_lsp(position),
587            version: 0,
588        },
589    })
590}
591
592fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
593    let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
594    let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
595    Completion {
596        range: buffer.anchor_before(start)..buffer.anchor_after(end),
597        text: completion.text,
598    }
599}
600
601async fn clear_copilot_dir() {
602    remove_matching(&paths::COPILOT_DIR, |_| true).await
603}
604
605async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
606    const SERVER_PATH: &'static str = "dist/agent.js";
607
608    ///Check for the latest copilot language server and download it if we haven't already
609    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
610        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
611
612        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
613
614        fs::create_dir_all(version_dir).await?;
615        let server_path = version_dir.join(SERVER_PATH);
616
617        if fs::metadata(&server_path).await.is_err() {
618            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
619            let dist_dir = version_dir.join("dist");
620            fs::create_dir_all(dist_dir.as_path()).await?;
621
622            let url = &release
623                .assets
624                .get(0)
625                .context("Github release for copilot contained no assets")?
626                .browser_download_url;
627
628            let mut response = http
629                .get(&url, Default::default(), true)
630                .await
631                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
632            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
633            let archive = Archive::new(decompressed_bytes);
634            archive.unpack(dist_dir).await?;
635
636            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
637        }
638
639        Ok(server_path)
640    }
641
642    match fetch_latest(http).await {
643        ok @ Result::Ok(..) => ok,
644        e @ Err(..) => {
645            e.log_err();
646            // Fetch a cached binary, if it exists
647            (|| async move {
648                let mut last_version_dir = None;
649                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
650                while let Some(entry) = entries.next().await {
651                    let entry = entry?;
652                    if entry.file_type().await?.is_dir() {
653                        last_version_dir = Some(entry.path());
654                    }
655                }
656                let last_version_dir =
657                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
658                let server_path = last_version_dir.join(SERVER_PATH);
659                if server_path.exists() {
660                    Ok(server_path)
661                } else {
662                    Err(anyhow!(
663                        "missing executable in directory {:?}",
664                        last_version_dir
665                    ))
666                }
667            })()
668            .await
669        }
670    }
671}