copilot.rs

  1mod auth_modal;
  2mod request;
  3
  4use anyhow::{anyhow, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use auth_modal::AuthModal;
  7use client::Client;
  8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 10use lsp::LanguageServer;
 11use settings::Settings;
 12use smol::{fs, io::BufReader, stream::StreamExt};
 13use std::{
 14    env::consts,
 15    path::{Path, PathBuf},
 16    sync::Arc,
 17};
 18use util::{
 19    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 20};
 21use workspace::Workspace;
 22
 23actions!(copilot, [SignIn, SignOut, ToggleAuthStatus]);
 24
 25pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 26    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 27    cx.set_global(copilot.clone());
 28    cx.add_action(|_workspace: &mut Workspace, _: &SignIn, cx| {
 29        let copilot = Copilot::global(cx);
 30        if copilot.read(cx).status() == Status::Authorized {
 31            return;
 32        }
 33
 34        if !copilot.read(cx).has_subscription() {
 35            let display_subscription =
 36                cx.subscribe(&copilot, |workspace, _copilot, e, cx| match e {
 37                    Event::PromptUserDeviceFlow => {
 38                        workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx));
 39                    }
 40                });
 41
 42            copilot.update(cx, |copilot, _cx| {
 43                copilot.set_subscription(display_subscription)
 44            })
 45        }
 46
 47        copilot
 48            .update(cx, |copilot, cx| copilot.sign_in(cx))
 49            .detach_and_log_err(cx);
 50    });
 51    cx.add_action(|workspace: &mut Workspace, _: &SignOut, cx| {
 52        let copilot = Copilot::global(cx);
 53
 54        copilot
 55            .update(cx, |copilot, cx| copilot.sign_out(cx))
 56            .detach_and_log_err(cx);
 57
 58        if workspace.modal::<AuthModal>().is_some() {
 59            workspace.dismiss_modal(cx)
 60        }
 61    });
 62    cx.add_action(|workspace: &mut Workspace, _: &ToggleAuthStatus, cx| {
 63        workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx))
 64    })
 65}
 66
 67fn build_auth_modal(cx: &mut gpui::ViewContext<Workspace>) -> gpui::ViewHandle<AuthModal> {
 68    let modal = cx.add_view(|cx| AuthModal::new(cx));
 69
 70    cx.subscribe(&modal, |workspace, _, e: &auth_modal::Event, cx| match e {
 71        auth_modal::Event::Dismiss => workspace.dismiss_modal(cx),
 72    })
 73    .detach();
 74
 75    modal
 76}
 77
 78enum CopilotServer {
 79    Downloading,
 80    Error(Arc<str>),
 81    Started {
 82        server: Arc<LanguageServer>,
 83        status: SignInStatus,
 84    },
 85}
 86
 87#[derive(Clone, Debug, PartialEq, Eq)]
 88struct PromptingUser {
 89    user_code: String,
 90    verification_uri: String,
 91}
 92
 93#[derive(Clone, Debug, PartialEq, Eq)]
 94enum SignInStatus {
 95    Authorized { user: String },
 96    Unauthorized { user: String },
 97    PromptingUser(PromptingUser),
 98    SignedOut,
 99}
100
101#[derive(Debug)]
102pub enum Event {
103    PromptUserDeviceFlow,
104}
105
106#[derive(Debug, PartialEq, Eq)]
107pub enum Status {
108    Downloading,
109    Error(Arc<str>),
110    SignedOut,
111    Unauthorized,
112    Authorized,
113}
114
115impl Status {
116    fn is_authorized(&self) -> bool {
117        matches!(self, Status::Authorized)
118    }
119}
120
121#[derive(Debug)]
122pub struct Completion {
123    pub position: Anchor,
124    pub text: String,
125}
126
127struct Copilot {
128    server: CopilotServer,
129    _display_subscription: Option<gpui::Subscription>,
130}
131
132impl Entity for Copilot {
133    type Event = Event;
134}
135
136impl Copilot {
137    fn global(cx: &AppContext) -> ModelHandle<Self> {
138        cx.global::<ModelHandle<Self>>().clone()
139    }
140
141    fn has_subscription(&self) -> bool {
142        self._display_subscription.is_some()
143    }
144
145    fn set_subscription(&mut self, display_subscription: gpui::Subscription) {
146        debug_assert!(self._display_subscription.is_none());
147        self._display_subscription = Some(display_subscription);
148    }
149
150    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
151        // TODO: Don't eagerly download the LSP
152        cx.spawn(|this, mut cx| async move {
153            let start_language_server = async {
154                let server_path = get_lsp_binary(http).await?;
155                let server =
156                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
157                let server = server.initialize(Default::default()).await?;
158                let status = server
159                    .request::<request::CheckStatus>(request::CheckStatusParams {
160                        local_checks_only: false,
161                    })
162                    .await?;
163                anyhow::Ok((server, status))
164            };
165
166            let server = start_language_server.await;
167            this.update(&mut cx, |this, cx| {
168                cx.notify();
169                match server {
170                    Ok((server, status)) => {
171                        this.server = CopilotServer::Started {
172                            server,
173                            status: SignInStatus::SignedOut,
174                        };
175                        this.update_sign_in_status(status, cx);
176                    }
177                    Err(error) => {
178                        this.server = CopilotServer::Error(error.to_string().into());
179                    }
180                }
181            })
182        })
183        .detach();
184
185        Self {
186            server: CopilotServer::Downloading,
187            _display_subscription: None,
188        }
189    }
190
191    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
192        if let CopilotServer::Started { server, .. } = &self.server {
193            let server = server.clone();
194            cx.spawn(|this, mut cx| async move {
195                let sign_in = server
196                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
197                    .await?;
198                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
199                    this.update(&mut cx, |this, cx| {
200                        this.update_prompting_user(
201                            flow.user_code.clone(),
202                            flow.verification_uri,
203                            cx,
204                        );
205
206                        cx.emit(Event::PromptUserDeviceFlow)
207                    });
208                    // TODO: catch an error here and clear the corresponding user code
209                    let response = server
210                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
211                            user_code: flow.user_code,
212                        })
213                        .await?;
214
215                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
216                }
217                anyhow::Ok(())
218            })
219        } else {
220            Task::ready(Err(anyhow!("copilot hasn't started yet")))
221        }
222    }
223
224    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
225        if let CopilotServer::Started { server, .. } = &self.server {
226            let server = server.clone();
227            cx.spawn(|this, mut cx| async move {
228                server
229                    .request::<request::SignOut>(request::SignOutParams {})
230                    .await?;
231                this.update(&mut cx, |this, cx| {
232                    if let CopilotServer::Started { status, .. } = &mut this.server {
233                        *status = SignInStatus::SignedOut;
234                        cx.notify();
235                    }
236                });
237
238                anyhow::Ok(())
239            })
240        } else {
241            Task::ready(Err(anyhow!("copilot hasn't started yet")))
242        }
243    }
244
245    pub fn completion<T>(
246        &self,
247        buffer: &ModelHandle<Buffer>,
248        position: T,
249        cx: &mut ModelContext<Self>,
250    ) -> Task<Result<Option<Completion>>>
251    where
252        T: ToPointUtf16,
253    {
254        let server = match self.authenticated_server() {
255            Ok(server) => server,
256            Err(error) => return Task::ready(Err(error)),
257        };
258
259        let buffer = buffer.read(cx).snapshot();
260        let request = server
261            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
262        cx.background().spawn(async move {
263            let result = request.await?;
264            let completion = result
265                .completions
266                .into_iter()
267                .next()
268                .map(|completion| completion_from_lsp(completion, &buffer));
269            anyhow::Ok(completion)
270        })
271    }
272
273    pub fn completions_cycling<T>(
274        &self,
275        buffer: &ModelHandle<Buffer>,
276        position: T,
277        cx: &mut ModelContext<Self>,
278    ) -> Task<Result<Vec<Completion>>>
279    where
280        T: ToPointUtf16,
281    {
282        let server = match self.authenticated_server() {
283            Ok(server) => server,
284            Err(error) => return Task::ready(Err(error)),
285        };
286
287        let buffer = buffer.read(cx).snapshot();
288        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
289            &buffer, position, cx,
290        ));
291        cx.background().spawn(async move {
292            let result = request.await?;
293            let completions = result
294                .completions
295                .into_iter()
296                .map(|completion| completion_from_lsp(completion, &buffer))
297                .collect();
298            anyhow::Ok(completions)
299        })
300    }
301
302    pub fn status(&self) -> Status {
303        match &self.server {
304            CopilotServer::Downloading => Status::Downloading,
305            CopilotServer::Error(error) => Status::Error(error.clone()),
306            CopilotServer::Started { status, .. } => match status {
307                SignInStatus::Authorized { .. } => Status::Authorized,
308                SignInStatus::Unauthorized { .. } | SignInStatus::PromptingUser { .. } => {
309                    Status::Unauthorized
310                }
311                SignInStatus::SignedOut => Status::SignedOut,
312            },
313        }
314    }
315
316    pub fn prompting_user(&self) -> Option<&PromptingUser> {
317        if let CopilotServer::Started { status, .. } = &self.server {
318            if let SignInStatus::PromptingUser(prompt) = status {
319                return Some(prompt);
320            }
321        }
322        None
323    }
324
325    fn update_prompting_user(
326        &mut self,
327        user_code: String,
328        verification_uri: String,
329        cx: &mut ModelContext<Self>,
330    ) {
331        if let CopilotServer::Started { status, .. } = &mut self.server {
332            *status = SignInStatus::PromptingUser(PromptingUser {
333                user_code,
334                verification_uri,
335            });
336            cx.notify();
337        }
338    }
339
340    fn update_sign_in_status(
341        &mut self,
342        lsp_status: request::SignInStatus,
343        cx: &mut ModelContext<Self>,
344    ) {
345        if let CopilotServer::Started { status, .. } = &mut self.server {
346            *status = match lsp_status {
347                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
348                    SignInStatus::Authorized { user }
349                }
350                request::SignInStatus::NotAuthorized { user } => {
351                    SignInStatus::Unauthorized { user }
352                }
353                _ => SignInStatus::SignedOut,
354            };
355            cx.notify();
356        }
357    }
358
359    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
360        match &self.server {
361            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
362            CopilotServer::Error(error) => Err(anyhow!(
363                "copilot was not started because of an error: {}",
364                error
365            )),
366            CopilotServer::Started { server, status } => {
367                if matches!(status, SignInStatus::Authorized { .. }) {
368                    Ok(server.clone())
369                } else {
370                    Err(anyhow!("must sign in before using copilot"))
371                }
372            }
373        }
374    }
375}
376
377fn build_completion_params<T>(
378    buffer: &BufferSnapshot,
379    position: T,
380    cx: &AppContext,
381) -> request::GetCompletionsParams
382where
383    T: ToPointUtf16,
384{
385    let position = position.to_point_utf16(&buffer);
386    let language_name = buffer.language_at(position).map(|language| language.name());
387    let language_name = language_name.as_deref();
388
389    let path;
390    let relative_path;
391    if let Some(file) = buffer.file() {
392        if let Some(file) = file.as_local() {
393            path = file.abs_path(cx);
394        } else {
395            path = file.full_path(cx);
396        }
397        relative_path = file.path().to_path_buf();
398    } else {
399        path = PathBuf::from("/untitled");
400        relative_path = PathBuf::from("untitled");
401    }
402
403    let settings = cx.global::<Settings>();
404    let language_id = match language_name {
405        Some("Plain Text") => "plaintext".to_string(),
406        Some(language_name) => language_name.to_lowercase(),
407        None => "plaintext".to_string(),
408    };
409    request::GetCompletionsParams {
410        doc: request::GetCompletionsDocument {
411            source: buffer.text(),
412            tab_size: settings.tab_size(language_name).into(),
413            indent_size: 1,
414            insert_spaces: !settings.hard_tabs(language_name),
415            uri: lsp::Url::from_file_path(&path).unwrap(),
416            path: path.to_string_lossy().into(),
417            relative_path: relative_path.to_string_lossy().into(),
418            language_id,
419            position: point_to_lsp(position),
420            version: 0,
421        },
422    }
423}
424
425fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
426    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
427    Completion {
428        position: buffer.anchor_before(position),
429        text: completion.display_text,
430    }
431}
432
433async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
434    ///Check for the latest copilot language server and download it if we haven't already
435    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
436        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
437        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
438        let asset = release
439            .assets
440            .iter()
441            .find(|asset| asset.name == asset_name)
442            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
443
444        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
445        let destination_path =
446            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
447
448        if fs::metadata(&destination_path).await.is_err() {
449            let mut response = http
450                .get(&asset.browser_download_url, Default::default(), true)
451                .await
452                .map_err(|err| anyhow!("error downloading release: {}", err))?;
453            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
454            let mut file = fs::File::create(&destination_path).await?;
455            futures::io::copy(decompressed_bytes, &mut file).await?;
456            fs::set_permissions(
457                &destination_path,
458                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
459            )
460            .await?;
461
462            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
463        }
464
465        Ok(destination_path)
466    }
467
468    match fetch_latest(http).await {
469        ok @ Result::Ok(..) => ok,
470        e @ Err(..) => {
471            e.log_err();
472            // Fetch a cached binary, if it exists
473            (|| async move {
474                let mut last = None;
475                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
476                while let Some(entry) = entries.next().await {
477                    last = Some(entry?.path());
478                }
479                last.ok_or_else(|| anyhow!("no cached binary"))
480            })()
481            .await
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use gpui::TestAppContext;
490    use util::http;
491
492    #[gpui::test]
493    async fn test_smoke(cx: &mut TestAppContext) {
494        Settings::test_async(cx);
495        let http = http::client();
496        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
497        smol::Timer::after(std::time::Duration::from_secs(2)).await;
498        copilot
499            .update(cx, |copilot, cx| copilot.sign_in(cx))
500            .await
501            .unwrap();
502        dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
503
504        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
505        dbg!(copilot
506            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
507            .await
508            .unwrap());
509        dbg!(copilot
510            .update(cx, |copilot, cx| copilot
511                .completions_cycling(&buffer, 12, cx))
512            .await
513            .unwrap());
514    }
515}