copilot.rs

  1mod request;
  2
  3use anyhow::{anyhow, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use client::Client;
  6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  7use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
  8use lsp::LanguageServer;
  9use settings::Settings;
 10use smol::{fs, io::BufReader, stream::StreamExt};
 11use std::{
 12    env::consts,
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15};
 16use util::{
 17    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 18};
 19
 20actions!(copilot, [SignIn, SignOut]);
 21
 22pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 23    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 24    cx.set_global(copilot);
 25    cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
 26        if let Some(copilot) = Copilot::global(cx) {
 27            copilot
 28                .update(cx, |copilot, cx| copilot.sign_in(cx))
 29                .detach_and_log_err(cx);
 30        }
 31    });
 32    cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
 33        if let Some(copilot) = Copilot::global(cx) {
 34            copilot
 35                .update(cx, |copilot, cx| copilot.sign_out(cx))
 36                .detach_and_log_err(cx);
 37        }
 38    });
 39}
 40
 41enum CopilotServer {
 42    Downloading,
 43    Error(Arc<str>),
 44    Started {
 45        server: Arc<LanguageServer>,
 46        status: SignInStatus,
 47    },
 48}
 49
 50#[derive(Clone, Debug, PartialEq, Eq)]
 51enum SignInStatus {
 52    Authorized { user: String },
 53    Unauthorized { user: String },
 54    SignedOut,
 55}
 56
 57#[derive(Debug)]
 58pub enum Event {
 59    PromptUserDeviceFlow {
 60        user_code: String,
 61        verification_uri: String,
 62    },
 63}
 64
 65#[derive(Debug)]
 66pub enum Status {
 67    Downloading,
 68    Error(Arc<str>),
 69    SignedOut,
 70    Unauthorized,
 71    Authorized,
 72}
 73
 74impl Status {
 75    fn is_authorized(&self) -> bool {
 76        matches!(self, Status::Authorized)
 77    }
 78}
 79
 80#[derive(Debug)]
 81pub struct Completion {
 82    pub position: Anchor,
 83    pub text: String,
 84}
 85
 86struct Copilot {
 87    server: CopilotServer,
 88}
 89
 90impl Entity for Copilot {
 91    type Event = Event;
 92}
 93
 94impl Copilot {
 95    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 96        if cx.has_global::<ModelHandle<Self>>() {
 97            let copilot = cx.global::<ModelHandle<Self>>().clone();
 98            if copilot.read(cx).status().is_authorized() {
 99                Some(copilot)
100            } else {
101                None
102            }
103        } else {
104            None
105        }
106    }
107
108    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
109        cx.spawn(|this, mut cx| async move {
110            let start_language_server = async {
111                let server_path = get_lsp_binary(http).await?;
112                let server =
113                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
114                let server = server.initialize(Default::default()).await?;
115                let status = server
116                    .request::<request::CheckStatus>(request::CheckStatusParams {
117                        local_checks_only: false,
118                    })
119                    .await?;
120                anyhow::Ok((server, status))
121            };
122
123            let server = start_language_server.await;
124            this.update(&mut cx, |this, cx| {
125                cx.notify();
126                match server {
127                    Ok((server, status)) => {
128                        this.server = CopilotServer::Started {
129                            server,
130                            status: SignInStatus::SignedOut,
131                        };
132                        this.update_sign_in_status(status, cx);
133                    }
134                    Err(error) => {
135                        this.server = CopilotServer::Error(error.to_string().into());
136                    }
137                }
138            })
139        })
140        .detach();
141        Self {
142            server: CopilotServer::Downloading,
143        }
144    }
145
146    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
147        if let CopilotServer::Started { server, .. } = &self.server {
148            let server = server.clone();
149            cx.spawn(|this, mut cx| async move {
150                let sign_in = server
151                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
152                    .await?;
153                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
154                    this.update(&mut cx, |_, cx| {
155                        cx.emit(Event::PromptUserDeviceFlow {
156                            user_code: flow.user_code.clone(),
157                            verification_uri: flow.verification_uri,
158                        });
159                    });
160                    let response = server
161                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
162                            user_code: flow.user_code,
163                        })
164                        .await?;
165                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
166                }
167                anyhow::Ok(())
168            })
169        } else {
170            Task::ready(Err(anyhow!("copilot hasn't started yet")))
171        }
172    }
173
174    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
175        if let CopilotServer::Started { server, .. } = &self.server {
176            let server = server.clone();
177            cx.spawn(|this, mut cx| async move {
178                server
179                    .request::<request::SignOut>(request::SignOutParams {})
180                    .await?;
181                this.update(&mut cx, |this, cx| {
182                    if let CopilotServer::Started { status, .. } = &mut this.server {
183                        *status = SignInStatus::SignedOut;
184                        cx.notify();
185                    }
186                });
187
188                anyhow::Ok(())
189            })
190        } else {
191            Task::ready(Err(anyhow!("copilot hasn't started yet")))
192        }
193    }
194
195    pub fn completion<T>(
196        &self,
197        buffer: &ModelHandle<Buffer>,
198        position: T,
199        cx: &mut ModelContext<Self>,
200    ) -> Task<Result<Option<Completion>>>
201    where
202        T: ToPointUtf16,
203    {
204        let server = match self.authenticated_server() {
205            Ok(server) => server,
206            Err(error) => return Task::ready(Err(error)),
207        };
208
209        let buffer = buffer.read(cx).snapshot();
210        let request = server
211            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
212        cx.background().spawn(async move {
213            let result = request.await?;
214            let completion = result
215                .completions
216                .into_iter()
217                .next()
218                .map(|completion| completion_from_lsp(completion, &buffer));
219            anyhow::Ok(completion)
220        })
221    }
222
223    pub fn completions_cycling<T>(
224        &self,
225        buffer: &ModelHandle<Buffer>,
226        position: T,
227        cx: &mut ModelContext<Self>,
228    ) -> Task<Result<Vec<Completion>>>
229    where
230        T: ToPointUtf16,
231    {
232        let server = match self.authenticated_server() {
233            Ok(server) => server,
234            Err(error) => return Task::ready(Err(error)),
235        };
236
237        let buffer = buffer.read(cx).snapshot();
238        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
239            &buffer, position, cx,
240        ));
241        cx.background().spawn(async move {
242            let result = request.await?;
243            let completions = result
244                .completions
245                .into_iter()
246                .map(|completion| completion_from_lsp(completion, &buffer))
247                .collect();
248            anyhow::Ok(completions)
249        })
250    }
251
252    pub fn status(&self) -> Status {
253        match &self.server {
254            CopilotServer::Downloading => Status::Downloading,
255            CopilotServer::Error(error) => Status::Error(error.clone()),
256            CopilotServer::Started { status, .. } => match status {
257                SignInStatus::Authorized { .. } => Status::Authorized,
258                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
259                SignInStatus::SignedOut => Status::SignedOut,
260            },
261        }
262    }
263
264    fn update_sign_in_status(
265        &mut self,
266        lsp_status: request::SignInStatus,
267        cx: &mut ModelContext<Self>,
268    ) {
269        if let CopilotServer::Started { status, .. } = &mut self.server {
270            *status = match lsp_status {
271                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
272                    SignInStatus::Authorized { user }
273                }
274                request::SignInStatus::NotAuthorized { user } => {
275                    SignInStatus::Unauthorized { user }
276                }
277                _ => SignInStatus::SignedOut,
278            };
279            cx.notify();
280        }
281    }
282
283    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
284        match &self.server {
285            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
286            CopilotServer::Error(error) => Err(anyhow!(
287                "copilot was not started because of an error: {}",
288                error
289            )),
290            CopilotServer::Started { server, status } => {
291                if matches!(status, SignInStatus::Authorized { .. }) {
292                    Ok(server.clone())
293                } else {
294                    Err(anyhow!("must sign in before using copilot"))
295                }
296            }
297        }
298    }
299}
300
301fn build_completion_params<T>(
302    buffer: &BufferSnapshot,
303    position: T,
304    cx: &AppContext,
305) -> request::GetCompletionsParams
306where
307    T: ToPointUtf16,
308{
309    let position = position.to_point_utf16(&buffer);
310    let language_name = buffer.language_at(position).map(|language| language.name());
311    let language_name = language_name.as_deref();
312
313    let path;
314    let relative_path;
315    if let Some(file) = buffer.file() {
316        if let Some(file) = file.as_local() {
317            path = file.abs_path(cx);
318        } else {
319            path = file.full_path(cx);
320        }
321        relative_path = file.path().to_path_buf();
322    } else {
323        path = PathBuf::from("/untitled");
324        relative_path = PathBuf::from("untitled");
325    }
326
327    let settings = cx.global::<Settings>();
328    let language_id = match language_name {
329        Some("Plain Text") => "plaintext".to_string(),
330        Some(language_name) => language_name.to_lowercase(),
331        None => "plaintext".to_string(),
332    };
333    request::GetCompletionsParams {
334        doc: request::GetCompletionsDocument {
335            source: buffer.text(),
336            tab_size: settings.tab_size(language_name).into(),
337            indent_size: 1,
338            insert_spaces: !settings.hard_tabs(language_name),
339            uri: lsp::Url::from_file_path(&path).unwrap(),
340            path: path.to_string_lossy().into(),
341            relative_path: relative_path.to_string_lossy().into(),
342            language_id,
343            position: point_to_lsp(position),
344            version: 0,
345        },
346    }
347}
348
349fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
350    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
351    Completion {
352        position: buffer.anchor_before(position),
353        text: completion.display_text,
354    }
355}
356
357async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
358    ///Check for the latest copilot language server and download it if we haven't already
359    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
360        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
361        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
362        let asset = release
363            .assets
364            .iter()
365            .find(|asset| asset.name == asset_name)
366            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
367
368        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
369        let destination_path =
370            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
371
372        if fs::metadata(&destination_path).await.is_err() {
373            let mut response = http
374                .get(&asset.browser_download_url, Default::default(), true)
375                .await
376                .map_err(|err| anyhow!("error downloading release: {}", err))?;
377            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
378            let mut file = fs::File::create(&destination_path).await?;
379            futures::io::copy(decompressed_bytes, &mut file).await?;
380            fs::set_permissions(
381                &destination_path,
382                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
383            )
384            .await?;
385
386            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
387        }
388
389        Ok(destination_path)
390    }
391
392    match fetch_latest(http).await {
393        ok @ Result::Ok(..) => ok,
394        e @ Err(..) => {
395            e.log_err();
396            // Fetch a cached binary, if it exists
397            (|| async move {
398                let mut last = None;
399                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
400                while let Some(entry) = entries.next().await {
401                    last = Some(entry?.path());
402                }
403                last.ok_or_else(|| anyhow!("no cached binary"))
404            })()
405            .await
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use gpui::TestAppContext;
414    use util::http;
415
416    #[gpui::test]
417    async fn test_smoke(cx: &mut TestAppContext) {
418        Settings::test_async(cx);
419        let http = http::client();
420        let copilot = cx.add_model(|cx| Copilot::start(http, cx));
421        smol::Timer::after(std::time::Duration::from_secs(2)).await;
422        copilot
423            .update(cx, |copilot, cx| copilot.sign_in(cx))
424            .await
425            .unwrap();
426        dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
427
428        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
429        dbg!(copilot
430            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
431            .await
432            .unwrap());
433        dbg!(copilot
434            .update(cx, |copilot, cx| copilot
435                .completions_cycling(&buffer, 12, cx))
436            .await
437            .unwrap());
438    }
439}