copilot.rs

  1mod auth_modal;
  2mod request;
  3
  4use anyhow::{anyhow, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use auth_modal::AuthModal;
  7use client::Client;
  8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 10use lsp::LanguageServer;
 11use settings::Settings;
 12use smol::{fs, io::BufReader, stream::StreamExt};
 13use std::{
 14    env::consts,
 15    path::{Path, PathBuf},
 16    sync::Arc,
 17};
 18use util::{
 19    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 20};
 21use workspace::Workspace;
 22
 23actions!(copilot, [SignIn, SignOut, ToggleAuthStatus]);
 24
 25pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 26    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 27    cx.set_global(copilot.clone());
 28    cx.add_action(|workspace: &mut Workspace, _: &SignIn, cx| {
 29        if let Some(copilot) = Copilot::global(cx) {
 30            if copilot.read(cx).status() == Status::Authorized {
 31                return;
 32            }
 33
 34            copilot
 35                .update(cx, |copilot, cx| copilot.sign_in(cx))
 36                .detach_and_log_err(cx);
 37
 38            workspace.toggle_modal(cx, |_workspace, cx| cx.add_view(|_cx| AuthModal {}));
 39        }
 40    });
 41    cx.add_action(|_: &mut Workspace, _: &SignOut, cx| {
 42        if let Some(copilot) = Copilot::global(cx) {
 43            copilot
 44                .update(cx, |copilot, cx| copilot.sign_out(cx))
 45                .detach_and_log_err(cx);
 46        }
 47    });
 48    cx.add_action(|workspace: &mut Workspace, _: &ToggleAuthStatus, cx| {
 49        workspace.toggle_modal(cx, |_workspace, cx| cx.add_view(|_cx| AuthModal {}))
 50    })
 51}
 52
 53enum CopilotServer {
 54    Downloading,
 55    Error(Arc<str>),
 56    Started {
 57        server: Arc<LanguageServer>,
 58        status: SignInStatus,
 59    },
 60}
 61
 62#[derive(Clone, Debug, PartialEq, Eq)]
 63enum SignInStatus {
 64    Authorized { user: String },
 65    Unauthorized { user: String },
 66    SignedOut,
 67}
 68
 69#[derive(Debug)]
 70pub enum Event {
 71    PromptUserDeviceFlow {
 72        user_code: String,
 73        verification_uri: String,
 74    },
 75}
 76
 77#[derive(Debug, PartialEq, Eq)]
 78pub enum Status {
 79    Downloading,
 80    Error(Arc<str>),
 81    SignedOut,
 82    Unauthorized,
 83    Authorized,
 84}
 85
 86impl Status {
 87    fn is_authorized(&self) -> bool {
 88        matches!(self, Status::Authorized)
 89    }
 90}
 91
 92#[derive(Debug)]
 93pub struct Completion {
 94    pub position: Anchor,
 95    pub text: String,
 96}
 97
 98struct Copilot {
 99    server: CopilotServer,
100}
101
102impl Entity for Copilot {
103    type Event = Event;
104}
105
106impl Copilot {
107    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
108        if cx.has_global::<ModelHandle<Self>>() {
109            let copilot = cx.global::<ModelHandle<Self>>().clone();
110            if copilot.read(cx).status().is_authorized() {
111                Some(copilot)
112            } else {
113                None
114            }
115        } else {
116            None
117        }
118    }
119
120    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
121        cx.spawn(|this, mut cx| async move {
122            let start_language_server = async {
123                let server_path = get_lsp_binary(http).await?;
124                let server =
125                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
126                let server = server.initialize(Default::default()).await?;
127                let status = server
128                    .request::<request::CheckStatus>(request::CheckStatusParams {
129                        local_checks_only: false,
130                    })
131                    .await?;
132                anyhow::Ok((server, status))
133            };
134
135            let server = start_language_server.await;
136            this.update(&mut cx, |this, cx| {
137                cx.notify();
138                match server {
139                    Ok((server, status)) => {
140                        this.server = CopilotServer::Started {
141                            server,
142                            status: SignInStatus::SignedOut,
143                        };
144                        this.update_sign_in_status(status, cx);
145                    }
146                    Err(error) => {
147                        this.server = CopilotServer::Error(error.to_string().into());
148                    }
149                }
150            })
151        })
152        .detach();
153
154        Self {
155            server: CopilotServer::Downloading,
156        }
157    }
158
159    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
160        if let CopilotServer::Started { server, .. } = &self.server {
161            let server = server.clone();
162            cx.spawn(|this, mut cx| async move {
163                let sign_in = server
164                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
165                    .await?;
166                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
167                    this.update(&mut cx, |_, cx| {
168                        cx.emit(Event::PromptUserDeviceFlow {
169                            user_code: flow.user_code.clone(),
170                            verification_uri: flow.verification_uri,
171                        });
172                    });
173                    let response = server
174                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
175                            user_code: flow.user_code,
176                        })
177                        .await?;
178                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
179                }
180                anyhow::Ok(())
181            })
182        } else {
183            Task::ready(Err(anyhow!("copilot hasn't started yet")))
184        }
185    }
186
187    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
188        if let CopilotServer::Started { server, .. } = &self.server {
189            let server = server.clone();
190            cx.spawn(|this, mut cx| async move {
191                server
192                    .request::<request::SignOut>(request::SignOutParams {})
193                    .await?;
194                this.update(&mut cx, |this, cx| {
195                    if let CopilotServer::Started { status, .. } = &mut this.server {
196                        *status = SignInStatus::SignedOut;
197                        cx.notify();
198                    }
199                });
200
201                anyhow::Ok(())
202            })
203        } else {
204            Task::ready(Err(anyhow!("copilot hasn't started yet")))
205        }
206    }
207
208    pub fn completion<T>(
209        &self,
210        buffer: &ModelHandle<Buffer>,
211        position: T,
212        cx: &mut ModelContext<Self>,
213    ) -> Task<Result<Option<Completion>>>
214    where
215        T: ToPointUtf16,
216    {
217        let server = match self.authenticated_server() {
218            Ok(server) => server,
219            Err(error) => return Task::ready(Err(error)),
220        };
221
222        let buffer = buffer.read(cx).snapshot();
223        let request = server
224            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
225        cx.background().spawn(async move {
226            let result = request.await?;
227            let completion = result
228                .completions
229                .into_iter()
230                .next()
231                .map(|completion| completion_from_lsp(completion, &buffer));
232            anyhow::Ok(completion)
233        })
234    }
235
236    pub fn completions_cycling<T>(
237        &self,
238        buffer: &ModelHandle<Buffer>,
239        position: T,
240        cx: &mut ModelContext<Self>,
241    ) -> Task<Result<Vec<Completion>>>
242    where
243        T: ToPointUtf16,
244    {
245        let server = match self.authenticated_server() {
246            Ok(server) => server,
247            Err(error) => return Task::ready(Err(error)),
248        };
249
250        let buffer = buffer.read(cx).snapshot();
251        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
252            &buffer, position, cx,
253        ));
254        cx.background().spawn(async move {
255            let result = request.await?;
256            let completions = result
257                .completions
258                .into_iter()
259                .map(|completion| completion_from_lsp(completion, &buffer))
260                .collect();
261            anyhow::Ok(completions)
262        })
263    }
264
265    pub fn status(&self) -> Status {
266        match &self.server {
267            CopilotServer::Downloading => Status::Downloading,
268            CopilotServer::Error(error) => Status::Error(error.clone()),
269            CopilotServer::Started { status, .. } => match status {
270                SignInStatus::Authorized { .. } => Status::Authorized,
271                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
272                SignInStatus::SignedOut => Status::SignedOut,
273            },
274        }
275    }
276
277    fn update_sign_in_status(
278        &mut self,
279        lsp_status: request::SignInStatus,
280        cx: &mut ModelContext<Self>,
281    ) {
282        if let CopilotServer::Started { status, .. } = &mut self.server {
283            *status = match lsp_status {
284                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
285                    SignInStatus::Authorized { user }
286                }
287                request::SignInStatus::NotAuthorized { user } => {
288                    SignInStatus::Unauthorized { user }
289                }
290                _ => SignInStatus::SignedOut,
291            };
292            cx.notify();
293        }
294    }
295
296    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
297        match &self.server {
298            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
299            CopilotServer::Error(error) => Err(anyhow!(
300                "copilot was not started because of an error: {}",
301                error
302            )),
303            CopilotServer::Started { server, status } => {
304                if matches!(status, SignInStatus::Authorized { .. }) {
305                    Ok(server.clone())
306                } else {
307                    Err(anyhow!("must sign in before using copilot"))
308                }
309            }
310        }
311    }
312}
313
314fn build_completion_params<T>(
315    buffer: &BufferSnapshot,
316    position: T,
317    cx: &AppContext,
318) -> request::GetCompletionsParams
319where
320    T: ToPointUtf16,
321{
322    let position = position.to_point_utf16(&buffer);
323    let language_name = buffer.language_at(position).map(|language| language.name());
324    let language_name = language_name.as_deref();
325
326    let path;
327    let relative_path;
328    if let Some(file) = buffer.file() {
329        if let Some(file) = file.as_local() {
330            path = file.abs_path(cx);
331        } else {
332            path = file.full_path(cx);
333        }
334        relative_path = file.path().to_path_buf();
335    } else {
336        path = PathBuf::from("/untitled");
337        relative_path = PathBuf::from("untitled");
338    }
339
340    let settings = cx.global::<Settings>();
341    let language_id = match language_name {
342        Some("Plain Text") => "plaintext".to_string(),
343        Some(language_name) => language_name.to_lowercase(),
344        None => "plaintext".to_string(),
345    };
346    request::GetCompletionsParams {
347        doc: request::GetCompletionsDocument {
348            source: buffer.text(),
349            tab_size: settings.tab_size(language_name).into(),
350            indent_size: 1,
351            insert_spaces: !settings.hard_tabs(language_name),
352            uri: lsp::Url::from_file_path(&path).unwrap(),
353            path: path.to_string_lossy().into(),
354            relative_path: relative_path.to_string_lossy().into(),
355            language_id,
356            position: point_to_lsp(position),
357            version: 0,
358        },
359    }
360}
361
362fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
363    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
364    Completion {
365        position: buffer.anchor_before(position),
366        text: completion.display_text,
367    }
368}
369
370async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
371    ///Check for the latest copilot language server and download it if we haven't already
372    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
373        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
374        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
375        let asset = release
376            .assets
377            .iter()
378            .find(|asset| asset.name == asset_name)
379            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
380
381        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
382        let destination_path =
383            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
384
385        if fs::metadata(&destination_path).await.is_err() {
386            let mut response = http
387                .get(&asset.browser_download_url, Default::default(), true)
388                .await
389                .map_err(|err| anyhow!("error downloading release: {}", err))?;
390            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
391            let mut file = fs::File::create(&destination_path).await?;
392            futures::io::copy(decompressed_bytes, &mut file).await?;
393            fs::set_permissions(
394                &destination_path,
395                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
396            )
397            .await?;
398
399            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
400        }
401
402        Ok(destination_path)
403    }
404
405    match fetch_latest(http).await {
406        ok @ Result::Ok(..) => ok,
407        e @ Err(..) => {
408            e.log_err();
409            // Fetch a cached binary, if it exists
410            (|| async move {
411                let mut last = None;
412                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
413                while let Some(entry) = entries.next().await {
414                    last = Some(entry?.path());
415                }
416                last.ok_or_else(|| anyhow!("no cached binary"))
417            })()
418            .await
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use gpui::TestAppContext;
427    use util::http;
428
429    #[gpui::test]
430    async fn test_smoke(cx: &mut TestAppContext) {
431        Settings::test_async(cx);
432        let http = http::client();
433        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
434        smol::Timer::after(std::time::Duration::from_secs(2)).await;
435        copilot
436            .update(cx, |copilot, cx| copilot.sign_in(cx))
437            .await
438            .unwrap();
439        dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
440
441        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
442        dbg!(copilot
443            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
444            .await
445            .unwrap());
446        dbg!(copilot
447            .update(cx, |copilot, cx| copilot
448                .completions_cycling(&buffer, 12, cx))
449            .await
450            .unwrap());
451    }
452}