1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use client::Client;
7use futures::{future::Shared, FutureExt, TryFutureExt};
8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
10use lsp::LanguageServer;
11use node_runtime::NodeRuntime;
12use settings::Settings;
13use smol::{fs, io::BufReader, stream::StreamExt};
14use std::{
15 env::consts,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19use util::{
20 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
21};
22
23actions!(copilot, [SignIn, SignOut, NextSuggestion]);
24
25pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
26 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
27 cx.set_global(copilot.clone());
28 cx.add_global_action(|_: &SignIn, cx| {
29 let copilot = Copilot::global(cx).unwrap();
30 copilot
31 .update(cx, |copilot, cx| copilot.sign_in(cx))
32 .detach_and_log_err(cx);
33 });
34 cx.add_global_action(|_: &SignOut, cx| {
35 let copilot = Copilot::global(cx).unwrap();
36 copilot
37 .update(cx, |copilot, cx| copilot.sign_out(cx))
38 .detach_and_log_err(cx);
39 });
40 sign_in::init(cx);
41}
42
43enum CopilotServer {
44 Downloading,
45 Error(Arc<str>),
46 Started {
47 server: Arc<LanguageServer>,
48 status: SignInStatus,
49 },
50}
51
52#[derive(Clone, Debug)]
53enum SignInStatus {
54 Authorized {
55 _user: String,
56 },
57 Unauthorized {
58 _user: String,
59 },
60 SigningIn {
61 prompt: Option<request::PromptUserDeviceFlow>,
62 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
63 },
64 SignedOut,
65}
66
67#[derive(Debug, PartialEq, Eq)]
68pub enum Status {
69 Downloading,
70 Error(Arc<str>),
71 SignedOut,
72 SigningIn {
73 prompt: Option<request::PromptUserDeviceFlow>,
74 },
75 Unauthorized,
76 Authorized,
77}
78
79impl Status {
80 pub fn is_authorized(&self) -> bool {
81 matches!(self, Status::Authorized)
82 }
83}
84
85#[derive(Debug, PartialEq, Eq)]
86pub struct Completion {
87 pub position: Anchor,
88 pub text: String,
89}
90
91pub struct Copilot {
92 server: CopilotServer,
93}
94
95impl Entity for Copilot {
96 type Event = ();
97}
98
99impl Copilot {
100 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
101 if cx.has_global::<ModelHandle<Self>>() {
102 Some(cx.global::<ModelHandle<Self>>().clone())
103 } else {
104 None
105 }
106 }
107
108 fn start(
109 http: Arc<dyn HttpClient>,
110 node_runtime: Arc<NodeRuntime>,
111 cx: &mut ModelContext<Self>,
112 ) -> Self {
113 // TODO: Don't eagerly download the LSP
114 cx.spawn(|this, mut cx| async move {
115 let start_language_server = async {
116 let server_path = get_lsp_binary(http).await?;
117 let server =
118 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
119 let server = server.initialize(Default::default()).await?;
120 let status = server
121 .request::<request::CheckStatus>(request::CheckStatusParams {
122 local_checks_only: false,
123 })
124 .await?;
125 anyhow::Ok((server, status))
126 };
127
128 let server = start_language_server.await;
129 this.update(&mut cx, |this, cx| {
130 cx.notify();
131 match server {
132 Ok((server, status)) => {
133 this.server = CopilotServer::Started {
134 server,
135 status: SignInStatus::SignedOut,
136 };
137 this.update_sign_in_status(status, cx);
138 }
139 Err(error) => {
140 this.server = CopilotServer::Error(error.to_string().into());
141 }
142 }
143 })
144 })
145 .detach();
146
147 Self {
148 server: CopilotServer::Downloading,
149 }
150 }
151
152 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
153 if let CopilotServer::Started { server, status } = &mut self.server {
154 let task = match status {
155 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
156 Task::ready(Ok(())).shared()
157 }
158 SignInStatus::SigningIn { task, .. } => {
159 cx.notify(); // To re-show the prompt, just in case.
160 task.clone()
161 }
162 SignInStatus::SignedOut => {
163 let server = server.clone();
164 let task = cx
165 .spawn(|this, mut cx| async move {
166 let sign_in = async {
167 let sign_in = server
168 .request::<request::SignInInitiate>(
169 request::SignInInitiateParams {},
170 )
171 .await?;
172 match sign_in {
173 request::SignInInitiateResult::AlreadySignedIn { user } => {
174 Ok(request::SignInStatus::Ok { user })
175 }
176 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
177 this.update(&mut cx, |this, cx| {
178 if let CopilotServer::Started { status, .. } =
179 &mut this.server
180 {
181 if let SignInStatus::SigningIn {
182 prompt: prompt_flow,
183 ..
184 } = status
185 {
186 *prompt_flow = Some(flow.clone());
187 cx.notify();
188 }
189 }
190 });
191 let response = server
192 .request::<request::SignInConfirm>(
193 request::SignInConfirmParams {
194 user_code: flow.user_code,
195 },
196 )
197 .await?;
198 Ok(response)
199 }
200 }
201 };
202
203 let sign_in = sign_in.await;
204 this.update(&mut cx, |this, cx| match sign_in {
205 Ok(status) => {
206 this.update_sign_in_status(status, cx);
207 Ok(())
208 }
209 Err(error) => {
210 this.update_sign_in_status(
211 request::SignInStatus::NotSignedIn,
212 cx,
213 );
214 Err(Arc::new(error))
215 }
216 })
217 })
218 .shared();
219 *status = SignInStatus::SigningIn {
220 prompt: None,
221 task: task.clone(),
222 };
223 cx.notify();
224 task
225 }
226 };
227
228 cx.foreground()
229 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
230 } else {
231 Task::ready(Err(anyhow!("copilot hasn't started yet")))
232 }
233 }
234
235 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
236 if let CopilotServer::Started { server, status } = &mut self.server {
237 *status = SignInStatus::SignedOut;
238 cx.notify();
239
240 let server = server.clone();
241 cx.background().spawn(async move {
242 server
243 .request::<request::SignOut>(request::SignOutParams {})
244 .await?;
245 anyhow::Ok(())
246 })
247 } else {
248 Task::ready(Err(anyhow!("copilot hasn't started yet")))
249 }
250 }
251
252 pub fn completion<T>(
253 &self,
254 buffer: &ModelHandle<Buffer>,
255 position: T,
256 cx: &mut ModelContext<Self>,
257 ) -> Task<Result<Option<Completion>>>
258 where
259 T: ToPointUtf16,
260 {
261 let server = match self.authorized_server() {
262 Ok(server) => server,
263 Err(error) => return Task::ready(Err(error)),
264 };
265
266 let buffer = buffer.read(cx).snapshot();
267 let request = server
268 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
269 cx.background().spawn(async move {
270 let result = request.await?;
271 let completion = result
272 .completions
273 .into_iter()
274 .next()
275 .map(|completion| completion_from_lsp(completion, &buffer));
276 anyhow::Ok(completion)
277 })
278 }
279
280 pub fn completions_cycling<T>(
281 &self,
282 buffer: &ModelHandle<Buffer>,
283 position: T,
284 cx: &mut ModelContext<Self>,
285 ) -> Task<Result<Vec<Completion>>>
286 where
287 T: ToPointUtf16,
288 {
289 let server = match self.authorized_server() {
290 Ok(server) => server,
291 Err(error) => return Task::ready(Err(error)),
292 };
293
294 let buffer = buffer.read(cx).snapshot();
295 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
296 &buffer, position, cx,
297 ));
298 cx.background().spawn(async move {
299 let result = request.await?;
300 let completions = result
301 .completions
302 .into_iter()
303 .map(|completion| completion_from_lsp(completion, &buffer))
304 .collect();
305 anyhow::Ok(completions)
306 })
307 }
308
309 pub fn status(&self) -> Status {
310 match &self.server {
311 CopilotServer::Downloading => Status::Downloading,
312 CopilotServer::Error(error) => Status::Error(error.clone()),
313 CopilotServer::Started { status, .. } => match status {
314 SignInStatus::Authorized { .. } => Status::Authorized,
315 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
316 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
317 prompt: prompt.clone(),
318 },
319 SignInStatus::SignedOut => Status::SignedOut,
320 },
321 }
322 }
323
324 fn update_sign_in_status(
325 &mut self,
326 lsp_status: request::SignInStatus,
327 cx: &mut ModelContext<Self>,
328 ) {
329 if let CopilotServer::Started { status, .. } = &mut self.server {
330 *status = match lsp_status {
331 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
332 SignInStatus::Authorized { _user: user }
333 }
334 request::SignInStatus::NotAuthorized { user } => {
335 SignInStatus::Unauthorized { _user: user }
336 }
337 _ => SignInStatus::SignedOut,
338 };
339 cx.notify();
340 }
341 }
342
343 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
344 match &self.server {
345 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
346 CopilotServer::Error(error) => Err(anyhow!(
347 "copilot was not started because of an error: {}",
348 error
349 )),
350 CopilotServer::Started { server, status } => {
351 if matches!(status, SignInStatus::Authorized { .. }) {
352 Ok(server.clone())
353 } else {
354 Err(anyhow!("must sign in before using copilot"))
355 }
356 }
357 }
358 }
359}
360
361fn build_completion_params<T>(
362 buffer: &BufferSnapshot,
363 position: T,
364 cx: &AppContext,
365) -> request::GetCompletionsParams
366where
367 T: ToPointUtf16,
368{
369 let position = position.to_point_utf16(&buffer);
370 let language_name = buffer.language_at(position).map(|language| language.name());
371 let language_name = language_name.as_deref();
372
373 let path;
374 let relative_path;
375 if let Some(file) = buffer.file() {
376 if let Some(file) = file.as_local() {
377 path = file.abs_path(cx);
378 } else {
379 path = file.full_path(cx);
380 }
381 relative_path = file.path().to_path_buf();
382 } else {
383 path = PathBuf::from("/untitled");
384 relative_path = PathBuf::from("untitled");
385 }
386
387 let settings = cx.global::<Settings>();
388 let language_id = match language_name {
389 Some("Plain Text") => "plaintext".to_string(),
390 Some(language_name) => language_name.to_lowercase(),
391 None => "plaintext".to_string(),
392 };
393 request::GetCompletionsParams {
394 doc: request::GetCompletionsDocument {
395 source: buffer.text(),
396 tab_size: settings.tab_size(language_name).into(),
397 indent_size: 1,
398 insert_spaces: !settings.hard_tabs(language_name),
399 uri: lsp::Url::from_file_path(&path).unwrap(),
400 path: path.to_string_lossy().into(),
401 relative_path: relative_path.to_string_lossy().into(),
402 language_id,
403 position: point_to_lsp(position),
404 version: 0,
405 },
406 }
407}
408
409fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
410 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
411 Completion {
412 position: buffer.anchor_before(position),
413 text: completion.display_text,
414 }
415}
416
417async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
418 ///Check for the latest copilot language server and download it if we haven't already
419 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
420 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
421 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
422 let asset = release
423 .assets
424 .iter()
425 .find(|asset| asset.name == asset_name)
426 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
427
428 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
429 let destination_path =
430 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
431
432 if fs::metadata(&destination_path).await.is_err() {
433 let mut response = http
434 .get(&asset.browser_download_url, Default::default(), true)
435 .await
436 .map_err(|err| anyhow!("error downloading release: {}", err))?;
437 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
438 let mut file = fs::File::create(&destination_path).await?;
439 futures::io::copy(decompressed_bytes, &mut file).await?;
440 fs::set_permissions(
441 &destination_path,
442 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
443 )
444 .await?;
445
446 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
447 }
448
449 Ok(destination_path)
450 }
451
452 match fetch_latest(http).await {
453 ok @ Result::Ok(..) => ok,
454 e @ Err(..) => {
455 e.log_err();
456 // Fetch a cached binary, if it exists
457 (|| async move {
458 let mut last = None;
459 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
460 while let Some(entry) = entries.next().await {
461 last = Some(entry?.path());
462 }
463 last.ok_or_else(|| anyhow!("no cached binary"))
464 })()
465 .await
466 }
467 }
468}