copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Result};
  5use client::Client;
  6use futures::{future::Shared, Future, FutureExt, TryFutureExt};
  7use gpui::{
  8    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
  9    Task,
 10};
 11use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 12use lsp::LanguageServer;
 13use node_runtime::NodeRuntime;
 14use settings::Settings;
 15use smol::{fs, stream::StreamExt};
 16use std::{
 17    ffi::OsString,
 18    path::{Path, PathBuf},
 19    sync::Arc,
 20};
 21use util::{fs::remove_matching, http::HttpClient, paths, ResultExt};
 22
 23const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 24actions!(copilot_auth, [SignIn, SignOut]);
 25
 26const COPILOT_NAMESPACE: &'static str = "copilot";
 27actions!(copilot, [NextSuggestion]);
 28
 29pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 30    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
 31    cx.set_global(copilot.clone());
 32    cx.add_global_action(|_: &SignIn, cx| {
 33        let copilot = Copilot::global(cx).unwrap();
 34        copilot
 35            .update(cx, |copilot, cx| copilot.sign_in(cx))
 36            .detach_and_log_err(cx);
 37    });
 38    cx.add_global_action(|_: &SignOut, cx| {
 39        let copilot = Copilot::global(cx).unwrap();
 40        copilot
 41            .update(cx, |copilot, cx| copilot.sign_out(cx))
 42            .detach_and_log_err(cx);
 43    });
 44
 45    cx.observe(&copilot, |handle, cx| {
 46        let status = handle.read(cx).status();
 47        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 48            move |filter, _cx| match status {
 49                Status::Disabled => {
 50                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 51                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 52                }
 53                Status::Authorized => {
 54                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 55                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 56                }
 57                _ => {
 58                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 59                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 60                }
 61            },
 62        );
 63    })
 64    .detach();
 65
 66    sign_in::init(cx);
 67}
 68
 69enum CopilotServer {
 70    Downloading,
 71    Error(Arc<str>),
 72    Disabled,
 73    Started {
 74        server: Arc<LanguageServer>,
 75        status: SignInStatus,
 76    },
 77}
 78
 79#[derive(Clone, Debug)]
 80enum SignInStatus {
 81    Authorized {
 82        _user: String,
 83    },
 84    Unauthorized {
 85        _user: String,
 86    },
 87    SigningIn {
 88        prompt: Option<request::PromptUserDeviceFlow>,
 89        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
 90    },
 91    SignedOut,
 92}
 93
 94#[derive(Debug, PartialEq, Eq)]
 95pub enum Status {
 96    Downloading,
 97    Error(Arc<str>),
 98    Disabled,
 99    SignedOut,
100    SigningIn {
101        prompt: Option<request::PromptUserDeviceFlow>,
102    },
103    Unauthorized,
104    Authorized,
105}
106
107impl Status {
108    pub fn is_authorized(&self) -> bool {
109        matches!(self, Status::Authorized)
110    }
111}
112
113#[derive(Debug, PartialEq, Eq)]
114pub struct Completion {
115    pub position: Anchor,
116    pub text: String,
117}
118
119pub struct Copilot {
120    server: CopilotServer,
121}
122
123impl Entity for Copilot {
124    type Event = ();
125}
126
127impl Copilot {
128    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
129        if cx.has_global::<ModelHandle<Self>>() {
130            Some(cx.global::<ModelHandle<Self>>().clone())
131        } else {
132            None
133        }
134    }
135
136    fn start(
137        http: Arc<dyn HttpClient>,
138        node_runtime: Arc<NodeRuntime>,
139        cx: &mut ModelContext<Self>,
140    ) -> Self {
141        // TODO: Make this task resilient to users thrashing the copilot setting
142        cx.observe_global::<Settings, _>({
143            let http = http.clone();
144            let node_runtime = node_runtime.clone();
145            move |this, cx| {
146                if cx.global::<Settings>().copilot.as_bool() {
147                    if matches!(this.server, CopilotServer::Disabled) {
148                        cx.spawn({
149                            let http = http.clone();
150                            let node_runtime = node_runtime.clone();
151                            move |this, cx| {
152                                Self::start_language_server(http, node_runtime, this, cx)
153                            }
154                        })
155                        .detach();
156                    }
157                } else {
158                    // TODO: What else needs to be turned off here?
159                    this.server = CopilotServer::Disabled
160                }
161            }
162        })
163        .detach();
164
165        if !cx.global::<Settings>().copilot.as_bool() {
166            return Self {
167                server: CopilotServer::Disabled,
168            };
169        }
170
171        cx.spawn({
172            let http = http.clone();
173            let node_runtime = node_runtime.clone();
174            move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
175        })
176        .detach();
177
178        Self {
179            server: CopilotServer::Downloading,
180        }
181    }
182
183    fn start_language_server(
184        http: Arc<dyn HttpClient>,
185        node_runtime: Arc<NodeRuntime>,
186        this: ModelHandle<Self>,
187        mut cx: AsyncAppContext,
188    ) -> impl Future<Output = ()> {
189        async move {
190            let start_language_server = async {
191                let server_path = get_copilot_lsp(http, node_runtime.clone()).await?;
192                let node_path = node_runtime.binary_path().await?;
193                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
194                let server =
195                    LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
196
197                let server = server.initialize(Default::default()).await?;
198                let status = server
199                    .request::<request::CheckStatus>(request::CheckStatusParams {
200                        local_checks_only: false,
201                    })
202                    .await?;
203                anyhow::Ok((server, status))
204            };
205
206            let server = start_language_server.await;
207            this.update(&mut cx, |this, cx| {
208                cx.notify();
209                match server {
210                    Ok((server, status)) => {
211                        this.server = CopilotServer::Started {
212                            server,
213                            status: SignInStatus::SignedOut,
214                        };
215                        this.update_sign_in_status(status, cx);
216                    }
217                    Err(error) => {
218                        this.server = CopilotServer::Error(error.to_string().into());
219                    }
220                }
221            })
222        }
223    }
224
225    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
226        if let CopilotServer::Started { server, status } = &mut self.server {
227            let task = match status {
228                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
229                    Task::ready(Ok(())).shared()
230                }
231                SignInStatus::SigningIn { task, .. } => {
232                    cx.notify(); // To re-show the prompt, just in case.
233                    task.clone()
234                }
235                SignInStatus::SignedOut => {
236                    let server = server.clone();
237                    let task = cx
238                        .spawn(|this, mut cx| async move {
239                            let sign_in = async {
240                                let sign_in = server
241                                    .request::<request::SignInInitiate>(
242                                        request::SignInInitiateParams {},
243                                    )
244                                    .await?;
245                                match sign_in {
246                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
247                                        Ok(request::SignInStatus::Ok { user })
248                                    }
249                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
250                                        this.update(&mut cx, |this, cx| {
251                                            if let CopilotServer::Started { status, .. } =
252                                                &mut this.server
253                                            {
254                                                if let SignInStatus::SigningIn {
255                                                    prompt: prompt_flow,
256                                                    ..
257                                                } = status
258                                                {
259                                                    *prompt_flow = Some(flow.clone());
260                                                    cx.notify();
261                                                }
262                                            }
263                                        });
264                                        let response = server
265                                            .request::<request::SignInConfirm>(
266                                                request::SignInConfirmParams {
267                                                    user_code: flow.user_code,
268                                                },
269                                            )
270                                            .await?;
271                                        Ok(response)
272                                    }
273                                }
274                            };
275
276                            let sign_in = sign_in.await;
277                            this.update(&mut cx, |this, cx| match sign_in {
278                                Ok(status) => {
279                                    this.update_sign_in_status(status, cx);
280                                    Ok(())
281                                }
282                                Err(error) => {
283                                    this.update_sign_in_status(
284                                        request::SignInStatus::NotSignedIn,
285                                        cx,
286                                    );
287                                    Err(Arc::new(error))
288                                }
289                            })
290                        })
291                        .shared();
292                    *status = SignInStatus::SigningIn {
293                        prompt: None,
294                        task: task.clone(),
295                    };
296                    cx.notify();
297                    task
298                }
299            };
300
301            cx.foreground()
302                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
303        } else {
304            Task::ready(Err(anyhow!("copilot hasn't started yet")))
305        }
306    }
307
308    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
309        if let CopilotServer::Started { server, status } = &mut self.server {
310            *status = SignInStatus::SignedOut;
311            cx.notify();
312
313            let server = server.clone();
314            cx.background().spawn(async move {
315                server
316                    .request::<request::SignOut>(request::SignOutParams {})
317                    .await?;
318                anyhow::Ok(())
319            })
320        } else {
321            Task::ready(Err(anyhow!("copilot hasn't started yet")))
322        }
323    }
324
325    pub fn completion<T>(
326        &self,
327        buffer: &ModelHandle<Buffer>,
328        position: T,
329        cx: &mut ModelContext<Self>,
330    ) -> Task<Result<Option<Completion>>>
331    where
332        T: ToPointUtf16,
333    {
334        let server = match self.authorized_server() {
335            Ok(server) => server,
336            Err(error) => return Task::ready(Err(error)),
337        };
338
339        let buffer = buffer.read(cx).snapshot();
340        let request = server
341            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
342        cx.background().spawn(async move {
343            let result = request.await?;
344            let completion = result
345                .completions
346                .into_iter()
347                .next()
348                .map(|completion| completion_from_lsp(completion, &buffer));
349            anyhow::Ok(completion)
350        })
351    }
352
353    pub fn completions_cycling<T>(
354        &self,
355        buffer: &ModelHandle<Buffer>,
356        position: T,
357        cx: &mut ModelContext<Self>,
358    ) -> Task<Result<Vec<Completion>>>
359    where
360        T: ToPointUtf16,
361    {
362        let server = match self.authorized_server() {
363            Ok(server) => server,
364            Err(error) => return Task::ready(Err(error)),
365        };
366
367        let buffer = buffer.read(cx).snapshot();
368        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
369            &buffer, position, cx,
370        ));
371        cx.background().spawn(async move {
372            let result = request.await?;
373            let completions = result
374                .completions
375                .into_iter()
376                .map(|completion| completion_from_lsp(completion, &buffer))
377                .collect();
378            anyhow::Ok(completions)
379        })
380    }
381
382    pub fn status(&self) -> Status {
383        match &self.server {
384            CopilotServer::Downloading => Status::Downloading,
385            CopilotServer::Disabled => Status::Disabled,
386            CopilotServer::Error(error) => Status::Error(error.clone()),
387            CopilotServer::Started { status, .. } => match status {
388                SignInStatus::Authorized { .. } => Status::Authorized,
389                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
390                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
391                    prompt: prompt.clone(),
392                },
393                SignInStatus::SignedOut => Status::SignedOut,
394            },
395        }
396    }
397
398    fn update_sign_in_status(
399        &mut self,
400        lsp_status: request::SignInStatus,
401        cx: &mut ModelContext<Self>,
402    ) {
403        if let CopilotServer::Started { status, .. } = &mut self.server {
404            *status = match lsp_status {
405                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
406                    SignInStatus::Authorized { _user: user }
407                }
408                request::SignInStatus::NotAuthorized { user } => {
409                    SignInStatus::Unauthorized { _user: user }
410                }
411                _ => SignInStatus::SignedOut,
412            };
413            cx.notify();
414        }
415    }
416
417    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
418        match &self.server {
419            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
420            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
421            CopilotServer::Error(error) => Err(anyhow!(
422                "copilot was not started because of an error: {}",
423                error
424            )),
425            CopilotServer::Started { server, status } => {
426                if matches!(status, SignInStatus::Authorized { .. }) {
427                    Ok(server.clone())
428                } else {
429                    Err(anyhow!("must sign in before using copilot"))
430                }
431            }
432        }
433    }
434}
435
436fn build_completion_params<T>(
437    buffer: &BufferSnapshot,
438    position: T,
439    cx: &AppContext,
440) -> request::GetCompletionsParams
441where
442    T: ToPointUtf16,
443{
444    let position = position.to_point_utf16(&buffer);
445    let language_name = buffer.language_at(position).map(|language| language.name());
446    let language_name = language_name.as_deref();
447
448    let path;
449    let relative_path;
450    if let Some(file) = buffer.file() {
451        if let Some(file) = file.as_local() {
452            path = file.abs_path(cx);
453        } else {
454            path = file.full_path(cx);
455        }
456        relative_path = file.path().to_path_buf();
457    } else {
458        path = PathBuf::from("/untitled");
459        relative_path = PathBuf::from("untitled");
460    }
461
462    let settings = cx.global::<Settings>();
463    let language_id = match language_name {
464        Some("Plain Text") => "plaintext".to_string(),
465        Some(language_name) => language_name.to_lowercase(),
466        None => "plaintext".to_string(),
467    };
468    request::GetCompletionsParams {
469        doc: request::GetCompletionsDocument {
470            source: buffer.text(),
471            tab_size: settings.tab_size(language_name).into(),
472            indent_size: 1,
473            insert_spaces: !settings.hard_tabs(language_name),
474            uri: lsp::Url::from_file_path(&path).unwrap(),
475            path: path.to_string_lossy().into(),
476            relative_path: relative_path.to_string_lossy().into(),
477            language_id,
478            position: point_to_lsp(position),
479            version: 0,
480        },
481    }
482}
483
484fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
485    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
486    Completion {
487        position: buffer.anchor_before(position),
488        text: completion.display_text,
489    }
490}
491
492async fn get_copilot_lsp(
493    http: Arc<dyn HttpClient>,
494    node: Arc<NodeRuntime>,
495) -> anyhow::Result<PathBuf> {
496    const SERVER_PATH: &'static str = "node_modules/copilot-node-server/copilot/dist/agent.js";
497
498    ///Check for the latest copilot language server and download it if we haven't already
499    async fn fetch_latest(
500        _http: Arc<dyn HttpClient>,
501        node: Arc<NodeRuntime>,
502    ) -> anyhow::Result<PathBuf> {
503        const COPILOT_NPM_PACKAGE: &'static str = "copilot-node-server";
504
505        let release = node.npm_package_latest_version(COPILOT_NPM_PACKAGE).await?;
506
507        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.clone()));
508
509        fs::create_dir_all(version_dir).await?;
510        let server_path = version_dir.join(SERVER_PATH);
511
512        if fs::metadata(&server_path).await.is_err() {
513            node.npm_install_packages([(COPILOT_NPM_PACKAGE, release.as_str())], version_dir)
514                .await?;
515
516            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
517        }
518
519        Ok(server_path)
520    }
521
522    match fetch_latest(http, node).await {
523        ok @ Result::Ok(..) => ok,
524        e @ Err(..) => {
525            e.log_err();
526            // Fetch a cached binary, if it exists
527            (|| async move {
528                let mut last_version_dir = None;
529                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
530                while let Some(entry) = entries.next().await {
531                    let entry = entry?;
532                    if entry.file_type().await?.is_dir() {
533                        last_version_dir = Some(entry.path());
534                    }
535                }
536                let last_version_dir =
537                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
538                let server_path = last_version_dir.join(SERVER_PATH);
539                if server_path.exists() {
540                    Ok(server_path)
541                } else {
542                    Err(anyhow!(
543                        "missing executable in directory {:?}",
544                        last_version_dir
545                    ))
546                }
547            })()
548            .await
549        }
550    }
551}