copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use client::Client;
  6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  7use gpui::{
  8    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
  9    Task,
 10};
 11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use settings::Settings;
 15use smol::{fs, stream::StreamExt};
 16use std::{
 17    ffi::OsString,
 18    path::{Path, PathBuf},
 19    sync::Arc,
 20};
 21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
 22
 23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 24actions!(copilot_auth, [SignIn, SignOut]);
 25
 26const COPILOT_NAMESPACE: &'static str = "copilot";
 27actions!(copilot, [NextSuggestion]);
 28
 29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 30    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 31    cx.set_global(copilot.clone());
 32    cx.add_global_action(|_: &SignIn, cx| {
 33        let copilot = Copilot::global(cx).unwrap();
 34        copilot
 35            .update(cx, |copilot, cx| copilot.sign_in(cx))
 36            .detach_and_log_err(cx);
 37    });
 38    cx.add_global_action(|_: &SignOut, cx| {
 39        let copilot = Copilot::global(cx).unwrap();
 40        copilot
 41            .update(cx, |copilot, cx| copilot.sign_out(cx))
 42            .detach_and_log_err(cx);
 43    });
 44
 45    cx.observe(&copilot, |handle, cx| {
 46        let status = handle.read(cx).status();
 47        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 48            move |filter, _cx| match status {
 49                Status::Disabled => {
 50                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 51                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 52                }
 53                Status::Authorized => {
 54                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 55                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 56                }
 57                _ => {
 58                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 60                }
 61            },
 62        );
 63    })
 64    .detach();
 65
 66    sign_in::init(cx);
 67}
 68
 69enum CopilotServer {
 70    Downloading,
 71    Error(Arc<str>),
 72    Disabled,
 73    Started {
 74        server: Arc<LanguageServer>,
 75        status: SignInStatus,
 76    },
 77}
 78
 79#[derive(Clone, Debug)]
 80enum SignInStatus {
 81    Authorized {
 82        _user: String,
 83    },
 84    Unauthorized {
 85        _user: String,
 86    },
 87    SigningIn {
 88        prompt: Option<request::PromptUserDeviceFlow>,
 89        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 90    },
 91    SignedOut,
 92}
 93
 94#[derive(Debug, PartialEq, Eq)]
 95pub enum Status {
 96    Downloading,
 97    Error(Arc<str>),
 98    Disabled,
 99    SignedOut,
100    SigningIn {
101        prompt: Option<request::PromptUserDeviceFlow>,
102    },
103    Unauthorized,
104    Authorized,
105}
106
107impl Status {
108    pub fn is_authorized(&self) -> bool {
109        matches!(self, Status::Authorized)
110    }
111}
112
113#[derive(Debug, PartialEq, Eq)]
114pub struct Completion {
115    pub position: Anchor,
116    pub text: String,
117}
118
119pub struct Copilot {
120    server: CopilotServer,
121}
122
123impl Entity for Copilot {
124    type Event = ();
125}
126
127impl Copilot {
128    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
129        if cx.has_global::<ModelHandle<Self>>() {
130            Some(cx.global::<ModelHandle<Self>>().clone())
131        } else {
132            None
133        }
134    }
135
136    fn start(
137        http: Arc<dyn HttpClient>,
138        node_runtime: Arc<NodeRuntime>,
139        cx: &mut ModelContext<Self>,
140    ) -> Self {
141        // TODO: Make this task resilient to users thrashing the copilot setting
142        cx.observe_global::<Settings, _>({
143            let http = http.clone();
144            let node_runtime = node_runtime.clone();
145            move |this, cx| {
146                if cx.global::<Settings>().copilot.as_bool() {
147                    if matches!(this.server, CopilotServer::Disabled) {
148                        cx.spawn({
149                            let http = http.clone();
150                            let node_runtime = node_runtime.clone();
151                            move |this, cx| {
152                                Self::start_language_server(http, node_runtime, this, cx)
153                            }
154                        })
155                        .detach();
156                    }
157                } else {
158                    // TODO: What else needs to be turned off here?
159                    this.server = CopilotServer::Disabled
160                }
161            }
162        })
163        .detach();
164
165        if !cx.global::<Settings>().copilot.as_bool() {
166            return Self {
167                server: CopilotServer::Disabled,
168            };
169        }
170
171        cx.spawn({
172            let http = http.clone();
173            let node_runtime = node_runtime.clone();
174            move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
175        })
176        .detach();
177
178        Self {
179            server: CopilotServer::Downloading,
180        }
181    }
182
183    fn start_language_server(
184        http: Arc<dyn HttpClient>,
185        node_runtime: Arc<NodeRuntime>,
186        this: ModelHandle<Self>,
187        mut cx: AsyncAppContext,
188    ) -> impl Future<Output = ()> {
189        async move {
190            let start_language_server = async {
191                let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
192                let node_path = node_runtime.binary_path().await?;
193                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
194                let server =
195                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
196
197                let server = server.initialize(Default::default()).await?;
198                let status = server
199                    .request::<request::CheckStatus>(request::CheckStatusParams {
200                        local_checks_only: false,
201                    })
202                    .await?;
203                anyhow::Ok((server, status))
204            };
205
206            let server = start_language_server.await;
207            this.update(&mut cx, |this, cx| {
208                cx.notify();
209                match server {
210                    Ok((server, status)) => {
211                        this.server = CopilotServer::Started {
212                            server,
213                            status: SignInStatus::SignedOut,
214                        };
215                        this.update_sign_in_status(status, cx);
216                    }
217                    Err(error) => {
218                        this.server = CopilotServer::Error(error.to_string().into());
219                    }
220                }
221            })
222        }
223    }
224
225    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
226        if let CopilotServer::Started { server, status } = &mut self.server {
227            let task = match status {
228                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
229                    cx.notify();
230                    Task::ready(Ok(())).shared()
231                }
232                SignInStatus::SigningIn { task, .. } => {
233                    cx.notify(); // To re-show the prompt, just in case.
234                    task.clone()
235                }
236                SignInStatus::SignedOut => {
237                    let server = server.clone();
238                    let task = cx
239                        .spawn(|this, mut cx| async move {
240                            let sign_in = async {
241                                let sign_in = server
242                                    .request::<request::SignInInitiate>(
243                                        request::SignInInitiateParams {},
244                                    )
245                                    .await?;
246                                match sign_in {
247                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
248                                        Ok(request::SignInStatus::Ok { user })
249                                    }
250                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
251                                        this.update(&mut cx, |this, cx| {
252                                            if let CopilotServer::Started { status, .. } =
253                                                &mut this.server
254                                            {
255                                                if let SignInStatus::SigningIn {
256                                                    prompt: prompt_flow,
257                                                    ..
258                                                } = status
259                                                {
260                                                    *prompt_flow = Some(flow.clone());
261                                                    cx.notify();
262                                                }
263                                            }
264                                        });
265                                        let response = server
266                                            .request::<request::SignInConfirm>(
267                                                request::SignInConfirmParams {
268                                                    user_code: flow.user_code,
269                                                },
270                                            )
271                                            .await?;
272                                        Ok(response)
273                                    }
274                                }
275                            };
276
277                            let sign_in = sign_in.await;
278                            this.update(&mut cx, |this, cx| match sign_in {
279                                Ok(status) => {
280                                    this.update_sign_in_status(status, cx);
281                                    Ok(())
282                                }
283                                Err(error) => {
284                                    this.update_sign_in_status(
285                                        request::SignInStatus::NotSignedIn,
286                                        cx,
287                                    );
288                                    Err(Arc::new(error))
289                                }
290                            })
291                        })
292                        .shared();
293                    *status = SignInStatus::SigningIn {
294                        prompt: None,
295                        task: task.clone(),
296                    };
297                    cx.notify();
298                    task
299                }
300            };
301
302            cx.foreground()
303                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
304        } else {
305            Task::ready(Err(anyhow!("copilot hasn't started yet")))
306        }
307    }
308
309    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
310        if let CopilotServer::Started { server, status } = &mut self.server {
311            *status = SignInStatus::SignedOut;
312            cx.notify();
313
314            let server = server.clone();
315            cx.background().spawn(async move {
316                server
317                    .request::<request::SignOut>(request::SignOutParams {})
318                    .await?;
319                anyhow::Ok(())
320            })
321        } else {
322            Task::ready(Err(anyhow!("copilot hasn't started yet")))
323        }
324    }
325
326    pub fn completion<T>(
327        &self,
328        buffer: &ModelHandle<Buffer>,
329        position: T,
330        cx: &mut ModelContext<Self>,
331    ) -> Task<Result<Option<Completion>>>
332    where
333        T: ToPointUtf16,
334    {
335        let server = match self.authorized_server() {
336            Ok(server) => server,
337            Err(error) => return Task::ready(Err(error)),
338        };
339
340        let buffer = buffer.read(cx).snapshot();
341        let request = server
342            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
343        cx.background().spawn(async move {
344            let result = request.await?;
345            let completion = result
346                .completions
347                .into_iter()
348                .next()
349                .map(|completion| completion_from_lsp(completion, &buffer));
350            anyhow::Ok(completion)
351        })
352    }
353
354    pub fn completions_cycling<T>(
355        &self,
356        buffer: &ModelHandle<Buffer>,
357        position: T,
358        cx: &mut ModelContext<Self>,
359    ) -> Task<Result<Vec<Completion>>>
360    where
361        T: ToPointUtf16,
362    {
363        let server = match self.authorized_server() {
364            Ok(server) => server,
365            Err(error) => return Task::ready(Err(error)),
366        };
367
368        let buffer = buffer.read(cx).snapshot();
369        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
370            &buffer, position, cx,
371        ));
372        cx.background().spawn(async move {
373            let result = request.await?;
374            let completions = result
375                .completions
376                .into_iter()
377                .map(|completion| completion_from_lsp(completion, &buffer))
378                .collect();
379            anyhow::Ok(completions)
380        })
381    }
382
383    pub fn status(&self) -> Status {
384        match &self.server {
385            CopilotServer::Downloading => Status::Downloading,
386            CopilotServer::Disabled => Status::Disabled,
387            CopilotServer::Error(error) => Status::Error(error.clone()),
388            CopilotServer::Started { status, .. } => match status {
389                SignInStatus::Authorized { .. } => Status::Authorized,
390                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
391                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
392                    prompt: prompt.clone(),
393                },
394                SignInStatus::SignedOut => Status::SignedOut,
395            },
396        }
397    }
398
399    fn update_sign_in_status(
400        &mut self,
401        lsp_status: request::SignInStatus,
402        cx: &mut ModelContext<Self>,
403    ) {
404        if let CopilotServer::Started { status, .. } = &mut self.server {
405            *status = match lsp_status {
406                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
407                    SignInStatus::Authorized { _user: user }
408                }
409                request::SignInStatus::NotAuthorized { user } => {
410                    SignInStatus::Unauthorized { _user: user }
411                }
412                _ => SignInStatus::SignedOut,
413            };
414            cx.notify();
415        }
416    }
417
418    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
419        match &self.server {
420            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
421            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
422            CopilotServer::Error(error) => Err(anyhow!(
423                "copilot was not started because of an error: {}",
424                error
425            )),
426            CopilotServer::Started { server, status } => {
427                if matches!(status, SignInStatus::Authorized { .. }) {
428                    Ok(server.clone())
429                } else {
430                    Err(anyhow!("must sign in before using copilot"))
431                }
432            }
433        }
434    }
435}
436
437fn build_completion_params<T>(
438    buffer: &BufferSnapshot,
439    position: T,
440    cx: &AppContext,
441) -> request::GetCompletionsParams
442where
443    T: ToPointUtf16,
444{
445    let position = position.to_point_utf16(&buffer);
446    let language_name = buffer.language_at(position).map(|language| language.name());
447    let language_name = language_name.as_deref();
448
449    let path;
450    let relative_path;
451    if let Some(file) = buffer.file() {
452        if let Some(file) = file.as_local() {
453            path = file.abs_path(cx);
454        } else {
455            path = file.full_path(cx);
456        }
457        relative_path = file.path().to_path_buf();
458    } else {
459        path = PathBuf::from("/untitled");
460        relative_path = PathBuf::from("untitled");
461    }
462
463    let settings = cx.global::<Settings>();
464    let language_id = match language_name {
465        Some("Plain Text") => "plaintext".to_string(),
466        Some(language_name) => language_name.to_lowercase(),
467        None => "plaintext".to_string(),
468    };
469    request::GetCompletionsParams {
470        doc: request::GetCompletionsDocument {
471            source: buffer.text(),
472            tab_size: settings.tab_size(language_name).into(),
473            indent_size: 1,
474            insert_spaces: !settings.hard_tabs(language_name),
475            uri: lsp::Url::from_file_path(&path).unwrap(),
476            path: path.to_string_lossy().into(),
477            relative_path: relative_path.to_string_lossy().into(),
478            language_id,
479            position: point_to_lsp(position),
480            version: 0,
481        },
482    }
483}
484
485fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
486    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
487    Completion {
488        position: buffer.anchor_before(position),
489        text: completion.display_text,
490    }
491}
492
493async fn get_copilot_lsp(
494    http: Arc<dyn HttpClient>,
495    node: Arc<NodeRuntime>,
496) -> anyhow::Result<PathBuf> {
497    const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
498
499    ///Check for the latest copilot language server and download it if we haven't already
500    async fn fetch_latest(
501        _http: Arc<dyn HttpClient>,
502        node: Arc<NodeRuntime>,
503    ) -> anyhow::Result<PathBuf> {
504        const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
505
506        let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
507
508        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
509
510        fs::create_dir_all(version_dir).await?;
511        let server_path = version_dir.join(SERVER_PATH);
512
513        if fs::metadata(&server_path).await.is_err() {
514            node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
515                .await?;
516
517            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
518        }
519
520        Ok(server_path)
521    }
522
523    match fetch_latest(http, node).await {
524        ok @ Result::Ok(..) => ok,
525        e @ Err(..) => {
526            e.log_err();
527            // Fetch a cached binary, if it exists
528            (|| async move {
529                let mut last_version_dir = None;
530                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
531                while let Some(entry) = entries.next().await {
532                    let entry = entry?;
533                    if entry.file_type().await?.is_dir() {
534                        last_version_dir = Some(entry.path());
535                    }
536                }
537                let last_version_dir =
538                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
539                let server_path = last_version_dir.join(SERVER_PATH);
540                if server_path.exists() {
541                    Ok(server_path)
542                } else {
543                    Err(anyhow!(
544                        "missing executable in directory {:?}",
545                        last_version_dir
546                    ))
547                }
548            })()
549            .await
550        }
551    }
552}