1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Result};
5use client::Client;
6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
7use gpui::{
8 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
9 Task,
10};
11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use settings::Settings;
15use smol::{fs, stream::StreamExt};
16use std::{
17 ffi::OsString,
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
22
23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
24actions!(copilot_auth, [SignIn, SignOut]);
25
26const COPILOT_NAMESPACE: &'static str = "copilot";
27actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
28
29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
30 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
31 cx.set_global(copilot.clone());
32 cx.add_global_action(|_: &SignIn, cx| {
33 let copilot = Copilot::global(cx).unwrap();
34 copilot
35 .update(cx, |copilot, cx| copilot.sign_in(cx))
36 .detach_and_log_err(cx);
37 });
38 cx.add_global_action(|_: &SignOut, cx| {
39 let copilot = Copilot::global(cx).unwrap();
40 copilot
41 .update(cx, |copilot, cx| copilot.sign_out(cx))
42 .detach_and_log_err(cx);
43 });
44
45 cx.observe(&copilot, |handle, cx| {
46 let status = handle.read(cx).status();
47 cx.update_global::<collections::CommandPaletteFilter, _, _>(
48 move |filter, _cx| match status {
49 Status::Disabled => {
50 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
51 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
52 }
53 Status::Authorized => {
54 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
55 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
56 }
57 _ => {
58 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
60 }
61 },
62 );
63 })
64 .detach();
65
66 sign_in::init(cx);
67}
68
69enum CopilotServer {
70 Disabled,
71 Starting {
72 _task: Shared<Task<()>>,
73 },
74 Error(Arc<str>),
75 Started {
76 server: Arc<LanguageServer>,
77 status: SignInStatus,
78 },
79}
80
81#[derive(Clone, Debug)]
82enum SignInStatus {
83 Authorized {
84 _user: String,
85 },
86 Unauthorized {
87 _user: String,
88 },
89 SigningIn {
90 prompt: Option<request::PromptUserDeviceFlow>,
91 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
92 },
93 SignedOut,
94}
95
96#[derive(Debug, PartialEq, Eq)]
97pub enum Status {
98 Starting,
99 Error(Arc<str>),
100 Disabled,
101 SignedOut,
102 SigningIn {
103 prompt: Option<request::PromptUserDeviceFlow>,
104 },
105 Unauthorized,
106 Authorized,
107}
108
109impl Status {
110 pub fn is_authorized(&self) -> bool {
111 matches!(self, Status::Authorized)
112 }
113}
114
115#[derive(Debug, PartialEq, Eq)]
116pub struct Completion {
117 pub position: Anchor,
118 pub text: String,
119}
120
121pub struct Copilot {
122 server: CopilotServer,
123}
124
125impl Entity for Copilot {
126 type Event = ();
127}
128
129impl Copilot {
130 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
131 if cx.has_global::<ModelHandle<Self>>() {
132 Some(cx.global::<ModelHandle<Self>>().clone())
133 } else {
134 None
135 }
136 }
137
138 fn start(
139 http: Arc<dyn HttpClient>,
140 node_runtime: Arc<NodeRuntime>,
141 cx: &mut ModelContext<Self>,
142 ) -> Self {
143 cx.observe_global::<Settings, _>({
144 let http = http.clone();
145 let node_runtime = node_runtime.clone();
146 move |this, cx| {
147 if cx.global::<Settings>().enable_copilot_integration {
148 if matches!(this.server, CopilotServer::Disabled) {
149 let start_task = cx
150 .spawn({
151 let http = http.clone();
152 let node_runtime = node_runtime.clone();
153 move |this, cx| {
154 Self::start_language_server(http, node_runtime, this, cx)
155 }
156 })
157 .shared();
158 this.server = CopilotServer::Starting { _task: start_task }
159 }
160 } else {
161 this.server = CopilotServer::Disabled
162 }
163 }
164 })
165 .detach();
166
167 if cx.global::<Settings>().enable_copilot_integration {
168 let start_task = cx
169 .spawn({
170 let http = http.clone();
171 let node_runtime = node_runtime.clone();
172 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
173 })
174 .shared();
175
176 Self {
177 server: CopilotServer::Starting { _task: start_task },
178 }
179 } else {
180 Self {
181 server: CopilotServer::Disabled,
182 }
183 }
184 }
185
186 fn start_language_server(
187 http: Arc<dyn HttpClient>,
188 node_runtime: Arc<NodeRuntime>,
189 this: ModelHandle<Self>,
190 mut cx: AsyncAppContext,
191 ) -> impl Future<Output = ()> {
192 async move {
193 let start_language_server = async {
194 let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
195 let node_path = node_runtime.binary_path().await?;
196 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
197 let server =
198 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
199
200 let server = server.initialize(Default::default()).await?;
201 let status = server
202 .request::<request::CheckStatus>(request::CheckStatusParams {
203 local_checks_only: false,
204 })
205 .await?;
206 anyhow::Ok((server, status))
207 };
208
209 let server = start_language_server.await;
210 this.update(&mut cx, |this, cx| {
211 cx.notify();
212 match server {
213 Ok((server, status)) => {
214 this.server = CopilotServer::Started {
215 server,
216 status: SignInStatus::SignedOut,
217 };
218 this.update_sign_in_status(status, cx);
219 }
220 Err(error) => {
221 this.server = CopilotServer::Error(error.to_string().into());
222 cx.notify()
223 }
224 }
225 })
226 }
227 }
228
229 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
230 if let CopilotServer::Started { server, status } = &mut self.server {
231 let task = match status {
232 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
233 Task::ready(Ok(())).shared()
234 }
235 SignInStatus::SigningIn { task, .. } => {
236 cx.notify();
237 task.clone()
238 }
239 SignInStatus::SignedOut => {
240 let server = server.clone();
241 let task = cx
242 .spawn(|this, mut cx| async move {
243 let sign_in = async {
244 let sign_in = server
245 .request::<request::SignInInitiate>(
246 request::SignInInitiateParams {},
247 )
248 .await?;
249 match sign_in {
250 request::SignInInitiateResult::AlreadySignedIn { user } => {
251 Ok(request::SignInStatus::Ok { user })
252 }
253 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
254 this.update(&mut cx, |this, cx| {
255 if let CopilotServer::Started { status, .. } =
256 &mut this.server
257 {
258 if let SignInStatus::SigningIn {
259 prompt: prompt_flow,
260 ..
261 } = status
262 {
263 *prompt_flow = Some(flow.clone());
264 cx.notify();
265 }
266 }
267 });
268 let response = server
269 .request::<request::SignInConfirm>(
270 request::SignInConfirmParams {
271 user_code: flow.user_code,
272 },
273 )
274 .await?;
275 Ok(response)
276 }
277 }
278 };
279
280 let sign_in = sign_in.await;
281 this.update(&mut cx, |this, cx| match sign_in {
282 Ok(status) => {
283 this.update_sign_in_status(status, cx);
284 Ok(())
285 }
286 Err(error) => {
287 this.update_sign_in_status(
288 request::SignInStatus::NotSignedIn,
289 cx,
290 );
291 Err(Arc::new(error))
292 }
293 })
294 })
295 .shared();
296 *status = SignInStatus::SigningIn {
297 prompt: None,
298 task: task.clone(),
299 };
300 cx.notify();
301 task
302 }
303 };
304
305 cx.foreground()
306 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
307 } else {
308 Task::ready(Err(anyhow!("copilot hasn't started yet")))
309 }
310 }
311
312 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
313 if let CopilotServer::Started { server, status } = &mut self.server {
314 *status = SignInStatus::SignedOut;
315 cx.notify();
316
317 let server = server.clone();
318 cx.background().spawn(async move {
319 server
320 .request::<request::SignOut>(request::SignOutParams {})
321 .await?;
322 anyhow::Ok(())
323 })
324 } else {
325 Task::ready(Err(anyhow!("copilot hasn't started yet")))
326 }
327 }
328
329 pub fn completion<T>(
330 &self,
331 buffer: &ModelHandle<Buffer>,
332 position: T,
333 cx: &mut ModelContext<Self>,
334 ) -> Task<Result<Option<Completion>>>
335 where
336 T: ToPointUtf16,
337 {
338 let server = match self.authorized_server() {
339 Ok(server) => server,
340 Err(error) => return Task::ready(Err(error)),
341 };
342
343 let buffer = buffer.read(cx).snapshot();
344 let request = server
345 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
346 cx.background().spawn(async move {
347 let result = request.await?;
348 let completion = result
349 .completions
350 .into_iter()
351 .next()
352 .map(|completion| completion_from_lsp(completion, &buffer));
353 anyhow::Ok(completion)
354 })
355 }
356
357 pub fn completions_cycling<T>(
358 &self,
359 buffer: &ModelHandle<Buffer>,
360 position: T,
361 cx: &mut ModelContext<Self>,
362 ) -> Task<Result<Vec<Completion>>>
363 where
364 T: ToPointUtf16,
365 {
366 let server = match self.authorized_server() {
367 Ok(server) => server,
368 Err(error) => return Task::ready(Err(error)),
369 };
370
371 let buffer = buffer.read(cx).snapshot();
372 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
373 &buffer, position, cx,
374 ));
375 cx.background().spawn(async move {
376 let result = request.await?;
377 let completions = result
378 .completions
379 .into_iter()
380 .map(|completion| completion_from_lsp(completion, &buffer))
381 .collect();
382 anyhow::Ok(completions)
383 })
384 }
385
386 pub fn status(&self) -> Status {
387 match &self.server {
388 CopilotServer::Starting { .. } => Status::Starting,
389 CopilotServer::Disabled => Status::Disabled,
390 CopilotServer::Error(error) => Status::Error(error.clone()),
391 CopilotServer::Started { status, .. } => match status {
392 SignInStatus::Authorized { .. } => Status::Authorized,
393 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
394 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
395 prompt: prompt.clone(),
396 },
397 SignInStatus::SignedOut => Status::SignedOut,
398 },
399 }
400 }
401
402 fn update_sign_in_status(
403 &mut self,
404 lsp_status: request::SignInStatus,
405 cx: &mut ModelContext<Self>,
406 ) {
407 if let CopilotServer::Started { status, .. } = &mut self.server {
408 *status = match lsp_status {
409 request::SignInStatus::Ok { user }
410 | request::SignInStatus::MaybeOk { user }
411 | request::SignInStatus::AlreadySignedIn { user } => {
412 SignInStatus::Authorized { _user: user }
413 }
414 request::SignInStatus::NotAuthorized { user } => {
415 SignInStatus::Unauthorized { _user: user }
416 }
417 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
418 };
419 cx.notify();
420 }
421 }
422
423 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
424 match &self.server {
425 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
426 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
427 CopilotServer::Error(error) => Err(anyhow!(
428 "copilot was not started because of an error: {}",
429 error
430 )),
431 CopilotServer::Started { server, status } => {
432 if matches!(status, SignInStatus::Authorized { .. }) {
433 Ok(server.clone())
434 } else {
435 Err(anyhow!("must sign in before using copilot"))
436 }
437 }
438 }
439 }
440}
441
442fn build_completion_params<T>(
443 buffer: &BufferSnapshot,
444 position: T,
445 cx: &AppContext,
446) -> request::GetCompletionsParams
447where
448 T: ToPointUtf16,
449{
450 let position = position.to_point_utf16(&buffer);
451 let language_name = buffer.language_at(position).map(|language| language.name());
452 let language_name = language_name.as_deref();
453
454 let path;
455 let relative_path;
456 if let Some(file) = buffer.file() {
457 if let Some(file) = file.as_local() {
458 path = file.abs_path(cx);
459 } else {
460 path = file.full_path(cx);
461 }
462 relative_path = file.path().to_path_buf();
463 } else {
464 path = PathBuf::from("/untitled");
465 relative_path = PathBuf::from("untitled");
466 }
467
468 let settings = cx.global::<Settings>();
469 let language_id = match language_name {
470 Some("Plain Text") => "plaintext".to_string(),
471 Some(language_name) => language_name.to_lowercase(),
472 None => "plaintext".to_string(),
473 };
474 request::GetCompletionsParams {
475 doc: request::GetCompletionsDocument {
476 source: buffer.text(),
477 tab_size: settings.tab_size(language_name).into(),
478 indent_size: 1,
479 insert_spaces: !settings.hard_tabs(language_name),
480 uri: lsp::Url::from_file_path(&path).unwrap(),
481 path: path.to_string_lossy().into(),
482 relative_path: relative_path.to_string_lossy().into(),
483 language_id,
484 position: point_to_lsp(position),
485 version: 0,
486 },
487 }
488}
489
490fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
491 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
492 Completion {
493 position: buffer.anchor_before(position),
494 text: completion.display_text,
495 }
496}
497
498async fn get_copilot_lsp(
499 http: Arc<dyn HttpClient>,
500 node: Arc<NodeRuntime>,
501) -> anyhow::Result<PathBuf> {
502 const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
503
504 ///Check for the latest copilot language server and download it if we haven't already
505 async fn fetch_latest(
506 _http: Arc<dyn HttpClient>,
507 node: Arc<NodeRuntime>,
508 ) -> anyhow::Result<PathBuf> {
509 const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
510
511 let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
512
513 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
514
515 fs::create_dir_all(version_dir).await?;
516 let server_path = version_dir.join(SERVER_PATH);
517
518 if fs::metadata(&server_path).await.is_err() {
519 node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
520 .await?;
521
522 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
523 }
524
525 Ok(server_path)
526 }
527
528 match fetch_latest(http, node).await {
529 ok @ Result::Ok(..) => ok,
530 e @ Err(..) => {
531 e.log_err();
532 // Fetch a cached binary, if it exists
533 (|| async move {
534 let mut last_version_dir = None;
535 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
536 while let Some(entry) = entries.next().await {
537 let entry = entry?;
538 if entry.file_type().await?.is_dir() {
539 last_version_dir = Some(entry.path());
540 }
541 }
542 let last_version_dir =
543 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
544 let server_path = last_version_dir.join(SERVER_PATH);
545 if server_path.exists() {
546 Ok(server_path)
547 } else {
548 Err(anyhow!(
549 "missing executable in directory {:?}",
550 last_version_dir
551 ))
552 }
553 })()
554 .await
555 }
556 }
557}