copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 14use lsp::LanguageServer;
 15use node_runtime::NodeRuntime;
 16use settings::Settings;
 17use smol::{fs, io::BufReader, stream::StreamExt};
 18use std::{
 19    ffi::OsString,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 25};
 26
 27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 28actions!(copilot_auth, [SignIn, SignOut]);
 29
 30const COPILOT_NAMESPACE: &'static str = "copilot";
 31actions!(
 32    copilot,
 33    [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
 34);
 35
 36pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 37    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 38    cx.set_global(copilot.clone());
 39    cx.add_global_action(|_: &SignIn, cx| {
 40        let copilot = Copilot::global(cx).unwrap();
 41        copilot
 42            .update(cx, |copilot, cx| copilot.sign_in(cx))
 43            .detach_and_log_err(cx);
 44    });
 45    cx.add_global_action(|_: &SignOut, cx| {
 46        let copilot = Copilot::global(cx).unwrap();
 47        copilot
 48            .update(cx, |copilot, cx| copilot.sign_out(cx))
 49            .detach_and_log_err(cx);
 50    });
 51
 52    cx.add_global_action(|_: &Reinstall, cx| {
 53        let copilot = Copilot::global(cx).unwrap();
 54        copilot
 55            .update(cx, |copilot, cx| copilot.reinstall(cx))
 56            .detach();
 57    });
 58
 59    cx.observe(&copilot, |handle, cx| {
 60        let status = handle.read(cx).status();
 61        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 62            move |filter, _cx| match status {
 63                Status::Disabled => {
 64                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 65                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 66                }
 67                Status::Authorized => {
 68                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 69                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 70                }
 71                _ => {
 72                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 73                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 74                }
 75            },
 76        );
 77    })
 78    .detach();
 79
 80    sign_in::init(cx);
 81}
 82
 83enum CopilotServer {
 84    Disabled,
 85    Starting {
 86        task: Shared<Task<()>>,
 87    },
 88    Error(Arc<str>),
 89    Started {
 90        server: Arc<LanguageServer>,
 91        status: SignInStatus,
 92    },
 93}
 94
 95#[derive(Clone, Debug)]
 96enum SignInStatus {
 97    Authorized {
 98        _user: String,
 99    },
100    Unauthorized {
101        _user: String,
102    },
103    SigningIn {
104        prompt: Option<request::PromptUserDeviceFlow>,
105        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
106    },
107    SignedOut,
108}
109
110#[derive(Debug, Clone)]
111pub enum Status {
112    Starting {
113        task: Shared<Task<()>>,
114    },
115    Error(Arc<str>),
116    Disabled,
117    SignedOut,
118    SigningIn {
119        prompt: Option<request::PromptUserDeviceFlow>,
120    },
121    Unauthorized,
122    Authorized,
123}
124
125impl Status {
126    pub fn is_authorized(&self) -> bool {
127        matches!(self, Status::Authorized)
128    }
129}
130
131#[derive(Debug, PartialEq, Eq)]
132pub struct Completion {
133    pub position: Anchor,
134    pub text: String,
135}
136
137pub struct Copilot {
138    http: Arc<dyn HttpClient>,
139    node_runtime: Arc<NodeRuntime>,
140    server: CopilotServer,
141}
142
143impl Entity for Copilot {
144    type Event = ();
145}
146
147impl Copilot {
148    pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
149        match self.server {
150            CopilotServer::Starting { ref task } => Some(task.clone()),
151            _ => None,
152        }
153    }
154
155    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
156        if cx.has_global::<ModelHandle<Self>>() {
157            Some(cx.global::<ModelHandle<Self>>().clone())
158        } else {
159            None
160        }
161    }
162
163    fn start(
164        http: Arc<dyn HttpClient>,
165        node_runtime: Arc<NodeRuntime>,
166        cx: &mut ModelContext<Self>,
167    ) -> Self {
168        cx.observe_global::<Settings, _>({
169            let http = http.clone();
170            let node_runtime = node_runtime.clone();
171            move |this, cx| {
172                if cx.global::<Settings>().enable_copilot_integration {
173                    if matches!(this.server, CopilotServer::Disabled) {
174                        let start_task = cx
175                            .spawn({
176                                let http = http.clone();
177                                let node_runtime = node_runtime.clone();
178                                move |this, cx| {
179                                    Self::start_language_server(http, node_runtime, this, cx)
180                                }
181                            })
182                            .shared();
183                        this.server = CopilotServer::Starting { task: start_task };
184                        cx.notify();
185                    }
186                } else {
187                    this.server = CopilotServer::Disabled;
188                    cx.notify();
189                }
190            }
191        })
192        .detach();
193
194        if cx.global::<Settings>().enable_copilot_integration {
195            let start_task = cx
196                .spawn({
197                    let http = http.clone();
198                    let node_runtime = node_runtime.clone();
199                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
200                })
201                .shared();
202
203            Self {
204                http,
205                node_runtime,
206                server: CopilotServer::Starting { task: start_task },
207            }
208        } else {
209            Self {
210                http,
211                node_runtime,
212                server: CopilotServer::Disabled,
213            }
214        }
215    }
216
217    fn start_language_server(
218        http: Arc<dyn HttpClient>,
219        node_runtime: Arc<NodeRuntime>,
220        this: ModelHandle<Self>,
221        mut cx: AsyncAppContext,
222    ) -> impl Future<Output = ()> {
223        async move {
224            let start_language_server = async {
225                let server_path = get_copilot_lsp(http).await?;
226                let node_path = node_runtime.binary_path().await?;
227                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
228                let server =
229                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
230
231                let server = server.initialize(Default::default()).await?;
232                let status = server
233                    .request::<request::CheckStatus>(request::CheckStatusParams {
234                        local_checks_only: false,
235                    })
236                    .await?;
237                anyhow::Ok((server, status))
238            };
239
240            let server = start_language_server.await;
241            this.update(&mut cx, |this, cx| {
242                cx.notify();
243                match server {
244                    Ok((server, status)) => {
245                        this.server = CopilotServer::Started {
246                            server,
247                            status: SignInStatus::SignedOut,
248                        };
249                        this.update_sign_in_status(status, cx);
250                    }
251                    Err(error) => {
252                        this.server = CopilotServer::Error(error.to_string().into());
253                        cx.notify()
254                    }
255                }
256            })
257        }
258    }
259
260    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
261        if let CopilotServer::Started { server, status } = &mut self.server {
262            let task = match status {
263                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
264                    Task::ready(Ok(())).shared()
265                }
266                SignInStatus::SigningIn { task, .. } => {
267                    cx.notify();
268                    task.clone()
269                }
270                SignInStatus::SignedOut => {
271                    let server = server.clone();
272                    let task = cx
273                        .spawn(|this, mut cx| async move {
274                            let sign_in = async {
275                                let sign_in = server
276                                    .request::<request::SignInInitiate>(
277                                        request::SignInInitiateParams {},
278                                    )
279                                    .await?;
280                                match sign_in {
281                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
282                                        Ok(request::SignInStatus::Ok { user })
283                                    }
284                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
285                                        this.update(&mut cx, |this, cx| {
286                                            if let CopilotServer::Started { status, .. } =
287                                                &mut this.server
288                                            {
289                                                if let SignInStatus::SigningIn {
290                                                    prompt: prompt_flow,
291                                                    ..
292                                                } = status
293                                                {
294                                                    *prompt_flow = Some(flow.clone());
295                                                    cx.notify();
296                                                }
297                                            }
298                                        });
299                                        let response = server
300                                            .request::<request::SignInConfirm>(
301                                                request::SignInConfirmParams {
302                                                    user_code: flow.user_code,
303                                                },
304                                            )
305                                            .await?;
306                                        Ok(response)
307                                    }
308                                }
309                            };
310
311                            let sign_in = sign_in.await;
312                            this.update(&mut cx, |this, cx| match sign_in {
313                                Ok(status) => {
314                                    this.update_sign_in_status(status, cx);
315                                    Ok(())
316                                }
317                                Err(error) => {
318                                    this.update_sign_in_status(
319                                        request::SignInStatus::NotSignedIn,
320                                        cx,
321                                    );
322                                    Err(Arc::new(error))
323                                }
324                            })
325                        })
326                        .shared();
327                    *status = SignInStatus::SigningIn {
328                        prompt: None,
329                        task: task.clone(),
330                    };
331                    cx.notify();
332                    task
333                }
334            };
335
336            cx.foreground()
337                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
338        } else {
339            // If we're downloading, wait until download is finished
340            // If we're in a stuck state, display to the user
341            Task::ready(Err(anyhow!("copilot hasn't started yet")))
342        }
343    }
344
345    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
346        if let CopilotServer::Started { server, status } = &mut self.server {
347            *status = SignInStatus::SignedOut;
348            cx.notify();
349
350            let server = server.clone();
351            cx.background().spawn(async move {
352                server
353                    .request::<request::SignOut>(request::SignOutParams {})
354                    .await?;
355                anyhow::Ok(())
356            })
357        } else {
358            Task::ready(Err(anyhow!("copilot hasn't started yet")))
359        }
360    }
361
362    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
363        let start_task = cx
364            .spawn({
365                let http = self.http.clone();
366                let node_runtime = self.node_runtime.clone();
367                move |this, cx| async move {
368                    clear_copilot_dir().await;
369                    Self::start_language_server(http, node_runtime, this, cx).await
370                }
371            })
372            .shared();
373
374        self.server = CopilotServer::Starting {
375            task: start_task.clone(),
376        };
377
378        cx.notify();
379
380        cx.foreground().spawn(start_task)
381    }
382
383    pub fn completion<T>(
384        &self,
385        buffer: &ModelHandle<Buffer>,
386        position: T,
387        cx: &mut ModelContext<Self>,
388    ) -> Task<Result<Option<Completion>>>
389    where
390        T: ToPointUtf16,
391    {
392        let server = match self.authorized_server() {
393            Ok(server) => server,
394            Err(error) => return Task::ready(Err(error)),
395        };
396
397        let buffer = buffer.read(cx).snapshot();
398        let request = server
399            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
400        cx.background().spawn(async move {
401            let result = request.await?;
402            let completion = result
403                .completions
404                .into_iter()
405                .next()
406                .map(|completion| completion_from_lsp(completion, &buffer));
407            anyhow::Ok(completion)
408        })
409    }
410
411    pub fn completions_cycling<T>(
412        &self,
413        buffer: &ModelHandle<Buffer>,
414        position: T,
415        cx: &mut ModelContext<Self>,
416    ) -> Task<Result<Vec<Completion>>>
417    where
418        T: ToPointUtf16,
419    {
420        let server = match self.authorized_server() {
421            Ok(server) => server,
422            Err(error) => return Task::ready(Err(error)),
423        };
424
425        let buffer = buffer.read(cx).snapshot();
426        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
427            &buffer, position, cx,
428        ));
429        cx.background().spawn(async move {
430            let result = request.await?;
431            let completions = result
432                .completions
433                .into_iter()
434                .map(|completion| completion_from_lsp(completion, &buffer))
435                .collect();
436            anyhow::Ok(completions)
437        })
438    }
439
440    pub fn status(&self) -> Status {
441        match &self.server {
442            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
443            CopilotServer::Disabled => Status::Disabled,
444            CopilotServer::Error(error) => Status::Error(error.clone()),
445            CopilotServer::Started { status, .. } => match status {
446                SignInStatus::Authorized { .. } => Status::Authorized,
447                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
448                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
449                    prompt: prompt.clone(),
450                },
451                SignInStatus::SignedOut => Status::SignedOut,
452            },
453        }
454    }
455
456    fn update_sign_in_status(
457        &mut self,
458        lsp_status: request::SignInStatus,
459        cx: &mut ModelContext<Self>,
460    ) {
461        if let CopilotServer::Started { status, .. } = &mut self.server {
462            *status = match lsp_status {
463                request::SignInStatus::Ok { user }
464                | request::SignInStatus::MaybeOk { user }
465                | request::SignInStatus::AlreadySignedIn { user } => {
466                    SignInStatus::Authorized { _user: user }
467                }
468                request::SignInStatus::NotAuthorized { user } => {
469                    SignInStatus::Unauthorized { _user: user }
470                }
471                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
472            };
473            cx.notify();
474        }
475    }
476
477    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
478        match &self.server {
479            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
480            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
481            CopilotServer::Error(error) => Err(anyhow!(
482                "copilot was not started because of an error: {}",
483                error
484            )),
485            CopilotServer::Started { server, status } => {
486                if matches!(status, SignInStatus::Authorized { .. }) {
487                    Ok(server.clone())
488                } else {
489                    Err(anyhow!("must sign in before using copilot"))
490                }
491            }
492        }
493    }
494}
495
496fn build_completion_params<T>(
497    buffer: &BufferSnapshot,
498    position: T,
499    cx: &AppContext,
500) -> request::GetCompletionsParams
501where
502    T: ToPointUtf16,
503{
504    let position = position.to_point_utf16(&buffer);
505    let language_name = buffer.language_at(position).map(|language| language.name());
506    let language_name = language_name.as_deref();
507
508    let path;
509    let relative_path;
510    if let Some(file) = buffer.file() {
511        if let Some(file) = file.as_local() {
512            path = file.abs_path(cx);
513        } else {
514            path = file.full_path(cx);
515        }
516        relative_path = file.path().to_path_buf();
517    } else {
518        path = PathBuf::from("/untitled");
519        relative_path = PathBuf::from("untitled");
520    }
521
522    let settings = cx.global::<Settings>();
523    let language_id = match language_name {
524        Some("Plain Text") => "plaintext".to_string(),
525        Some(language_name) => language_name.to_lowercase(),
526        None => "plaintext".to_string(),
527    };
528    request::GetCompletionsParams {
529        doc: request::GetCompletionsDocument {
530            source: buffer.text(),
531            tab_size: settings.tab_size(language_name).into(),
532            indent_size: 1,
533            insert_spaces: !settings.hard_tabs(language_name),
534            uri: lsp::Url::from_file_path(&path).unwrap(),
535            path: path.to_string_lossy().into(),
536            relative_path: relative_path.to_string_lossy().into(),
537            language_id,
538            position: point_to_lsp(position),
539            version: 0,
540        },
541    }
542}
543
544fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
545    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
546    Completion {
547        position: buffer.anchor_before(position),
548        text: completion.display_text,
549    }
550}
551
552async fn clear_copilot_dir() {
553    remove_matching(&paths::COPILOT_DIR, |_| true).await
554}
555
556async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
557    const SERVER_PATH: &'static str = "dist/agent.js";
558
559    ///Check for the latest copilot language server and download it if we haven't already
560    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
561        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
562
563        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
564
565        fs::create_dir_all(version_dir).await?;
566        let server_path = version_dir.join(SERVER_PATH);
567
568        if fs::metadata(&server_path).await.is_err() {
569            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
570            let dist_dir = version_dir.join("dist");
571            fs::create_dir_all(dist_dir.as_path()).await?;
572
573            let url = &release
574                .assets
575                .get(0)
576                .context("Github release for copilot contained no assets")?
577                .browser_download_url;
578
579            let mut response = http
580                .get(&url, Default::default(), true)
581                .await
582                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
583            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
584            let archive = Archive::new(decompressed_bytes);
585            archive.unpack(dist_dir).await?;
586
587            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
588        }
589
590        Ok(server_path)
591    }
592
593    match fetch_latest(http).await {
594        ok @ Result::Ok(..) => ok,
595        e @ Err(..) => {
596            e.log_err();
597            // Fetch a cached binary, if it exists
598            (|| async move {
599                let mut last_version_dir = None;
600                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
601                while let Some(entry) = entries.next().await {
602                    let entry = entry?;
603                    if entry.file_type().await?.is_dir() {
604                        last_version_dir = Some(entry.path());
605                    }
606                }
607                let last_version_dir =
608                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
609                let server_path = last_version_dir.join(SERVER_PATH);
610                if server_path.exists() {
611                    Ok(server_path)
612                } else {
613                    Err(anyhow!(
614                        "missing executable in directory {:?}",
615                        last_version_dir
616                    ))
617                }
618            })()
619            .await
620        }
621    }
622}