1pub mod copilot_button;
2mod request;
3mod sign_in;
4
5use anyhow::{anyhow, Result};
6use client::Client;
7use futures::{future::Shared, Future, FutureExt, TryFutureExt};
8use gpui::{
9 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
10 Task,
11};
12use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
13use lsp::LanguageServer;
14use node_runtime::NodeRuntime;
15use settings::Settings;
16use smol::{fs, stream::StreamExt};
17use std::{
18 ffi::OsString,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
23
24const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
25actions!(copilot_auth, [SignIn, SignOut]);
26
27const COPILOT_NAMESPACE: &'static str = "copilot";
28actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
29
30pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
31 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
32 cx.set_global(copilot.clone());
33 cx.add_global_action(|_: &SignIn, cx| {
34 let copilot = Copilot::global(cx).unwrap();
35 copilot
36 .update(cx, |copilot, cx| copilot.sign_in(cx))
37 .detach_and_log_err(cx);
38 });
39 cx.add_global_action(|_: &SignOut, cx| {
40 let copilot = Copilot::global(cx).unwrap();
41 copilot
42 .update(cx, |copilot, cx| copilot.sign_out(cx))
43 .detach_and_log_err(cx);
44 });
45
46 cx.observe(&copilot, |handle, cx| {
47 let status = handle.read(cx).status();
48 cx.update_global::<collections::CommandPaletteFilter, _, _>(
49 move |filter, _cx| match status {
50 Status::Disabled => {
51 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
52 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
53 }
54 Status::Authorized => {
55 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
56 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
57 }
58 _ => {
59 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
60 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
61 }
62 },
63 );
64 })
65 .detach();
66
67 sign_in::init(cx);
68}
69
70enum CopilotServer {
71 Disabled,
72 Starting {
73 _task: Shared<Task<()>>,
74 },
75 Error(Arc<str>),
76 Started {
77 server: Arc<LanguageServer>,
78 status: SignInStatus,
79 },
80}
81
82#[derive(Clone, Debug)]
83enum SignInStatus {
84 Authorized {
85 _user: String,
86 },
87 Unauthorized {
88 _user: String,
89 },
90 SigningIn {
91 prompt: Option<request::PromptUserDeviceFlow>,
92 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
93 },
94 SignedOut,
95}
96
97#[derive(Debug, PartialEq, Eq)]
98pub enum Status {
99 Starting,
100 Error(Arc<str>),
101 Disabled,
102 SignedOut,
103 SigningIn {
104 prompt: Option<request::PromptUserDeviceFlow>,
105 },
106 Unauthorized,
107 Authorized,
108}
109
110impl Status {
111 pub fn is_authorized(&self) -> bool {
112 matches!(self, Status::Authorized)
113 }
114}
115
116#[derive(Debug, PartialEq, Eq)]
117pub struct Completion {
118 pub position: Anchor,
119 pub text: String,
120}
121
122pub struct Copilot {
123 server: CopilotServer,
124}
125
126impl Entity for Copilot {
127 type Event = ();
128}
129
130impl Copilot {
131 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
132 if cx.has_global::<ModelHandle<Self>>() {
133 Some(cx.global::<ModelHandle<Self>>().clone())
134 } else {
135 None
136 }
137 }
138
139 fn start(
140 http: Arc<dyn HttpClient>,
141 node_runtime: Arc<NodeRuntime>,
142 cx: &mut ModelContext<Self>,
143 ) -> Self {
144 cx.observe_global::<Settings, _>({
145 let http = http.clone();
146 let node_runtime = node_runtime.clone();
147 move |this, cx| {
148 if cx.global::<Settings>().enable_copilot_integration {
149 if matches!(this.server, CopilotServer::Disabled) {
150 let start_task = cx
151 .spawn({
152 let http = http.clone();
153 let node_runtime = node_runtime.clone();
154 move |this, cx| {
155 Self::start_language_server(http, node_runtime, this, cx)
156 }
157 })
158 .shared();
159 this.server = CopilotServer::Starting { _task: start_task }
160 }
161 } else {
162 this.server = CopilotServer::Disabled
163 }
164 }
165 })
166 .detach();
167
168 if cx.global::<Settings>().enable_copilot_integration {
169 let start_task = cx
170 .spawn({
171 let http = http.clone();
172 let node_runtime = node_runtime.clone();
173 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
174 })
175 .shared();
176
177 Self {
178 server: CopilotServer::Starting { _task: start_task },
179 }
180 } else {
181 Self {
182 server: CopilotServer::Disabled,
183 }
184 }
185 }
186
187 fn start_language_server(
188 http: Arc<dyn HttpClient>,
189 node_runtime: Arc<NodeRuntime>,
190 this: ModelHandle<Self>,
191 mut cx: AsyncAppContext,
192 ) -> impl Future<Output = ()> {
193 async move {
194 let start_language_server = async {
195 let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
196 let node_path = node_runtime.binary_path().await?;
197 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
198 let server =
199 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
200
201 let server = server.initialize(Default::default()).await?;
202 let status = server
203 .request::<request::CheckStatus>(request::CheckStatusParams {
204 local_checks_only: false,
205 })
206 .await?;
207 anyhow::Ok((server, status))
208 };
209
210 let server = start_language_server.await;
211 this.update(&mut cx, |this, cx| {
212 cx.notify();
213 match server {
214 Ok((server, status)) => {
215 this.server = CopilotServer::Started {
216 server,
217 status: SignInStatus::SignedOut,
218 };
219 this.update_sign_in_status(status, cx);
220 }
221 Err(error) => {
222 this.server = CopilotServer::Error(error.to_string().into());
223 cx.notify()
224 }
225 }
226 })
227 }
228 }
229
230 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
231 if let CopilotServer::Started { server, status } = &mut self.server {
232 let task = match status {
233 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
234 Task::ready(Ok(())).shared()
235 }
236 SignInStatus::SigningIn { task, .. } => {
237 cx.notify();
238 task.clone()
239 }
240 SignInStatus::SignedOut => {
241 let server = server.clone();
242 let task = cx
243 .spawn(|this, mut cx| async move {
244 let sign_in = async {
245 let sign_in = server
246 .request::<request::SignInInitiate>(
247 request::SignInInitiateParams {},
248 )
249 .await?;
250 match sign_in {
251 request::SignInInitiateResult::AlreadySignedIn { user } => {
252 Ok(request::SignInStatus::Ok { user })
253 }
254 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
255 this.update(&mut cx, |this, cx| {
256 if let CopilotServer::Started { status, .. } =
257 &mut this.server
258 {
259 if let SignInStatus::SigningIn {
260 prompt: prompt_flow,
261 ..
262 } = status
263 {
264 *prompt_flow = Some(flow.clone());
265 cx.notify();
266 }
267 }
268 });
269 let response = server
270 .request::<request::SignInConfirm>(
271 request::SignInConfirmParams {
272 user_code: flow.user_code,
273 },
274 )
275 .await?;
276 Ok(response)
277 }
278 }
279 };
280
281 let sign_in = sign_in.await;
282 this.update(&mut cx, |this, cx| match sign_in {
283 Ok(status) => {
284 this.update_sign_in_status(status, cx);
285 Ok(())
286 }
287 Err(error) => {
288 this.update_sign_in_status(
289 request::SignInStatus::NotSignedIn,
290 cx,
291 );
292 Err(Arc::new(error))
293 }
294 })
295 })
296 .shared();
297 *status = SignInStatus::SigningIn {
298 prompt: None,
299 task: task.clone(),
300 };
301 cx.notify();
302 task
303 }
304 };
305
306 cx.foreground()
307 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
308 } else {
309 Task::ready(Err(anyhow!("copilot hasn't started yet")))
310 }
311 }
312
313 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
314 if let CopilotServer::Started { server, status } = &mut self.server {
315 *status = SignInStatus::SignedOut;
316 cx.notify();
317
318 let server = server.clone();
319 cx.background().spawn(async move {
320 server
321 .request::<request::SignOut>(request::SignOutParams {})
322 .await?;
323 anyhow::Ok(())
324 })
325 } else {
326 Task::ready(Err(anyhow!("copilot hasn't started yet")))
327 }
328 }
329
330 pub fn completion<T>(
331 &self,
332 buffer: &ModelHandle<Buffer>,
333 position: T,
334 cx: &mut ModelContext<Self>,
335 ) -> Task<Result<Option<Completion>>>
336 where
337 T: ToPointUtf16,
338 {
339 let server = match self.authorized_server() {
340 Ok(server) => server,
341 Err(error) => return Task::ready(Err(error)),
342 };
343
344 let buffer = buffer.read(cx).snapshot();
345 let request = server
346 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
347 cx.background().spawn(async move {
348 let result = request.await?;
349 let completion = result
350 .completions
351 .into_iter()
352 .next()
353 .map(|completion| completion_from_lsp(completion, &buffer));
354 anyhow::Ok(completion)
355 })
356 }
357
358 pub fn completions_cycling<T>(
359 &self,
360 buffer: &ModelHandle<Buffer>,
361 position: T,
362 cx: &mut ModelContext<Self>,
363 ) -> Task<Result<Vec<Completion>>>
364 where
365 T: ToPointUtf16,
366 {
367 let server = match self.authorized_server() {
368 Ok(server) => server,
369 Err(error) => return Task::ready(Err(error)),
370 };
371
372 let buffer = buffer.read(cx).snapshot();
373 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
374 &buffer, position, cx,
375 ));
376 cx.background().spawn(async move {
377 let result = request.await?;
378 let completions = result
379 .completions
380 .into_iter()
381 .map(|completion| completion_from_lsp(completion, &buffer))
382 .collect();
383 anyhow::Ok(completions)
384 })
385 }
386
387 pub fn status(&self) -> Status {
388 match &self.server {
389 CopilotServer::Starting { .. } => Status::Starting,
390 CopilotServer::Disabled => Status::Disabled,
391 CopilotServer::Error(error) => Status::Error(error.clone()),
392 CopilotServer::Started { status, .. } => match status {
393 SignInStatus::Authorized { .. } => Status::Authorized,
394 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
395 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
396 prompt: prompt.clone(),
397 },
398 SignInStatus::SignedOut => Status::SignedOut,
399 },
400 }
401 }
402
403 fn update_sign_in_status(
404 &mut self,
405 lsp_status: request::SignInStatus,
406 cx: &mut ModelContext<Self>,
407 ) {
408 if let CopilotServer::Started { status, .. } = &mut self.server {
409 *status = match lsp_status {
410 request::SignInStatus::Ok { user }
411 | request::SignInStatus::MaybeOk { user }
412 | request::SignInStatus::AlreadySignedIn { user } => {
413 SignInStatus::Authorized { _user: user }
414 }
415 request::SignInStatus::NotAuthorized { user } => {
416 SignInStatus::Unauthorized { _user: user }
417 }
418 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
419 };
420 cx.notify();
421 }
422 }
423
424 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
425 match &self.server {
426 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
427 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
428 CopilotServer::Error(error) => Err(anyhow!(
429 "copilot was not started because of an error: {}",
430 error
431 )),
432 CopilotServer::Started { server, status } => {
433 if matches!(status, SignInStatus::Authorized { .. }) {
434 Ok(server.clone())
435 } else {
436 Err(anyhow!("must sign in before using copilot"))
437 }
438 }
439 }
440 }
441}
442
443fn build_completion_params<T>(
444 buffer: &BufferSnapshot,
445 position: T,
446 cx: &AppContext,
447) -> request::GetCompletionsParams
448where
449 T: ToPointUtf16,
450{
451 let position = position.to_point_utf16(&buffer);
452 let language_name = buffer.language_at(position).map(|language| language.name());
453 let language_name = language_name.as_deref();
454
455 let path;
456 let relative_path;
457 if let Some(file) = buffer.file() {
458 if let Some(file) = file.as_local() {
459 path = file.abs_path(cx);
460 } else {
461 path = file.full_path(cx);
462 }
463 relative_path = file.path().to_path_buf();
464 } else {
465 path = PathBuf::from("/untitled");
466 relative_path = PathBuf::from("untitled");
467 }
468
469 let settings = cx.global::<Settings>();
470 let language_id = match language_name {
471 Some("Plain Text") => "plaintext".to_string(),
472 Some(language_name) => language_name.to_lowercase(),
473 None => "plaintext".to_string(),
474 };
475 request::GetCompletionsParams {
476 doc: request::GetCompletionsDocument {
477 source: buffer.text(),
478 tab_size: settings.tab_size(language_name).into(),
479 indent_size: 1,
480 insert_spaces: !settings.hard_tabs(language_name),
481 uri: lsp::Url::from_file_path(&path).unwrap(),
482 path: path.to_string_lossy().into(),
483 relative_path: relative_path.to_string_lossy().into(),
484 language_id,
485 position: point_to_lsp(position),
486 version: 0,
487 },
488 }
489}
490
491fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
492 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
493 Completion {
494 position: buffer.anchor_before(position),
495 text: completion.display_text,
496 }
497}
498
499async fn get_copilot_lsp(
500 http: Arc<dyn HttpClient>,
501 node: Arc<NodeRuntime>,
502) -> anyhow::Result<PathBuf> {
503 const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
504
505 ///Check for the latest copilot language server and download it if we haven't already
506 async fn fetch_latest(
507 _http: Arc<dyn HttpClient>,
508 node: Arc<NodeRuntime>,
509 ) -> anyhow::Result<PathBuf> {
510 const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
511
512 let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
513
514 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
515
516 fs::create_dir_all(version_dir).await?;
517 let server_path = version_dir.join(SERVER_PATH);
518
519 if fs::metadata(&server_path).await.is_err() {
520 node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
521 .await?;
522
523 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
524 }
525
526 Ok(server_path)
527 }
528
529 match fetch_latest(http, node).await {
530 ok @ Result::Ok(..) => ok,
531 e @ Err(..) => {
532 e.log_err();
533 // Fetch a cached binary, if it exists
534 (|| async move {
535 let mut last_version_dir = None;
536 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
537 while let Some(entry) = entries.next().await {
538 let entry = entry?;
539 if entry.file_type().await?.is_dir() {
540 last_version_dir = Some(entry.path());
541 }
542 }
543 let last_version_dir =
544 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
545 let server_path = last_version_dir.join(SERVER_PATH);
546 if server_path.exists() {
547 Ok(server_path)
548 } else {
549 Err(anyhow!(
550 "missing executable in directory {:?}",
551 last_version_dir
552 ))
553 }
554 })()
555 .await
556 }
557 }
558}