1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
14use lsp::LanguageServer;
15use node_runtime::NodeRuntime;
16use settings::Settings;
17use smol::{fs, io::BufReader, stream::StreamExt};
18use std::{
19 ffi::OsString,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
25};
26
27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
28actions!(copilot_auth, [SignIn, SignOut]);
29
30const COPILOT_NAMESPACE: &'static str = "copilot";
31actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
32
33pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
34 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
35 cx.set_global(copilot.clone());
36 cx.add_global_action(|_: &SignIn, cx| {
37 let copilot = Copilot::global(cx).unwrap();
38 copilot
39 .update(cx, |copilot, cx| copilot.sign_in(cx))
40 .detach_and_log_err(cx);
41 });
42 cx.add_global_action(|_: &SignOut, cx| {
43 let copilot = Copilot::global(cx).unwrap();
44 copilot
45 .update(cx, |copilot, cx| copilot.sign_out(cx))
46 .detach_and_log_err(cx);
47 });
48
49 cx.observe(&copilot, |handle, cx| {
50 let status = handle.read(cx).status();
51 cx.update_global::<collections::CommandPaletteFilter, _, _>(
52 move |filter, _cx| match status {
53 Status::Disabled => {
54 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
55 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
56 }
57 Status::Authorized => {
58 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
60 }
61 _ => {
62 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
63 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
64 }
65 },
66 );
67 })
68 .detach();
69
70 sign_in::init(cx);
71}
72
73enum CopilotServer {
74 Disabled,
75 Starting {
76 _task: Shared<Task<()>>,
77 },
78 Error(Arc<str>),
79 Started {
80 server: Arc<LanguageServer>,
81 status: SignInStatus,
82 },
83}
84
85#[derive(Clone, Debug)]
86enum SignInStatus {
87 Authorized {
88 _user: String,
89 },
90 Unauthorized {
91 _user: String,
92 },
93 SigningIn {
94 prompt: Option<request::PromptUserDeviceFlow>,
95 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
96 },
97 SignedOut,
98}
99
100#[derive(Debug, PartialEq, Eq)]
101pub enum Status {
102 Starting,
103 Error(Arc<str>),
104 Disabled,
105 SignedOut,
106 SigningIn {
107 prompt: Option<request::PromptUserDeviceFlow>,
108 },
109 Unauthorized,
110 Authorized,
111}
112
113impl Status {
114 pub fn is_authorized(&self) -> bool {
115 matches!(self, Status::Authorized)
116 }
117}
118
119#[derive(Debug, PartialEq, Eq)]
120pub struct Completion {
121 pub position: Anchor,
122 pub text: String,
123}
124
125pub struct Copilot {
126 server: CopilotServer,
127}
128
129impl Entity for Copilot {
130 type Event = ();
131}
132
133impl Copilot {
134 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
135 if cx.has_global::<ModelHandle<Self>>() {
136 Some(cx.global::<ModelHandle<Self>>().clone())
137 } else {
138 None
139 }
140 }
141
142 fn start(
143 http: Arc<dyn HttpClient>,
144 node_runtime: Arc<NodeRuntime>,
145 cx: &mut ModelContext<Self>,
146 ) -> Self {
147 cx.observe_global::<Settings, _>({
148 let http = http.clone();
149 let node_runtime = node_runtime.clone();
150 move |this, cx| {
151 if cx.global::<Settings>().enable_copilot_integration {
152 if matches!(this.server, CopilotServer::Disabled) {
153 let start_task = cx
154 .spawn({
155 let http = http.clone();
156 let node_runtime = node_runtime.clone();
157 move |this, cx| {
158 Self::start_language_server(http, node_runtime, this, cx)
159 }
160 })
161 .shared();
162 this.server = CopilotServer::Starting { _task: start_task }
163 }
164 } else {
165 this.server = CopilotServer::Disabled
166 }
167 }
168 })
169 .detach();
170
171 if cx.global::<Settings>().enable_copilot_integration {
172 let start_task = cx
173 .spawn({
174 let http = http.clone();
175 let node_runtime = node_runtime.clone();
176 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
177 })
178 .shared();
179
180 Self {
181 server: CopilotServer::Starting { _task: start_task },
182 }
183 } else {
184 Self {
185 server: CopilotServer::Disabled,
186 }
187 }
188 }
189
190 fn start_language_server(
191 http: Arc<dyn HttpClient>,
192 node_runtime: Arc<NodeRuntime>,
193 this: ModelHandle<Self>,
194 mut cx: AsyncAppContext,
195 ) -> impl Future<Output = ()> {
196 async move {
197 let start_language_server = async {
198 let server_path = get_copilot_lsp(http).await?;
199 let node_path = node_runtime.binary_path().await?;
200 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
201 let server =
202 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
203
204 let server = server.initialize(Default::default()).await?;
205 let status = server
206 .request::<request::CheckStatus>(request::CheckStatusParams {
207 local_checks_only: false,
208 })
209 .await?;
210 anyhow::Ok((server, status))
211 };
212
213 let server = start_language_server.await;
214 this.update(&mut cx, |this, cx| {
215 cx.notify();
216 match server {
217 Ok((server, status)) => {
218 this.server = CopilotServer::Started {
219 server,
220 status: SignInStatus::SignedOut,
221 };
222 this.update_sign_in_status(status, cx);
223 }
224 Err(error) => {
225 this.server = CopilotServer::Error(error.to_string().into());
226 cx.notify()
227 }
228 }
229 })
230 }
231 }
232
233 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
234 if let CopilotServer::Started { server, status } = &mut self.server {
235 let task = match status {
236 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
237 Task::ready(Ok(())).shared()
238 }
239 SignInStatus::SigningIn { task, .. } => {
240 cx.notify();
241 task.clone()
242 }
243 SignInStatus::SignedOut => {
244 let server = server.clone();
245 let task = cx
246 .spawn(|this, mut cx| async move {
247 let sign_in = async {
248 let sign_in = server
249 .request::<request::SignInInitiate>(
250 request::SignInInitiateParams {},
251 )
252 .await?;
253 match sign_in {
254 request::SignInInitiateResult::AlreadySignedIn { user } => {
255 Ok(request::SignInStatus::Ok { user })
256 }
257 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
258 this.update(&mut cx, |this, cx| {
259 if let CopilotServer::Started { status, .. } =
260 &mut this.server
261 {
262 if let SignInStatus::SigningIn {
263 prompt: prompt_flow,
264 ..
265 } = status
266 {
267 *prompt_flow = Some(flow.clone());
268 cx.notify();
269 }
270 }
271 });
272 let response = server
273 .request::<request::SignInConfirm>(
274 request::SignInConfirmParams {
275 user_code: flow.user_code,
276 },
277 )
278 .await?;
279 Ok(response)
280 }
281 }
282 };
283
284 let sign_in = sign_in.await;
285 this.update(&mut cx, |this, cx| match sign_in {
286 Ok(status) => {
287 this.update_sign_in_status(status, cx);
288 Ok(())
289 }
290 Err(error) => {
291 this.update_sign_in_status(
292 request::SignInStatus::NotSignedIn,
293 cx,
294 );
295 Err(Arc::new(error))
296 }
297 })
298 })
299 .shared();
300 *status = SignInStatus::SigningIn {
301 prompt: None,
302 task: task.clone(),
303 };
304 cx.notify();
305 task
306 }
307 };
308
309 cx.foreground()
310 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
311 } else {
312 // If we're downloading, wait until download is finished
313 // If we're in a stuck state, display to the user
314 Task::ready(Err(anyhow!("copilot hasn't started yet")))
315 }
316 }
317
318 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
319 if let CopilotServer::Started { server, status } = &mut self.server {
320 *status = SignInStatus::SignedOut;
321 cx.notify();
322
323 let server = server.clone();
324 cx.background().spawn(async move {
325 server
326 .request::<request::SignOut>(request::SignOutParams {})
327 .await?;
328 anyhow::Ok(())
329 })
330 } else {
331 Task::ready(Err(anyhow!("copilot hasn't started yet")))
332 }
333 }
334
335 pub fn completion<T>(
336 &self,
337 buffer: &ModelHandle<Buffer>,
338 position: T,
339 cx: &mut ModelContext<Self>,
340 ) -> Task<Result<Option<Completion>>>
341 where
342 T: ToPointUtf16,
343 {
344 let server = match self.authorized_server() {
345 Ok(server) => server,
346 Err(error) => return Task::ready(Err(error)),
347 };
348
349 let buffer = buffer.read(cx).snapshot();
350 let request = server
351 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
352 cx.background().spawn(async move {
353 let result = request.await?;
354 let completion = result
355 .completions
356 .into_iter()
357 .next()
358 .map(|completion| completion_from_lsp(completion, &buffer));
359 anyhow::Ok(completion)
360 })
361 }
362
363 pub fn completions_cycling<T>(
364 &self,
365 buffer: &ModelHandle<Buffer>,
366 position: T,
367 cx: &mut ModelContext<Self>,
368 ) -> Task<Result<Vec<Completion>>>
369 where
370 T: ToPointUtf16,
371 {
372 let server = match self.authorized_server() {
373 Ok(server) => server,
374 Err(error) => return Task::ready(Err(error)),
375 };
376
377 let buffer = buffer.read(cx).snapshot();
378 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
379 &buffer, position, cx,
380 ));
381 cx.background().spawn(async move {
382 let result = request.await?;
383 let completions = result
384 .completions
385 .into_iter()
386 .map(|completion| completion_from_lsp(completion, &buffer))
387 .collect();
388 anyhow::Ok(completions)
389 })
390 }
391
392 pub fn status(&self) -> Status {
393 match &self.server {
394 CopilotServer::Starting { .. } => Status::Starting,
395 CopilotServer::Disabled => Status::Disabled,
396 CopilotServer::Error(error) => Status::Error(error.clone()),
397 CopilotServer::Started { status, .. } => match status {
398 SignInStatus::Authorized { .. } => Status::Authorized,
399 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
400 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
401 prompt: prompt.clone(),
402 },
403 SignInStatus::SignedOut => Status::SignedOut,
404 },
405 }
406 }
407
408 fn update_sign_in_status(
409 &mut self,
410 lsp_status: request::SignInStatus,
411 cx: &mut ModelContext<Self>,
412 ) {
413 if let CopilotServer::Started { status, .. } = &mut self.server {
414 *status = match lsp_status {
415 request::SignInStatus::Ok { user }
416 | request::SignInStatus::MaybeOk { user }
417 | request::SignInStatus::AlreadySignedIn { user } => {
418 SignInStatus::Authorized { _user: user }
419 }
420 request::SignInStatus::NotAuthorized { user } => {
421 SignInStatus::Unauthorized { _user: user }
422 }
423 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
424 };
425 cx.notify();
426 }
427 }
428
429 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
430 match &self.server {
431 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
432 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
433 CopilotServer::Error(error) => Err(anyhow!(
434 "copilot was not started because of an error: {}",
435 error
436 )),
437 CopilotServer::Started { server, status } => {
438 if matches!(status, SignInStatus::Authorized { .. }) {
439 Ok(server.clone())
440 } else {
441 Err(anyhow!("must sign in before using copilot"))
442 }
443 }
444 }
445 }
446}
447
448fn build_completion_params<T>(
449 buffer: &BufferSnapshot,
450 position: T,
451 cx: &AppContext,
452) -> request::GetCompletionsParams
453where
454 T: ToPointUtf16,
455{
456 let position = position.to_point_utf16(&buffer);
457 let language_name = buffer.language_at(position).map(|language| language.name());
458 let language_name = language_name.as_deref();
459
460 let path;
461 let relative_path;
462 if let Some(file) = buffer.file() {
463 if let Some(file) = file.as_local() {
464 path = file.abs_path(cx);
465 } else {
466 path = file.full_path(cx);
467 }
468 relative_path = file.path().to_path_buf();
469 } else {
470 path = PathBuf::from("/untitled");
471 relative_path = PathBuf::from("untitled");
472 }
473
474 let settings = cx.global::<Settings>();
475 let language_id = match language_name {
476 Some("Plain Text") => "plaintext".to_string(),
477 Some(language_name) => language_name.to_lowercase(),
478 None => "plaintext".to_string(),
479 };
480 request::GetCompletionsParams {
481 doc: request::GetCompletionsDocument {
482 source: buffer.text(),
483 tab_size: settings.tab_size(language_name).into(),
484 indent_size: 1,
485 insert_spaces: !settings.hard_tabs(language_name),
486 uri: lsp::Url::from_file_path(&path).unwrap(),
487 path: path.to_string_lossy().into(),
488 relative_path: relative_path.to_string_lossy().into(),
489 language_id,
490 position: point_to_lsp(position),
491 version: 0,
492 },
493 }
494}
495
496fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
497 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
498 Completion {
499 position: buffer.anchor_before(position),
500 text: completion.display_text,
501 }
502}
503
504async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
505 const SERVER_PATH: &'static str = "agent.js";
506
507 ///Check for the latest copilot language server and download it if we haven't already
508 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
509 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
510
511 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
512
513 fs::create_dir_all(version_dir).await?;
514 let server_path = version_dir.join(SERVER_PATH);
515
516 if fs::metadata(&server_path).await.is_err() {
517 let url = &release
518 .assets
519 .get(0)
520 .context("Github release for copilot contained no assets")?
521 .browser_download_url;
522
523 let mut response = http
524 .get(&url, Default::default(), true)
525 .await
526 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
527 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
528 let archive = Archive::new(decompressed_bytes);
529 archive.unpack(version_dir).await?;
530
531 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
532 }
533
534 Ok(server_path)
535 }
536
537 match fetch_latest(http).await {
538 ok @ Result::Ok(..) => ok,
539 e @ Err(..) => {
540 e.log_err();
541 // Fetch a cached binary, if it exists
542 (|| async move {
543 let mut last_version_dir = None;
544 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
545 while let Some(entry) = entries.next().await {
546 let entry = entry?;
547 if entry.file_type().await?.is_dir() {
548 last_version_dir = Some(entry.path());
549 }
550 }
551 let last_version_dir =
552 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
553 let server_path = last_version_dir.join(SERVER_PATH);
554 if server_path.exists() {
555 Ok(server_path)
556 } else {
557 Err(anyhow!(
558 "missing executable in directory {:?}",
559 last_version_dir
560 ))
561 }
562 })()
563 .await
564 }
565 }
566}