copilot.rs

  1mod request;
  2
  3use anyhow::{anyhow, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use client::Client;
  6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
  7use language::{Buffer, ToPointUtf16};
  8use lsp::LanguageServer;
  9use smol::{fs, io::BufReader, stream::StreamExt};
 10use std::{
 11    env::consts,
 12    path::{Path, PathBuf},
 13    sync::Arc,
 14};
 15use util::{
 16    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 17};
 18
 19actions!(copilot, [SignIn, SignOut]);
 20
 21pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 22    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 23    cx.set_global(copilot);
 24    cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
 25        if let Some(copilot) = Copilot::global(cx) {
 26            copilot
 27                .update(cx, |copilot, cx| copilot.sign_in(cx))
 28                .detach_and_log_err(cx);
 29        }
 30    });
 31    cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
 32        if let Some(copilot) = Copilot::global(cx) {
 33            copilot
 34                .update(cx, |copilot, cx| copilot.sign_out(cx))
 35                .detach_and_log_err(cx);
 36        }
 37    });
 38}
 39
 40enum CopilotServer {
 41    Downloading,
 42    Error(Arc<str>),
 43    Started {
 44        server: Arc<LanguageServer>,
 45        status: SignInStatus,
 46    },
 47}
 48
 49#[derive(Clone, Debug, PartialEq, Eq)]
 50enum SignInStatus {
 51    Authorized { user: String },
 52    Unauthorized { user: String },
 53    SignedOut,
 54}
 55
 56pub enum Event {
 57    PromptUserDeviceFlow {
 58        user_code: String,
 59        verification_uri: String,
 60    },
 61}
 62
 63#[derive(Debug)]
 64pub enum Status {
 65    Downloading,
 66    Error(Arc<str>),
 67    SignedOut,
 68    Unauthorized,
 69    Authorized,
 70}
 71
 72impl Status {
 73    fn is_authorized(&self) -> bool {
 74        matches!(self, Status::Authorized)
 75    }
 76}
 77
 78struct Copilot {
 79    server: CopilotServer,
 80}
 81
 82impl Entity for Copilot {
 83    type Event = Event;
 84}
 85
 86impl Copilot {
 87    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 88        if cx.has_global::<ModelHandle<Self>>() {
 89            let copilot = cx.global::<ModelHandle<Self>>().clone();
 90            if copilot.read(cx).status().is_authorized() {
 91                Some(copilot)
 92            } else {
 93                None
 94            }
 95        } else {
 96            None
 97        }
 98    }
 99
100    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
101        cx.spawn(|this, mut cx| async move {
102            let start_language_server = async {
103                let server_path = get_lsp_binary(http).await?;
104                let server =
105                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
106                let server = server.initialize(Default::default()).await?;
107                let status = server
108                    .request::<request::CheckStatus>(request::CheckStatusParams {
109                        local_checks_only: false,
110                    })
111                    .await?;
112                anyhow::Ok((server, status))
113            };
114
115            let server = start_language_server.await;
116            this.update(&mut cx, |this, cx| {
117                cx.notify();
118                match server {
119                    Ok((server, status)) => {
120                        this.server = CopilotServer::Started {
121                            server,
122                            status: SignInStatus::SignedOut,
123                        };
124                        this.update_sign_in_status(status, cx);
125                    }
126                    Err(error) => {
127                        this.server = CopilotServer::Error(error.to_string().into());
128                    }
129                }
130            })
131        })
132        .detach();
133        Self {
134            server: CopilotServer::Downloading,
135        }
136    }
137
138    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
139        if let CopilotServer::Started { server, .. } = &self.server {
140            let server = server.clone();
141            cx.spawn(|this, mut cx| async move {
142                let sign_in = server
143                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
144                    .await?;
145                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
146                    this.update(&mut cx, |_, cx| {
147                        cx.emit(Event::PromptUserDeviceFlow {
148                            user_code: flow.user_code.clone(),
149                            verification_uri: flow.verification_uri,
150                        });
151                    });
152                    let response = server
153                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
154                            user_code: flow.user_code,
155                        })
156                        .await?;
157                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
158                }
159                anyhow::Ok(())
160            })
161        } else {
162            Task::ready(Err(anyhow!("copilot hasn't started yet")))
163        }
164    }
165
166    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
167        if let CopilotServer::Started { server, .. } = &self.server {
168            let server = server.clone();
169            cx.spawn(|this, mut cx| async move {
170                server
171                    .request::<request::SignOut>(request::SignOutParams {})
172                    .await?;
173                this.update(&mut cx, |this, cx| {
174                    if let CopilotServer::Started { status, .. } = &mut this.server {
175                        *status = SignInStatus::SignedOut;
176                        cx.notify();
177                    }
178                });
179
180                anyhow::Ok(())
181            })
182        } else {
183            Task::ready(Err(anyhow!("copilot hasn't started yet")))
184        }
185    }
186
187    pub fn completions<T>(
188        &self,
189        buffer: &ModelHandle<Buffer>,
190        position: T,
191        cx: &mut ModelContext<Self>,
192    ) -> Task<Result<()>>
193    where
194        T: ToPointUtf16,
195    {
196        let server = match self.authenticated_server() {
197            Ok(server) => server,
198            Err(error) => return Task::ready(Err(error)),
199        };
200
201        cx.spawn(|this, cx| async move { anyhow::Ok(()) })
202    }
203
204    pub fn status(&self) -> Status {
205        match &self.server {
206            CopilotServer::Downloading => Status::Downloading,
207            CopilotServer::Error(error) => Status::Error(error.clone()),
208            CopilotServer::Started { status, .. } => match status {
209                SignInStatus::Authorized { .. } => Status::Authorized,
210                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
211                SignInStatus::SignedOut => Status::SignedOut,
212            },
213        }
214    }
215
216    fn update_sign_in_status(
217        &mut self,
218        lsp_status: request::SignInStatus,
219        cx: &mut ModelContext<Self>,
220    ) {
221        if let CopilotServer::Started { status, .. } = &mut self.server {
222            *status = match lsp_status {
223                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
224                    SignInStatus::Authorized { user }
225                }
226                request::SignInStatus::NotAuthorized { user } => {
227                    SignInStatus::Unauthorized { user }
228                }
229                _ => SignInStatus::SignedOut,
230            };
231            cx.notify();
232        }
233    }
234
235    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
236        match &self.server {
237            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
238            CopilotServer::Error(error) => Err(anyhow!(
239                "copilot was not started because of an error: {}",
240                error
241            )),
242            CopilotServer::Started { server, status } => {
243                if matches!(status, SignInStatus::Authorized { .. }) {
244                    Ok(server.clone())
245                } else {
246                    Err(anyhow!("must sign in before using copilot"))
247                }
248            }
249        }
250    }
251}
252
253async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
254    ///Check for the latest copilot language server and download it if we haven't already
255    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
256        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
257        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
258        let asset = release
259            .assets
260            .iter()
261            .find(|asset| asset.name == asset_name)
262            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
263
264        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
265        let destination_path =
266            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
267
268        if fs::metadata(&destination_path).await.is_err() {
269            let mut response = http
270                .get(&asset.browser_download_url, Default::default(), true)
271                .await
272                .map_err(|err| anyhow!("error downloading release: {}", err))?;
273            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
274            let mut file = fs::File::create(&destination_path).await?;
275            futures::io::copy(decompressed_bytes, &mut file).await?;
276            fs::set_permissions(
277                &destination_path,
278                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
279            )
280            .await?;
281
282            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
283        }
284
285        Ok(destination_path)
286    }
287
288    match fetch_latest(http).await {
289        ok @ Result::Ok(..) => ok,
290        e @ Err(..) => {
291            e.log_err();
292            // Fetch a cached binary, if it exists
293            (|| async move {
294                let mut last = None;
295                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
296                while let Some(entry) = entries.next().await {
297                    last = Some(entry?.path());
298                }
299                last.ok_or_else(|| anyhow!("no cached binary"))
300            })()
301            .await
302        }
303    }
304}