1mod auth_modal;
2mod request;
3
4use anyhow::{anyhow, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use auth_modal::AuthModal;
7use client::Client;
8use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
9use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
10use lsp::LanguageServer;
11use settings::Settings;
12use smol::{fs, io::BufReader, stream::StreamExt};
13use std::{
14 env::consts,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{
19 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
20};
21use workspace::Workspace;
22
23actions!(copilot, [SignIn, SignOut, ToggleAuthStatus]);
24
25pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
26 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
27 cx.set_global(copilot.clone());
28 cx.add_action(|workspace: &mut Workspace, _: &SignIn, cx| {
29 if let Some(copilot) = Copilot::global(cx) {
30 if copilot.read(cx).status() == Status::Authorized {
31 return;
32 }
33
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_in(cx))
36 .detach_and_log_err(cx);
37
38 workspace.toggle_modal(cx, |_workspace, cx| cx.add_view(|_cx| AuthModal {}));
39 }
40 });
41 cx.add_action(|_: &mut Workspace, _: &SignOut, cx| {
42 if let Some(copilot) = Copilot::global(cx) {
43 copilot
44 .update(cx, |copilot, cx| copilot.sign_out(cx))
45 .detach_and_log_err(cx);
46 }
47 });
48 cx.add_action(|workspace: &mut Workspace, _: &ToggleAuthStatus, cx| {
49 workspace.toggle_modal(cx, |_workspace, cx| cx.add_view(|_cx| AuthModal {}))
50 })
51}
52
53enum CopilotServer {
54 Downloading,
55 Error(Arc<str>),
56 Started {
57 server: Arc<LanguageServer>,
58 status: SignInStatus,
59 },
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
63enum SignInStatus {
64 Authorized { user: String },
65 Unauthorized { user: String },
66 SignedOut,
67}
68
69#[derive(Debug)]
70pub enum Event {
71 PromptUserDeviceFlow {
72 user_code: String,
73 verification_uri: String,
74 },
75}
76
77#[derive(Debug, PartialEq, Eq)]
78pub enum Status {
79 Downloading,
80 Error(Arc<str>),
81 SignedOut,
82 Unauthorized,
83 Authorized,
84}
85
86impl Status {
87 fn is_authorized(&self) -> bool {
88 matches!(self, Status::Authorized)
89 }
90}
91
92#[derive(Debug)]
93pub struct Completion {
94 pub position: Anchor,
95 pub text: String,
96}
97
98struct Copilot {
99 server: CopilotServer,
100}
101
102impl Entity for Copilot {
103 type Event = Event;
104}
105
106impl Copilot {
107 fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
108 if cx.has_global::<ModelHandle<Self>>() {
109 let copilot = cx.global::<ModelHandle<Self>>().clone();
110 if copilot.read(cx).status().is_authorized() {
111 Some(copilot)
112 } else {
113 None
114 }
115 } else {
116 None
117 }
118 }
119
120 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
121 cx.spawn(|this, mut cx| async move {
122 let start_language_server = async {
123 let server_path = get_lsp_binary(http).await?;
124 let server =
125 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
126 let server = server.initialize(Default::default()).await?;
127 let status = server
128 .request::<request::CheckStatus>(request::CheckStatusParams {
129 local_checks_only: false,
130 })
131 .await?;
132 anyhow::Ok((server, status))
133 };
134
135 let server = start_language_server.await;
136 this.update(&mut cx, |this, cx| {
137 cx.notify();
138 match server {
139 Ok((server, status)) => {
140 this.server = CopilotServer::Started {
141 server,
142 status: SignInStatus::SignedOut,
143 };
144 this.update_sign_in_status(status, cx);
145 }
146 Err(error) => {
147 this.server = CopilotServer::Error(error.to_string().into());
148 }
149 }
150 })
151 })
152 .detach();
153
154 Self {
155 server: CopilotServer::Downloading,
156 }
157 }
158
159 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
160 if let CopilotServer::Started { server, .. } = &self.server {
161 let server = server.clone();
162 cx.spawn(|this, mut cx| async move {
163 let sign_in = server
164 .request::<request::SignInInitiate>(request::SignInInitiateParams {})
165 .await?;
166 if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
167 this.update(&mut cx, |_, cx| {
168 cx.emit(Event::PromptUserDeviceFlow {
169 user_code: flow.user_code.clone(),
170 verification_uri: flow.verification_uri,
171 });
172 });
173 let response = server
174 .request::<request::SignInConfirm>(request::SignInConfirmParams {
175 user_code: flow.user_code,
176 })
177 .await?;
178 this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
179 }
180 anyhow::Ok(())
181 })
182 } else {
183 Task::ready(Err(anyhow!("copilot hasn't started yet")))
184 }
185 }
186
187 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
188 if let CopilotServer::Started { server, .. } = &self.server {
189 let server = server.clone();
190 cx.spawn(|this, mut cx| async move {
191 server
192 .request::<request::SignOut>(request::SignOutParams {})
193 .await?;
194 this.update(&mut cx, |this, cx| {
195 if let CopilotServer::Started { status, .. } = &mut this.server {
196 *status = SignInStatus::SignedOut;
197 cx.notify();
198 }
199 });
200
201 anyhow::Ok(())
202 })
203 } else {
204 Task::ready(Err(anyhow!("copilot hasn't started yet")))
205 }
206 }
207
208 pub fn completion<T>(
209 &self,
210 buffer: &ModelHandle<Buffer>,
211 position: T,
212 cx: &mut ModelContext<Self>,
213 ) -> Task<Result<Option<Completion>>>
214 where
215 T: ToPointUtf16,
216 {
217 let server = match self.authenticated_server() {
218 Ok(server) => server,
219 Err(error) => return Task::ready(Err(error)),
220 };
221
222 let buffer = buffer.read(cx).snapshot();
223 let request = server
224 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
225 cx.background().spawn(async move {
226 let result = request.await?;
227 let completion = result
228 .completions
229 .into_iter()
230 .next()
231 .map(|completion| completion_from_lsp(completion, &buffer));
232 anyhow::Ok(completion)
233 })
234 }
235
236 pub fn completions_cycling<T>(
237 &self,
238 buffer: &ModelHandle<Buffer>,
239 position: T,
240 cx: &mut ModelContext<Self>,
241 ) -> Task<Result<Vec<Completion>>>
242 where
243 T: ToPointUtf16,
244 {
245 let server = match self.authenticated_server() {
246 Ok(server) => server,
247 Err(error) => return Task::ready(Err(error)),
248 };
249
250 let buffer = buffer.read(cx).snapshot();
251 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
252 &buffer, position, cx,
253 ));
254 cx.background().spawn(async move {
255 let result = request.await?;
256 let completions = result
257 .completions
258 .into_iter()
259 .map(|completion| completion_from_lsp(completion, &buffer))
260 .collect();
261 anyhow::Ok(completions)
262 })
263 }
264
265 pub fn status(&self) -> Status {
266 match &self.server {
267 CopilotServer::Downloading => Status::Downloading,
268 CopilotServer::Error(error) => Status::Error(error.clone()),
269 CopilotServer::Started { status, .. } => match status {
270 SignInStatus::Authorized { .. } => Status::Authorized,
271 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
272 SignInStatus::SignedOut => Status::SignedOut,
273 },
274 }
275 }
276
277 fn update_sign_in_status(
278 &mut self,
279 lsp_status: request::SignInStatus,
280 cx: &mut ModelContext<Self>,
281 ) {
282 if let CopilotServer::Started { status, .. } = &mut self.server {
283 *status = match lsp_status {
284 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
285 SignInStatus::Authorized { user }
286 }
287 request::SignInStatus::NotAuthorized { user } => {
288 SignInStatus::Unauthorized { user }
289 }
290 _ => SignInStatus::SignedOut,
291 };
292 cx.notify();
293 }
294 }
295
296 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
297 match &self.server {
298 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
299 CopilotServer::Error(error) => Err(anyhow!(
300 "copilot was not started because of an error: {}",
301 error
302 )),
303 CopilotServer::Started { server, status } => {
304 if matches!(status, SignInStatus::Authorized { .. }) {
305 Ok(server.clone())
306 } else {
307 Err(anyhow!("must sign in before using copilot"))
308 }
309 }
310 }
311 }
312}
313
314fn build_completion_params<T>(
315 buffer: &BufferSnapshot,
316 position: T,
317 cx: &AppContext,
318) -> request::GetCompletionsParams
319where
320 T: ToPointUtf16,
321{
322 let position = position.to_point_utf16(&buffer);
323 let language_name = buffer.language_at(position).map(|language| language.name());
324 let language_name = language_name.as_deref();
325
326 let path;
327 let relative_path;
328 if let Some(file) = buffer.file() {
329 if let Some(file) = file.as_local() {
330 path = file.abs_path(cx);
331 } else {
332 path = file.full_path(cx);
333 }
334 relative_path = file.path().to_path_buf();
335 } else {
336 path = PathBuf::from("/untitled");
337 relative_path = PathBuf::from("untitled");
338 }
339
340 let settings = cx.global::<Settings>();
341 let language_id = match language_name {
342 Some("Plain Text") => "plaintext".to_string(),
343 Some(language_name) => language_name.to_lowercase(),
344 None => "plaintext".to_string(),
345 };
346 request::GetCompletionsParams {
347 doc: request::GetCompletionsDocument {
348 source: buffer.text(),
349 tab_size: settings.tab_size(language_name).into(),
350 indent_size: 1,
351 insert_spaces: !settings.hard_tabs(language_name),
352 uri: lsp::Url::from_file_path(&path).unwrap(),
353 path: path.to_string_lossy().into(),
354 relative_path: relative_path.to_string_lossy().into(),
355 language_id,
356 position: point_to_lsp(position),
357 version: 0,
358 },
359 }
360}
361
362fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
363 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
364 Completion {
365 position: buffer.anchor_before(position),
366 text: completion.display_text,
367 }
368}
369
370async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
371 ///Check for the latest copilot language server and download it if we haven't already
372 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
373 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
374 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
375 let asset = release
376 .assets
377 .iter()
378 .find(|asset| asset.name == asset_name)
379 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
380
381 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
382 let destination_path =
383 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
384
385 if fs::metadata(&destination_path).await.is_err() {
386 let mut response = http
387 .get(&asset.browser_download_url, Default::default(), true)
388 .await
389 .map_err(|err| anyhow!("error downloading release: {}", err))?;
390 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
391 let mut file = fs::File::create(&destination_path).await?;
392 futures::io::copy(decompressed_bytes, &mut file).await?;
393 fs::set_permissions(
394 &destination_path,
395 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
396 )
397 .await?;
398
399 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
400 }
401
402 Ok(destination_path)
403 }
404
405 match fetch_latest(http).await {
406 ok @ Result::Ok(..) => ok,
407 e @ Err(..) => {
408 e.log_err();
409 // Fetch a cached binary, if it exists
410 (|| async move {
411 let mut last = None;
412 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
413 while let Some(entry) = entries.next().await {
414 last = Some(entry?.path());
415 }
416 last.ok_or_else(|| anyhow!("no cached binary"))
417 })()
418 .await
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use gpui::TestAppContext;
427 use util::http;
428
429 #[gpui::test]
430 async fn test_smoke(cx: &mut TestAppContext) {
431 Settings::test_async(cx);
432 let http = http::client();
433 let copilot = cx.add_model(|cx| Copilot::start(http, cx));
434 smol::Timer::after(std::time::Duration::from_secs(2)).await;
435 copilot
436 .update(cx, |copilot, cx| copilot.sign_in(cx))
437 .await
438 .unwrap();
439 dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
440
441 let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
442 dbg!(copilot
443 .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
444 .await
445 .unwrap());
446 dbg!(copilot
447 .update(cx, |copilot, cx| copilot
448 .completions_cycling(&buffer, 12, cx))
449 .await
450 .unwrap());
451 }
452}