copilot.rs

  1pub mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use collections::HashMap;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 14use log::{debug, error};
 15use lsp::LanguageServer;
 16use node_runtime::NodeRuntime;
 17use request::{LogMessage, StatusNotification};
 18use settings::Settings;
 19use smol::{fs, io::BufReader, stream::StreamExt};
 20use std::{
 21    ffi::OsString,
 22    ops::Range,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25};
 26use util::{
 27    channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
 28    paths, ResultExt,
 29};
 30
 31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 32actions!(copilot_auth, [SignIn, SignOut]);
 33
 34const COPILOT_NAMESPACE: &'static str = "copilot";
 35actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 36
 37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 38    // Disable Copilot for stable releases.
 39    if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
 40        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 41            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 42            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 43        });
 44        return;
 45    }
 46
 47    let copilot = cx.add_model({
 48        let node_runtime = node_runtime.clone();
 49        move |cx| Copilot::start(http, node_runtime, cx)
 50    });
 51    cx.set_global(copilot.clone());
 52
 53    sign_in::init(cx);
 54    cx.add_global_action(|_: &SignIn, cx| {
 55        if let Some(copilot) = Copilot::global(cx) {
 56            copilot
 57                .update(cx, |copilot, cx| copilot.sign_in(cx))
 58                .detach_and_log_err(cx);
 59        }
 60    });
 61    cx.add_global_action(|_: &SignOut, cx| {
 62        if let Some(copilot) = Copilot::global(cx) {
 63            copilot
 64                .update(cx, |copilot, cx| copilot.sign_out(cx))
 65                .detach_and_log_err(cx);
 66        }
 67    });
 68
 69    cx.add_global_action(|_: &Reinstall, cx| {
 70        if let Some(copilot) = Copilot::global(cx) {
 71            copilot
 72                .update(cx, |copilot, cx| copilot.reinstall(cx))
 73                .detach();
 74        }
 75    });
 76}
 77
 78enum CopilotServer {
 79    Disabled,
 80    Starting {
 81        task: Shared<Task<()>>,
 82    },
 83    Error(Arc<str>),
 84    Started {
 85        server: Arc<LanguageServer>,
 86        status: SignInStatus,
 87        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
 88    },
 89}
 90
 91#[derive(Clone, Debug)]
 92enum SignInStatus {
 93    Authorized,
 94    Unauthorized,
 95    SigningIn {
 96        prompt: Option<request::PromptUserDeviceFlow>,
 97        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 98    },
 99    SignedOut,
100}
101
102#[derive(Debug, Clone)]
103pub enum Status {
104    Starting {
105        task: Shared<Task<()>>,
106    },
107    Error(Arc<str>),
108    Disabled,
109    SignedOut,
110    SigningIn {
111        prompt: Option<request::PromptUserDeviceFlow>,
112    },
113    Unauthorized,
114    Authorized,
115}
116
117impl Status {
118    pub fn is_authorized(&self) -> bool {
119        matches!(self, Status::Authorized)
120    }
121}
122
123#[derive(Debug, PartialEq, Eq)]
124pub struct Completion {
125    pub range: Range<Anchor>,
126    pub text: String,
127}
128
129pub struct Copilot {
130    http: Arc<dyn HttpClient>,
131    node_runtime: Arc<NodeRuntime>,
132    server: CopilotServer,
133}
134
135impl Entity for Copilot {
136    type Event = ();
137}
138
139impl Copilot {
140    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
141        if cx.has_global::<ModelHandle<Self>>() {
142            Some(cx.global::<ModelHandle<Self>>().clone())
143        } else {
144            None
145        }
146    }
147
148    fn start(
149        http: Arc<dyn HttpClient>,
150        node_runtime: Arc<NodeRuntime>,
151        cx: &mut ModelContext<Self>,
152    ) -> Self {
153        cx.observe_global::<Settings, _>({
154            let http = http.clone();
155            let node_runtime = node_runtime.clone();
156            move |this, cx| {
157                if cx.global::<Settings>().enable_copilot_integration {
158                    if matches!(this.server, CopilotServer::Disabled) {
159                        let start_task = cx
160                            .spawn({
161                                let http = http.clone();
162                                let node_runtime = node_runtime.clone();
163                                move |this, cx| {
164                                    Self::start_language_server(http, node_runtime, this, cx)
165                                }
166                            })
167                            .shared();
168                        this.server = CopilotServer::Starting { task: start_task };
169                        cx.notify();
170                    }
171                } else {
172                    this.server = CopilotServer::Disabled;
173                    cx.notify();
174                }
175            }
176        })
177        .detach();
178
179        if cx.global::<Settings>().enable_copilot_integration {
180            let start_task = cx
181                .spawn({
182                    let http = http.clone();
183                    let node_runtime = node_runtime.clone();
184                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
185                })
186                .shared();
187
188            Self {
189                http,
190                node_runtime,
191                server: CopilotServer::Starting { task: start_task },
192            }
193        } else {
194            Self {
195                http,
196                node_runtime,
197                server: CopilotServer::Disabled,
198            }
199        }
200    }
201
202    #[cfg(any(test, feature = "test-support"))]
203    pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
204        let (server, fake_server) =
205            LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
206        let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
207        let this = cx.add_model(|cx| Self {
208            http: http.clone(),
209            node_runtime: NodeRuntime::new(http, cx.background().clone()),
210            server: CopilotServer::Started {
211                server: Arc::new(server),
212                status: SignInStatus::Authorized,
213                subscriptions_by_buffer_id: Default::default(),
214            },
215        });
216        (this, fake_server)
217    }
218
219    fn start_language_server(
220        http: Arc<dyn HttpClient>,
221        node_runtime: Arc<NodeRuntime>,
222        this: ModelHandle<Self>,
223        mut cx: AsyncAppContext,
224    ) -> impl Future<Output = ()> {
225        async move {
226            let start_language_server = async {
227                let server_path = get_copilot_lsp(http).await?;
228                let node_path = node_runtime.binary_path().await?;
229                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
230                let server = LanguageServer::new(
231                    0,
232                    &node_path,
233                    arguments,
234                    Path::new("/"),
235                    None,
236                    cx.clone(),
237                )?;
238
239                let server = server.initialize(Default::default()).await?;
240                let status = server
241                    .request::<request::CheckStatus>(request::CheckStatusParams {
242                        local_checks_only: false,
243                    })
244                    .await?;
245
246                server
247                    .on_notification::<LogMessage, _>(|params, _cx| {
248                        match params.level {
249                            // Copilot is pretty agressive about logging
250                            0 => debug!("copilot: {}", params.message),
251                            1 => debug!("copilot: {}", params.message),
252                            _ => error!("copilot: {}", params.message),
253                        }
254
255                        debug!("copilot metadata: {}", params.metadata_str);
256                        debug!("copilot extra: {:?}", params.extra);
257                    })
258                    .detach();
259
260                server
261                    .on_notification::<StatusNotification, _>(
262                        |_, _| { /* Silence the notification */ },
263                    )
264                    .detach();
265
266                anyhow::Ok((server, status))
267            };
268
269            let server = start_language_server.await;
270            this.update(&mut cx, |this, cx| {
271                cx.notify();
272                match server {
273                    Ok((server, status)) => {
274                        this.server = CopilotServer::Started {
275                            server,
276                            status: SignInStatus::SignedOut,
277                            subscriptions_by_buffer_id: Default::default(),
278                        };
279                        this.update_sign_in_status(status, cx);
280                    }
281                    Err(error) => {
282                        this.server = CopilotServer::Error(error.to_string().into());
283                        cx.notify()
284                    }
285                }
286            })
287        }
288    }
289
290    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
291        if let CopilotServer::Started { server, status, .. } = &mut self.server {
292            let task = match status {
293                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
294                    Task::ready(Ok(())).shared()
295                }
296                SignInStatus::SigningIn { task, .. } => {
297                    cx.notify();
298                    task.clone()
299                }
300                SignInStatus::SignedOut => {
301                    let server = server.clone();
302                    let task = cx
303                        .spawn(|this, mut cx| async move {
304                            let sign_in = async {
305                                let sign_in = server
306                                    .request::<request::SignInInitiate>(
307                                        request::SignInInitiateParams {},
308                                    )
309                                    .await?;
310                                match sign_in {
311                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
312                                        Ok(request::SignInStatus::Ok { user })
313                                    }
314                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
315                                        this.update(&mut cx, |this, cx| {
316                                            if let CopilotServer::Started { status, .. } =
317                                                &mut this.server
318                                            {
319                                                if let SignInStatus::SigningIn {
320                                                    prompt: prompt_flow,
321                                                    ..
322                                                } = status
323                                                {
324                                                    *prompt_flow = Some(flow.clone());
325                                                    cx.notify();
326                                                }
327                                            }
328                                        });
329                                        let response = server
330                                            .request::<request::SignInConfirm>(
331                                                request::SignInConfirmParams {
332                                                    user_code: flow.user_code,
333                                                },
334                                            )
335                                            .await?;
336                                        Ok(response)
337                                    }
338                                }
339                            };
340
341                            let sign_in = sign_in.await;
342                            this.update(&mut cx, |this, cx| match sign_in {
343                                Ok(status) => {
344                                    this.update_sign_in_status(status, cx);
345                                    Ok(())
346                                }
347                                Err(error) => {
348                                    this.update_sign_in_status(
349                                        request::SignInStatus::NotSignedIn,
350                                        cx,
351                                    );
352                                    Err(Arc::new(error))
353                                }
354                            })
355                        })
356                        .shared();
357                    *status = SignInStatus::SigningIn {
358                        prompt: None,
359                        task: task.clone(),
360                    };
361                    cx.notify();
362                    task
363                }
364            };
365
366            cx.foreground()
367                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
368        } else {
369            // If we're downloading, wait until download is finished
370            // If we're in a stuck state, display to the user
371            Task::ready(Err(anyhow!("copilot hasn't started yet")))
372        }
373    }
374
375    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
376        if let CopilotServer::Started { server, status, .. } = &mut self.server {
377            *status = SignInStatus::SignedOut;
378            cx.notify();
379
380            let server = server.clone();
381            cx.background().spawn(async move {
382                server
383                    .request::<request::SignOut>(request::SignOutParams {})
384                    .await?;
385                anyhow::Ok(())
386            })
387        } else {
388            Task::ready(Err(anyhow!("copilot hasn't started yet")))
389        }
390    }
391
392    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
393        let start_task = cx
394            .spawn({
395                let http = self.http.clone();
396                let node_runtime = self.node_runtime.clone();
397                move |this, cx| async move {
398                    clear_copilot_dir().await;
399                    Self::start_language_server(http, node_runtime, this, cx).await
400                }
401            })
402            .shared();
403
404        self.server = CopilotServer::Starting {
405            task: start_task.clone(),
406        };
407
408        cx.notify();
409
410        cx.foreground().spawn(start_task)
411    }
412
413    pub fn completions<T>(
414        &mut self,
415        buffer: &ModelHandle<Buffer>,
416        position: T,
417        cx: &mut ModelContext<Self>,
418    ) -> Task<Result<Vec<Completion>>>
419    where
420        T: ToPointUtf16,
421    {
422        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
423    }
424
425    pub fn completions_cycling<T>(
426        &mut self,
427        buffer: &ModelHandle<Buffer>,
428        position: T,
429        cx: &mut ModelContext<Self>,
430    ) -> Task<Result<Vec<Completion>>>
431    where
432        T: ToPointUtf16,
433    {
434        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
435    }
436
437    fn request_completions<R, T>(
438        &mut self,
439        buffer: &ModelHandle<Buffer>,
440        position: T,
441        cx: &mut ModelContext<Self>,
442    ) -> Task<Result<Vec<Completion>>>
443    where
444        R: lsp::request::Request<
445            Params = request::GetCompletionsParams,
446            Result = request::GetCompletionsResult,
447        >,
448        T: ToPointUtf16,
449    {
450        let buffer_id = buffer.id();
451        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
452        let snapshot = buffer.read(cx).snapshot();
453        let server = match &mut self.server {
454            CopilotServer::Starting { .. } => {
455                return Task::ready(Err(anyhow!("copilot is still starting")))
456            }
457            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
458            CopilotServer::Error(error) => {
459                return Task::ready(Err(anyhow!(
460                    "copilot was not started because of an error: {}",
461                    error
462                )))
463            }
464            CopilotServer::Started {
465                server,
466                status,
467                subscriptions_by_buffer_id,
468            } => {
469                if matches!(status, SignInStatus::Authorized { .. }) {
470                    subscriptions_by_buffer_id
471                        .entry(buffer_id)
472                        .or_insert_with(|| {
473                            server
474                                .notify::<lsp::notification::DidOpenTextDocument>(
475                                    lsp::DidOpenTextDocumentParams {
476                                        text_document: lsp::TextDocumentItem {
477                                            uri: uri.clone(),
478                                            language_id: id_for_language(
479                                                buffer.read(cx).language(),
480                                            ),
481                                            version: 0,
482                                            text: snapshot.text(),
483                                        },
484                                    },
485                                )
486                                .log_err();
487
488                            let uri = uri.clone();
489                            cx.observe_release(buffer, move |this, _, _| {
490                                if let CopilotServer::Started {
491                                    server,
492                                    subscriptions_by_buffer_id,
493                                    ..
494                                } = &mut this.server
495                                {
496                                    server
497                                        .notify::<lsp::notification::DidCloseTextDocument>(
498                                            lsp::DidCloseTextDocumentParams {
499                                                text_document: lsp::TextDocumentIdentifier::new(
500                                                    uri.clone(),
501                                                ),
502                                            },
503                                        )
504                                        .log_err();
505                                    subscriptions_by_buffer_id.remove(&buffer_id);
506                                }
507                            })
508                        });
509
510                    server.clone()
511                } else {
512                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
513                }
514            }
515        };
516
517        let settings = cx.global::<Settings>();
518        let position = position.to_point_utf16(&snapshot);
519        let language = snapshot.language_at(position);
520        let language_name = language.map(|language| language.name());
521        let language_name = language_name.as_deref();
522        let tab_size = settings.tab_size(language_name);
523        let hard_tabs = settings.hard_tabs(language_name);
524        let language_id = id_for_language(language);
525
526        let path;
527        let relative_path;
528        if let Some(file) = snapshot.file() {
529            if let Some(file) = file.as_local() {
530                path = file.abs_path(cx);
531            } else {
532                path = file.full_path(cx);
533            }
534            relative_path = file.path().to_path_buf();
535        } else {
536            path = PathBuf::new();
537            relative_path = PathBuf::new();
538        }
539
540        cx.background().spawn(async move {
541            let result = server
542                .request::<R>(request::GetCompletionsParams {
543                    doc: request::GetCompletionsDocument {
544                        source: snapshot.text(),
545                        tab_size: tab_size.into(),
546                        indent_size: 1,
547                        insert_spaces: !hard_tabs,
548                        uri,
549                        path: path.to_string_lossy().into(),
550                        relative_path: relative_path.to_string_lossy().into(),
551                        language_id,
552                        position: point_to_lsp(position),
553                        version: 0,
554                    },
555                })
556                .await?;
557            let completions = result
558                .completions
559                .into_iter()
560                .map(|completion| {
561                    let start = snapshot
562                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
563                    let end =
564                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
565                    Completion {
566                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
567                        text: completion.text,
568                    }
569                })
570                .collect();
571            anyhow::Ok(completions)
572        })
573    }
574
575    pub fn status(&self) -> Status {
576        match &self.server {
577            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
578            CopilotServer::Disabled => Status::Disabled,
579            CopilotServer::Error(error) => Status::Error(error.clone()),
580            CopilotServer::Started { status, .. } => match status {
581                SignInStatus::Authorized { .. } => Status::Authorized,
582                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
583                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
584                    prompt: prompt.clone(),
585                },
586                SignInStatus::SignedOut => Status::SignedOut,
587            },
588        }
589    }
590
591    fn update_sign_in_status(
592        &mut self,
593        lsp_status: request::SignInStatus,
594        cx: &mut ModelContext<Self>,
595    ) {
596        if let CopilotServer::Started { status, .. } = &mut self.server {
597            *status = match lsp_status {
598                request::SignInStatus::Ok { .. }
599                | request::SignInStatus::MaybeOk { .. }
600                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
601                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
602                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
603            };
604            cx.notify();
605        }
606    }
607}
608
609fn id_for_language(language: Option<&Arc<Language>>) -> String {
610    let language_name = language.map(|language| language.name());
611    match language_name.as_deref() {
612        Some("Plain Text") => "plaintext".to_string(),
613        Some(language_name) => language_name.to_lowercase(),
614        None => "plaintext".to_string(),
615    }
616}
617
618async fn clear_copilot_dir() {
619    remove_matching(&paths::COPILOT_DIR, |_| true).await
620}
621
622async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
623    const SERVER_PATH: &'static str = "dist/agent.js";
624
625    ///Check for the latest copilot language server and download it if we haven't already
626    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
627        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
628
629        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
630
631        fs::create_dir_all(version_dir).await?;
632        let server_path = version_dir.join(SERVER_PATH);
633
634        if fs::metadata(&server_path).await.is_err() {
635            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
636            let dist_dir = version_dir.join("dist");
637            fs::create_dir_all(dist_dir.as_path()).await?;
638
639            let url = &release
640                .assets
641                .get(0)
642                .context("Github release for copilot contained no assets")?
643                .browser_download_url;
644
645            let mut response = http
646                .get(&url, Default::default(), true)
647                .await
648                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
649            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
650            let archive = Archive::new(decompressed_bytes);
651            archive.unpack(dist_dir).await?;
652
653            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
654        }
655
656        Ok(server_path)
657    }
658
659    match fetch_latest(http).await {
660        ok @ Result::Ok(..) => ok,
661        e @ Err(..) => {
662            e.log_err();
663            // Fetch a cached binary, if it exists
664            (|| async move {
665                let mut last_version_dir = None;
666                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
667                while let Some(entry) = entries.next().await {
668                    let entry = entry?;
669                    if entry.file_type().await?.is_dir() {
670                        last_version_dir = Some(entry.path());
671                    }
672                }
673                let last_version_dir =
674                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
675                let server_path = last_version_dir.join(SERVER_PATH);
676                if server_path.exists() {
677                    Ok(server_path)
678                } else {
679                    Err(anyhow!(
680                        "missing executable in directory {:?}",
681                        last_version_dir
682                    ))
683                }
684            })()
685            .await
686        }
687    }
688}