1mod request;
2
3use anyhow::{anyhow, Result};
4use async_compression::futures::bufread::GzipDecoder;
5use client::Client;
6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
7use language::{point_to_lsp, Buffer, ToPointUtf16};
8use lsp::LanguageServer;
9use settings::Settings;
10use smol::{fs, io::BufReader, stream::StreamExt};
11use std::{
12 env::consts,
13 path::{Path, PathBuf},
14 sync::Arc,
15};
16use util::{
17 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
18};
19
20actions!(copilot, [SignIn, SignOut]);
21
22pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
23 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
24 cx.set_global(copilot);
25 cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
26 if let Some(copilot) = Copilot::global(cx) {
27 copilot
28 .update(cx, |copilot, cx| copilot.sign_in(cx))
29 .detach_and_log_err(cx);
30 }
31 });
32 cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
33 if let Some(copilot) = Copilot::global(cx) {
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_out(cx))
36 .detach_and_log_err(cx);
37 }
38 });
39}
40
41enum CopilotServer {
42 Downloading,
43 Error(Arc<str>),
44 Started {
45 server: Arc<LanguageServer>,
46 status: SignInStatus,
47 },
48}
49
50#[derive(Clone, Debug, PartialEq, Eq)]
51enum SignInStatus {
52 Authorized { user: String },
53 Unauthorized { user: String },
54 SignedOut,
55}
56
57#[derive(Debug)]
58pub enum Event {
59 PromptUserDeviceFlow {
60 user_code: String,
61 verification_uri: String,
62 },
63}
64
65#[derive(Debug)]
66pub enum Status {
67 Downloading,
68 Error(Arc<str>),
69 SignedOut,
70 Unauthorized,
71 Authorized,
72}
73
74impl Status {
75 fn is_authorized(&self) -> bool {
76 matches!(self, Status::Authorized)
77 }
78}
79
80struct Copilot {
81 server: CopilotServer,
82}
83
84impl Entity for Copilot {
85 type Event = Event;
86}
87
88impl Copilot {
89 fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
90 if cx.has_global::<ModelHandle<Self>>() {
91 let copilot = cx.global::<ModelHandle<Self>>().clone();
92 if copilot.read(cx).status().is_authorized() {
93 Some(copilot)
94 } else {
95 None
96 }
97 } else {
98 None
99 }
100 }
101
102 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
103 cx.spawn(|this, mut cx| async move {
104 let start_language_server = async {
105 let server_path = get_lsp_binary(http).await?;
106 let server =
107 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
108 let server = server.initialize(Default::default()).await?;
109 let status = server
110 .request::<request::CheckStatus>(request::CheckStatusParams {
111 local_checks_only: false,
112 })
113 .await?;
114 anyhow::Ok((server, status))
115 };
116
117 let server = start_language_server.await;
118 this.update(&mut cx, |this, cx| {
119 cx.notify();
120 match server {
121 Ok((server, status)) => {
122 this.server = CopilotServer::Started {
123 server,
124 status: SignInStatus::SignedOut,
125 };
126 this.update_sign_in_status(status, cx);
127 }
128 Err(error) => {
129 this.server = CopilotServer::Error(error.to_string().into());
130 }
131 }
132 })
133 })
134 .detach();
135 Self {
136 server: CopilotServer::Downloading,
137 }
138 }
139
140 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
141 if let CopilotServer::Started { server, .. } = &self.server {
142 let server = server.clone();
143 cx.spawn(|this, mut cx| async move {
144 let sign_in = server
145 .request::<request::SignInInitiate>(request::SignInInitiateParams {})
146 .await?;
147 if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
148 this.update(&mut cx, |_, cx| {
149 cx.emit(Event::PromptUserDeviceFlow {
150 user_code: flow.user_code.clone(),
151 verification_uri: flow.verification_uri,
152 });
153 });
154 let response = server
155 .request::<request::SignInConfirm>(request::SignInConfirmParams {
156 user_code: flow.user_code,
157 })
158 .await?;
159 this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
160 }
161 anyhow::Ok(())
162 })
163 } else {
164 Task::ready(Err(anyhow!("copilot hasn't started yet")))
165 }
166 }
167
168 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
169 if let CopilotServer::Started { server, .. } = &self.server {
170 let server = server.clone();
171 cx.spawn(|this, mut cx| async move {
172 server
173 .request::<request::SignOut>(request::SignOutParams {})
174 .await?;
175 this.update(&mut cx, |this, cx| {
176 if let CopilotServer::Started { status, .. } = &mut this.server {
177 *status = SignInStatus::SignedOut;
178 cx.notify();
179 }
180 });
181
182 anyhow::Ok(())
183 })
184 } else {
185 Task::ready(Err(anyhow!("copilot hasn't started yet")))
186 }
187 }
188
189 pub fn completions<T>(
190 &self,
191 buffer: &ModelHandle<Buffer>,
192 position: T,
193 cx: &mut ModelContext<Self>,
194 ) -> Task<Result<()>>
195 where
196 T: ToPointUtf16,
197 {
198 let server = match self.authenticated_server() {
199 Ok(server) => server,
200 Err(error) => return Task::ready(Err(error)),
201 };
202
203 let buffer = buffer.read(cx).snapshot();
204 let position = position.to_point_utf16(&buffer);
205 let language_name = buffer.language_at(position).map(|language| language.name());
206 let language_name = language_name.as_deref();
207
208 let path;
209 let relative_path;
210 if let Some(file) = buffer.file() {
211 if let Some(file) = file.as_local() {
212 path = file.abs_path(cx);
213 } else {
214 path = file.full_path(cx);
215 }
216 relative_path = file.path().to_path_buf();
217 } else {
218 path = PathBuf::from("/untitled");
219 relative_path = PathBuf::from("untitled");
220 }
221
222 let settings = cx.global::<Settings>();
223 let request = server.request::<request::GetCompletions>(request::GetCompletionsParams {
224 doc: request::GetCompletionsDocument {
225 source: buffer.text(),
226 tab_size: settings.tab_size(language_name).into(),
227 indent_size: 1,
228 insert_spaces: !settings.hard_tabs(language_name),
229 uri: lsp::Url::from_file_path(&path).unwrap(),
230 path: path.to_string_lossy().into(),
231 relative_path: relative_path.to_string_lossy().into(),
232 language_id: "csharp".into(),
233 position: point_to_lsp(position),
234 version: 0,
235 },
236 });
237 cx.spawn(|this, cx| async move {
238 dbg!(request.await?);
239
240 anyhow::Ok(())
241 })
242 }
243
244 pub fn status(&self) -> Status {
245 match &self.server {
246 CopilotServer::Downloading => Status::Downloading,
247 CopilotServer::Error(error) => Status::Error(error.clone()),
248 CopilotServer::Started { status, .. } => match status {
249 SignInStatus::Authorized { .. } => Status::Authorized,
250 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
251 SignInStatus::SignedOut => Status::SignedOut,
252 },
253 }
254 }
255
256 fn update_sign_in_status(
257 &mut self,
258 lsp_status: request::SignInStatus,
259 cx: &mut ModelContext<Self>,
260 ) {
261 if let CopilotServer::Started { status, .. } = &mut self.server {
262 *status = match lsp_status {
263 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
264 SignInStatus::Authorized { user }
265 }
266 request::SignInStatus::NotAuthorized { user } => {
267 SignInStatus::Unauthorized { user }
268 }
269 _ => SignInStatus::SignedOut,
270 };
271 cx.notify();
272 }
273 }
274
275 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
276 match &self.server {
277 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
278 CopilotServer::Error(error) => Err(anyhow!(
279 "copilot was not started because of an error: {}",
280 error
281 )),
282 CopilotServer::Started { server, status } => {
283 if matches!(status, SignInStatus::Authorized { .. }) {
284 Ok(server.clone())
285 } else {
286 Err(anyhow!("must sign in before using copilot"))
287 }
288 }
289 }
290 }
291}
292
293async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
294 ///Check for the latest copilot language server and download it if we haven't already
295 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
296 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
297 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
298 let asset = release
299 .assets
300 .iter()
301 .find(|asset| asset.name == asset_name)
302 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
303
304 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
305 let destination_path =
306 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
307
308 if fs::metadata(&destination_path).await.is_err() {
309 let mut response = http
310 .get(&asset.browser_download_url, Default::default(), true)
311 .await
312 .map_err(|err| anyhow!("error downloading release: {}", err))?;
313 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
314 let mut file = fs::File::create(&destination_path).await?;
315 futures::io::copy(decompressed_bytes, &mut file).await?;
316 fs::set_permissions(
317 &destination_path,
318 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
319 )
320 .await?;
321
322 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
323 }
324
325 Ok(destination_path)
326 }
327
328 match fetch_latest(http).await {
329 ok @ Result::Ok(..) => ok,
330 e @ Err(..) => {
331 e.log_err();
332 // Fetch a cached binary, if it exists
333 (|| async move {
334 let mut last = None;
335 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
336 while let Some(entry) = entries.next().await {
337 last = Some(entry?.path());
338 }
339 last.ok_or_else(|| anyhow!("no cached binary"))
340 })()
341 .await
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use gpui::TestAppContext;
350 use util::http;
351
352 #[gpui::test]
353 async fn test_smoke(cx: &mut TestAppContext) {
354 Settings::test_async(cx);
355 let http = http::client();
356 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
357 smol::Timer::after(std::time::Duration::from_secs(5)).await;
358 copilot
359 .update(cx, |copilot, cx| copilot.sign_in(cx))
360 .await
361 .unwrap();
362 dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
363
364 let buffer = cx.add_model(|cx| language::Buffer::new(0, "Lorem ipsum dol", cx));
365 copilot
366 .update(cx, |copilot, cx| copilot.completions(&buffer, 15, cx))
367 .await
368 .unwrap();
369 }
370}