copilot.rs

  1pub mod copilot_button;
  2mod request;
  3mod sign_in;
  4
  5use anyhow::{anyhow, Result};
  6use client::Client;
  7use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  8use gpui::{
  9    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 10    Task,
 11};
 12use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 13use lsp::LanguageServer;
 14use node_runtime::NodeRuntime;
 15use settings::Settings;
 16use smol::{fs, stream::StreamExt};
 17use std::{
 18    ffi::OsString,
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
 23
 24const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 25actions!(copilot_auth, [SignIn, SignOut]);
 26
 27const COPILOT_NAMESPACE: &'static str = "copilot";
 28actions!(copilot, [NextSuggestion, PreviousSuggestion, Toggle]);
 29
 30pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 31    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 32    cx.set_global(copilot.clone());
 33    cx.add_global_action(|_: &SignIn, cx| {
 34        let copilot = Copilot::global(cx).unwrap();
 35        copilot
 36            .update(cx, |copilot, cx| copilot.sign_in(cx))
 37            .detach_and_log_err(cx);
 38    });
 39    cx.add_global_action(|_: &SignOut, cx| {
 40        let copilot = Copilot::global(cx).unwrap();
 41        copilot
 42            .update(cx, |copilot, cx| copilot.sign_out(cx))
 43            .detach_and_log_err(cx);
 44    });
 45
 46    cx.observe(&copilot, |handle, cx| {
 47        let status = handle.read(cx).status();
 48        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 49            move |filter, _cx| match status {
 50                Status::Disabled => {
 51                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 52                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 53                }
 54                Status::Authorized => {
 55                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 56                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 57                }
 58                _ => {
 59                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 60                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 61                }
 62            },
 63        );
 64    })
 65    .detach();
 66
 67    sign_in::init(cx);
 68}
 69
 70enum CopilotServer {
 71    Disabled,
 72    Starting {
 73        _task: Shared<Task<()>>,
 74    },
 75    Error(Arc<str>),
 76    Started {
 77        server: Arc<LanguageServer>,
 78        status: SignInStatus,
 79    },
 80}
 81
 82#[derive(Clone, Debug)]
 83enum SignInStatus {
 84    Authorized {
 85        _user: String,
 86    },
 87    Unauthorized {
 88        _user: String,
 89    },
 90    SigningIn {
 91        prompt: Option<request::PromptUserDeviceFlow>,
 92        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 93    },
 94    SignedOut,
 95}
 96
 97#[derive(Debug, PartialEq, Eq)]
 98pub enum Status {
 99    Starting,
100    Error(Arc<str>),
101    Disabled,
102    SignedOut,
103    SigningIn {
104        prompt: Option<request::PromptUserDeviceFlow>,
105    },
106    Unauthorized,
107    Authorized,
108}
109
110impl Status {
111    pub fn is_authorized(&self) -> bool {
112        matches!(self, Status::Authorized)
113    }
114}
115
116#[derive(Debug, PartialEq, Eq)]
117pub struct Completion {
118    pub position: Anchor,
119    pub text: String,
120}
121
122pub struct Copilot {
123    server: CopilotServer,
124}
125
126impl Entity for Copilot {
127    type Event = ();
128}
129
130impl Copilot {
131    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
132        if cx.has_global::<ModelHandle<Self>>() {
133            Some(cx.global::<ModelHandle<Self>>().clone())
134        } else {
135            None
136        }
137    }
138
139    fn start(
140        http: Arc<dyn HttpClient>,
141        node_runtime: Arc<NodeRuntime>,
142        cx: &mut ModelContext<Self>,
143    ) -> Self {
144        cx.observe_global::<Settings, _>({
145            let http = http.clone();
146            let node_runtime = node_runtime.clone();
147            move |this, cx| {
148                if cx.global::<Settings>().enable_copilot_integration {
149                    if matches!(this.server, CopilotServer::Disabled) {
150                        let start_task = cx
151                            .spawn({
152                                let http = http.clone();
153                                let node_runtime = node_runtime.clone();
154                                move |this, cx| {
155                                    Self::start_language_server(http, node_runtime, this, cx)
156                                }
157                            })
158                            .shared();
159                        this.server = CopilotServer::Starting { _task: start_task }
160                    }
161                } else {
162                    this.server = CopilotServer::Disabled
163                }
164            }
165        })
166        .detach();
167
168        if cx.global::<Settings>().enable_copilot_integration {
169            let start_task = cx
170                .spawn({
171                    let http = http.clone();
172                    let node_runtime = node_runtime.clone();
173                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
174                })
175                .shared();
176
177            Self {
178                server: CopilotServer::Starting { _task: start_task },
179            }
180        } else {
181            Self {
182                server: CopilotServer::Disabled,
183            }
184        }
185    }
186
187    fn start_language_server(
188        http: Arc<dyn HttpClient>,
189        node_runtime: Arc<NodeRuntime>,
190        this: ModelHandle<Self>,
191        mut cx: AsyncAppContext,
192    ) -> impl Future<Output = ()> {
193        async move {
194            let start_language_server = async {
195                let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
196                let node_path = node_runtime.binary_path().await?;
197                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
198                let server =
199                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
200
201                let server = server.initialize(Default::default()).await?;
202                let status = server
203                    .request::<request::CheckStatus>(request::CheckStatusParams {
204                        local_checks_only: false,
205                    })
206                    .await?;
207                anyhow::Ok((server, status))
208            };
209
210            let server = start_language_server.await;
211            this.update(&mut cx, |this, cx| {
212                cx.notify();
213                match server {
214                    Ok((server, status)) => {
215                        this.server = CopilotServer::Started {
216                            server,
217                            status: SignInStatus::SignedOut,
218                        };
219                        this.update_sign_in_status(status, cx);
220                    }
221                    Err(error) => {
222                        this.server = CopilotServer::Error(error.to_string().into());
223                        cx.notify()
224                    }
225                }
226            })
227        }
228    }
229
230    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
231        if let CopilotServer::Started { server, status } = &mut self.server {
232            let task = match status {
233                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
234                    Task::ready(Ok(())).shared()
235                }
236                SignInStatus::SigningIn { task, .. } => {
237                    cx.notify();
238                    task.clone()
239                }
240                SignInStatus::SignedOut => {
241                    let server = server.clone();
242                    let task = cx
243                        .spawn(|this, mut cx| async move {
244                            let sign_in = async {
245                                let sign_in = server
246                                    .request::<request::SignInInitiate>(
247                                        request::SignInInitiateParams {},
248                                    )
249                                    .await?;
250                                match sign_in {
251                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
252                                        Ok(request::SignInStatus::Ok { user })
253                                    }
254                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
255                                        this.update(&mut cx, |this, cx| {
256                                            if let CopilotServer::Started { status, .. } =
257                                                &mut this.server
258                                            {
259                                                if let SignInStatus::SigningIn {
260                                                    prompt: prompt_flow,
261                                                    ..
262                                                } = status
263                                                {
264                                                    *prompt_flow = Some(flow.clone());
265                                                    cx.notify();
266                                                }
267                                            }
268                                        });
269                                        let response = server
270                                            .request::<request::SignInConfirm>(
271                                                request::SignInConfirmParams {
272                                                    user_code: flow.user_code,
273                                                },
274                                            )
275                                            .await?;
276                                        Ok(response)
277                                    }
278                                }
279                            };
280
281                            let sign_in = sign_in.await;
282                            this.update(&mut cx, |this, cx| match sign_in {
283                                Ok(status) => {
284                                    this.update_sign_in_status(status, cx);
285                                    Ok(())
286                                }
287                                Err(error) => {
288                                    this.update_sign_in_status(
289                                        request::SignInStatus::NotSignedIn,
290                                        cx,
291                                    );
292                                    Err(Arc::new(error))
293                                }
294                            })
295                        })
296                        .shared();
297                    *status = SignInStatus::SigningIn {
298                        prompt: None,
299                        task: task.clone(),
300                    };
301                    cx.notify();
302                    task
303                }
304            };
305
306            cx.foreground()
307                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
308        } else {
309            Task::ready(Err(anyhow!("copilot hasn't started yet")))
310        }
311    }
312
313    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
314        if let CopilotServer::Started { server, status } = &mut self.server {
315            *status = SignInStatus::SignedOut;
316            cx.notify();
317
318            let server = server.clone();
319            cx.background().spawn(async move {
320                server
321                    .request::<request::SignOut>(request::SignOutParams {})
322                    .await?;
323                anyhow::Ok(())
324            })
325        } else {
326            Task::ready(Err(anyhow!("copilot hasn't started yet")))
327        }
328    }
329
330    pub fn completion<T>(
331        &self,
332        buffer: &ModelHandle<Buffer>,
333        position: T,
334        cx: &mut ModelContext<Self>,
335    ) -> Task<Result<Option<Completion>>>
336    where
337        T: ToPointUtf16,
338    {
339        let server = match self.authorized_server() {
340            Ok(server) => server,
341            Err(error) => return Task::ready(Err(error)),
342        };
343
344        let buffer = buffer.read(cx).snapshot();
345        let request = server
346            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
347        cx.background().spawn(async move {
348            let result = request.await?;
349            let completion = result
350                .completions
351                .into_iter()
352                .next()
353                .map(|completion| completion_from_lsp(completion, &buffer));
354            anyhow::Ok(completion)
355        })
356    }
357
358    pub fn completions_cycling<T>(
359        &self,
360        buffer: &ModelHandle<Buffer>,
361        position: T,
362        cx: &mut ModelContext<Self>,
363    ) -> Task<Result<Vec<Completion>>>
364    where
365        T: ToPointUtf16,
366    {
367        let server = match self.authorized_server() {
368            Ok(server) => server,
369            Err(error) => return Task::ready(Err(error)),
370        };
371
372        let buffer = buffer.read(cx).snapshot();
373        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
374            &buffer, position, cx,
375        ));
376        cx.background().spawn(async move {
377            let result = request.await?;
378            let completions = result
379                .completions
380                .into_iter()
381                .map(|completion| completion_from_lsp(completion, &buffer))
382                .collect();
383            anyhow::Ok(completions)
384        })
385    }
386
387    pub fn status(&self) -> Status {
388        match &self.server {
389            CopilotServer::Starting { .. } => Status::Starting,
390            CopilotServer::Disabled => Status::Disabled,
391            CopilotServer::Error(error) => Status::Error(error.clone()),
392            CopilotServer::Started { status, .. } => match status {
393                SignInStatus::Authorized { .. } => Status::Authorized,
394                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
395                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
396                    prompt: prompt.clone(),
397                },
398                SignInStatus::SignedOut => Status::SignedOut,
399            },
400        }
401    }
402
403    fn update_sign_in_status(
404        &mut self,
405        lsp_status: request::SignInStatus,
406        cx: &mut ModelContext<Self>,
407    ) {
408        if let CopilotServer::Started { status, .. } = &mut self.server {
409            *status = match lsp_status {
410                request::SignInStatus::Ok { user }
411                | request::SignInStatus::MaybeOk { user }
412                | request::SignInStatus::AlreadySignedIn { user } => {
413                    SignInStatus::Authorized { _user: user }
414                }
415                request::SignInStatus::NotAuthorized { user } => {
416                    SignInStatus::Unauthorized { _user: user }
417                }
418                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
419            };
420            cx.notify();
421        }
422    }
423
424    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
425        match &self.server {
426            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
427            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
428            CopilotServer::Error(error) => Err(anyhow!(
429                "copilot was not started because of an error: {}",
430                error
431            )),
432            CopilotServer::Started { server, status } => {
433                if matches!(status, SignInStatus::Authorized { .. }) {
434                    Ok(server.clone())
435                } else {
436                    Err(anyhow!("must sign in before using copilot"))
437                }
438            }
439        }
440    }
441}
442
443fn build_completion_params<T>(
444    buffer: &BufferSnapshot,
445    position: T,
446    cx: &AppContext,
447) -> request::GetCompletionsParams
448where
449    T: ToPointUtf16,
450{
451    let position = position.to_point_utf16(&buffer);
452    let language_name = buffer.language_at(position).map(|language| language.name());
453    let language_name = language_name.as_deref();
454
455    let path;
456    let relative_path;
457    if let Some(file) = buffer.file() {
458        if let Some(file) = file.as_local() {
459            path = file.abs_path(cx);
460        } else {
461            path = file.full_path(cx);
462        }
463        relative_path = file.path().to_path_buf();
464    } else {
465        path = PathBuf::from("/untitled");
466        relative_path = PathBuf::from("untitled");
467    }
468
469    let settings = cx.global::<Settings>();
470    let language_id = match language_name {
471        Some("Plain Text") => "plaintext".to_string(),
472        Some(language_name) => language_name.to_lowercase(),
473        None => "plaintext".to_string(),
474    };
475    request::GetCompletionsParams {
476        doc: request::GetCompletionsDocument {
477            source: buffer.text(),
478            tab_size: settings.tab_size(language_name).into(),
479            indent_size: 1,
480            insert_spaces: !settings.hard_tabs(language_name),
481            uri: lsp::Url::from_file_path(&path).unwrap(),
482            path: path.to_string_lossy().into(),
483            relative_path: relative_path.to_string_lossy().into(),
484            language_id,
485            position: point_to_lsp(position),
486            version: 0,
487        },
488    }
489}
490
491fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
492    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
493    Completion {
494        position: buffer.anchor_before(position),
495        text: completion.display_text,
496    }
497}
498
499async fn get_copilot_lsp(
500    http: Arc<dyn HttpClient>,
501    node: Arc<NodeRuntime>,
502) -> anyhow::Result<PathBuf> {
503    const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
504
505    ///Check for the latest copilot language server and download it if we haven't already
506    async fn fetch_latest(
507        _http: Arc<dyn HttpClient>,
508        node: Arc<NodeRuntime>,
509    ) -> anyhow::Result<PathBuf> {
510        const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
511
512        let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
513
514        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
515
516        fs::create_dir_all(version_dir).await?;
517        let server_path = version_dir.join(SERVER_PATH);
518
519        if fs::metadata(&server_path).await.is_err() {
520            node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
521                .await?;
522
523            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
524        }
525
526        Ok(server_path)
527    }
528
529    match fetch_latest(http, node).await {
530        ok @ Result::Ok(..) => ok,
531        e @ Err(..) => {
532            e.log_err();
533            // Fetch a cached binary, if it exists
534            (|| async move {
535                let mut last_version_dir = None;
536                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
537                while let Some(entry) = entries.next().await {
538                    let entry = entry?;
539                    if entry.file_type().await?.is_dir() {
540                        last_version_dir = Some(entry.path());
541                    }
542                }
543                let last_version_dir =
544                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
545                let server_path = last_version_dir.join(SERVER_PATH);
546                if server_path.exists() {
547                    Ok(server_path)
548                } else {
549                    Err(anyhow!(
550                        "missing executable in directory {:?}",
551                        last_version_dir
552                    ))
553                }
554            })()
555            .await
556        }
557    }
558}