1mod auth_modal;
2mod request;
3
4use anyhow::{anyhow, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use auth_modal::AuthModal;
7use client::Client;
8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
10use lsp::LanguageServer;
11use settings::Settings;
12use smol::{fs, io::BufReader, stream::StreamExt};
13use std::{
14 env::consts,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{
19 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
20};
21use workspace::Workspace;
22
23actions!(copilot, [SignIn, SignOut, ToggleAuthStatus]);
24
25pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
26 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
27 cx.set_global(copilot.clone());
28 cx.add_action(|_workspace: &mut Workspace, _: &SignIn, cx| {
29 let copilot = Copilot::global(cx);
30 if copilot.read(cx).status() == Status::Authorized {
31 return;
32 }
33
34 if !copilot.read(cx).has_subscription() {
35 let display_subscription =
36 cx.subscribe(&copilot, |workspace, _copilot, e, cx| match e {
37 Event::PromptUserDeviceFlow => {
38 workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx));
39 }
40 });
41
42 copilot.update(cx, |copilot, _cx| {
43 copilot.set_subscription(display_subscription)
44 })
45 }
46
47 copilot
48 .update(cx, |copilot, cx| copilot.sign_in(cx))
49 .detach_and_log_err(cx);
50 });
51 cx.add_action(|workspace: &mut Workspace, _: &SignOut, cx| {
52 let copilot = Copilot::global(cx);
53
54 copilot
55 .update(cx, |copilot, cx| copilot.sign_out(cx))
56 .detach_and_log_err(cx);
57
58 if workspace.modal::<AuthModal>().is_some() {
59 workspace.dismiss_modal(cx)
60 }
61 });
62 cx.add_action(|workspace: &mut Workspace, _: &ToggleAuthStatus, cx| {
63 workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx))
64 })
65}
66
67fn build_auth_modal(cx: &mut gpui::ViewContext<Workspace>) -> gpui::ViewHandle<AuthModal> {
68 let modal = cx.add_view(|cx| AuthModal::new(cx));
69
70 cx.subscribe(&modal, |workspace, _, e: &auth_modal::Event, cx| match e {
71 auth_modal::Event::Dismiss => workspace.dismiss_modal(cx),
72 })
73 .detach();
74
75 modal
76}
77
78enum CopilotServer {
79 Downloading,
80 Error(Arc<str>),
81 Started {
82 server: Arc<LanguageServer>,
83 status: SignInStatus,
84 },
85}
86
87#[derive(Clone, Debug, PartialEq, Eq)]
88struct PromptingUser {
89 user_code: String,
90 verification_uri: String,
91}
92
93#[derive(Clone, Debug, PartialEq, Eq)]
94enum SignInStatus {
95 Authorized { user: String },
96 Unauthorized { user: String },
97 PromptingUser(PromptingUser),
98 SignedOut,
99}
100
101#[derive(Debug)]
102pub enum Event {
103 PromptUserDeviceFlow,
104}
105
106#[derive(Debug, PartialEq, Eq)]
107pub enum Status {
108 Downloading,
109 Error(Arc<str>),
110 SignedOut,
111 Unauthorized,
112 Authorized,
113}
114
115impl Status {
116 fn is_authorized(&self) -> bool {
117 matches!(self, Status::Authorized)
118 }
119}
120
121#[derive(Debug)]
122pub struct Completion {
123 pub position: Anchor,
124 pub text: String,
125}
126
127struct Copilot {
128 server: CopilotServer,
129 _display_subscription: Option<gpui::Subscription>,
130}
131
132impl Entity for Copilot {
133 type Event = Event;
134}
135
136impl Copilot {
137 fn global(cx: &AppContext) -> ModelHandle<Self> {
138 cx.global::<ModelHandle<Self>>().clone()
139 }
140
141 fn has_subscription(&self) -> bool {
142 self._display_subscription.is_some()
143 }
144
145 fn set_subscription(&mut self, display_subscription: gpui::Subscription) {
146 debug_assert!(self._display_subscription.is_none());
147 self._display_subscription = Some(display_subscription);
148 }
149
150 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
151 // TODO: Don't eagerly download the LSP
152 cx.spawn(|this, mut cx| async move {
153 let start_language_server = async {
154 let server_path = get_lsp_binary(http).await?;
155 let server =
156 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
157 let server = server.initialize(Default::default()).await?;
158 let status = server
159 .request::<request::CheckStatus>(request::CheckStatusParams {
160 local_checks_only: false,
161 })
162 .await?;
163 anyhow::Ok((server, status))
164 };
165
166 let server = start_language_server.await;
167 this.update(&mut cx, |this, cx| {
168 cx.notify();
169 match server {
170 Ok((server, status)) => {
171 this.server = CopilotServer::Started {
172 server,
173 status: SignInStatus::SignedOut,
174 };
175 this.update_sign_in_status(status, cx);
176 }
177 Err(error) => {
178 this.server = CopilotServer::Error(error.to_string().into());
179 }
180 }
181 })
182 })
183 .detach();
184
185 Self {
186 server: CopilotServer::Downloading,
187 _display_subscription: None,
188 }
189 }
190
191 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
192 if let CopilotServer::Started { server, .. } = &self.server {
193 let server = server.clone();
194 cx.spawn(|this, mut cx| async move {
195 let sign_in = server
196 .request::<request::SignInInitiate>(request::SignInInitiateParams {})
197 .await?;
198 if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
199 this.update(&mut cx, |this, cx| {
200 this.update_prompting_user(
201 flow.user_code.clone(),
202 flow.verification_uri,
203 cx,
204 );
205
206 cx.emit(Event::PromptUserDeviceFlow)
207 });
208 // TODO: catch an error here and clear the corresponding user code
209 let response = server
210 .request::<request::SignInConfirm>(request::SignInConfirmParams {
211 user_code: flow.user_code,
212 })
213 .await?;
214
215 this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
216 }
217 anyhow::Ok(())
218 })
219 } else {
220 Task::ready(Err(anyhow!("copilot hasn't started yet")))
221 }
222 }
223
224 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
225 if let CopilotServer::Started { server, .. } = &self.server {
226 let server = server.clone();
227 cx.spawn(|this, mut cx| async move {
228 server
229 .request::<request::SignOut>(request::SignOutParams {})
230 .await?;
231 this.update(&mut cx, |this, cx| {
232 if let CopilotServer::Started { status, .. } = &mut this.server {
233 *status = SignInStatus::SignedOut;
234 cx.notify();
235 }
236 });
237
238 anyhow::Ok(())
239 })
240 } else {
241 Task::ready(Err(anyhow!("copilot hasn't started yet")))
242 }
243 }
244
245 pub fn completion<T>(
246 &self,
247 buffer: &ModelHandle<Buffer>,
248 position: T,
249 cx: &mut ModelContext<Self>,
250 ) -> Task<Result<Option<Completion>>>
251 where
252 T: ToPointUtf16,
253 {
254 let server = match self.authenticated_server() {
255 Ok(server) => server,
256 Err(error) => return Task::ready(Err(error)),
257 };
258
259 let buffer = buffer.read(cx).snapshot();
260 let request = server
261 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
262 cx.background().spawn(async move {
263 let result = request.await?;
264 let completion = result
265 .completions
266 .into_iter()
267 .next()
268 .map(|completion| completion_from_lsp(completion, &buffer));
269 anyhow::Ok(completion)
270 })
271 }
272
273 pub fn completions_cycling<T>(
274 &self,
275 buffer: &ModelHandle<Buffer>,
276 position: T,
277 cx: &mut ModelContext<Self>,
278 ) -> Task<Result<Vec<Completion>>>
279 where
280 T: ToPointUtf16,
281 {
282 let server = match self.authenticated_server() {
283 Ok(server) => server,
284 Err(error) => return Task::ready(Err(error)),
285 };
286
287 let buffer = buffer.read(cx).snapshot();
288 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
289 &buffer, position, cx,
290 ));
291 cx.background().spawn(async move {
292 let result = request.await?;
293 let completions = result
294 .completions
295 .into_iter()
296 .map(|completion| completion_from_lsp(completion, &buffer))
297 .collect();
298 anyhow::Ok(completions)
299 })
300 }
301
302 pub fn status(&self) -> Status {
303 match &self.server {
304 CopilotServer::Downloading => Status::Downloading,
305 CopilotServer::Error(error) => Status::Error(error.clone()),
306 CopilotServer::Started { status, .. } => match status {
307 SignInStatus::Authorized { .. } => Status::Authorized,
308 SignInStatus::Unauthorized { .. } | SignInStatus::PromptingUser { .. } => {
309 Status::Unauthorized
310 }
311 SignInStatus::SignedOut => Status::SignedOut,
312 },
313 }
314 }
315
316 pub fn prompting_user(&self) -> Option<&PromptingUser> {
317 if let CopilotServer::Started { status, .. } = &self.server {
318 if let SignInStatus::PromptingUser(prompt) = status {
319 return Some(prompt);
320 }
321 }
322 None
323 }
324
325 fn update_prompting_user(
326 &mut self,
327 user_code: String,
328 verification_uri: String,
329 cx: &mut ModelContext<Self>,
330 ) {
331 if let CopilotServer::Started { status, .. } = &mut self.server {
332 *status = SignInStatus::PromptingUser(PromptingUser {
333 user_code,
334 verification_uri,
335 });
336 cx.notify();
337 }
338 }
339
340 fn update_sign_in_status(
341 &mut self,
342 lsp_status: request::SignInStatus,
343 cx: &mut ModelContext<Self>,
344 ) {
345 if let CopilotServer::Started { status, .. } = &mut self.server {
346 *status = match lsp_status {
347 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
348 SignInStatus::Authorized { user }
349 }
350 request::SignInStatus::NotAuthorized { user } => {
351 SignInStatus::Unauthorized { user }
352 }
353 _ => SignInStatus::SignedOut,
354 };
355 cx.notify();
356 }
357 }
358
359 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
360 match &self.server {
361 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
362 CopilotServer::Error(error) => Err(anyhow!(
363 "copilot was not started because of an error: {}",
364 error
365 )),
366 CopilotServer::Started { server, status } => {
367 if matches!(status, SignInStatus::Authorized { .. }) {
368 Ok(server.clone())
369 } else {
370 Err(anyhow!("must sign in before using copilot"))
371 }
372 }
373 }
374 }
375}
376
377fn build_completion_params<T>(
378 buffer: &BufferSnapshot,
379 position: T,
380 cx: &AppContext,
381) -> request::GetCompletionsParams
382where
383 T: ToPointUtf16,
384{
385 let position = position.to_point_utf16(&buffer);
386 let language_name = buffer.language_at(position).map(|language| language.name());
387 let language_name = language_name.as_deref();
388
389 let path;
390 let relative_path;
391 if let Some(file) = buffer.file() {
392 if let Some(file) = file.as_local() {
393 path = file.abs_path(cx);
394 } else {
395 path = file.full_path(cx);
396 }
397 relative_path = file.path().to_path_buf();
398 } else {
399 path = PathBuf::from("/untitled");
400 relative_path = PathBuf::from("untitled");
401 }
402
403 let settings = cx.global::<Settings>();
404 let language_id = match language_name {
405 Some("Plain Text") => "plaintext".to_string(),
406 Some(language_name) => language_name.to_lowercase(),
407 None => "plaintext".to_string(),
408 };
409 request::GetCompletionsParams {
410 doc: request::GetCompletionsDocument {
411 source: buffer.text(),
412 tab_size: settings.tab_size(language_name).into(),
413 indent_size: 1,
414 insert_spaces: !settings.hard_tabs(language_name),
415 uri: lsp::Url::from_file_path(&path).unwrap(),
416 path: path.to_string_lossy().into(),
417 relative_path: relative_path.to_string_lossy().into(),
418 language_id,
419 position: point_to_lsp(position),
420 version: 0,
421 },
422 }
423}
424
425fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
426 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
427 Completion {
428 position: buffer.anchor_before(position),
429 text: completion.display_text,
430 }
431}
432
433async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
434 ///Check for the latest copilot language server and download it if we haven't already
435 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
436 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
437 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
438 let asset = release
439 .assets
440 .iter()
441 .find(|asset| asset.name == asset_name)
442 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
443
444 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
445 let destination_path =
446 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
447
448 if fs::metadata(&destination_path).await.is_err() {
449 let mut response = http
450 .get(&asset.browser_download_url, Default::default(), true)
451 .await
452 .map_err(|err| anyhow!("error downloading release: {}", err))?;
453 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
454 let mut file = fs::File::create(&destination_path).await?;
455 futures::io::copy(decompressed_bytes, &mut file).await?;
456 fs::set_permissions(
457 &destination_path,
458 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
459 )
460 .await?;
461
462 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
463 }
464
465 Ok(destination_path)
466 }
467
468 match fetch_latest(http).await {
469 ok @ Result::Ok(..) => ok,
470 e @ Err(..) => {
471 e.log_err();
472 // Fetch a cached binary, if it exists
473 (|| async move {
474 let mut last = None;
475 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
476 while let Some(entry) = entries.next().await {
477 last = Some(entry?.path());
478 }
479 last.ok_or_else(|| anyhow!("no cached binary"))
480 })()
481 .await
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use gpui::TestAppContext;
490 use util::http;
491
492 #[gpui::test]
493 async fn test_smoke(cx: &mut TestAppContext) {
494 Settings::test_async(cx);
495 let http = http::client();
496 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
497 smol::Timer::after(std::time::Duration::from_secs(2)).await;
498 copilot
499 .update(cx, |copilot, cx| copilot.sign_in(cx))
500 .await
501 .unwrap();
502 dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
503
504 let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
505 dbg!(copilot
506 .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
507 .await
508 .unwrap());
509 dbg!(copilot
510 .update(cx, |copilot, cx| copilot
511 .completions_cycling(&buffer, 12, cx))
512 .await
513 .unwrap());
514 }
515}