1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use client::Client;
6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
7use gpui::{
8 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
9 Task,
10};
11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use settings::Settings;
15use smol::{fs, stream::StreamExt};
16use std::{
17 ffi::OsString,
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
22
23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
24actions!(copilot_auth, [SignIn, SignOut]);
25
26const COPILOT_NAMESPACE: &'static str = "copilot";
27actions!(copilot, [NextSuggestion]);
28
29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
30 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
31 cx.set_global(copilot.clone());
32 cx.add_global_action(|_: &SignIn, cx| {
33 let copilot = Copilot::global(cx).unwrap();
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_in(cx))
36 .detach_and_log_err(cx);
37 });
38 cx.add_global_action(|_: &SignOut, cx| {
39 let copilot = Copilot::global(cx).unwrap();
40 copilot
41 .update(cx, |copilot, cx| copilot.sign_out(cx))
42 .detach_and_log_err(cx);
43 });
44
45 cx.observe(&copilot, |handle, cx| {
46 let status = handle.read(cx).status();
47 cx.update_global::<collections::CommandPaletteFilter, _, _>(
48 move |filter, _cx| match status {
49 Status::Disabled => {
50 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
51 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
52 }
53 Status::Authorized => {
54 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
55 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
56 }
57 _ => {
58 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
60 }
61 },
62 );
63 })
64 .detach();
65
66 sign_in::init(cx);
67}
68
69enum CopilotServer {
70 Downloading,
71 Error(Arc<str>),
72 Disabled,
73 Started {
74 server: Arc<LanguageServer>,
75 status: SignInStatus,
76 },
77}
78
79#[derive(Clone, Debug)]
80enum SignInStatus {
81 Authorized {
82 _user: String,
83 },
84 Unauthorized {
85 _user: String,
86 },
87 SigningIn {
88 prompt: Option<request::PromptUserDeviceFlow>,
89 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
90 },
91 SignedOut,
92}
93
94#[derive(Debug, PartialEq, Eq)]
95pub enum Status {
96 Downloading,
97 Error(Arc<str>),
98 Disabled,
99 SignedOut,
100 SigningIn {
101 prompt: Option<request::PromptUserDeviceFlow>,
102 },
103 Unauthorized,
104 Authorized,
105}
106
107impl Status {
108 pub fn is_authorized(&self) -> bool {
109 matches!(self, Status::Authorized)
110 }
111}
112
113#[derive(Debug, PartialEq, Eq)]
114pub struct Completion {
115 pub position: Anchor,
116 pub text: String,
117}
118
119pub struct Copilot {
120 server: CopilotServer,
121}
122
123impl Entity for Copilot {
124 type Event = ();
125}
126
127impl Copilot {
128 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
129 if cx.has_global::<ModelHandle<Self>>() {
130 Some(cx.global::<ModelHandle<Self>>().clone())
131 } else {
132 None
133 }
134 }
135
136 fn start(
137 http: Arc<dyn HttpClient>,
138 node_runtime: Arc<NodeRuntime>,
139 cx: &mut ModelContext<Self>,
140 ) -> Self {
141 // TODO: Make this task resilient to users thrashing the copilot setting
142 cx.observe_global::<Settings, _>({
143 let http = http.clone();
144 let node_runtime = node_runtime.clone();
145 move |this, cx| {
146 if cx.global::<Settings>().copilot.as_bool() {
147 if matches!(this.server, CopilotServer::Disabled) {
148 cx.spawn({
149 let http = http.clone();
150 let node_runtime = node_runtime.clone();
151 move |this, cx| {
152 Self::start_language_server(http, node_runtime, this, cx)
153 }
154 })
155 .detach();
156 }
157 } else {
158 // TODO: What else needs to be turned off here?
159 this.server = CopilotServer::Disabled
160 }
161 }
162 })
163 .detach();
164
165 if !cx.global::<Settings>().copilot.as_bool() {
166 return Self {
167 server: CopilotServer::Disabled,
168 };
169 }
170
171 cx.spawn({
172 let http = http.clone();
173 let node_runtime = node_runtime.clone();
174 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
175 })
176 .detach();
177
178 Self {
179 server: CopilotServer::Downloading,
180 }
181 }
182
183 fn start_language_server(
184 http: Arc<dyn HttpClient>,
185 node_runtime: Arc<NodeRuntime>,
186 this: ModelHandle<Self>,
187 mut cx: AsyncAppContext,
188 ) -> impl Future<Output = ()> {
189 async move {
190 let start_language_server = async {
191 let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
192 let node_path = node_runtime.binary_path().await?;
193 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
194 let server =
195 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
196
197 let server = server.initialize(Default::default()).await?;
198 let status = server
199 .request::<request::CheckStatus>(request::CheckStatusParams {
200 local_checks_only: false,
201 })
202 .await?;
203 anyhow::Ok((server, status))
204 };
205
206 let server = start_language_server.await;
207 this.update(&mut cx, |this, cx| {
208 cx.notify();
209 match server {
210 Ok((server, status)) => {
211 this.server = CopilotServer::Started {
212 server,
213 status: SignInStatus::SignedOut,
214 };
215 this.update_sign_in_status(status, cx);
216 }
217 Err(error) => {
218 this.server = CopilotServer::Error(error.to_string().into());
219 }
220 }
221 })
222 }
223 }
224
225 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
226 if let CopilotServer::Started { server, status } = &mut self.server {
227 let task = match status {
228 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
229 cx.notify();
230 Task::ready(Ok(())).shared()
231 }
232 SignInStatus::SigningIn { task, .. } => {
233 cx.notify(); // To re-show the prompt, just in case.
234 task.clone()
235 }
236 SignInStatus::SignedOut => {
237 let server = server.clone();
238 let task = cx
239 .spawn(|this, mut cx| async move {
240 let sign_in = async {
241 let sign_in = server
242 .request::<request::SignInInitiate>(
243 request::SignInInitiateParams {},
244 )
245 .await?;
246 match sign_in {
247 request::SignInInitiateResult::AlreadySignedIn { user } => {
248 Ok(request::SignInStatus::Ok { user })
249 }
250 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
251 this.update(&mut cx, |this, cx| {
252 if let CopilotServer::Started { status, .. } =
253 &mut this.server
254 {
255 if let SignInStatus::SigningIn {
256 prompt: prompt_flow,
257 ..
258 } = status
259 {
260 *prompt_flow = Some(flow.clone());
261 cx.notify();
262 }
263 }
264 });
265 let response = server
266 .request::<request::SignInConfirm>(
267 request::SignInConfirmParams {
268 user_code: flow.user_code,
269 },
270 )
271 .await?;
272 Ok(response)
273 }
274 }
275 };
276
277 let sign_in = sign_in.await;
278 this.update(&mut cx, |this, cx| match sign_in {
279 Ok(status) => {
280 this.update_sign_in_status(status, cx);
281 Ok(())
282 }
283 Err(error) => {
284 this.update_sign_in_status(
285 request::SignInStatus::NotSignedIn,
286 cx,
287 );
288 Err(Arc::new(error))
289 }
290 })
291 })
292 .shared();
293 *status = SignInStatus::SigningIn {
294 prompt: None,
295 task: task.clone(),
296 };
297 cx.notify();
298 task
299 }
300 };
301
302 cx.foreground()
303 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
304 } else {
305 Task::ready(Err(anyhow!("copilot hasn't started yet")))
306 }
307 }
308
309 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
310 if let CopilotServer::Started { server, status } = &mut self.server {
311 *status = SignInStatus::SignedOut;
312 cx.notify();
313
314 let server = server.clone();
315 cx.background().spawn(async move {
316 server
317 .request::<request::SignOut>(request::SignOutParams {})
318 .await?;
319 anyhow::Ok(())
320 })
321 } else {
322 Task::ready(Err(anyhow!("copilot hasn't started yet")))
323 }
324 }
325
326 pub fn completion<T>(
327 &self,
328 buffer: &ModelHandle<Buffer>,
329 position: T,
330 cx: &mut ModelContext<Self>,
331 ) -> Task<Result<Option<Completion>>>
332 where
333 T: ToPointUtf16,
334 {
335 let server = match self.authorized_server() {
336 Ok(server) => server,
337 Err(error) => return Task::ready(Err(error)),
338 };
339
340 let buffer = buffer.read(cx).snapshot();
341 let request = server
342 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
343 cx.background().spawn(async move {
344 let result = request.await?;
345 let completion = result
346 .completions
347 .into_iter()
348 .next()
349 .map(|completion| completion_from_lsp(completion, &buffer));
350 anyhow::Ok(completion)
351 })
352 }
353
354 pub fn completions_cycling<T>(
355 &self,
356 buffer: &ModelHandle<Buffer>,
357 position: T,
358 cx: &mut ModelContext<Self>,
359 ) -> Task<Result<Vec<Completion>>>
360 where
361 T: ToPointUtf16,
362 {
363 let server = match self.authorized_server() {
364 Ok(server) => server,
365 Err(error) => return Task::ready(Err(error)),
366 };
367
368 let buffer = buffer.read(cx).snapshot();
369 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
370 &buffer, position, cx,
371 ));
372 cx.background().spawn(async move {
373 let result = request.await?;
374 let completions = result
375 .completions
376 .into_iter()
377 .map(|completion| completion_from_lsp(completion, &buffer))
378 .collect();
379 anyhow::Ok(completions)
380 })
381 }
382
383 pub fn status(&self) -> Status {
384 match &self.server {
385 CopilotServer::Downloading => Status::Downloading,
386 CopilotServer::Disabled => Status::Disabled,
387 CopilotServer::Error(error) => Status::Error(error.clone()),
388 CopilotServer::Started { status, .. } => match status {
389 SignInStatus::Authorized { .. } => Status::Authorized,
390 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
391 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
392 prompt: prompt.clone(),
393 },
394 SignInStatus::SignedOut => Status::SignedOut,
395 },
396 }
397 }
398
399 fn update_sign_in_status(
400 &mut self,
401 lsp_status: request::SignInStatus,
402 cx: &mut ModelContext<Self>,
403 ) {
404 if let CopilotServer::Started { status, .. } = &mut self.server {
405 *status = match lsp_status {
406 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
407 SignInStatus::Authorized { _user: user }
408 }
409 request::SignInStatus::NotAuthorized { user } => {
410 SignInStatus::Unauthorized { _user: user }
411 }
412 _ => SignInStatus::SignedOut,
413 };
414 cx.notify();
415 }
416 }
417
418 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
419 match &self.server {
420 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
421 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
422 CopilotServer::Error(error) => Err(anyhow!(
423 "copilot was not started because of an error: {}",
424 error
425 )),
426 CopilotServer::Started { server, status } => {
427 if matches!(status, SignInStatus::Authorized { .. }) {
428 Ok(server.clone())
429 } else {
430 Err(anyhow!("must sign in before using copilot"))
431 }
432 }
433 }
434 }
435}
436
437fn build_completion_params<T>(
438 buffer: &BufferSnapshot,
439 position: T,
440 cx: &AppContext,
441) -> request::GetCompletionsParams
442where
443 T: ToPointUtf16,
444{
445 let position = position.to_point_utf16(&buffer);
446 let language_name = buffer.language_at(position).map(|language| language.name());
447 let language_name = language_name.as_deref();
448
449 let path;
450 let relative_path;
451 if let Some(file) = buffer.file() {
452 if let Some(file) = file.as_local() {
453 path = file.abs_path(cx);
454 } else {
455 path = file.full_path(cx);
456 }
457 relative_path = file.path().to_path_buf();
458 } else {
459 path = PathBuf::from("/untitled");
460 relative_path = PathBuf::from("untitled");
461 }
462
463 let settings = cx.global::<Settings>();
464 let language_id = match language_name {
465 Some("Plain Text") => "plaintext".to_string(),
466 Some(language_name) => language_name.to_lowercase(),
467 None => "plaintext".to_string(),
468 };
469 request::GetCompletionsParams {
470 doc: request::GetCompletionsDocument {
471 source: buffer.text(),
472 tab_size: settings.tab_size(language_name).into(),
473 indent_size: 1,
474 insert_spaces: !settings.hard_tabs(language_name),
475 uri: lsp::Url::from_file_path(&path).unwrap(),
476 path: path.to_string_lossy().into(),
477 relative_path: relative_path.to_string_lossy().into(),
478 language_id,
479 position: point_to_lsp(position),
480 version: 0,
481 },
482 }
483}
484
485fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
486 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
487 Completion {
488 position: buffer.anchor_before(position),
489 text: completion.display_text,
490 }
491}
492
493async fn get_copilot_lsp(
494 http: Arc<dyn HttpClient>,
495 node: Arc<NodeRuntime>,
496) -> anyhow::Result<PathBuf> {
497 const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
498
499 ///Check for the latest copilot language server and download it if we haven't already
500 async fn fetch_latest(
501 _http: Arc<dyn HttpClient>,
502 node: Arc<NodeRuntime>,
503 ) -> anyhow::Result<PathBuf> {
504 const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
505
506 let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
507
508 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
509
510 fs::create_dir_all(version_dir).await?;
511 let server_path = version_dir.join(SERVER_PATH);
512
513 if fs::metadata(&server_path).await.is_err() {
514 node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
515 .await?;
516
517 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
518 }
519
520 Ok(server_path)
521 }
522
523 match fetch_latest(http, node).await {
524 ok @ Result::Ok(..) => ok,
525 e @ Err(..) => {
526 e.log_err();
527 // Fetch a cached binary, if it exists
528 (|| async move {
529 let mut last_version_dir = None;
530 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
531 while let Some(entry) = entries.next().await {
532 let entry = entry?;
533 if entry.file_type().await?.is_dir() {
534 last_version_dir = Some(entry.path());
535 }
536 }
537 let last_version_dir =
538 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
539 let server_path = last_version_dir.join(SERVER_PATH);
540 if server_path.exists() {
541 Ok(server_path)
542 } else {
543 Err(anyhow!(
544 "missing executable in directory {:?}",
545 last_version_dir
546 ))
547 }
548 })()
549 .await
550 }
551 }
552}