copilot.rs

  1mod request;
  2
  3use anyhow::{anyhow, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use client::Client;
  6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  7use language::{point_to_lsp, Buffer, ToPointUtf16};
  8use lsp::LanguageServer;
  9use settings::Settings;
 10use smol::{fs, io::BufReader, stream::StreamExt};
 11use std::{
 12    env::consts,
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15};
 16use util::{
 17    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 18};
 19
 20actions!(copilot, [SignIn, SignOut]);
 21
 22pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 23    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 24    cx.set_global(copilot);
 25    cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
 26        if let Some(copilot) = Copilot::global(cx) {
 27            copilot
 28                .update(cx, |copilot, cx| copilot.sign_in(cx))
 29                .detach_and_log_err(cx);
 30        }
 31    });
 32    cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
 33        if let Some(copilot) = Copilot::global(cx) {
 34            copilot
 35                .update(cx, |copilot, cx| copilot.sign_out(cx))
 36                .detach_and_log_err(cx);
 37        }
 38    });
 39}
 40
 41enum CopilotServer {
 42    Downloading,
 43    Error(Arc<str>),
 44    Started {
 45        server: Arc<LanguageServer>,
 46        status: SignInStatus,
 47    },
 48}
 49
 50#[derive(Clone, Debug, PartialEq, Eq)]
 51enum SignInStatus {
 52    Authorized { user: String },
 53    Unauthorized { user: String },
 54    SignedOut,
 55}
 56
 57#[derive(Debug)]
 58pub enum Event {
 59    PromptUserDeviceFlow {
 60        user_code: String,
 61        verification_uri: String,
 62    },
 63}
 64
 65#[derive(Debug)]
 66pub enum Status {
 67    Downloading,
 68    Error(Arc<str>),
 69    SignedOut,
 70    Unauthorized,
 71    Authorized,
 72}
 73
 74impl Status {
 75    fn is_authorized(&self) -> bool {
 76        matches!(self, Status::Authorized)
 77    }
 78}
 79
 80struct Copilot {
 81    server: CopilotServer,
 82}
 83
 84impl Entity for Copilot {
 85    type Event = Event;
 86}
 87
 88impl Copilot {
 89    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 90        if cx.has_global::<ModelHandle<Self>>() {
 91            let copilot = cx.global::<ModelHandle<Self>>().clone();
 92            if copilot.read(cx).status().is_authorized() {
 93                Some(copilot)
 94            } else {
 95                None
 96            }
 97        } else {
 98            None
 99        }
100    }
101
102    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
103        cx.spawn(|this, mut cx| async move {
104            let start_language_server = async {
105                let server_path = get_lsp_binary(http).await?;
106                let server =
107                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
108                let server = server.initialize(Default::default()).await?;
109                let status = server
110                    .request::<request::CheckStatus>(request::CheckStatusParams {
111                        local_checks_only: false,
112                    })
113                    .await?;
114                anyhow::Ok((server, status))
115            };
116
117            let server = start_language_server.await;
118            this.update(&mut cx, |this, cx| {
119                cx.notify();
120                match server {
121                    Ok((server, status)) => {
122                        this.server = CopilotServer::Started {
123                            server,
124                            status: SignInStatus::SignedOut,
125                        };
126                        this.update_sign_in_status(status, cx);
127                    }
128                    Err(error) => {
129                        this.server = CopilotServer::Error(error.to_string().into());
130                    }
131                }
132            })
133        })
134        .detach();
135        Self {
136            server: CopilotServer::Downloading,
137        }
138    }
139
140    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
141        if let CopilotServer::Started { server, .. } = &self.server {
142            let server = server.clone();
143            cx.spawn(|this, mut cx| async move {
144                let sign_in = server
145                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
146                    .await?;
147                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
148                    this.update(&mut cx, |_, cx| {
149                        cx.emit(Event::PromptUserDeviceFlow {
150                            user_code: flow.user_code.clone(),
151                            verification_uri: flow.verification_uri,
152                        });
153                    });
154                    let response = server
155                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
156                            user_code: flow.user_code,
157                        })
158                        .await?;
159                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
160                }
161                anyhow::Ok(())
162            })
163        } else {
164            Task::ready(Err(anyhow!("copilot hasn't started yet")))
165        }
166    }
167
168    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
169        if let CopilotServer::Started { server, .. } = &self.server {
170            let server = server.clone();
171            cx.spawn(|this, mut cx| async move {
172                server
173                    .request::<request::SignOut>(request::SignOutParams {})
174                    .await?;
175                this.update(&mut cx, |this, cx| {
176                    if let CopilotServer::Started { status, .. } = &mut this.server {
177                        *status = SignInStatus::SignedOut;
178                        cx.notify();
179                    }
180                });
181
182                anyhow::Ok(())
183            })
184        } else {
185            Task::ready(Err(anyhow!("copilot hasn't started yet")))
186        }
187    }
188
189    pub fn completions<T>(
190        &self,
191        buffer: &ModelHandle<Buffer>,
192        position: T,
193        cx: &mut ModelContext<Self>,
194    ) -> Task<Result<()>>
195    where
196        T: ToPointUtf16,
197    {
198        let server = match self.authenticated_server() {
199            Ok(server) => server,
200            Err(error) => return Task::ready(Err(error)),
201        };
202
203        let buffer = buffer.read(cx).snapshot();
204        let position = position.to_point_utf16(&buffer);
205        let language_name = buffer.language_at(position).map(|language| language.name());
206        let language_name = language_name.as_deref();
207
208        let path;
209        let relative_path;
210        if let Some(file) = buffer.file() {
211            if let Some(file) = file.as_local() {
212                path = file.abs_path(cx);
213            } else {
214                path = file.full_path(cx);
215            }
216            relative_path = file.path().to_path_buf();
217        } else {
218            path = PathBuf::from("/untitled");
219            relative_path = PathBuf::from("untitled");
220        }
221
222        let settings = cx.global::<Settings>();
223        let request = server.request::<request::GetCompletions>(request::GetCompletionsParams {
224            doc: request::GetCompletionsDocument {
225                source: buffer.text(),
226                tab_size: settings.tab_size(language_name).into(),
227                indent_size: 1,
228                insert_spaces: !settings.hard_tabs(language_name),
229                uri: lsp::Url::from_file_path(&path).unwrap(),
230                path: path.to_string_lossy().into(),
231                relative_path: relative_path.to_string_lossy().into(),
232                language_id: "csharp".into(),
233                position: point_to_lsp(position),
234                version: 0,
235            },
236        });
237        cx.spawn(|this, cx| async move {
238            dbg!(request.await?);
239
240            anyhow::Ok(())
241        })
242    }
243
244    pub fn status(&self) -> Status {
245        match &self.server {
246            CopilotServer::Downloading => Status::Downloading,
247            CopilotServer::Error(error) => Status::Error(error.clone()),
248            CopilotServer::Started { status, .. } => match status {
249                SignInStatus::Authorized { .. } => Status::Authorized,
250                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
251                SignInStatus::SignedOut => Status::SignedOut,
252            },
253        }
254    }
255
256    fn update_sign_in_status(
257        &mut self,
258        lsp_status: request::SignInStatus,
259        cx: &mut ModelContext<Self>,
260    ) {
261        if let CopilotServer::Started { status, .. } = &mut self.server {
262            *status = match lsp_status {
263                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
264                    SignInStatus::Authorized { user }
265                }
266                request::SignInStatus::NotAuthorized { user } => {
267                    SignInStatus::Unauthorized { user }
268                }
269                _ => SignInStatus::SignedOut,
270            };
271            cx.notify();
272        }
273    }
274
275    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
276        match &self.server {
277            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
278            CopilotServer::Error(error) => Err(anyhow!(
279                "copilot was not started because of an error: {}",
280                error
281            )),
282            CopilotServer::Started { server, status } => {
283                if matches!(status, SignInStatus::Authorized { .. }) {
284                    Ok(server.clone())
285                } else {
286                    Err(anyhow!("must sign in before using copilot"))
287                }
288            }
289        }
290    }
291}
292
293async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
294    ///Check for the latest copilot language server and download it if we haven't already
295    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
296        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
297        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
298        let asset = release
299            .assets
300            .iter()
301            .find(|asset| asset.name == asset_name)
302            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
303
304        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
305        let destination_path =
306            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
307
308        if fs::metadata(&destination_path).await.is_err() {
309            let mut response = http
310                .get(&asset.browser_download_url, Default::default(), true)
311                .await
312                .map_err(|err| anyhow!("error downloading release: {}", err))?;
313            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
314            let mut file = fs::File::create(&destination_path).await?;
315            futures::io::copy(decompressed_bytes, &mut file).await?;
316            fs::set_permissions(
317                &destination_path,
318                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
319            )
320            .await?;
321
322            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
323        }
324
325        Ok(destination_path)
326    }
327
328    match fetch_latest(http).await {
329        ok @ Result::Ok(..) => ok,
330        e @ Err(..) => {
331            e.log_err();
332            // Fetch a cached binary, if it exists
333            (|| async move {
334                let mut last = None;
335                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
336                while let Some(entry) = entries.next().await {
337                    last = Some(entry?.path());
338                }
339                last.ok_or_else(|| anyhow!("no cached binary"))
340            })()
341            .await
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use gpui::TestAppContext;
350    use util::http;
351
352    #[gpui::test]
353    async fn test_smoke(cx: &mut TestAppContext) {
354        Settings::test_async(cx);
355        let http = http::client();
356        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
357        smol::Timer::after(std::time::Duration::from_secs(5)).await;
358        copilot
359            .update(cx, |copilot, cx| copilot.sign_in(cx))
360            .await
361            .unwrap();
362        dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
363
364        let buffer = cx.add_model(|cx| language::Buffer::new(0, "Lorem ipsum dol", cx));
365        copilot
366            .update(cx, |copilot, cx| copilot.completions(&buffer, 15, cx))
367            .await
368            .unwrap();
369    }
370}