1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use client::Client;
7use futures::{future::Shared, FutureExt, TryFutureExt};
8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
10use lsp::LanguageServer;
11use settings::Settings;
12use smol::{fs, io::BufReader, stream::StreamExt};
13use std::{
14 env::consts,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{
19 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
20};
21
22actions!(copilot, [SignIn, SignOut]);
23
24pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
25 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
26 cx.set_global(copilot.clone());
27 cx.add_global_action(|_: &SignIn, cx| {
28 let copilot = Copilot::global(cx);
29 copilot
30 .update(cx, |copilot, cx| copilot.sign_in(cx))
31 .detach_and_log_err(cx);
32 });
33 cx.add_global_action(|_: &SignOut, cx| {
34 let copilot = Copilot::global(cx);
35 copilot
36 .update(cx, |copilot, cx| copilot.sign_out(cx))
37 .detach_and_log_err(cx);
38 });
39 sign_in::init(cx);
40}
41
42enum CopilotServer {
43 Downloading,
44 Error(Arc<str>),
45 Started {
46 server: Arc<LanguageServer>,
47 status: SignInStatus,
48 },
49}
50
51#[derive(Clone, Debug)]
52enum SignInStatus {
53 Authorized {
54 user: String,
55 },
56 Unauthorized {
57 user: String,
58 },
59 SigningIn {
60 prompt: Option<request::PromptUserDeviceFlow>,
61 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
62 },
63 SignedOut,
64}
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum Status {
68 Downloading,
69 Error(Arc<str>),
70 SignedOut,
71 SigningIn {
72 prompt: Option<request::PromptUserDeviceFlow>,
73 },
74 Unauthorized,
75 Authorized,
76}
77
78#[derive(Debug)]
79pub struct Completion {
80 pub position: Anchor,
81 pub text: String,
82}
83
84struct Copilot {
85 server: CopilotServer,
86}
87
88impl Entity for Copilot {
89 type Event = ();
90}
91
92impl Copilot {
93 fn global(cx: &AppContext) -> ModelHandle<Self> {
94 cx.global::<ModelHandle<Self>>().clone()
95 }
96
97 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
98 // TODO: Don't eagerly download the LSP
99 cx.spawn(|this, mut cx| async move {
100 let start_language_server = async {
101 let server_path = get_lsp_binary(http).await?;
102 let server =
103 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
104 let server = server.initialize(Default::default()).await?;
105 let status = server
106 .request::<request::CheckStatus>(request::CheckStatusParams {
107 local_checks_only: false,
108 })
109 .await?;
110 anyhow::Ok((server, status))
111 };
112
113 let server = start_language_server.await;
114 this.update(&mut cx, |this, cx| {
115 cx.notify();
116 match server {
117 Ok((server, status)) => {
118 this.server = CopilotServer::Started {
119 server,
120 status: SignInStatus::SignedOut,
121 };
122 this.update_sign_in_status(status, cx);
123 }
124 Err(error) => {
125 this.server = CopilotServer::Error(error.to_string().into());
126 }
127 }
128 })
129 })
130 .detach();
131
132 Self {
133 server: CopilotServer::Downloading,
134 }
135 }
136
137 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
138 if let CopilotServer::Started { server, status } = &mut self.server {
139 let task = match status {
140 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
141 Task::ready(Ok(())).shared()
142 }
143 SignInStatus::SigningIn { task, .. } => task.clone(),
144 SignInStatus::SignedOut => {
145 let server = server.clone();
146 let task = cx
147 .spawn(|this, mut cx| async move {
148 let sign_in = async {
149 let sign_in = server
150 .request::<request::SignInInitiate>(
151 request::SignInInitiateParams {},
152 )
153 .await?;
154 match sign_in {
155 request::SignInInitiateResult::AlreadySignedIn { user } => {
156 Ok(request::SignInStatus::Ok { user })
157 }
158 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
159 this.update(&mut cx, |this, cx| {
160 if let CopilotServer::Started { status, .. } =
161 &mut this.server
162 {
163 if let SignInStatus::SigningIn {
164 prompt: prompt_flow,
165 ..
166 } = status
167 {
168 *prompt_flow = Some(flow.clone());
169 cx.notify();
170 }
171 }
172 });
173 let response = server
174 .request::<request::SignInConfirm>(
175 request::SignInConfirmParams {
176 user_code: flow.user_code,
177 },
178 )
179 .await?;
180 Ok(response)
181 }
182 }
183 };
184
185 let sign_in = sign_in.await;
186 this.update(&mut cx, |this, cx| match sign_in {
187 Ok(status) => {
188 this.update_sign_in_status(status, cx);
189 Ok(())
190 }
191 Err(error) => {
192 this.update_sign_in_status(
193 request::SignInStatus::NotSignedIn,
194 cx,
195 );
196 Err(Arc::new(error))
197 }
198 })
199 })
200 .shared();
201 *status = SignInStatus::SigningIn {
202 prompt: None,
203 task: task.clone(),
204 };
205 cx.notify();
206 task
207 }
208 };
209
210 cx.foreground()
211 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
212 } else {
213 Task::ready(Err(anyhow!("copilot hasn't started yet")))
214 }
215 }
216
217 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
218 if let CopilotServer::Started { server, status } = &mut self.server {
219 *status = SignInStatus::SignedOut;
220 cx.notify();
221
222 let server = server.clone();
223 cx.background().spawn(async move {
224 server
225 .request::<request::SignOut>(request::SignOutParams {})
226 .await?;
227 anyhow::Ok(())
228 })
229 } else {
230 Task::ready(Err(anyhow!("copilot hasn't started yet")))
231 }
232 }
233
234 pub fn completion<T>(
235 &self,
236 buffer: &ModelHandle<Buffer>,
237 position: T,
238 cx: &mut ModelContext<Self>,
239 ) -> Task<Result<Option<Completion>>>
240 where
241 T: ToPointUtf16,
242 {
243 let server = match self.authenticated_server() {
244 Ok(server) => server,
245 Err(error) => return Task::ready(Err(error)),
246 };
247
248 let buffer = buffer.read(cx).snapshot();
249 let request = server
250 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
251 cx.background().spawn(async move {
252 let result = request.await?;
253 let completion = result
254 .completions
255 .into_iter()
256 .next()
257 .map(|completion| completion_from_lsp(completion, &buffer));
258 anyhow::Ok(completion)
259 })
260 }
261
262 pub fn completions_cycling<T>(
263 &self,
264 buffer: &ModelHandle<Buffer>,
265 position: T,
266 cx: &mut ModelContext<Self>,
267 ) -> Task<Result<Vec<Completion>>>
268 where
269 T: ToPointUtf16,
270 {
271 let server = match self.authenticated_server() {
272 Ok(server) => server,
273 Err(error) => return Task::ready(Err(error)),
274 };
275
276 let buffer = buffer.read(cx).snapshot();
277 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
278 &buffer, position, cx,
279 ));
280 cx.background().spawn(async move {
281 let result = request.await?;
282 let completions = result
283 .completions
284 .into_iter()
285 .map(|completion| completion_from_lsp(completion, &buffer))
286 .collect();
287 anyhow::Ok(completions)
288 })
289 }
290
291 pub fn status(&self) -> Status {
292 match &self.server {
293 CopilotServer::Downloading => Status::Downloading,
294 CopilotServer::Error(error) => Status::Error(error.clone()),
295 CopilotServer::Started { status, .. } => match status {
296 SignInStatus::Authorized { .. } => Status::Authorized,
297 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
298 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
299 prompt: prompt.clone(),
300 },
301 SignInStatus::SignedOut => Status::SignedOut,
302 },
303 }
304 }
305
306 fn update_sign_in_status(
307 &mut self,
308 lsp_status: request::SignInStatus,
309 cx: &mut ModelContext<Self>,
310 ) {
311 if let CopilotServer::Started { status, .. } = &mut self.server {
312 *status = match lsp_status {
313 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
314 SignInStatus::Authorized { user }
315 }
316 request::SignInStatus::NotAuthorized { user } => {
317 SignInStatus::Unauthorized { user }
318 }
319 _ => SignInStatus::SignedOut,
320 };
321 cx.notify();
322 }
323 }
324
325 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
326 match &self.server {
327 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
328 CopilotServer::Error(error) => Err(anyhow!(
329 "copilot was not started because of an error: {}",
330 error
331 )),
332 CopilotServer::Started { server, status } => {
333 if matches!(status, SignInStatus::Authorized { .. }) {
334 Ok(server.clone())
335 } else {
336 Err(anyhow!("must sign in before using copilot"))
337 }
338 }
339 }
340 }
341}
342
343fn build_completion_params<T>(
344 buffer: &BufferSnapshot,
345 position: T,
346 cx: &AppContext,
347) -> request::GetCompletionsParams
348where
349 T: ToPointUtf16,
350{
351 let position = position.to_point_utf16(&buffer);
352 let language_name = buffer.language_at(position).map(|language| language.name());
353 let language_name = language_name.as_deref();
354
355 let path;
356 let relative_path;
357 if let Some(file) = buffer.file() {
358 if let Some(file) = file.as_local() {
359 path = file.abs_path(cx);
360 } else {
361 path = file.full_path(cx);
362 }
363 relative_path = file.path().to_path_buf();
364 } else {
365 path = PathBuf::from("/untitled");
366 relative_path = PathBuf::from("untitled");
367 }
368
369 let settings = cx.global::<Settings>();
370 let language_id = match language_name {
371 Some("Plain Text") => "plaintext".to_string(),
372 Some(language_name) => language_name.to_lowercase(),
373 None => "plaintext".to_string(),
374 };
375 request::GetCompletionsParams {
376 doc: request::GetCompletionsDocument {
377 source: buffer.text(),
378 tab_size: settings.tab_size(language_name).into(),
379 indent_size: 1,
380 insert_spaces: !settings.hard_tabs(language_name),
381 uri: lsp::Url::from_file_path(&path).unwrap(),
382 path: path.to_string_lossy().into(),
383 relative_path: relative_path.to_string_lossy().into(),
384 language_id,
385 position: point_to_lsp(position),
386 version: 0,
387 },
388 }
389}
390
391fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
392 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
393 Completion {
394 position: buffer.anchor_before(position),
395 text: completion.display_text,
396 }
397}
398
399async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
400 ///Check for the latest copilot language server and download it if we haven't already
401 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
402 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
403 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
404 let asset = release
405 .assets
406 .iter()
407 .find(|asset| asset.name == asset_name)
408 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
409
410 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
411 let destination_path =
412 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
413
414 if fs::metadata(&destination_path).await.is_err() {
415 let mut response = http
416 .get(&asset.browser_download_url, Default::default(), true)
417 .await
418 .map_err(|err| anyhow!("error downloading release: {}", err))?;
419 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
420 let mut file = fs::File::create(&destination_path).await?;
421 futures::io::copy(decompressed_bytes, &mut file).await?;
422 fs::set_permissions(
423 &destination_path,
424 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
425 )
426 .await?;
427
428 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
429 }
430
431 Ok(destination_path)
432 }
433
434 match fetch_latest(http).await {
435 ok @ Result::Ok(..) => ok,
436 e @ Err(..) => {
437 e.log_err();
438 // Fetch a cached binary, if it exists
439 (|| async move {
440 let mut last = None;
441 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
442 while let Some(entry) = entries.next().await {
443 last = Some(entry?.path());
444 }
445 last.ok_or_else(|| anyhow!("no cached binary"))
446 })()
447 .await
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use gpui::TestAppContext;
456 use util::http;
457
458 #[gpui::test]
459 async fn test_smoke(cx: &mut TestAppContext) {
460 Settings::test_async(cx);
461 let http = http::client();
462 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
463 smol::Timer::after(std::time::Duration::from_secs(2)).await;
464 copilot
465 .update(cx, |copilot, cx| copilot.sign_in(cx))
466 .await
467 .unwrap();
468 dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
469
470 let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
471 dbg!(copilot
472 .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
473 .await
474 .unwrap());
475 dbg!(copilot
476 .update(cx, |copilot, cx| copilot
477 .completions_cycling(&buffer, 12, cx))
478 .await
479 .unwrap());
480 }
481}