copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use client::Client;
  7use futures::{future::Shared, FutureExt, TryFutureExt};
  8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 10use lsp::LanguageServer;
 11use settings::Settings;
 12use smol::{fs, io::BufReader, stream::StreamExt};
 13use std::{
 14    env::consts,
 15    path::{Path, PathBuf},
 16    sync::Arc,
 17};
 18use util::{
 19    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 20};
 21
 22actions!(copilot, [SignIn, SignOut, NextSuggestion]);
 23
 24pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 25    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 26    cx.set_global(copilot.clone());
 27    cx.add_global_action(|_: &SignIn, cx| {
 28        let copilot = Copilot::global(cx).unwrap();
 29        copilot
 30            .update(cx, |copilot, cx| copilot.sign_in(cx))
 31            .detach_and_log_err(cx);
 32    });
 33    cx.add_global_action(|_: &SignOut, cx| {
 34        let copilot = Copilot::global(cx).unwrap();
 35        copilot
 36            .update(cx, |copilot, cx| copilot.sign_out(cx))
 37            .detach_and_log_err(cx);
 38    });
 39    sign_in::init(cx);
 40}
 41
 42enum CopilotServer {
 43    Downloading,
 44    Error(Arc<str>),
 45    Started {
 46        server: Arc<LanguageServer>,
 47        status: SignInStatus,
 48    },
 49}
 50
 51#[derive(Clone, Debug)]
 52enum SignInStatus {
 53    Authorized {
 54        user: String,
 55    },
 56    Unauthorized {
 57        user: String,
 58    },
 59    SigningIn {
 60        prompt: Option<request::PromptUserDeviceFlow>,
 61        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 62    },
 63    SignedOut,
 64}
 65
 66#[derive(Debug, PartialEq, Eq)]
 67pub enum Status {
 68    Downloading,
 69    Error(Arc<str>),
 70    SignedOut,
 71    SigningIn {
 72        prompt: Option<request::PromptUserDeviceFlow>,
 73    },
 74    Unauthorized,
 75    Authorized,
 76}
 77
 78impl Status {
 79    pub fn is_authorized(&self) -> bool {
 80        matches!(self, Status::Authorized)
 81    }
 82}
 83
 84#[derive(Debug, PartialEq, Eq)]
 85pub struct Completion {
 86    pub position: Anchor,
 87    pub text: String,
 88}
 89
 90pub struct Copilot {
 91    server: CopilotServer,
 92}
 93
 94impl Entity for Copilot {
 95    type Event = ();
 96}
 97
 98impl Copilot {
 99    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
100        if cx.has_global::<ModelHandle<Self>>() {
101            Some(cx.global::<ModelHandle<Self>>().clone())
102        } else {
103            None
104        }
105    }
106
107    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
108        // TODO: Don't eagerly download the LSP
109        cx.spawn(|this, mut cx| async move {
110            let start_language_server = async {
111                let server_path = get_lsp_binary(http).await?;
112                let server =
113                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
114                let server = server.initialize(Default::default()).await?;
115                let status = server
116                    .request::<request::CheckStatus>(request::CheckStatusParams {
117                        local_checks_only: false,
118                    })
119                    .await?;
120                anyhow::Ok((server, status))
121            };
122
123            let server = start_language_server.await;
124            this.update(&mut cx, |this, cx| {
125                cx.notify();
126                match server {
127                    Ok((server, status)) => {
128                        this.server = CopilotServer::Started {
129                            server,
130                            status: SignInStatus::SignedOut,
131                        };
132                        this.update_sign_in_status(status, cx);
133                    }
134                    Err(error) => {
135                        this.server = CopilotServer::Error(error.to_string().into());
136                    }
137                }
138            })
139        })
140        .detach();
141
142        Self {
143            server: CopilotServer::Downloading,
144        }
145    }
146
147    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
148        if let CopilotServer::Started { server, status } = &mut self.server {
149            let task = match status {
150                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
151                    Task::ready(Ok(())).shared()
152                }
153                SignInStatus::SigningIn { task, .. } => task.clone(),
154                SignInStatus::SignedOut => {
155                    let server = server.clone();
156                    let task = cx
157                        .spawn(|this, mut cx| async move {
158                            let sign_in = async {
159                                let sign_in = server
160                                    .request::<request::SignInInitiate>(
161                                        request::SignInInitiateParams {},
162                                    )
163                                    .await?;
164                                match sign_in {
165                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
166                                        Ok(request::SignInStatus::Ok { user })
167                                    }
168                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
169                                        this.update(&mut cx, |this, cx| {
170                                            if let CopilotServer::Started { status, .. } =
171                                                &mut this.server
172                                            {
173                                                if let SignInStatus::SigningIn {
174                                                    prompt: prompt_flow,
175                                                    ..
176                                                } = status
177                                                {
178                                                    *prompt_flow = Some(flow.clone());
179                                                    cx.notify();
180                                                }
181                                            }
182                                        });
183                                        let response = server
184                                            .request::<request::SignInConfirm>(
185                                                request::SignInConfirmParams {
186                                                    user_code: flow.user_code,
187                                                },
188                                            )
189                                            .await?;
190                                        Ok(response)
191                                    }
192                                }
193                            };
194
195                            let sign_in = sign_in.await;
196                            this.update(&mut cx, |this, cx| match sign_in {
197                                Ok(status) => {
198                                    this.update_sign_in_status(status, cx);
199                                    Ok(())
200                                }
201                                Err(error) => {
202                                    this.update_sign_in_status(
203                                        request::SignInStatus::NotSignedIn,
204                                        cx,
205                                    );
206                                    Err(Arc::new(error))
207                                }
208                            })
209                        })
210                        .shared();
211                    *status = SignInStatus::SigningIn {
212                        prompt: None,
213                        task: task.clone(),
214                    };
215                    cx.notify();
216                    task
217                }
218            };
219
220            cx.foreground()
221                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
222        } else {
223            Task::ready(Err(anyhow!("copilot hasn't started yet")))
224        }
225    }
226
227    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
228        if let CopilotServer::Started { server, status } = &mut self.server {
229            *status = SignInStatus::SignedOut;
230            cx.notify();
231
232            let server = server.clone();
233            cx.background().spawn(async move {
234                server
235                    .request::<request::SignOut>(request::SignOutParams {})
236                    .await?;
237                anyhow::Ok(())
238            })
239        } else {
240            Task::ready(Err(anyhow!("copilot hasn't started yet")))
241        }
242    }
243
244    pub fn completion<T>(
245        &self,
246        buffer: &ModelHandle<Buffer>,
247        position: T,
248        cx: &mut ModelContext<Self>,
249    ) -> Task<Result<Option<Completion>>>
250    where
251        T: ToPointUtf16,
252    {
253        let server = match self.authorized_server() {
254            Ok(server) => server,
255            Err(error) => return Task::ready(Err(error)),
256        };
257
258        let buffer = buffer.read(cx).snapshot();
259        let request = server
260            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
261        cx.background().spawn(async move {
262            let result = request.await?;
263            let completion = result
264                .completions
265                .into_iter()
266                .next()
267                .map(|completion| completion_from_lsp(completion, &buffer));
268            anyhow::Ok(completion)
269        })
270    }
271
272    pub fn completions_cycling<T>(
273        &self,
274        buffer: &ModelHandle<Buffer>,
275        position: T,
276        cx: &mut ModelContext<Self>,
277    ) -> Task<Result<Vec<Completion>>>
278    where
279        T: ToPointUtf16,
280    {
281        let server = match self.authorized_server() {
282            Ok(server) => server,
283            Err(error) => return Task::ready(Err(error)),
284        };
285
286        let buffer = buffer.read(cx).snapshot();
287        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
288            &buffer, position, cx,
289        ));
290        cx.background().spawn(async move {
291            let result = request.await?;
292            let completions = result
293                .completions
294                .into_iter()
295                .map(|completion| completion_from_lsp(completion, &buffer))
296                .collect();
297            anyhow::Ok(completions)
298        })
299    }
300
301    pub fn status(&self) -> Status {
302        match &self.server {
303            CopilotServer::Downloading => Status::Downloading,
304            CopilotServer::Error(error) => Status::Error(error.clone()),
305            CopilotServer::Started { status, .. } => match status {
306                SignInStatus::Authorized { .. } => Status::Authorized,
307                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
308                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
309                    prompt: prompt.clone(),
310                },
311                SignInStatus::SignedOut => Status::SignedOut,
312            },
313        }
314    }
315
316    fn update_sign_in_status(
317        &mut self,
318        lsp_status: request::SignInStatus,
319        cx: &mut ModelContext<Self>,
320    ) {
321        if let CopilotServer::Started { status, .. } = &mut self.server {
322            *status = match lsp_status {
323                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
324                    SignInStatus::Authorized { user }
325                }
326                request::SignInStatus::NotAuthorized { user } => {
327                    SignInStatus::Unauthorized { user }
328                }
329                _ => SignInStatus::SignedOut,
330            };
331            cx.notify();
332        }
333    }
334
335    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
336        match &self.server {
337            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
338            CopilotServer::Error(error) => Err(anyhow!(
339                "copilot was not started because of an error: {}",
340                error
341            )),
342            CopilotServer::Started { server, status } => {
343                if matches!(status, SignInStatus::Authorized { .. }) {
344                    Ok(server.clone())
345                } else {
346                    Err(anyhow!("must sign in before using copilot"))
347                }
348            }
349        }
350    }
351}
352
353fn build_completion_params<T>(
354    buffer: &BufferSnapshot,
355    position: T,
356    cx: &AppContext,
357) -> request::GetCompletionsParams
358where
359    T: ToPointUtf16,
360{
361    let position = position.to_point_utf16(&buffer);
362    let language_name = buffer.language_at(position).map(|language| language.name());
363    let language_name = language_name.as_deref();
364
365    let path;
366    let relative_path;
367    if let Some(file) = buffer.file() {
368        if let Some(file) = file.as_local() {
369            path = file.abs_path(cx);
370        } else {
371            path = file.full_path(cx);
372        }
373        relative_path = file.path().to_path_buf();
374    } else {
375        path = PathBuf::from("/untitled");
376        relative_path = PathBuf::from("untitled");
377    }
378
379    let settings = cx.global::<Settings>();
380    let language_id = match language_name {
381        Some("Plain Text") => "plaintext".to_string(),
382        Some(language_name) => language_name.to_lowercase(),
383        None => "plaintext".to_string(),
384    };
385    request::GetCompletionsParams {
386        doc: request::GetCompletionsDocument {
387            source: buffer.text(),
388            tab_size: settings.tab_size(language_name).into(),
389            indent_size: 1,
390            insert_spaces: !settings.hard_tabs(language_name),
391            uri: lsp::Url::from_file_path(&path).unwrap(),
392            path: path.to_string_lossy().into(),
393            relative_path: relative_path.to_string_lossy().into(),
394            language_id,
395            position: point_to_lsp(position),
396            version: 0,
397        },
398    }
399}
400
401fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
402    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
403    Completion {
404        position: buffer.anchor_before(position),
405        text: completion.display_text,
406    }
407}
408
409async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
410    ///Check for the latest copilot language server and download it if we haven't already
411    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
412        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
413        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
414        let asset = release
415            .assets
416            .iter()
417            .find(|asset| asset.name == asset_name)
418            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
419
420        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
421        let destination_path =
422            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
423
424        if fs::metadata(&destination_path).await.is_err() {
425            let mut response = http
426                .get(&asset.browser_download_url, Default::default(), true)
427                .await
428                .map_err(|err| anyhow!("error downloading release: {}", err))?;
429            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
430            let mut file = fs::File::create(&destination_path).await?;
431            futures::io::copy(decompressed_bytes, &mut file).await?;
432            fs::set_permissions(
433                &destination_path,
434                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
435            )
436            .await?;
437
438            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
439        }
440
441        Ok(destination_path)
442    }
443
444    match fetch_latest(http).await {
445        ok @ Result::Ok(..) => ok,
446        e @ Err(..) => {
447            e.log_err();
448            // Fetch a cached binary, if it exists
449            (|| async move {
450                let mut last = None;
451                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
452                while let Some(entry) = entries.next().await {
453                    last = Some(entry?.path());
454                }
455                last.ok_or_else(|| anyhow!("no cached binary"))
456            })()
457            .await
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use gpui::TestAppContext;
466    use util::http;
467
468    #[gpui::test]
469    async fn test_smoke(cx: &mut TestAppContext) {
470        Settings::test_async(cx);
471        let http = http::client();
472        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
473        smol::Timer::after(std::time::Duration::from_secs(2)).await;
474        copilot
475            .update(cx, |copilot, cx| copilot.sign_in(cx))
476            .await
477            .unwrap();
478        copilot.read_with(cx, |copilot, _| copilot.status());
479
480        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
481        dbg!(copilot
482            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
483            .await
484            .unwrap());
485        dbg!(copilot
486            .update(cx, |copilot, cx| copilot
487                .completions_cycling(&buffer, 12, cx))
488            .await
489            .unwrap());
490    }
491}