copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 14use lsp::LanguageServer;
 15use node_runtime::NodeRuntime;
 16use settings::Settings;
 17use smol::{fs, io::BufReader, stream::StreamExt};
 18use std::{
 19    ffi::OsString,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 25};
 26
 27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 28actions!(copilot_auth, [SignIn, SignOut]);
 29
 30const COPILOT_NAMESPACE: &'static str = "copilot";
 31actions!(
 32    copilot,
 33    [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
 34);
 35
 36pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 37    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 38    cx.set_global(copilot.clone());
 39    cx.add_global_action(|_: &SignIn, cx| {
 40        let copilot = Copilot::global(cx).unwrap();
 41        copilot
 42            .update(cx, |copilot, cx| copilot.sign_in(cx))
 43            .detach_and_log_err(cx);
 44    });
 45    cx.add_global_action(|_: &SignOut, cx| {
 46        let copilot = Copilot::global(cx).unwrap();
 47        copilot
 48            .update(cx, |copilot, cx| copilot.sign_out(cx))
 49            .detach_and_log_err(cx);
 50    });
 51
 52    cx.add_global_action(|_: &Reinstall, cx| {
 53        let copilot = Copilot::global(cx).unwrap();
 54        copilot
 55            .update(cx, |copilot, cx| copilot.reinstall(cx))
 56            .detach();
 57    });
 58
 59    cx.observe(&copilot, |handle, cx| {
 60        let status = handle.read(cx).status();
 61        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 62            move |filter, _cx| match status {
 63                Status::Disabled => {
 64                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 65                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 66                }
 67                Status::Authorized => {
 68                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 69                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 70                }
 71                _ => {
 72                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 73                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 74                }
 75            },
 76        );
 77    })
 78    .detach();
 79
 80    sign_in::init(cx);
 81}
 82
 83enum CopilotServer {
 84    Disabled,
 85    Starting {
 86        task: Shared<Task<()>>,
 87    },
 88    Error(Arc<str>),
 89    Started {
 90        server: Arc<LanguageServer>,
 91        status: SignInStatus,
 92    },
 93}
 94
 95#[derive(Clone, Debug)]
 96enum SignInStatus {
 97    Authorized {
 98        _user: String,
 99    },
100    Unauthorized {
101        _user: String,
102    },
103    SigningIn {
104        prompt: Option<request::PromptUserDeviceFlow>,
105        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
106    },
107    SignedOut,
108}
109
110#[derive(Debug, Clone)]
111pub enum Status {
112    Starting {
113        task: Shared<Task<()>>,
114    },
115    Error(Arc<str>),
116    Disabled,
117    SignedOut,
118    SigningIn {
119        prompt: Option<request::PromptUserDeviceFlow>,
120    },
121    Unauthorized,
122    Authorized,
123}
124
125impl Status {
126    pub fn is_authorized(&self) -> bool {
127        matches!(self, Status::Authorized)
128    }
129}
130
131#[derive(Debug, PartialEq, Eq)]
132pub struct Completion {
133    pub position: Anchor,
134    pub text: String,
135}
136
137pub struct Copilot {
138    http: Arc<dyn HttpClient>,
139    node_runtime: Arc<NodeRuntime>,
140    server: CopilotServer,
141}
142
143impl Entity for Copilot {
144    type Event = ();
145}
146
147impl Copilot {
148    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
149        match self.server {
150            CopilotServer::Starting { ref task } => Some(task.clone()),
151            _ => None,
152        }
153    }
154
155    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
156        if cx.has_global::<ModelHandle<Self>>() {
157            Some(cx.global::<ModelHandle<Self>>().clone())
158        } else {
159            None
160        }
161    }
162
163    fn start(
164        http: Arc<dyn HttpClient>,
165        node_runtime: Arc<NodeRuntime>,
166        cx: &mut ModelContext<Self>,
167    ) -> Self {
168        cx.observe_global::<Settings, _>({
169            let http = http.clone();
170            let node_runtime = node_runtime.clone();
171            move |this, cx| {
172                if cx.global::<Settings>().enable_copilot_integration {
173                    if matches!(this.server, CopilotServer::Disabled) {
174                        let start_task = cx
175                            .spawn({
176                                let http = http.clone();
177                                let node_runtime = node_runtime.clone();
178                                move |this, cx| {
179                                    Self::start_language_server(http, node_runtime, this, cx)
180                                }
181                            })
182                            .shared();
183                        this.server = CopilotServer::Starting { task: start_task };
184                        cx.notify();
185                    }
186                } else {
187                    this.server = CopilotServer::Disabled;
188                    cx.notify();
189                }
190            }
191        })
192        .detach();
193
194        if cx.global::<Settings>().enable_copilot_integration {
195            let start_task = cx
196                .spawn({
197                    let http = http.clone();
198                    let node_runtime = node_runtime.clone();
199                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
200                })
201                .shared();
202
203            Self {
204                http,
205                node_runtime,
206                server: CopilotServer::Starting { task: start_task },
207            }
208        } else {
209            Self {
210                http,
211                node_runtime,
212                server: CopilotServer::Disabled,
213            }
214        }
215    }
216
217    fn start_language_server(
218        http: Arc<dyn HttpClient>,
219        node_runtime: Arc<NodeRuntime>,
220        this: ModelHandle<Self>,
221        mut cx: AsyncAppContext,
222    ) -> impl Future<Output = ()> {
223        async move {
224            let start_language_server = async {
225                let server_path = get_copilot_lsp(http).await?;
226                let node_path = node_runtime.binary_path().await?;
227                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
228                let server = LanguageServer::new(
229                    0,
230                    &node_path,
231                    arguments,
232                    Path::new("/"),
233                    None,
234                    cx.clone(),
235                )?;
236
237                let server = server.initialize(Default::default()).await?;
238                let status = server
239                    .request::<request::CheckStatus>(request::CheckStatusParams {
240                        local_checks_only: false,
241                    })
242                    .await?;
243                anyhow::Ok((server, status))
244            };
245
246            let server = start_language_server.await;
247            this.update(&mut cx, |this, cx| {
248                cx.notify();
249                match server {
250                    Ok((server, status)) => {
251                        this.server = CopilotServer::Started {
252                            server,
253                            status: SignInStatus::SignedOut,
254                        };
255                        this.update_sign_in_status(status, cx);
256                    }
257                    Err(error) => {
258                        this.server = CopilotServer::Error(error.to_string().into());
259                        cx.notify()
260                    }
261                }
262            })
263        }
264    }
265
266    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
267        if let CopilotServer::Started { server, status } = &mut self.server {
268            let task = match status {
269                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
270                    Task::ready(Ok(())).shared()
271                }
272                SignInStatus::SigningIn { task, .. } => {
273                    cx.notify();
274                    task.clone()
275                }
276                SignInStatus::SignedOut => {
277                    let server = server.clone();
278                    let task = cx
279                        .spawn(|this, mut cx| async move {
280                            let sign_in = async {
281                                let sign_in = server
282                                    .request::<request::SignInInitiate>(
283                                        request::SignInInitiateParams {},
284                                    )
285                                    .await?;
286                                match sign_in {
287                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
288                                        Ok(request::SignInStatus::Ok { user })
289                                    }
290                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
291                                        this.update(&mut cx, |this, cx| {
292                                            if let CopilotServer::Started { status, .. } =
293                                                &mut this.server
294                                            {
295                                                if let SignInStatus::SigningIn {
296                                                    prompt: prompt_flow,
297                                                    ..
298                                                } = status
299                                                {
300                                                    *prompt_flow = Some(flow.clone());
301                                                    cx.notify();
302                                                }
303                                            }
304                                        });
305                                        let response = server
306                                            .request::<request::SignInConfirm>(
307                                                request::SignInConfirmParams {
308                                                    user_code: flow.user_code,
309                                                },
310                                            )
311                                            .await?;
312                                        Ok(response)
313                                    }
314                                }
315                            };
316
317                            let sign_in = sign_in.await;
318                            this.update(&mut cx, |this, cx| match sign_in {
319                                Ok(status) => {
320                                    this.update_sign_in_status(status, cx);
321                                    Ok(())
322                                }
323                                Err(error) => {
324                                    this.update_sign_in_status(
325                                        request::SignInStatus::NotSignedIn,
326                                        cx,
327                                    );
328                                    Err(Arc::new(error))
329                                }
330                            })
331                        })
332                        .shared();
333                    *status = SignInStatus::SigningIn {
334                        prompt: None,
335                        task: task.clone(),
336                    };
337                    cx.notify();
338                    task
339                }
340            };
341
342            cx.foreground()
343                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
344        } else {
345            // If we're downloading, wait until download is finished
346            // If we're in a stuck state, display to the user
347            Task::ready(Err(anyhow!("copilot hasn't started yet")))
348        }
349    }
350
351    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
352        if let CopilotServer::Started { server, status } = &mut self.server {
353            *status = SignInStatus::SignedOut;
354            cx.notify();
355
356            let server = server.clone();
357            cx.background().spawn(async move {
358                server
359                    .request::<request::SignOut>(request::SignOutParams {})
360                    .await?;
361                anyhow::Ok(())
362            })
363        } else {
364            Task::ready(Err(anyhow!("copilot hasn't started yet")))
365        }
366    }
367
368    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
369        let start_task = cx
370            .spawn({
371                let http = self.http.clone();
372                let node_runtime = self.node_runtime.clone();
373                move |this, cx| async move {
374                    clear_copilot_dir().await;
375                    Self::start_language_server(http, node_runtime, this, cx).await
376                }
377            })
378            .shared();
379
380        self.server = CopilotServer::Starting {
381            task: start_task.clone(),
382        };
383
384        cx.notify();
385
386        cx.foreground().spawn(start_task)
387    }
388
389    pub fn completion<T>(
390        &self,
391        buffer: &ModelHandle<Buffer>,
392        position: T,
393        cx: &mut ModelContext<Self>,
394    ) -> Task<Result<Option<Completion>>>
395    where
396        T: ToPointUtf16,
397    {
398        let server = match self.authorized_server() {
399            Ok(server) => server,
400            Err(error) => return Task::ready(Err(error)),
401        };
402
403        let buffer = buffer.read(cx).snapshot();
404        let request = server
405            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
406        cx.background().spawn(async move {
407            let result = request.await?;
408            let completion = result
409                .completions
410                .into_iter()
411                .next()
412                .map(|completion| completion_from_lsp(completion, &buffer));
413            anyhow::Ok(completion)
414        })
415    }
416
417    pub fn completions_cycling<T>(
418        &self,
419        buffer: &ModelHandle<Buffer>,
420        position: T,
421        cx: &mut ModelContext<Self>,
422    ) -> Task<Result<Vec<Completion>>>
423    where
424        T: ToPointUtf16,
425    {
426        let server = match self.authorized_server() {
427            Ok(server) => server,
428            Err(error) => return Task::ready(Err(error)),
429        };
430
431        let buffer = buffer.read(cx).snapshot();
432        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
433            &buffer, position, cx,
434        ));
435        cx.background().spawn(async move {
436            let result = request.await?;
437            let completions = result
438                .completions
439                .into_iter()
440                .map(|completion| completion_from_lsp(completion, &buffer))
441                .collect();
442            anyhow::Ok(completions)
443        })
444    }
445
446    pub fn status(&self) -> Status {
447        match &self.server {
448            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
449            CopilotServer::Disabled => Status::Disabled,
450            CopilotServer::Error(error) => Status::Error(error.clone()),
451            CopilotServer::Started { status, .. } => match status {
452                SignInStatus::Authorized { .. } => Status::Authorized,
453                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
454                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
455                    prompt: prompt.clone(),
456                },
457                SignInStatus::SignedOut => Status::SignedOut,
458            },
459        }
460    }
461
462    fn update_sign_in_status(
463        &mut self,
464        lsp_status: request::SignInStatus,
465        cx: &mut ModelContext<Self>,
466    ) {
467        if let CopilotServer::Started { status, .. } = &mut self.server {
468            *status = match lsp_status {
469                request::SignInStatus::Ok { user }
470                | request::SignInStatus::MaybeOk { user }
471                | request::SignInStatus::AlreadySignedIn { user } => {
472                    SignInStatus::Authorized { _user: user }
473                }
474                request::SignInStatus::NotAuthorized { user } => {
475                    SignInStatus::Unauthorized { _user: user }
476                }
477                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
478            };
479            cx.notify();
480        }
481    }
482
483    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
484        match &self.server {
485            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
486            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
487            CopilotServer::Error(error) => Err(anyhow!(
488                "copilot was not started because of an error: {}",
489                error
490            )),
491            CopilotServer::Started { server, status } => {
492                if matches!(status, SignInStatus::Authorized { .. }) {
493                    Ok(server.clone())
494                } else {
495                    Err(anyhow!("must sign in before using copilot"))
496                }
497            }
498        }
499    }
500}
501
502fn build_completion_params<T>(
503    buffer: &BufferSnapshot,
504    position: T,
505    cx: &AppContext,
506) -> request::GetCompletionsParams
507where
508    T: ToPointUtf16,
509{
510    let position = position.to_point_utf16(&buffer);
511    let language_name = buffer.language_at(position).map(|language| language.name());
512    let language_name = language_name.as_deref();
513
514    let path;
515    let relative_path;
516    if let Some(file) = buffer.file() {
517        if let Some(file) = file.as_local() {
518            path = file.abs_path(cx);
519        } else {
520            path = file.full_path(cx);
521        }
522        relative_path = file.path().to_path_buf();
523    } else {
524        path = PathBuf::from("/untitled");
525        relative_path = PathBuf::from("untitled");
526    }
527
528    let settings = cx.global::<Settings>();
529    let language_id = match language_name {
530        Some("Plain Text") => "plaintext".to_string(),
531        Some(language_name) => language_name.to_lowercase(),
532        None => "plaintext".to_string(),
533    };
534    request::GetCompletionsParams {
535        doc: request::GetCompletionsDocument {
536            source: buffer.text(),
537            tab_size: settings.tab_size(language_name).into(),
538            indent_size: 1,
539            insert_spaces: !settings.hard_tabs(language_name),
540            uri: lsp::Url::from_file_path(&path).unwrap(),
541            path: path.to_string_lossy().into(),
542            relative_path: relative_path.to_string_lossy().into(),
543            language_id,
544            position: point_to_lsp(position),
545            version: 0,
546        },
547    }
548}
549
550fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
551    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
552    Completion {
553        position: buffer.anchor_before(position),
554        text: completion.display_text,
555    }
556}
557
558async fn clear_copilot_dir() {
559    remove_matching(&paths::COPILOT_DIR, |_| true).await
560}
561
562async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
563    const SERVER_PATH: &'static str = "dist/agent.js";
564
565    ///Check for the latest copilot language server and download it if we haven't already
566    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
567        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
568
569        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
570
571        fs::create_dir_all(version_dir).await?;
572        let server_path = version_dir.join(SERVER_PATH);
573
574        if fs::metadata(&server_path).await.is_err() {
575            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
576            let dist_dir = version_dir.join("dist");
577            fs::create_dir_all(dist_dir.as_path()).await?;
578
579            let url = &release
580                .assets
581                .get(0)
582                .context("Github release for copilot contained no assets")?
583                .browser_download_url;
584
585            let mut response = http
586                .get(&url, Default::default(), true)
587                .await
588                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
589            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
590            let archive = Archive::new(decompressed_bytes);
591            archive.unpack(dist_dir).await?;
592
593            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
594        }
595
596        Ok(server_path)
597    }
598
599    match fetch_latest(http).await {
600        ok @ Result::Ok(..) => ok,
601        e @ Err(..) => {
602            e.log_err();
603            // Fetch a cached binary, if it exists
604            (|| async move {
605                let mut last_version_dir = None;
606                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
607                while let Some(entry) = entries.next().await {
608                    let entry = entry?;
609                    if entry.file_type().await?.is_dir() {
610                        last_version_dir = Some(entry.path());
611                    }
612                }
613                let last_version_dir =
614                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
615                let server_path = last_version_dir.join(SERVER_PATH);
616                if server_path.exists() {
617                    Ok(server_path)
618                } else {
619                    Err(anyhow!(
620                        "missing executable in directory {:?}",
621                        last_version_dir
622                    ))
623                }
624            })()
625            .await
626        }
627    }
628}