copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use client::Client;
  7use futures::{future::Shared, FutureExt, TryFutureExt};
  8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 10use lsp::LanguageServer;
 11use settings::Settings;
 12use smol::{fs, io::BufReader, stream::StreamExt};
 13use std::{
 14    env::consts,
 15    path::{Path, PathBuf},
 16    sync::Arc,
 17};
 18use util::{
 19    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 20};
 21
 22actions!(copilot, [SignIn, SignOut]);
 23
 24pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 25    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 26    cx.set_global(copilot.clone());
 27    cx.add_global_action(|_: &SignIn, cx| {
 28        let copilot = Copilot::global(cx);
 29        copilot
 30            .update(cx, |copilot, cx| copilot.sign_in(cx))
 31            .detach_and_log_err(cx);
 32    });
 33    cx.add_global_action(|_: &SignOut, cx| {
 34        let copilot = Copilot::global(cx);
 35        copilot
 36            .update(cx, |copilot, cx| copilot.sign_out(cx))
 37            .detach_and_log_err(cx);
 38    });
 39    sign_in::init(cx);
 40}
 41
 42enum CopilotServer {
 43    Downloading,
 44    Error(Arc<str>),
 45    Started {
 46        server: Arc<LanguageServer>,
 47        status: SignInStatus,
 48    },
 49}
 50
 51#[derive(Clone, Debug)]
 52enum SignInStatus {
 53    Authorized {
 54        user: String,
 55    },
 56    Unauthorized {
 57        user: String,
 58    },
 59    SigningIn {
 60        prompt: Option<request::PromptUserDeviceFlow>,
 61        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 62    },
 63    SignedOut,
 64}
 65
 66#[derive(Debug, PartialEq, Eq)]
 67pub enum Status {
 68    Downloading,
 69    Error(Arc<str>),
 70    SignedOut,
 71    SigningIn {
 72        prompt: Option<request::PromptUserDeviceFlow>,
 73    },
 74    Unauthorized,
 75    Authorized,
 76}
 77
 78#[derive(Debug)]
 79pub struct Completion {
 80    pub position: Anchor,
 81    pub text: String,
 82}
 83
 84struct Copilot {
 85    server: CopilotServer,
 86}
 87
 88impl Entity for Copilot {
 89    type Event = ();
 90}
 91
 92impl Copilot {
 93    fn global(cx: &AppContext) -> ModelHandle<Self> {
 94        cx.global::<ModelHandle<Self>>().clone()
 95    }
 96
 97    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
 98        // TODO: Don't eagerly download the LSP
 99        cx.spawn(|this, mut cx| async move {
100            let start_language_server = async {
101                let server_path = get_lsp_binary(http).await?;
102                let server =
103                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
104                let server = server.initialize(Default::default()).await?;
105                let status = server
106                    .request::<request::CheckStatus>(request::CheckStatusParams {
107                        local_checks_only: false,
108                    })
109                    .await?;
110                anyhow::Ok((server, status))
111            };
112
113            let server = start_language_server.await;
114            this.update(&mut cx, |this, cx| {
115                cx.notify();
116                match server {
117                    Ok((server, status)) => {
118                        this.server = CopilotServer::Started {
119                            server,
120                            status: SignInStatus::SignedOut,
121                        };
122                        this.update_sign_in_status(status, cx);
123                    }
124                    Err(error) => {
125                        this.server = CopilotServer::Error(error.to_string().into());
126                    }
127                }
128            })
129        })
130        .detach();
131
132        Self {
133            server: CopilotServer::Downloading,
134        }
135    }
136
137    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
138        if let CopilotServer::Started { server, status } = &mut self.server {
139            let task = match status {
140                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
141                    Task::ready(Ok(())).shared()
142                }
143                SignInStatus::SigningIn { task, .. } => task.clone(),
144                SignInStatus::SignedOut => {
145                    let server = server.clone();
146                    let task = cx
147                        .spawn(|this, mut cx| async move {
148                            let sign_in = async {
149                                let sign_in = server
150                                    .request::<request::SignInInitiate>(
151                                        request::SignInInitiateParams {},
152                                    )
153                                    .await?;
154                                match sign_in {
155                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
156                                        Ok(request::SignInStatus::Ok { user })
157                                    }
158                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
159                                        this.update(&mut cx, |this, cx| {
160                                            if let CopilotServer::Started { status, .. } =
161                                                &mut this.server
162                                            {
163                                                if let SignInStatus::SigningIn {
164                                                    prompt: prompt_flow,
165                                                    ..
166                                                } = status
167                                                {
168                                                    *prompt_flow = Some(flow.clone());
169                                                    cx.notify();
170                                                }
171                                            }
172                                        });
173                                        let response = server
174                                            .request::<request::SignInConfirm>(
175                                                request::SignInConfirmParams {
176                                                    user_code: flow.user_code,
177                                                },
178                                            )
179                                            .await?;
180                                        Ok(response)
181                                    }
182                                }
183                            };
184
185                            let sign_in = sign_in.await;
186                            this.update(&mut cx, |this, cx| match sign_in {
187                                Ok(status) => {
188                                    this.update_sign_in_status(status, cx);
189                                    Ok(())
190                                }
191                                Err(error) => {
192                                    this.update_sign_in_status(
193                                        request::SignInStatus::NotSignedIn,
194                                        cx,
195                                    );
196                                    Err(Arc::new(error))
197                                }
198                            })
199                        })
200                        .shared();
201                    *status = SignInStatus::SigningIn {
202                        prompt: None,
203                        task: task.clone(),
204                    };
205                    cx.notify();
206                    task
207                }
208            };
209
210            cx.foreground()
211                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
212        } else {
213            Task::ready(Err(anyhow!("copilot hasn't started yet")))
214        }
215    }
216
217    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
218        if let CopilotServer::Started { server, status } = &mut self.server {
219            *status = SignInStatus::SignedOut;
220            cx.notify();
221
222            let server = server.clone();
223            cx.background().spawn(async move {
224                server
225                    .request::<request::SignOut>(request::SignOutParams {})
226                    .await?;
227                anyhow::Ok(())
228            })
229        } else {
230            Task::ready(Err(anyhow!("copilot hasn't started yet")))
231        }
232    }
233
234    pub fn completion<T>(
235        &self,
236        buffer: &ModelHandle<Buffer>,
237        position: T,
238        cx: &mut ModelContext<Self>,
239    ) -> Task<Result<Option<Completion>>>
240    where
241        T: ToPointUtf16,
242    {
243        let server = match self.authenticated_server() {
244            Ok(server) => server,
245            Err(error) => return Task::ready(Err(error)),
246        };
247
248        let buffer = buffer.read(cx).snapshot();
249        let request = server
250            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
251        cx.background().spawn(async move {
252            let result = request.await?;
253            let completion = result
254                .completions
255                .into_iter()
256                .next()
257                .map(|completion| completion_from_lsp(completion, &buffer));
258            anyhow::Ok(completion)
259        })
260    }
261
262    pub fn completions_cycling<T>(
263        &self,
264        buffer: &ModelHandle<Buffer>,
265        position: T,
266        cx: &mut ModelContext<Self>,
267    ) -> Task<Result<Vec<Completion>>>
268    where
269        T: ToPointUtf16,
270    {
271        let server = match self.authenticated_server() {
272            Ok(server) => server,
273            Err(error) => return Task::ready(Err(error)),
274        };
275
276        let buffer = buffer.read(cx).snapshot();
277        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
278            &buffer, position, cx,
279        ));
280        cx.background().spawn(async move {
281            let result = request.await?;
282            let completions = result
283                .completions
284                .into_iter()
285                .map(|completion| completion_from_lsp(completion, &buffer))
286                .collect();
287            anyhow::Ok(completions)
288        })
289    }
290
291    pub fn status(&self) -> Status {
292        match &self.server {
293            CopilotServer::Downloading => Status::Downloading,
294            CopilotServer::Error(error) => Status::Error(error.clone()),
295            CopilotServer::Started { status, .. } => match status {
296                SignInStatus::Authorized { .. } => Status::Authorized,
297                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
298                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
299                    prompt: prompt.clone(),
300                },
301                SignInStatus::SignedOut => Status::SignedOut,
302            },
303        }
304    }
305
306    fn update_sign_in_status(
307        &mut self,
308        lsp_status: request::SignInStatus,
309        cx: &mut ModelContext<Self>,
310    ) {
311        if let CopilotServer::Started { status, .. } = &mut self.server {
312            *status = match lsp_status {
313                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
314                    SignInStatus::Authorized { user }
315                }
316                request::SignInStatus::NotAuthorized { user } => {
317                    SignInStatus::Unauthorized { user }
318                }
319                _ => SignInStatus::SignedOut,
320            };
321            cx.notify();
322        }
323    }
324
325    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
326        match &self.server {
327            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
328            CopilotServer::Error(error) => Err(anyhow!(
329                "copilot was not started because of an error: {}",
330                error
331            )),
332            CopilotServer::Started { server, status } => {
333                if matches!(status, SignInStatus::Authorized { .. }) {
334                    Ok(server.clone())
335                } else {
336                    Err(anyhow!("must sign in before using copilot"))
337                }
338            }
339        }
340    }
341}
342
343fn build_completion_params<T>(
344    buffer: &BufferSnapshot,
345    position: T,
346    cx: &AppContext,
347) -> request::GetCompletionsParams
348where
349    T: ToPointUtf16,
350{
351    let position = position.to_point_utf16(&buffer);
352    let language_name = buffer.language_at(position).map(|language| language.name());
353    let language_name = language_name.as_deref();
354
355    let path;
356    let relative_path;
357    if let Some(file) = buffer.file() {
358        if let Some(file) = file.as_local() {
359            path = file.abs_path(cx);
360        } else {
361            path = file.full_path(cx);
362        }
363        relative_path = file.path().to_path_buf();
364    } else {
365        path = PathBuf::from("/untitled");
366        relative_path = PathBuf::from("untitled");
367    }
368
369    let settings = cx.global::<Settings>();
370    let language_id = match language_name {
371        Some("Plain Text") => "plaintext".to_string(),
372        Some(language_name) => language_name.to_lowercase(),
373        None => "plaintext".to_string(),
374    };
375    request::GetCompletionsParams {
376        doc: request::GetCompletionsDocument {
377            source: buffer.text(),
378            tab_size: settings.tab_size(language_name).into(),
379            indent_size: 1,
380            insert_spaces: !settings.hard_tabs(language_name),
381            uri: lsp::Url::from_file_path(&path).unwrap(),
382            path: path.to_string_lossy().into(),
383            relative_path: relative_path.to_string_lossy().into(),
384            language_id,
385            position: point_to_lsp(position),
386            version: 0,
387        },
388    }
389}
390
391fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
392    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
393    Completion {
394        position: buffer.anchor_before(position),
395        text: completion.display_text,
396    }
397}
398
399async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
400    ///Check for the latest copilot language server and download it if we haven't already
401    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
402        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
403        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
404        let asset = release
405            .assets
406            .iter()
407            .find(|asset| asset.name == asset_name)
408            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
409
410        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
411        let destination_path =
412            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
413
414        if fs::metadata(&destination_path).await.is_err() {
415            let mut response = http
416                .get(&asset.browser_download_url, Default::default(), true)
417                .await
418                .map_err(|err| anyhow!("error downloading release: {}", err))?;
419            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
420            let mut file = fs::File::create(&destination_path).await?;
421            futures::io::copy(decompressed_bytes, &mut file).await?;
422            fs::set_permissions(
423                &destination_path,
424                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
425            )
426            .await?;
427
428            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
429        }
430
431        Ok(destination_path)
432    }
433
434    match fetch_latest(http).await {
435        ok @ Result::Ok(..) => ok,
436        e @ Err(..) => {
437            e.log_err();
438            // Fetch a cached binary, if it exists
439            (|| async move {
440                let mut last = None;
441                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
442                while let Some(entry) = entries.next().await {
443                    last = Some(entry?.path());
444                }
445                last.ok_or_else(|| anyhow!("no cached binary"))
446            })()
447            .await
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use gpui::TestAppContext;
456    use util::http;
457
458    #[gpui::test]
459    async fn test_smoke(cx: &mut TestAppContext) {
460        Settings::test_async(cx);
461        let http = http::client();
462        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
463        smol::Timer::after(std::time::Duration::from_secs(2)).await;
464        copilot
465            .update(cx, |copilot, cx| copilot.sign_in(cx))
466            .await
467            .unwrap();
468        dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
469
470        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
471        dbg!(copilot
472            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
473            .await
474            .unwrap());
475        dbg!(copilot
476            .update(cx, |copilot, cx| copilot
477                .completions_cycling(&buffer, 12, cx))
478            .await
479            .unwrap());
480    }
481}