1mod request;
2
3use anyhow::{anyhow, Result};
4use async_compression::futures::bufread::GzipDecoder;
5use client::Client;
6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
7use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
8use lsp::LanguageServer;
9use settings::Settings;
10use smol::{fs, io::BufReader, stream::StreamExt};
11use std::{
12 env::consts,
13 path::{Path, PathBuf},
14 sync::Arc,
15};
16use util::{
17 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
18};
19
20actions!(copilot, [SignIn, SignOut]);
21
22pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
23 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
24 cx.set_global(copilot);
25 cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
26 if let Some(copilot) = Copilot::global(cx) {
27 copilot
28 .update(cx, |copilot, cx| copilot.sign_in(cx))
29 .detach_and_log_err(cx);
30 }
31 });
32 cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
33 if let Some(copilot) = Copilot::global(cx) {
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_out(cx))
36 .detach_and_log_err(cx);
37 }
38 });
39}
40
41enum CopilotServer {
42 Downloading,
43 Error(Arc<str>),
44 Started {
45 server: Arc<LanguageServer>,
46 status: SignInStatus,
47 },
48}
49
50#[derive(Clone, Debug, PartialEq, Eq)]
51enum SignInStatus {
52 Authorized { user: String },
53 Unauthorized { user: String },
54 SignedOut,
55}
56
57#[derive(Debug)]
58pub enum Event {
59 PromptUserDeviceFlow {
60 user_code: String,
61 verification_uri: String,
62 },
63}
64
65#[derive(Debug)]
66pub enum Status {
67 Downloading,
68 Error(Arc<str>),
69 SignedOut,
70 Unauthorized,
71 Authorized,
72}
73
74impl Status {
75 fn is_authorized(&self) -> bool {
76 matches!(self, Status::Authorized)
77 }
78}
79
80#[derive(Debug)]
81pub struct Completion {
82 pub position: Anchor,
83 pub text: String,
84}
85
86struct Copilot {
87 server: CopilotServer,
88}
89
90impl Entity for Copilot {
91 type Event = Event;
92}
93
94impl Copilot {
95 fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
96 if cx.has_global::<ModelHandle<Self>>() {
97 let copilot = cx.global::<ModelHandle<Self>>().clone();
98 if copilot.read(cx).status().is_authorized() {
99 Some(copilot)
100 } else {
101 None
102 }
103 } else {
104 None
105 }
106 }
107
108 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
109 cx.spawn(|this, mut cx| async move {
110 let start_language_server = async {
111 let server_path = get_lsp_binary(http).await?;
112 let server =
113 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
114 let server = server.initialize(Default::default()).await?;
115 let status = server
116 .request::<request::CheckStatus>(request::CheckStatusParams {
117 local_checks_only: false,
118 })
119 .await?;
120 anyhow::Ok((server, status))
121 };
122
123 let server = start_language_server.await;
124 this.update(&mut cx, |this, cx| {
125 cx.notify();
126 match server {
127 Ok((server, status)) => {
128 this.server = CopilotServer::Started {
129 server,
130 status: SignInStatus::SignedOut,
131 };
132 this.update_sign_in_status(status, cx);
133 }
134 Err(error) => {
135 this.server = CopilotServer::Error(error.to_string().into());
136 }
137 }
138 })
139 })
140 .detach();
141 Self {
142 server: CopilotServer::Downloading,
143 }
144 }
145
146 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
147 if let CopilotServer::Started { server, .. } = &self.server {
148 let server = server.clone();
149 cx.spawn(|this, mut cx| async move {
150 let sign_in = server
151 .request::<request::SignInInitiate>(request::SignInInitiateParams {})
152 .await?;
153 if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
154 this.update(&mut cx, |_, cx| {
155 cx.emit(Event::PromptUserDeviceFlow {
156 user_code: flow.user_code.clone(),
157 verification_uri: flow.verification_uri,
158 });
159 });
160 let response = server
161 .request::<request::SignInConfirm>(request::SignInConfirmParams {
162 user_code: flow.user_code,
163 })
164 .await?;
165 this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
166 }
167 anyhow::Ok(())
168 })
169 } else {
170 Task::ready(Err(anyhow!("copilot hasn't started yet")))
171 }
172 }
173
174 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
175 if let CopilotServer::Started { server, .. } = &self.server {
176 let server = server.clone();
177 cx.spawn(|this, mut cx| async move {
178 server
179 .request::<request::SignOut>(request::SignOutParams {})
180 .await?;
181 this.update(&mut cx, |this, cx| {
182 if let CopilotServer::Started { status, .. } = &mut this.server {
183 *status = SignInStatus::SignedOut;
184 cx.notify();
185 }
186 });
187
188 anyhow::Ok(())
189 })
190 } else {
191 Task::ready(Err(anyhow!("copilot hasn't started yet")))
192 }
193 }
194
195 pub fn completion<T>(
196 &self,
197 buffer: &ModelHandle<Buffer>,
198 position: T,
199 cx: &mut ModelContext<Self>,
200 ) -> Task<Result<Option<Completion>>>
201 where
202 T: ToPointUtf16,
203 {
204 let server = match self.authenticated_server() {
205 Ok(server) => server,
206 Err(error) => return Task::ready(Err(error)),
207 };
208
209 let buffer = buffer.read(cx).snapshot();
210 let request = server
211 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
212 cx.background().spawn(async move {
213 let result = request.await?;
214 let completion = result
215 .completions
216 .into_iter()
217 .next()
218 .map(|completion| completion_from_lsp(completion, &buffer));
219 anyhow::Ok(completion)
220 })
221 }
222
223 pub fn completions_cycling<T>(
224 &self,
225 buffer: &ModelHandle<Buffer>,
226 position: T,
227 cx: &mut ModelContext<Self>,
228 ) -> Task<Result<Vec<Completion>>>
229 where
230 T: ToPointUtf16,
231 {
232 let server = match self.authenticated_server() {
233 Ok(server) => server,
234 Err(error) => return Task::ready(Err(error)),
235 };
236
237 let buffer = buffer.read(cx).snapshot();
238 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
239 &buffer, position, cx,
240 ));
241 cx.background().spawn(async move {
242 let result = request.await?;
243 let completions = result
244 .completions
245 .into_iter()
246 .map(|completion| completion_from_lsp(completion, &buffer))
247 .collect();
248 anyhow::Ok(completions)
249 })
250 }
251
252 pub fn status(&self) -> Status {
253 match &self.server {
254 CopilotServer::Downloading => Status::Downloading,
255 CopilotServer::Error(error) => Status::Error(error.clone()),
256 CopilotServer::Started { status, .. } => match status {
257 SignInStatus::Authorized { .. } => Status::Authorized,
258 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
259 SignInStatus::SignedOut => Status::SignedOut,
260 },
261 }
262 }
263
264 fn update_sign_in_status(
265 &mut self,
266 lsp_status: request::SignInStatus,
267 cx: &mut ModelContext<Self>,
268 ) {
269 if let CopilotServer::Started { status, .. } = &mut self.server {
270 *status = match lsp_status {
271 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
272 SignInStatus::Authorized { user }
273 }
274 request::SignInStatus::NotAuthorized { user } => {
275 SignInStatus::Unauthorized { user }
276 }
277 _ => SignInStatus::SignedOut,
278 };
279 cx.notify();
280 }
281 }
282
283 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
284 match &self.server {
285 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
286 CopilotServer::Error(error) => Err(anyhow!(
287 "copilot was not started because of an error: {}",
288 error
289 )),
290 CopilotServer::Started { server, status } => {
291 if matches!(status, SignInStatus::Authorized { .. }) {
292 Ok(server.clone())
293 } else {
294 Err(anyhow!("must sign in before using copilot"))
295 }
296 }
297 }
298 }
299}
300
301fn build_completion_params<T>(
302 buffer: &BufferSnapshot,
303 position: T,
304 cx: &AppContext,
305) -> request::GetCompletionsParams
306where
307 T: ToPointUtf16,
308{
309 let position = position.to_point_utf16(&buffer);
310 let language_name = buffer.language_at(position).map(|language| language.name());
311 let language_name = language_name.as_deref();
312
313 let path;
314 let relative_path;
315 if let Some(file) = buffer.file() {
316 if let Some(file) = file.as_local() {
317 path = file.abs_path(cx);
318 } else {
319 path = file.full_path(cx);
320 }
321 relative_path = file.path().to_path_buf();
322 } else {
323 path = PathBuf::from("/untitled");
324 relative_path = PathBuf::from("untitled");
325 }
326
327 let settings = cx.global::<Settings>();
328 let language_id = match language_name {
329 Some("Plain Text") => "plaintext".to_string(),
330 Some(language_name) => language_name.to_lowercase(),
331 None => "plaintext".to_string(),
332 };
333 request::GetCompletionsParams {
334 doc: request::GetCompletionsDocument {
335 source: buffer.text(),
336 tab_size: settings.tab_size(language_name).into(),
337 indent_size: 1,
338 insert_spaces: !settings.hard_tabs(language_name),
339 uri: lsp::Url::from_file_path(&path).unwrap(),
340 path: path.to_string_lossy().into(),
341 relative_path: relative_path.to_string_lossy().into(),
342 language_id,
343 position: point_to_lsp(position),
344 version: 0,
345 },
346 }
347}
348
349fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
350 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
351 Completion {
352 position: buffer.anchor_before(position),
353 text: completion.display_text,
354 }
355}
356
357async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
358 ///Check for the latest copilot language server and download it if we haven't already
359 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
360 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
361 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
362 let asset = release
363 .assets
364 .iter()
365 .find(|asset| asset.name == asset_name)
366 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
367
368 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
369 let destination_path =
370 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
371
372 if fs::metadata(&destination_path).await.is_err() {
373 let mut response = http
374 .get(&asset.browser_download_url, Default::default(), true)
375 .await
376 .map_err(|err| anyhow!("error downloading release: {}", err))?;
377 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
378 let mut file = fs::File::create(&destination_path).await?;
379 futures::io::copy(decompressed_bytes, &mut file).await?;
380 fs::set_permissions(
381 &destination_path,
382 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
383 )
384 .await?;
385
386 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
387 }
388
389 Ok(destination_path)
390 }
391
392 match fetch_latest(http).await {
393 ok @ Result::Ok(..) => ok,
394 e @ Err(..) => {
395 e.log_err();
396 // Fetch a cached binary, if it exists
397 (|| async move {
398 let mut last = None;
399 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
400 while let Some(entry) = entries.next().await {
401 last = Some(entry?.path());
402 }
403 last.ok_or_else(|| anyhow!("no cached binary"))
404 })()
405 .await
406 }
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use gpui::TestAppContext;
414 use util::http;
415
416 #[gpui::test]
417 async fn test_smoke(cx: &mut TestAppContext) {
418 Settings::test_async(cx);
419 let http = http::client();
420 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
421 smol::Timer::after(std::time::Duration::from_secs(2)).await;
422 copilot
423 .update(cx, |copilot, cx| copilot.sign_in(cx))
424 .await
425 .unwrap();
426 dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
427
428 let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
429 dbg!(copilot
430 .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
431 .await
432 .unwrap());
433 dbg!(copilot
434 .update(cx, |copilot, cx| copilot
435 .completions_cycling(&buffer, 12, cx))
436 .await
437 .unwrap());
438 }
439}