1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use client::Client;
7use futures::{future::Shared, FutureExt, TryFutureExt};
8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
10use lsp::LanguageServer;
11use settings::Settings;
12use smol::{fs, io::BufReader, stream::StreamExt};
13use std::{
14 env::consts,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{
19 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
20};
21
22actions!(copilot, [SignIn, SignOut, NextSuggestion]);
23
24pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
25 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
26 cx.set_global(copilot.clone());
27 cx.add_global_action(|_: &SignIn, cx| {
28 let copilot = Copilot::global(cx).unwrap();
29 copilot
30 .update(cx, |copilot, cx| copilot.sign_in(cx))
31 .detach_and_log_err(cx);
32 });
33 cx.add_global_action(|_: &SignOut, cx| {
34 let copilot = Copilot::global(cx).unwrap();
35 copilot
36 .update(cx, |copilot, cx| copilot.sign_out(cx))
37 .detach_and_log_err(cx);
38 });
39 sign_in::init(cx);
40}
41
42enum CopilotServer {
43 Downloading,
44 Error(Arc<str>),
45 Started {
46 server: Arc<LanguageServer>,
47 status: SignInStatus,
48 },
49}
50
51#[derive(Clone, Debug)]
52enum SignInStatus {
53 Authorized {
54 user: String,
55 },
56 Unauthorized {
57 user: String,
58 },
59 SigningIn {
60 prompt: Option<request::PromptUserDeviceFlow>,
61 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
62 },
63 SignedOut,
64}
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum Status {
68 Downloading,
69 Error(Arc<str>),
70 SignedOut,
71 SigningIn {
72 prompt: Option<request::PromptUserDeviceFlow>,
73 },
74 Unauthorized,
75 Authorized,
76}
77
78impl Status {
79 pub fn is_authorized(&self) -> bool {
80 matches!(self, Status::Authorized)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq)]
85pub struct Completion {
86 pub position: Anchor,
87 pub text: String,
88}
89
90pub struct Copilot {
91 server: CopilotServer,
92}
93
94impl Entity for Copilot {
95 type Event = ();
96}
97
98impl Copilot {
99 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
100 if cx.has_global::<ModelHandle<Self>>() {
101 Some(cx.global::<ModelHandle<Self>>().clone())
102 } else {
103 None
104 }
105 }
106
107 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
108 // TODO: Don't eagerly download the LSP
109 cx.spawn(|this, mut cx| async move {
110 let start_language_server = async {
111 let server_path = get_lsp_binary(http).await?;
112 let server =
113 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
114 let server = server.initialize(Default::default()).await?;
115 let status = server
116 .request::<request::CheckStatus>(request::CheckStatusParams {
117 local_checks_only: false,
118 })
119 .await?;
120 anyhow::Ok((server, status))
121 };
122
123 let server = start_language_server.await;
124 this.update(&mut cx, |this, cx| {
125 cx.notify();
126 match server {
127 Ok((server, status)) => {
128 this.server = CopilotServer::Started {
129 server,
130 status: SignInStatus::SignedOut,
131 };
132 this.update_sign_in_status(status, cx);
133 }
134 Err(error) => {
135 this.server = CopilotServer::Error(error.to_string().into());
136 }
137 }
138 })
139 })
140 .detach();
141
142 Self {
143 server: CopilotServer::Downloading,
144 }
145 }
146
147 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
148 if let CopilotServer::Started { server, status } = &mut self.server {
149 let task = match status {
150 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
151 Task::ready(Ok(())).shared()
152 }
153 SignInStatus::SigningIn { task, .. } => task.clone(),
154 SignInStatus::SignedOut => {
155 let server = server.clone();
156 let task = cx
157 .spawn(|this, mut cx| async move {
158 let sign_in = async {
159 let sign_in = server
160 .request::<request::SignInInitiate>(
161 request::SignInInitiateParams {},
162 )
163 .await?;
164 match sign_in {
165 request::SignInInitiateResult::AlreadySignedIn { user } => {
166 Ok(request::SignInStatus::Ok { user })
167 }
168 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
169 this.update(&mut cx, |this, cx| {
170 if let CopilotServer::Started { status, .. } =
171 &mut this.server
172 {
173 if let SignInStatus::SigningIn {
174 prompt: prompt_flow,
175 ..
176 } = status
177 {
178 *prompt_flow = Some(flow.clone());
179 cx.notify();
180 }
181 }
182 });
183 let response = server
184 .request::<request::SignInConfirm>(
185 request::SignInConfirmParams {
186 user_code: flow.user_code,
187 },
188 )
189 .await?;
190 Ok(response)
191 }
192 }
193 };
194
195 let sign_in = sign_in.await;
196 this.update(&mut cx, |this, cx| match sign_in {
197 Ok(status) => {
198 this.update_sign_in_status(status, cx);
199 Ok(())
200 }
201 Err(error) => {
202 this.update_sign_in_status(
203 request::SignInStatus::NotSignedIn,
204 cx,
205 );
206 Err(Arc::new(error))
207 }
208 })
209 })
210 .shared();
211 *status = SignInStatus::SigningIn {
212 prompt: None,
213 task: task.clone(),
214 };
215 cx.notify();
216 task
217 }
218 };
219
220 cx.foreground()
221 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
222 } else {
223 Task::ready(Err(anyhow!("copilot hasn't started yet")))
224 }
225 }
226
227 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
228 if let CopilotServer::Started { server, status } = &mut self.server {
229 *status = SignInStatus::SignedOut;
230 cx.notify();
231
232 let server = server.clone();
233 cx.background().spawn(async move {
234 server
235 .request::<request::SignOut>(request::SignOutParams {})
236 .await?;
237 anyhow::Ok(())
238 })
239 } else {
240 Task::ready(Err(anyhow!("copilot hasn't started yet")))
241 }
242 }
243
244 pub fn completion<T>(
245 &self,
246 buffer: &ModelHandle<Buffer>,
247 position: T,
248 cx: &mut ModelContext<Self>,
249 ) -> Task<Result<Option<Completion>>>
250 where
251 T: ToPointUtf16,
252 {
253 let server = match self.authorized_server() {
254 Ok(server) => server,
255 Err(error) => return Task::ready(Err(error)),
256 };
257
258 let buffer = buffer.read(cx).snapshot();
259 let request = server
260 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
261 cx.background().spawn(async move {
262 let result = request.await?;
263 let completion = result
264 .completions
265 .into_iter()
266 .next()
267 .map(|completion| completion_from_lsp(completion, &buffer));
268 anyhow::Ok(completion)
269 })
270 }
271
272 pub fn completions_cycling<T>(
273 &self,
274 buffer: &ModelHandle<Buffer>,
275 position: T,
276 cx: &mut ModelContext<Self>,
277 ) -> Task<Result<Vec<Completion>>>
278 where
279 T: ToPointUtf16,
280 {
281 let server = match self.authorized_server() {
282 Ok(server) => server,
283 Err(error) => return Task::ready(Err(error)),
284 };
285
286 let buffer = buffer.read(cx).snapshot();
287 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
288 &buffer, position, cx,
289 ));
290 cx.background().spawn(async move {
291 let result = request.await?;
292 let completions = result
293 .completions
294 .into_iter()
295 .map(|completion| completion_from_lsp(completion, &buffer))
296 .collect();
297 anyhow::Ok(completions)
298 })
299 }
300
301 pub fn status(&self) -> Status {
302 match &self.server {
303 CopilotServer::Downloading => Status::Downloading,
304 CopilotServer::Error(error) => Status::Error(error.clone()),
305 CopilotServer::Started { status, .. } => match status {
306 SignInStatus::Authorized { .. } => Status::Authorized,
307 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
308 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
309 prompt: prompt.clone(),
310 },
311 SignInStatus::SignedOut => Status::SignedOut,
312 },
313 }
314 }
315
316 fn update_sign_in_status(
317 &mut self,
318 lsp_status: request::SignInStatus,
319 cx: &mut ModelContext<Self>,
320 ) {
321 if let CopilotServer::Started { status, .. } = &mut self.server {
322 *status = match lsp_status {
323 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
324 SignInStatus::Authorized { user }
325 }
326 request::SignInStatus::NotAuthorized { user } => {
327 SignInStatus::Unauthorized { user }
328 }
329 _ => SignInStatus::SignedOut,
330 };
331 cx.notify();
332 }
333 }
334
335 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
336 match &self.server {
337 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
338 CopilotServer::Error(error) => Err(anyhow!(
339 "copilot was not started because of an error: {}",
340 error
341 )),
342 CopilotServer::Started { server, status } => {
343 if matches!(status, SignInStatus::Authorized { .. }) {
344 Ok(server.clone())
345 } else {
346 Err(anyhow!("must sign in before using copilot"))
347 }
348 }
349 }
350 }
351}
352
353fn build_completion_params<T>(
354 buffer: &BufferSnapshot,
355 position: T,
356 cx: &AppContext,
357) -> request::GetCompletionsParams
358where
359 T: ToPointUtf16,
360{
361 let position = position.to_point_utf16(&buffer);
362 let language_name = buffer.language_at(position).map(|language| language.name());
363 let language_name = language_name.as_deref();
364
365 let path;
366 let relative_path;
367 if let Some(file) = buffer.file() {
368 if let Some(file) = file.as_local() {
369 path = file.abs_path(cx);
370 } else {
371 path = file.full_path(cx);
372 }
373 relative_path = file.path().to_path_buf();
374 } else {
375 path = PathBuf::from("/untitled");
376 relative_path = PathBuf::from("untitled");
377 }
378
379 let settings = cx.global::<Settings>();
380 let language_id = match language_name {
381 Some("Plain Text") => "plaintext".to_string(),
382 Some(language_name) => language_name.to_lowercase(),
383 None => "plaintext".to_string(),
384 };
385 request::GetCompletionsParams {
386 doc: request::GetCompletionsDocument {
387 source: buffer.text(),
388 tab_size: settings.tab_size(language_name).into(),
389 indent_size: 1,
390 insert_spaces: !settings.hard_tabs(language_name),
391 uri: lsp::Url::from_file_path(&path).unwrap(),
392 path: path.to_string_lossy().into(),
393 relative_path: relative_path.to_string_lossy().into(),
394 language_id,
395 position: point_to_lsp(position),
396 version: 0,
397 },
398 }
399}
400
401fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
402 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
403 Completion {
404 position: buffer.anchor_before(position),
405 text: completion.display_text,
406 }
407}
408
409async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
410 ///Check for the latest copilot language server and download it if we haven't already
411 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
412 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
413 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
414 let asset = release
415 .assets
416 .iter()
417 .find(|asset| asset.name == asset_name)
418 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
419
420 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
421 let destination_path =
422 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
423
424 if fs::metadata(&destination_path).await.is_err() {
425 let mut response = http
426 .get(&asset.browser_download_url, Default::default(), true)
427 .await
428 .map_err(|err| anyhow!("error downloading release: {}", err))?;
429 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
430 let mut file = fs::File::create(&destination_path).await?;
431 futures::io::copy(decompressed_bytes, &mut file).await?;
432 fs::set_permissions(
433 &destination_path,
434 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
435 )
436 .await?;
437
438 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
439 }
440
441 Ok(destination_path)
442 }
443
444 match fetch_latest(http).await {
445 ok @ Result::Ok(..) => ok,
446 e @ Err(..) => {
447 e.log_err();
448 // Fetch a cached binary, if it exists
449 (|| async move {
450 let mut last = None;
451 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
452 while let Some(entry) = entries.next().await {
453 last = Some(entry?.path());
454 }
455 last.ok_or_else(|| anyhow!("no cached binary"))
456 })()
457 .await
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use gpui::TestAppContext;
466 use util::http;
467
468 #[gpui::test]
469 async fn test_smoke(cx: &mut TestAppContext) {
470 Settings::test_async(cx);
471 let http = http::client();
472 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
473 smol::Timer::after(std::time::Duration::from_secs(2)).await;
474 copilot
475 .update(cx, |copilot, cx| copilot.sign_in(cx))
476 .await
477 .unwrap();
478 copilot.read_with(cx, |copilot, _| copilot.status());
479
480 let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
481 dbg!(copilot
482 .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
483 .await
484 .unwrap());
485 dbg!(copilot
486 .update(cx, |copilot, cx| copilot
487 .completions_cycling(&buffer, 12, cx))
488 .await
489 .unwrap());
490 }
491}