copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use client::Client;
  6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  7use gpui::{
  8    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
  9    Task,
 10};
 11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use settings::Settings;
 15use smol::{fs, stream::StreamExt};
 16use std::{
 17    ffi::OsString,
 18    path::{Path, PathBuf},
 19    sync::Arc,
 20};
 21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
 22
 23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 24actions!(copilot_auth, [SignIn, SignOut]);
 25
 26const COPILOT_NAMESPACE: &'static str = "copilot";
 27actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
 28
 29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 30    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 31    cx.set_global(copilot.clone());
 32    cx.add_global_action(|_: &SignIn, cx| {
 33        let copilot = Copilot::global(cx).unwrap();
 34        copilot
 35            .update(cx, |copilot, cx| copilot.sign_in(cx))
 36            .detach_and_log_err(cx);
 37    });
 38    cx.add_global_action(|_: &SignOut, cx| {
 39        let copilot = Copilot::global(cx).unwrap();
 40        copilot
 41            .update(cx, |copilot, cx| copilot.sign_out(cx))
 42            .detach_and_log_err(cx);
 43    });
 44
 45    cx.observe(&copilot, |handle, cx| {
 46        let status = handle.read(cx).status();
 47        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 48            move |filter, _cx| match status {
 49                Status::Disabled => {
 50                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 51                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 52                }
 53                Status::Authorized => {
 54                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 55                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 56                }
 57                _ => {
 58                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 60                }
 61            },
 62        );
 63    })
 64    .detach();
 65
 66    sign_in::init(cx);
 67}
 68
 69enum CopilotServer {
 70    Disabled,
 71    Starting {
 72        _task: Shared<Task<()>>,
 73    },
 74    Error(Arc<str>),
 75    Started {
 76        server: Arc<LanguageServer>,
 77        status: SignInStatus,
 78    },
 79}
 80
 81#[derive(Clone, Debug)]
 82enum SignInStatus {
 83    Authorized {
 84        _user: String,
 85    },
 86    Unauthorized {
 87        _user: String,
 88    },
 89    SigningIn {
 90        prompt: Option<request::PromptUserDeviceFlow>,
 91        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 92    },
 93    SignedOut,
 94}
 95
 96#[derive(Debug, PartialEq, Eq)]
 97pub enum Status {
 98    Starting,
 99    Error(Arc<str>),
100    Disabled,
101    SignedOut,
102    SigningIn {
103        prompt: Option<request::PromptUserDeviceFlow>,
104    },
105    Unauthorized,
106    Authorized,
107}
108
109impl Status {
110    pub fn is_authorized(&self) -> bool {
111        matches!(self, Status::Authorized)
112    }
113}
114
115#[derive(Debug, PartialEq, Eq)]
116pub struct Completion {
117    pub position: Anchor,
118    pub text: String,
119}
120
121pub struct Copilot {
122    server: CopilotServer,
123}
124
125impl Entity for Copilot {
126    type Event = ();
127}
128
129impl Copilot {
130    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
131        if cx.has_global::<ModelHandle<Self>>() {
132            Some(cx.global::<ModelHandle<Self>>().clone())
133        } else {
134            None
135        }
136    }
137
138    fn start(
139        http: Arc<dyn HttpClient>,
140        node_runtime: Arc<NodeRuntime>,
141        cx: &mut ModelContext<Self>,
142    ) -> Self {
143        cx.observe_global::<Settings, _>({
144            let http = http.clone();
145            let node_runtime = node_runtime.clone();
146            move |this, cx| {
147                if cx.global::<Settings>().enable_copilot_integration {
148                    if matches!(this.server, CopilotServer::Disabled) {
149                        let start_task = cx
150                            .spawn({
151                                let http = http.clone();
152                                let node_runtime = node_runtime.clone();
153                                move |this, cx| {
154                                    Self::start_language_server(http, node_runtime, this, cx)
155                                }
156                            })
157                            .shared();
158                        this.server = CopilotServer::Starting { _task: start_task }
159                    }
160                } else {
161                    this.server = CopilotServer::Disabled
162                }
163            }
164        })
165        .detach();
166
167        if cx.global::<Settings>().enable_copilot_integration {
168            let start_task = cx
169                .spawn({
170                    let http = http.clone();
171                    let node_runtime = node_runtime.clone();
172                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
173                })
174                .shared();
175
176            Self {
177                server: CopilotServer::Starting { _task: start_task },
178            }
179        } else {
180            Self {
181                server: CopilotServer::Disabled,
182            }
183        }
184    }
185
186    fn start_language_server(
187        http: Arc<dyn HttpClient>,
188        node_runtime: Arc<NodeRuntime>,
189        this: ModelHandle<Self>,
190        mut cx: AsyncAppContext,
191    ) -> impl Future<Output = ()> {
192        async move {
193            let start_language_server = async {
194                let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
195                let node_path = node_runtime.binary_path().await?;
196                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
197                let server =
198                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
199
200                let server = server.initialize(Default::default()).await?;
201                let status = server
202                    .request::<request::CheckStatus>(request::CheckStatusParams {
203                        local_checks_only: false,
204                    })
205                    .await?;
206                anyhow::Ok((server, status))
207            };
208
209            let server = start_language_server.await;
210            this.update(&mut cx, |this, cx| {
211                cx.notify();
212                match server {
213                    Ok((server, status)) => {
214                        this.server = CopilotServer::Started {
215                            server,
216                            status: SignInStatus::SignedOut,
217                        };
218                        this.update_sign_in_status(status, cx);
219                    }
220                    Err(error) => {
221                        this.server = CopilotServer::Error(error.to_string().into());
222                        cx.notify()
223                    }
224                }
225            })
226        }
227    }
228
229    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
230        if let CopilotServer::Started { server, status } = &mut self.server {
231            let task = match status {
232                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
233                    Task::ready(Ok(())).shared()
234                }
235                SignInStatus::SigningIn { task, .. } => {
236                    cx.notify();
237                    task.clone()
238                }
239                SignInStatus::SignedOut => {
240                    let server = server.clone();
241                    let task = cx
242                        .spawn(|this, mut cx| async move {
243                            let sign_in = async {
244                                let sign_in = server
245                                    .request::<request::SignInInitiate>(
246                                        request::SignInInitiateParams {},
247                                    )
248                                    .await?;
249                                match sign_in {
250                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
251                                        Ok(request::SignInStatus::Ok { user })
252                                    }
253                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
254                                        this.update(&mut cx, |this, cx| {
255                                            if let CopilotServer::Started { status, .. } =
256                                                &mut this.server
257                                            {
258                                                if let SignInStatus::SigningIn {
259                                                    prompt: prompt_flow,
260                                                    ..
261                                                } = status
262                                                {
263                                                    *prompt_flow = Some(flow.clone());
264                                                    cx.notify();
265                                                }
266                                            }
267                                        });
268                                        let response = server
269                                            .request::<request::SignInConfirm>(
270                                                request::SignInConfirmParams {
271                                                    user_code: flow.user_code,
272                                                },
273                                            )
274                                            .await?;
275                                        Ok(response)
276                                    }
277                                }
278                            };
279
280                            let sign_in = sign_in.await;
281                            this.update(&mut cx, |this, cx| match sign_in {
282                                Ok(status) => {
283                                    this.update_sign_in_status(status, cx);
284                                    Ok(())
285                                }
286                                Err(error) => {
287                                    this.update_sign_in_status(
288                                        request::SignInStatus::NotSignedIn,
289                                        cx,
290                                    );
291                                    Err(Arc::new(error))
292                                }
293                            })
294                        })
295                        .shared();
296                    *status = SignInStatus::SigningIn {
297                        prompt: None,
298                        task: task.clone(),
299                    };
300                    cx.notify();
301                    task
302                }
303            };
304
305            cx.foreground()
306                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
307        } else {
308            Task::ready(Err(anyhow!("copilot hasn't started yet")))
309        }
310    }
311
312    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
313        if let CopilotServer::Started { server, status } = &mut self.server {
314            *status = SignInStatus::SignedOut;
315            cx.notify();
316
317            let server = server.clone();
318            cx.background().spawn(async move {
319                server
320                    .request::<request::SignOut>(request::SignOutParams {})
321                    .await?;
322                anyhow::Ok(())
323            })
324        } else {
325            Task::ready(Err(anyhow!("copilot hasn't started yet")))
326        }
327    }
328
329    pub fn completion<T>(
330        &self,
331        buffer: &ModelHandle<Buffer>,
332        position: T,
333        cx: &mut ModelContext<Self>,
334    ) -> Task<Result<Option<Completion>>>
335    where
336        T: ToPointUtf16,
337    {
338        let server = match self.authorized_server() {
339            Ok(server) => server,
340            Err(error) => return Task::ready(Err(error)),
341        };
342
343        let buffer = buffer.read(cx).snapshot();
344        let request = server
345            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
346        cx.background().spawn(async move {
347            let result = request.await?;
348            let completion = result
349                .completions
350                .into_iter()
351                .next()
352                .map(|completion| completion_from_lsp(completion, &buffer));
353            anyhow::Ok(completion)
354        })
355    }
356
357    pub fn completions_cycling<T>(
358        &self,
359        buffer: &ModelHandle<Buffer>,
360        position: T,
361        cx: &mut ModelContext<Self>,
362    ) -> Task<Result<Vec<Completion>>>
363    where
364        T: ToPointUtf16,
365    {
366        let server = match self.authorized_server() {
367            Ok(server) => server,
368            Err(error) => return Task::ready(Err(error)),
369        };
370
371        let buffer = buffer.read(cx).snapshot();
372        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
373            &buffer, position, cx,
374        ));
375        cx.background().spawn(async move {
376            let result = request.await?;
377            let completions = result
378                .completions
379                .into_iter()
380                .map(|completion| completion_from_lsp(completion, &buffer))
381                .collect();
382            anyhow::Ok(completions)
383        })
384    }
385
386    pub fn status(&self) -> Status {
387        match &self.server {
388            CopilotServer::Starting { .. } => Status::Starting,
389            CopilotServer::Disabled => Status::Disabled,
390            CopilotServer::Error(error) => Status::Error(error.clone()),
391            CopilotServer::Started { status, .. } => match status {
392                SignInStatus::Authorized { .. } => Status::Authorized,
393                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
394                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
395                    prompt: prompt.clone(),
396                },
397                SignInStatus::SignedOut => Status::SignedOut,
398            },
399        }
400    }
401
402    fn update_sign_in_status(
403        &mut self,
404        lsp_status: request::SignInStatus,
405        cx: &mut ModelContext<Self>,
406    ) {
407        if let CopilotServer::Started { status, .. } = &mut self.server {
408            *status = match lsp_status {
409                request::SignInStatus::Ok { user }
410                | request::SignInStatus::MaybeOk { user }
411                | request::SignInStatus::AlreadySignedIn { user } => {
412                    SignInStatus::Authorized { _user: user }
413                }
414                request::SignInStatus::NotAuthorized { user } => {
415                    SignInStatus::Unauthorized { _user: user }
416                }
417                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
418            };
419            cx.notify();
420        }
421    }
422
423    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
424        match &self.server {
425            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
426            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
427            CopilotServer::Error(error) => Err(anyhow!(
428                "copilot was not started because of an error: {}",
429                error
430            )),
431            CopilotServer::Started { server, status } => {
432                if matches!(status, SignInStatus::Authorized { .. }) {
433                    Ok(server.clone())
434                } else {
435                    Err(anyhow!("must sign in before using copilot"))
436                }
437            }
438        }
439    }
440}
441
442fn build_completion_params<T>(
443    buffer: &BufferSnapshot,
444    position: T,
445    cx: &AppContext,
446) -> request::GetCompletionsParams
447where
448    T: ToPointUtf16,
449{
450    let position = position.to_point_utf16(&buffer);
451    let language_name = buffer.language_at(position).map(|language| language.name());
452    let language_name = language_name.as_deref();
453
454    let path;
455    let relative_path;
456    if let Some(file) = buffer.file() {
457        if let Some(file) = file.as_local() {
458            path = file.abs_path(cx);
459        } else {
460            path = file.full_path(cx);
461        }
462        relative_path = file.path().to_path_buf();
463    } else {
464        path = PathBuf::from("/untitled");
465        relative_path = PathBuf::from("untitled");
466    }
467
468    let settings = cx.global::<Settings>();
469    let language_id = match language_name {
470        Some("Plain Text") => "plaintext".to_string(),
471        Some(language_name) => language_name.to_lowercase(),
472        None => "plaintext".to_string(),
473    };
474    request::GetCompletionsParams {
475        doc: request::GetCompletionsDocument {
476            source: buffer.text(),
477            tab_size: settings.tab_size(language_name).into(),
478            indent_size: 1,
479            insert_spaces: !settings.hard_tabs(language_name),
480            uri: lsp::Url::from_file_path(&path).unwrap(),
481            path: path.to_string_lossy().into(),
482            relative_path: relative_path.to_string_lossy().into(),
483            language_id,
484            position: point_to_lsp(position),
485            version: 0,
486        },
487    }
488}
489
490fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
491    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
492    Completion {
493        position: buffer.anchor_before(position),
494        text: completion.display_text,
495    }
496}
497
498async fn get_copilot_lsp(
499    http: Arc<dyn HttpClient>,
500    node: Arc<NodeRuntime>,
501) -> anyhow::Result<PathBuf> {
502    const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
503
504    ///Check for the latest copilot language server and download it if we haven't already
505    async fn fetch_latest(
506        _http: Arc<dyn HttpClient>,
507        node: Arc<NodeRuntime>,
508    ) -> anyhow::Result<PathBuf> {
509        const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
510
511        let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
512
513        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
514
515        fs::create_dir_all(version_dir).await?;
516        let server_path = version_dir.join(SERVER_PATH);
517
518        if fs::metadata(&server_path).await.is_err() {
519            node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
520                .await?;
521
522            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
523        }
524
525        Ok(server_path)
526    }
527
528    match fetch_latest(http, node).await {
529        ok @ Result::Ok(..) => ok,
530        e @ Err(..) => {
531            e.log_err();
532            // Fetch a cached binary, if it exists
533            (|| async move {
534                let mut last_version_dir = None;
535                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
536                while let Some(entry) = entries.next().await {
537                    let entry = entry?;
538                    if entry.file_type().await?.is_dir() {
539                        last_version_dir = Some(entry.path());
540                    }
541                }
542                let last_version_dir =
543                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
544                let server_path = last_version_dir.join(SERVER_PATH);
545                if server_path.exists() {
546                    Ok(server_path)
547                } else {
548                    Err(anyhow!(
549                        "missing executable in directory {:?}",
550                        last_version_dir
551                    ))
552                }
553            })()
554            .await
555        }
556    }
557}