copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  9use gpui::{
 10    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 11    Task,
 12};
 13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 14use lsp::LanguageServer;
 15use node_runtime::NodeRuntime;
 16use settings::Settings;
 17use smol::{fs, io::BufReader, stream::StreamExt};
 18use std::{
 19    ffi::OsString,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{
 24    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 25};
 26
 27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 28actions!(copilot_auth, [SignIn, SignOut]);
 29
 30const COPILOT_NAMESPACE: &'static str = "copilot";
 31actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
 32
 33pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 34    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 35    cx.set_global(copilot.clone());
 36    cx.add_global_action(|_: &SignIn, cx| {
 37        let copilot = Copilot::global(cx).unwrap();
 38        copilot
 39            .update(cx, |copilot, cx| copilot.sign_in(cx))
 40            .detach_and_log_err(cx);
 41    });
 42    cx.add_global_action(|_: &SignOut, cx| {
 43        let copilot = Copilot::global(cx).unwrap();
 44        copilot
 45            .update(cx, |copilot, cx| copilot.sign_out(cx))
 46            .detach_and_log_err(cx);
 47    });
 48
 49    cx.observe(&copilot, |handle, cx| {
 50        let status = handle.read(cx).status();
 51        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 52            move |filter, _cx| match status {
 53                Status::Disabled => {
 54                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 55                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 56                }
 57                Status::Authorized => {
 58                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 60                }
 61                _ => {
 62                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 63                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 64                }
 65            },
 66        );
 67    })
 68    .detach();
 69
 70    sign_in::init(cx);
 71}
 72
 73enum CopilotServer {
 74    Disabled,
 75    Starting {
 76        _task: Shared<Task<()>>,
 77    },
 78    Error(Arc<str>),
 79    Started {
 80        server: Arc<LanguageServer>,
 81        status: SignInStatus,
 82    },
 83}
 84
 85#[derive(Clone, Debug)]
 86enum SignInStatus {
 87    Authorized {
 88        _user: String,
 89    },
 90    Unauthorized {
 91        _user: String,
 92    },
 93    SigningIn {
 94        prompt: Option<request::PromptUserDeviceFlow>,
 95        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 96    },
 97    SignedOut,
 98}
 99
100#[derive(Debug, PartialEq, Eq)]
101pub enum Status {
102    Starting,
103    Error(Arc<str>),
104    Disabled,
105    SignedOut,
106    SigningIn {
107        prompt: Option<request::PromptUserDeviceFlow>,
108    },
109    Unauthorized,
110    Authorized,
111}
112
113impl Status {
114    pub fn is_authorized(&self) -> bool {
115        matches!(self, Status::Authorized)
116    }
117}
118
119#[derive(Debug, PartialEq, Eq)]
120pub struct Completion {
121    pub position: Anchor,
122    pub text: String,
123}
124
125pub struct Copilot {
126    server: CopilotServer,
127}
128
129impl Entity for Copilot {
130    type Event = ();
131}
132
133impl Copilot {
134    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
135        if cx.has_global::<ModelHandle<Self>>() {
136            Some(cx.global::<ModelHandle<Self>>().clone())
137        } else {
138            None
139        }
140    }
141
142    fn start(
143        http: Arc<dyn HttpClient>,
144        node_runtime: Arc<NodeRuntime>,
145        cx: &mut ModelContext<Self>,
146    ) -> Self {
147        cx.observe_global::<Settings, _>({
148            let http = http.clone();
149            let node_runtime = node_runtime.clone();
150            move |this, cx| {
151                if cx.global::<Settings>().enable_copilot_integration {
152                    if matches!(this.server, CopilotServer::Disabled) {
153                        let start_task = cx
154                            .spawn({
155                                let http = http.clone();
156                                let node_runtime = node_runtime.clone();
157                                move |this, cx| {
158                                    Self::start_language_server(http, node_runtime, this, cx)
159                                }
160                            })
161                            .shared();
162                        this.server = CopilotServer::Starting { _task: start_task }
163                    }
164                } else {
165                    this.server = CopilotServer::Disabled
166                }
167            }
168        })
169        .detach();
170
171        if cx.global::<Settings>().enable_copilot_integration {
172            let start_task = cx
173                .spawn({
174                    let http = http.clone();
175                    let node_runtime = node_runtime.clone();
176                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
177                })
178                .shared();
179
180            Self {
181                server: CopilotServer::Starting { _task: start_task },
182            }
183        } else {
184            Self {
185                server: CopilotServer::Disabled,
186            }
187        }
188    }
189
190    fn start_language_server(
191        http: Arc<dyn HttpClient>,
192        node_runtime: Arc<NodeRuntime>,
193        this: ModelHandle<Self>,
194        mut cx: AsyncAppContext,
195    ) -> impl Future<Output = ()> {
196        async move {
197            let start_language_server = async {
198                let server_path = get_copilot_lsp(http).await?;
199                let node_path = node_runtime.binary_path().await?;
200                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
201                let server =
202                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
203
204                let server = server.initialize(Default::default()).await?;
205                let status = server
206                    .request::<request::CheckStatus>(request::CheckStatusParams {
207                        local_checks_only: false,
208                    })
209                    .await?;
210                anyhow::Ok((server, status))
211            };
212
213            let server = start_language_server.await;
214            this.update(&mut cx, |this, cx| {
215                cx.notify();
216                match server {
217                    Ok((server, status)) => {
218                        this.server = CopilotServer::Started {
219                            server,
220                            status: SignInStatus::SignedOut,
221                        };
222                        this.update_sign_in_status(status, cx);
223                    }
224                    Err(error) => {
225                        this.server = CopilotServer::Error(error.to_string().into());
226                        cx.notify()
227                    }
228                }
229            })
230        }
231    }
232
233    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
234        if let CopilotServer::Started { server, status } = &mut self.server {
235            let task = match status {
236                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
237                    Task::ready(Ok(())).shared()
238                }
239                SignInStatus::SigningIn { task, .. } => {
240                    cx.notify();
241                    task.clone()
242                }
243                SignInStatus::SignedOut => {
244                    let server = server.clone();
245                    let task = cx
246                        .spawn(|this, mut cx| async move {
247                            let sign_in = async {
248                                let sign_in = server
249                                    .request::<request::SignInInitiate>(
250                                        request::SignInInitiateParams {},
251                                    )
252                                    .await?;
253                                match sign_in {
254                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
255                                        Ok(request::SignInStatus::Ok { user })
256                                    }
257                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
258                                        this.update(&mut cx, |this, cx| {
259                                            if let CopilotServer::Started { status, .. } =
260                                                &mut this.server
261                                            {
262                                                if let SignInStatus::SigningIn {
263                                                    prompt: prompt_flow,
264                                                    ..
265                                                } = status
266                                                {
267                                                    *prompt_flow = Some(flow.clone());
268                                                    cx.notify();
269                                                }
270                                            }
271                                        });
272                                        let response = server
273                                            .request::<request::SignInConfirm>(
274                                                request::SignInConfirmParams {
275                                                    user_code: flow.user_code,
276                                                },
277                                            )
278                                            .await?;
279                                        Ok(response)
280                                    }
281                                }
282                            };
283
284                            let sign_in = sign_in.await;
285                            this.update(&mut cx, |this, cx| match sign_in {
286                                Ok(status) => {
287                                    this.update_sign_in_status(status, cx);
288                                    Ok(())
289                                }
290                                Err(error) => {
291                                    this.update_sign_in_status(
292                                        request::SignInStatus::NotSignedIn,
293                                        cx,
294                                    );
295                                    Err(Arc::new(error))
296                                }
297                            })
298                        })
299                        .shared();
300                    *status = SignInStatus::SigningIn {
301                        prompt: None,
302                        task: task.clone(),
303                    };
304                    cx.notify();
305                    task
306                }
307            };
308
309            cx.foreground()
310                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
311        } else {
312            // If we're downloading, wait until download is finished
313            // If we're in a stuck state, display to the user
314            Task::ready(Err(anyhow!("copilot hasn't started yet")))
315        }
316    }
317
318    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
319        if let CopilotServer::Started { server, status } = &mut self.server {
320            *status = SignInStatus::SignedOut;
321            cx.notify();
322
323            let server = server.clone();
324            cx.background().spawn(async move {
325                server
326                    .request::<request::SignOut>(request::SignOutParams {})
327                    .await?;
328                anyhow::Ok(())
329            })
330        } else {
331            Task::ready(Err(anyhow!("copilot hasn't started yet")))
332        }
333    }
334
335    pub fn completion<T>(
336        &self,
337        buffer: &ModelHandle<Buffer>,
338        position: T,
339        cx: &mut ModelContext<Self>,
340    ) -> Task<Result<Option<Completion>>>
341    where
342        T: ToPointUtf16,
343    {
344        let server = match self.authorized_server() {
345            Ok(server) => server,
346            Err(error) => return Task::ready(Err(error)),
347        };
348
349        let buffer = buffer.read(cx).snapshot();
350        let request = server
351            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
352        cx.background().spawn(async move {
353            let result = request.await?;
354            let completion = result
355                .completions
356                .into_iter()
357                .next()
358                .map(|completion| completion_from_lsp(completion, &buffer));
359            anyhow::Ok(completion)
360        })
361    }
362
363    pub fn completions_cycling<T>(
364        &self,
365        buffer: &ModelHandle<Buffer>,
366        position: T,
367        cx: &mut ModelContext<Self>,
368    ) -> Task<Result<Vec<Completion>>>
369    where
370        T: ToPointUtf16,
371    {
372        let server = match self.authorized_server() {
373            Ok(server) => server,
374            Err(error) => return Task::ready(Err(error)),
375        };
376
377        let buffer = buffer.read(cx).snapshot();
378        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
379            &buffer, position, cx,
380        ));
381        cx.background().spawn(async move {
382            let result = request.await?;
383            let completions = result
384                .completions
385                .into_iter()
386                .map(|completion| completion_from_lsp(completion, &buffer))
387                .collect();
388            anyhow::Ok(completions)
389        })
390    }
391
392    pub fn status(&self) -> Status {
393        match &self.server {
394            CopilotServer::Starting { .. } => Status::Starting,
395            CopilotServer::Disabled => Status::Disabled,
396            CopilotServer::Error(error) => Status::Error(error.clone()),
397            CopilotServer::Started { status, .. } => match status {
398                SignInStatus::Authorized { .. } => Status::Authorized,
399                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
400                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
401                    prompt: prompt.clone(),
402                },
403                SignInStatus::SignedOut => Status::SignedOut,
404            },
405        }
406    }
407
408    fn update_sign_in_status(
409        &mut self,
410        lsp_status: request::SignInStatus,
411        cx: &mut ModelContext<Self>,
412    ) {
413        if let CopilotServer::Started { status, .. } = &mut self.server {
414            *status = match lsp_status {
415                request::SignInStatus::Ok { user }
416                | request::SignInStatus::MaybeOk { user }
417                | request::SignInStatus::AlreadySignedIn { user } => {
418                    SignInStatus::Authorized { _user: user }
419                }
420                request::SignInStatus::NotAuthorized { user } => {
421                    SignInStatus::Unauthorized { _user: user }
422                }
423                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
424            };
425            cx.notify();
426        }
427    }
428
429    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
430        match &self.server {
431            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
432            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
433            CopilotServer::Error(error) => Err(anyhow!(
434                "copilot was not started because of an error: {}",
435                error
436            )),
437            CopilotServer::Started { server, status } => {
438                if matches!(status, SignInStatus::Authorized { .. }) {
439                    Ok(server.clone())
440                } else {
441                    Err(anyhow!("must sign in before using copilot"))
442                }
443            }
444        }
445    }
446}
447
448fn build_completion_params<T>(
449    buffer: &BufferSnapshot,
450    position: T,
451    cx: &AppContext,
452) -> request::GetCompletionsParams
453where
454    T: ToPointUtf16,
455{
456    let position = position.to_point_utf16(&buffer);
457    let language_name = buffer.language_at(position).map(|language| language.name());
458    let language_name = language_name.as_deref();
459
460    let path;
461    let relative_path;
462    if let Some(file) = buffer.file() {
463        if let Some(file) = file.as_local() {
464            path = file.abs_path(cx);
465        } else {
466            path = file.full_path(cx);
467        }
468        relative_path = file.path().to_path_buf();
469    } else {
470        path = PathBuf::from("/untitled");
471        relative_path = PathBuf::from("untitled");
472    }
473
474    let settings = cx.global::<Settings>();
475    let language_id = match language_name {
476        Some("Plain Text") => "plaintext".to_string(),
477        Some(language_name) => language_name.to_lowercase(),
478        None => "plaintext".to_string(),
479    };
480    request::GetCompletionsParams {
481        doc: request::GetCompletionsDocument {
482            source: buffer.text(),
483            tab_size: settings.tab_size(language_name).into(),
484            indent_size: 1,
485            insert_spaces: !settings.hard_tabs(language_name),
486            uri: lsp::Url::from_file_path(&path).unwrap(),
487            path: path.to_string_lossy().into(),
488            relative_path: relative_path.to_string_lossy().into(),
489            language_id,
490            position: point_to_lsp(position),
491            version: 0,
492        },
493    }
494}
495
496fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
497    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
498    Completion {
499        position: buffer.anchor_before(position),
500        text: completion.display_text,
501    }
502}
503
504async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
505    const SERVER_PATH: &'static str = "agent.js";
506
507    ///Check for the latest copilot language server and download it if we haven't already
508    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
509        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
510
511        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
512
513        fs::create_dir_all(version_dir).await?;
514        let server_path = version_dir.join(SERVER_PATH);
515
516        if fs::metadata(&server_path).await.is_err() {
517            let url = &release
518                .assets
519                .get(0)
520                .context("Github release for copilot contained no assets")?
521                .browser_download_url;
522
523            let mut response = http
524                .get(&url, Default::default(), true)
525                .await
526                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
527            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
528            let archive = Archive::new(decompressed_bytes);
529            archive.unpack(version_dir).await?;
530
531            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
532        }
533
534        Ok(server_path)
535    }
536
537    match fetch_latest(http).await {
538        ok @ Result::Ok(..) => ok,
539        e @ Err(..) => {
540            e.log_err();
541            // Fetch a cached binary, if it exists
542            (|| async move {
543                let mut last_version_dir = None;
544                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
545                while let Some(entry) = entries.next().await {
546                    let entry = entry?;
547                    if entry.file_type().await?.is_dir() {
548                        last_version_dir = Some(entry.path());
549                    }
550                }
551                let last_version_dir =
552                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
553                let server_path = last_version_dir.join(SERVER_PATH);
554                if server_path.exists() {
555                    Ok(server_path)
556                } else {
557                    Err(anyhow!(
558                        "missing executable in directory {:?}",
559                        last_version_dir
560                    ))
561                }
562            })()
563            .await
564        }
565    }
566}