copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use client::Client;
  7use futures::{future::Shared, FutureExt, TryFutureExt};
  8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 10use lsp::LanguageServer;
 11use node_runtime::NodeRuntime;
 12use settings::Settings;
 13use smol::{fs, io::BufReader, stream::StreamExt};
 14use std::{
 15    env::consts,
 16    path::{Path, PathBuf},
 17    sync::Arc,
 18};
 19use util::{
 20    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 21};
 22
 23actions!(copilot, [SignIn, SignOut, NextSuggestion]);
 24
 25pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 26    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 27    cx.set_global(copilot.clone());
 28    cx.add_global_action(|_: &SignIn, cx| {
 29        let copilot = Copilot::global(cx).unwrap();
 30        copilot
 31            .update(cx, |copilot, cx| copilot.sign_in(cx))
 32            .detach_and_log_err(cx);
 33    });
 34    cx.add_global_action(|_: &SignOut, cx| {
 35        let copilot = Copilot::global(cx).unwrap();
 36        copilot
 37            .update(cx, |copilot, cx| copilot.sign_out(cx))
 38            .detach_and_log_err(cx);
 39    });
 40    sign_in::init(cx);
 41}
 42
 43enum CopilotServer {
 44    Downloading,
 45    Error(Arc<str>),
 46    Started {
 47        server: Arc<LanguageServer>,
 48        status: SignInStatus,
 49    },
 50}
 51
 52#[derive(Clone, Debug)]
 53enum SignInStatus {
 54    Authorized {
 55        _user: String,
 56    },
 57    Unauthorized {
 58        _user: String,
 59    },
 60    SigningIn {
 61        prompt: Option<request::PromptUserDeviceFlow>,
 62        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 63    },
 64    SignedOut,
 65}
 66
 67#[derive(Debug, PartialEq, Eq)]
 68pub enum Status {
 69    Downloading,
 70    Error(Arc<str>),
 71    SignedOut,
 72    SigningIn {
 73        prompt: Option<request::PromptUserDeviceFlow>,
 74    },
 75    Unauthorized,
 76    Authorized,
 77}
 78
 79impl Status {
 80    pub fn is_authorized(&self) -> bool {
 81        matches!(self, Status::Authorized)
 82    }
 83}
 84
 85#[derive(Debug, PartialEq, Eq)]
 86pub struct Completion {
 87    pub position: Anchor,
 88    pub text: String,
 89}
 90
 91pub struct Copilot {
 92    server: CopilotServer,
 93}
 94
 95impl Entity for Copilot {
 96    type Event = ();
 97}
 98
 99impl Copilot {
100    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
101        if cx.has_global::<ModelHandle<Self>>() {
102            Some(cx.global::<ModelHandle<Self>>().clone())
103        } else {
104            None
105        }
106    }
107
108    fn start(
109        http: Arc<dyn HttpClient>,
110        node_runtime: Arc<NodeRuntime>,
111        cx: &mut ModelContext<Self>,
112    ) -> Self {
113        // TODO: Don't eagerly download the LSP
114        cx.spawn(|this, mut cx| async move {
115            let start_language_server = async {
116                let server_path = get_lsp_binary(http).await?;
117                let server =
118                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
119                let server = server.initialize(Default::default()).await?;
120                let status = server
121                    .request::<request::CheckStatus>(request::CheckStatusParams {
122                        local_checks_only: false,
123                    })
124                    .await?;
125                anyhow::Ok((server, status))
126            };
127
128            let server = start_language_server.await;
129            this.update(&mut cx, |this, cx| {
130                cx.notify();
131                match server {
132                    Ok((server, status)) => {
133                        this.server = CopilotServer::Started {
134                            server,
135                            status: SignInStatus::SignedOut,
136                        };
137                        this.update_sign_in_status(status, cx);
138                    }
139                    Err(error) => {
140                        this.server = CopilotServer::Error(error.to_string().into());
141                    }
142                }
143            })
144        })
145        .detach();
146
147        Self {
148            server: CopilotServer::Downloading,
149        }
150    }
151
152    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
153        if let CopilotServer::Started { server, status } = &mut self.server {
154            let task = match status {
155                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
156                    Task::ready(Ok(())).shared()
157                }
158                SignInStatus::SigningIn { task, .. } => {
159                    cx.notify(); // To re-show the prompt, just in case.
160                    task.clone()
161                }
162                SignInStatus::SignedOut => {
163                    let server = server.clone();
164                    let task = cx
165                        .spawn(|this, mut cx| async move {
166                            let sign_in = async {
167                                let sign_in = server
168                                    .request::<request::SignInInitiate>(
169                                        request::SignInInitiateParams {},
170                                    )
171                                    .await?;
172                                match sign_in {
173                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
174                                        Ok(request::SignInStatus::Ok { user })
175                                    }
176                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
177                                        this.update(&mut cx, |this, cx| {
178                                            if let CopilotServer::Started { status, .. } =
179                                                &mut this.server
180                                            {
181                                                if let SignInStatus::SigningIn {
182                                                    prompt: prompt_flow,
183                                                    ..
184                                                } = status
185                                                {
186                                                    *prompt_flow = Some(flow.clone());
187                                                    cx.notify();
188                                                }
189                                            }
190                                        });
191                                        let response = server
192                                            .request::<request::SignInConfirm>(
193                                                request::SignInConfirmParams {
194                                                    user_code: flow.user_code,
195                                                },
196                                            )
197                                            .await?;
198                                        Ok(response)
199                                    }
200                                }
201                            };
202
203                            let sign_in = sign_in.await;
204                            this.update(&mut cx, |this, cx| match sign_in {
205                                Ok(status) => {
206                                    this.update_sign_in_status(status, cx);
207                                    Ok(())
208                                }
209                                Err(error) => {
210                                    this.update_sign_in_status(
211                                        request::SignInStatus::NotSignedIn,
212                                        cx,
213                                    );
214                                    Err(Arc::new(error))
215                                }
216                            })
217                        })
218                        .shared();
219                    *status = SignInStatus::SigningIn {
220                        prompt: None,
221                        task: task.clone(),
222                    };
223                    cx.notify();
224                    task
225                }
226            };
227
228            cx.foreground()
229                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
230        } else {
231            Task::ready(Err(anyhow!("copilot hasn't started yet")))
232        }
233    }
234
235    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
236        if let CopilotServer::Started { server, status } = &mut self.server {
237            *status = SignInStatus::SignedOut;
238            cx.notify();
239
240            let server = server.clone();
241            cx.background().spawn(async move {
242                server
243                    .request::<request::SignOut>(request::SignOutParams {})
244                    .await?;
245                anyhow::Ok(())
246            })
247        } else {
248            Task::ready(Err(anyhow!("copilot hasn't started yet")))
249        }
250    }
251
252    pub fn completion<T>(
253        &self,
254        buffer: &ModelHandle<Buffer>,
255        position: T,
256        cx: &mut ModelContext<Self>,
257    ) -> Task<Result<Option<Completion>>>
258    where
259        T: ToPointUtf16,
260    {
261        let server = match self.authorized_server() {
262            Ok(server) => server,
263            Err(error) => return Task::ready(Err(error)),
264        };
265
266        let buffer = buffer.read(cx).snapshot();
267        let request = server
268            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
269        cx.background().spawn(async move {
270            let result = request.await?;
271            let completion = result
272                .completions
273                .into_iter()
274                .next()
275                .map(|completion| completion_from_lsp(completion, &buffer));
276            anyhow::Ok(completion)
277        })
278    }
279
280    pub fn completions_cycling<T>(
281        &self,
282        buffer: &ModelHandle<Buffer>,
283        position: T,
284        cx: &mut ModelContext<Self>,
285    ) -> Task<Result<Vec<Completion>>>
286    where
287        T: ToPointUtf16,
288    {
289        let server = match self.authorized_server() {
290            Ok(server) => server,
291            Err(error) => return Task::ready(Err(error)),
292        };
293
294        let buffer = buffer.read(cx).snapshot();
295        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
296            &buffer, position, cx,
297        ));
298        cx.background().spawn(async move {
299            let result = request.await?;
300            let completions = result
301                .completions
302                .into_iter()
303                .map(|completion| completion_from_lsp(completion, &buffer))
304                .collect();
305            anyhow::Ok(completions)
306        })
307    }
308
309    pub fn status(&self) -> Status {
310        match &self.server {
311            CopilotServer::Downloading => Status::Downloading,
312            CopilotServer::Error(error) => Status::Error(error.clone()),
313            CopilotServer::Started { status, .. } => match status {
314                SignInStatus::Authorized { .. } => Status::Authorized,
315                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
316                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
317                    prompt: prompt.clone(),
318                },
319                SignInStatus::SignedOut => Status::SignedOut,
320            },
321        }
322    }
323
324    fn update_sign_in_status(
325        &mut self,
326        lsp_status: request::SignInStatus,
327        cx: &mut ModelContext<Self>,
328    ) {
329        if let CopilotServer::Started { status, .. } = &mut self.server {
330            *status = match lsp_status {
331                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
332                    SignInStatus::Authorized { _user: user }
333                }
334                request::SignInStatus::NotAuthorized { user } => {
335                    SignInStatus::Unauthorized { _user: user }
336                }
337                _ => SignInStatus::SignedOut,
338            };
339            cx.notify();
340        }
341    }
342
343    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
344        match &self.server {
345            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
346            CopilotServer::Error(error) => Err(anyhow!(
347                "copilot was not started because of an error: {}",
348                error
349            )),
350            CopilotServer::Started { server, status } => {
351                if matches!(status, SignInStatus::Authorized { .. }) {
352                    Ok(server.clone())
353                } else {
354                    Err(anyhow!("must sign in before using copilot"))
355                }
356            }
357        }
358    }
359}
360
361fn build_completion_params<T>(
362    buffer: &BufferSnapshot,
363    position: T,
364    cx: &AppContext,
365) -> request::GetCompletionsParams
366where
367    T: ToPointUtf16,
368{
369    let position = position.to_point_utf16(&buffer);
370    let language_name = buffer.language_at(position).map(|language| language.name());
371    let language_name = language_name.as_deref();
372
373    let path;
374    let relative_path;
375    if let Some(file) = buffer.file() {
376        if let Some(file) = file.as_local() {
377            path = file.abs_path(cx);
378        } else {
379            path = file.full_path(cx);
380        }
381        relative_path = file.path().to_path_buf();
382    } else {
383        path = PathBuf::from("/untitled");
384        relative_path = PathBuf::from("untitled");
385    }
386
387    let settings = cx.global::<Settings>();
388    let language_id = match language_name {
389        Some("Plain Text") => "plaintext".to_string(),
390        Some(language_name) => language_name.to_lowercase(),
391        None => "plaintext".to_string(),
392    };
393    request::GetCompletionsParams {
394        doc: request::GetCompletionsDocument {
395            source: buffer.text(),
396            tab_size: settings.tab_size(language_name).into(),
397            indent_size: 1,
398            insert_spaces: !settings.hard_tabs(language_name),
399            uri: lsp::Url::from_file_path(&path).unwrap(),
400            path: path.to_string_lossy().into(),
401            relative_path: relative_path.to_string_lossy().into(),
402            language_id,
403            position: point_to_lsp(position),
404            version: 0,
405        },
406    }
407}
408
409fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
410    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
411    Completion {
412        position: buffer.anchor_before(position),
413        text: completion.display_text,
414    }
415}
416
417async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
418    ///Check for the latest copilot language server and download it if we haven't already
419    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
420        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
421        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
422        let asset = release
423            .assets
424            .iter()
425            .find(|asset| asset.name == asset_name)
426            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
427
428        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
429        let destination_path =
430            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
431
432        if fs::metadata(&destination_path).await.is_err() {
433            let mut response = http
434                .get(&asset.browser_download_url, Default::default(), true)
435                .await
436                .map_err(|err| anyhow!("error downloading release: {}", err))?;
437            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
438            let mut file = fs::File::create(&destination_path).await?;
439            futures::io::copy(decompressed_bytes, &mut file).await?;
440            fs::set_permissions(
441                &destination_path,
442                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
443            )
444            .await?;
445
446            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
447        }
448
449        Ok(destination_path)
450    }
451
452    match fetch_latest(http).await {
453        ok @ Result::Ok(..) => ok,
454        e @ Err(..) => {
455            e.log_err();
456            // Fetch a cached binary, if it exists
457            (|| async move {
458                let mut last = None;
459                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
460                while let Some(entry) = entries.next().await {
461                    last = Some(entry?.path());
462                }
463                last.ok_or_else(|| anyhow!("no cached binary"))
464            })()
465            .await
466        }
467    }
468}