1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use client::Client;
6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
7use gpui::{
8 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
9 Task,
10};
11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use settings::Settings;
15use smol::{fs, stream::StreamExt};
16use std::{
17 ffi::OsString,
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
22
23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
24actions!(copilot_auth, [SignIn, SignOut]);
25
26const COPILOT_NAMESPACE: &'static str = "copilot";
27actions!(copilot, [NextSuggestion]);
28
29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
30 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
31 cx.set_global(copilot.clone());
32 cx.add_global_action(|_: &SignIn, cx| {
33 let copilot = Copilot::global(cx).unwrap();
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_in(cx))
36 .detach_and_log_err(cx);
37 });
38 cx.add_global_action(|_: &SignOut, cx| {
39 let copilot = Copilot::global(cx).unwrap();
40 copilot
41 .update(cx, |copilot, cx| copilot.sign_out(cx))
42 .detach_and_log_err(cx);
43 });
44
45 cx.observe(&copilot, |handle, cx| {
46 let status = handle.read(cx).status();
47 cx.update_global::<collections::CommandPaletteFilter, _, _>(
48 move |filter, _cx| match status {
49 Status::Disabled => {
50 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
51 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
52 }
53 Status::Authorized => {
54 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
55 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
56 }
57 _ => {
58 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
60 }
61 },
62 );
63 })
64 .detach();
65
66 sign_in::init(cx);
67}
68
69enum CopilotServer {
70 Downloading,
71 Error(Arc<str>),
72 Disabled,
73 Started {
74 server: Arc<LanguageServer>,
75 status: SignInStatus,
76 },
77}
78
79#[derive(Clone, Debug)]
80enum SignInStatus {
81 Authorized {
82 _user: String,
83 },
84 Unauthorized {
85 _user: String,
86 },
87 SigningIn {
88 prompt: Option<request::PromptUserDeviceFlow>,
89 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
90 },
91 SignedOut,
92}
93
94#[derive(Debug, PartialEq, Eq)]
95pub enum Status {
96 Downloading,
97 Error(Arc<str>),
98 Disabled,
99 SignedOut,
100 SigningIn {
101 prompt: Option<request::PromptUserDeviceFlow>,
102 },
103 Unauthorized,
104 Authorized,
105}
106
107impl Status {
108 pub fn is_authorized(&self) -> bool {
109 matches!(self, Status::Authorized)
110 }
111}
112
113#[derive(Debug, PartialEq, Eq)]
114pub struct Completion {
115 pub position: Anchor,
116 pub text: String,
117}
118
119pub struct Copilot {
120 server: CopilotServer,
121}
122
123impl Entity for Copilot {
124 type Event = ();
125}
126
127impl Copilot {
128 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
129 if cx.has_global::<ModelHandle<Self>>() {
130 Some(cx.global::<ModelHandle<Self>>().clone())
131 } else {
132 None
133 }
134 }
135
136 fn start(
137 http: Arc<dyn HttpClient>,
138 node_runtime: Arc<NodeRuntime>,
139 cx: &mut ModelContext<Self>,
140 ) -> Self {
141 // TODO: Make this task resilient to users thrashing the copilot setting
142 cx.observe_global::<Settings, _>({
143 let http = http.clone();
144 let node_runtime = node_runtime.clone();
145 move |this, cx| {
146 if cx.global::<Settings>().copilot.as_bool() {
147 if matches!(this.server, CopilotServer::Disabled) {
148 cx.spawn({
149 let http = http.clone();
150 let node_runtime = node_runtime.clone();
151 move |this, cx| {
152 Self::start_language_server(http, node_runtime, this, cx)
153 }
154 })
155 .detach();
156 }
157 } else {
158 // TODO: What else needs to be turned off here?
159 this.server = CopilotServer::Disabled
160 }
161 }
162 })
163 .detach();
164
165 if !cx.global::<Settings>().copilot.as_bool() {
166 return Self {
167 server: CopilotServer::Disabled,
168 };
169 }
170
171 cx.spawn({
172 let http = http.clone();
173 let node_runtime = node_runtime.clone();
174 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
175 })
176 .detach();
177
178 Self {
179 server: CopilotServer::Downloading,
180 }
181 }
182
183 fn start_language_server(
184 http: Arc<dyn HttpClient>,
185 node_runtime: Arc<NodeRuntime>,
186 this: ModelHandle<Self>,
187 mut cx: AsyncAppContext,
188 ) -> impl Future<Output = ()> {
189 async move {
190 let start_language_server = async {
191 let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
192 let node_path = node_runtime.binary_path().await?;
193 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
194 let server =
195 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
196
197 let server = server.initialize(Default::default()).await?;
198 let status = server
199 .request::<request::CheckStatus>(request::CheckStatusParams {
200 local_checks_only: false,
201 })
202 .await?;
203 anyhow::Ok((server, status))
204 };
205
206 let server = start_language_server.await;
207 this.update(&mut cx, |this, cx| {
208 cx.notify();
209 match server {
210 Ok((server, status)) => {
211 this.server = CopilotServer::Started {
212 server,
213 status: SignInStatus::SignedOut,
214 };
215 this.update_sign_in_status(status, cx);
216 }
217 Err(error) => {
218 this.server = CopilotServer::Error(error.to_string().into());
219 }
220 }
221 })
222 }
223 }
224
225 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
226 if let CopilotServer::Started { server, status } = &mut self.server {
227 let task = match status {
228 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
229 Task::ready(Ok(())).shared()
230 }
231 SignInStatus::SigningIn { task, .. } => {
232 cx.notify(); // To re-show the prompt, just in case.
233 task.clone()
234 }
235 SignInStatus::SignedOut => {
236 let server = server.clone();
237 let task = cx
238 .spawn(|this, mut cx| async move {
239 let sign_in = async {
240 let sign_in = server
241 .request::<request::SignInInitiate>(
242 request::SignInInitiateParams {},
243 )
244 .await?;
245 match sign_in {
246 request::SignInInitiateResult::AlreadySignedIn { user } => {
247 Ok(request::SignInStatus::Ok { user })
248 }
249 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
250 this.update(&mut cx, |this, cx| {
251 if let CopilotServer::Started { status, .. } =
252 &mut this.server
253 {
254 if let SignInStatus::SigningIn {
255 prompt: prompt_flow,
256 ..
257 } = status
258 {
259 *prompt_flow = Some(flow.clone());
260 cx.notify();
261 }
262 }
263 });
264 let response = server
265 .request::<request::SignInConfirm>(
266 request::SignInConfirmParams {
267 user_code: flow.user_code,
268 },
269 )
270 .await?;
271 Ok(response)
272 }
273 }
274 };
275
276 let sign_in = sign_in.await;
277 this.update(&mut cx, |this, cx| match sign_in {
278 Ok(status) => {
279 this.update_sign_in_status(status, cx);
280 Ok(())
281 }
282 Err(error) => {
283 this.update_sign_in_status(
284 request::SignInStatus::NotSignedIn,
285 cx,
286 );
287 Err(Arc::new(error))
288 }
289 })
290 })
291 .shared();
292 *status = SignInStatus::SigningIn {
293 prompt: None,
294 task: task.clone(),
295 };
296 cx.notify();
297 task
298 }
299 };
300
301 cx.foreground()
302 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
303 } else {
304 Task::ready(Err(anyhow!("copilot hasn't started yet")))
305 }
306 }
307
308 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
309 if let CopilotServer::Started { server, status } = &mut self.server {
310 *status = SignInStatus::SignedOut;
311 cx.notify();
312
313 let server = server.clone();
314 cx.background().spawn(async move {
315 server
316 .request::<request::SignOut>(request::SignOutParams {})
317 .await?;
318 anyhow::Ok(())
319 })
320 } else {
321 Task::ready(Err(anyhow!("copilot hasn't started yet")))
322 }
323 }
324
325 pub fn completion<T>(
326 &self,
327 buffer: &ModelHandle<Buffer>,
328 position: T,
329 cx: &mut ModelContext<Self>,
330 ) -> Task<Result<Option<Completion>>>
331 where
332 T: ToPointUtf16,
333 {
334 let server = match self.authorized_server() {
335 Ok(server) => server,
336 Err(error) => return Task::ready(Err(error)),
337 };
338
339 let buffer = buffer.read(cx).snapshot();
340 let request = server
341 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
342 cx.background().spawn(async move {
343 let result = request.await?;
344 let completion = result
345 .completions
346 .into_iter()
347 .next()
348 .map(|completion| completion_from_lsp(completion, &buffer));
349 anyhow::Ok(completion)
350 })
351 }
352
353 pub fn completions_cycling<T>(
354 &self,
355 buffer: &ModelHandle<Buffer>,
356 position: T,
357 cx: &mut ModelContext<Self>,
358 ) -> Task<Result<Vec<Completion>>>
359 where
360 T: ToPointUtf16,
361 {
362 let server = match self.authorized_server() {
363 Ok(server) => server,
364 Err(error) => return Task::ready(Err(error)),
365 };
366
367 let buffer = buffer.read(cx).snapshot();
368 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
369 &buffer, position, cx,
370 ));
371 cx.background().spawn(async move {
372 let result = request.await?;
373 let completions = result
374 .completions
375 .into_iter()
376 .map(|completion| completion_from_lsp(completion, &buffer))
377 .collect();
378 anyhow::Ok(completions)
379 })
380 }
381
382 pub fn status(&self) -> Status {
383 match &self.server {
384 CopilotServer::Downloading => Status::Downloading,
385 CopilotServer::Disabled => Status::Disabled,
386 CopilotServer::Error(error) => Status::Error(error.clone()),
387 CopilotServer::Started { status, .. } => match status {
388 SignInStatus::Authorized { .. } => Status::Authorized,
389 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
390 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
391 prompt: prompt.clone(),
392 },
393 SignInStatus::SignedOut => Status::SignedOut,
394 },
395 }
396 }
397
398 fn update_sign_in_status(
399 &mut self,
400 lsp_status: request::SignInStatus,
401 cx: &mut ModelContext<Self>,
402 ) {
403 if let CopilotServer::Started { status, .. } = &mut self.server {
404 *status = match lsp_status {
405 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
406 SignInStatus::Authorized { _user: user }
407 }
408 request::SignInStatus::NotAuthorized { user } => {
409 SignInStatus::Unauthorized { _user: user }
410 }
411 _ => SignInStatus::SignedOut,
412 };
413 cx.notify();
414 }
415 }
416
417 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
418 match &self.server {
419 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
420 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
421 CopilotServer::Error(error) => Err(anyhow!(
422 "copilot was not started because of an error: {}",
423 error
424 )),
425 CopilotServer::Started { server, status } => {
426 if matches!(status, SignInStatus::Authorized { .. }) {
427 Ok(server.clone())
428 } else {
429 Err(anyhow!("must sign in before using copilot"))
430 }
431 }
432 }
433 }
434}
435
436fn build_completion_params<T>(
437 buffer: &BufferSnapshot,
438 position: T,
439 cx: &AppContext,
440) -> request::GetCompletionsParams
441where
442 T: ToPointUtf16,
443{
444 let position = position.to_point_utf16(&buffer);
445 let language_name = buffer.language_at(position).map(|language| language.name());
446 let language_name = language_name.as_deref();
447
448 let path;
449 let relative_path;
450 if let Some(file) = buffer.file() {
451 if let Some(file) = file.as_local() {
452 path = file.abs_path(cx);
453 } else {
454 path = file.full_path(cx);
455 }
456 relative_path = file.path().to_path_buf();
457 } else {
458 path = PathBuf::from("/untitled");
459 relative_path = PathBuf::from("untitled");
460 }
461
462 let settings = cx.global::<Settings>();
463 let language_id = match language_name {
464 Some("Plain Text") => "plaintext".to_string(),
465 Some(language_name) => language_name.to_lowercase(),
466 None => "plaintext".to_string(),
467 };
468 request::GetCompletionsParams {
469 doc: request::GetCompletionsDocument {
470 source: buffer.text(),
471 tab_size: settings.tab_size(language_name).into(),
472 indent_size: 1,
473 insert_spaces: !settings.hard_tabs(language_name),
474 uri: lsp::Url::from_file_path(&path).unwrap(),
475 path: path.to_string_lossy().into(),
476 relative_path: relative_path.to_string_lossy().into(),
477 language_id,
478 position: point_to_lsp(position),
479 version: 0,
480 },
481 }
482}
483
484fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
485 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
486 Completion {
487 position: buffer.anchor_before(position),
488 text: completion.display_text,
489 }
490}
491
492async fn get_copilot_lsp(
493 http: Arc<dyn HttpClient>,
494 node: Arc<NodeRuntime>,
495) -> anyhow::Result<PathBuf> {
496 const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
497
498 ///Check for the latest copilot language server and download it if we haven't already
499 async fn fetch_latest(
500 _http: Arc<dyn HttpClient>,
501 node: Arc<NodeRuntime>,
502 ) -> anyhow::Result<PathBuf> {
503 const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
504
505 let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
506
507 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
508
509 fs::create_dir_all(version_dir).await?;
510 let server_path = version_dir.join(SERVER_PATH);
511
512 if fs::metadata(&server_path).await.is_err() {
513 node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
514 .await?;
515
516 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
517 }
518
519 Ok(server_path)
520 }
521
522 match fetch_latest(http, node).await {
523 ok @ Result::Ok(..) => ok,
524 e @ Err(..) => {
525 e.log_err();
526 // Fetch a cached binary, if it exists
527 (|| async move {
528 let mut last_version_dir = None;
529 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
530 while let Some(entry) = entries.next().await {
531 let entry = entry?;
532 if entry.file_type().await?.is_dir() {
533 last_version_dir = Some(entry.path());
534 }
535 }
536 let last_version_dir =
537 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
538 let server_path = last_version_dir.join(SERVER_PATH);
539 if server_path.exists() {
540 Ok(server_path)
541 } else {
542 Err(anyhow!(
543 "missing executable in directory {:?}",
544 last_version_dir
545 ))
546 }
547 })()
548 .await
549 }
550 }
551}