Start copilot and check sign in status

Antonio Scandurra created

Change summary

crates/copilot/src/copilot.rs | 107 +++++++++++++++++++++++++++---------
crates/copilot/src/request.rs |  22 +++++++
2 files changed, 101 insertions(+), 28 deletions(-)

Detailed changes

crates/copilot/src/copilot.rs 🔗

@@ -1,9 +1,16 @@
-use anyhow::{anyhow, Ok};
+mod request;
+
+use anyhow::{anyhow, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use client::Client;
-use gpui::{actions, MutableAppContext};
+use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext};
+use lsp::LanguageServer;
 use smol::{fs, io::BufReader, stream::StreamExt};
-use std::{env::consts, path::PathBuf, sync::Arc};
+use std::{
+    env::consts,
+    path::{Path, PathBuf},
+    sync::Arc,
+};
 use util::{
     fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 };
@@ -11,46 +18,89 @@ use util::{
 actions!(copilot, [SignIn]);
 
 pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
-    cx.add_global_action(move |_: &SignIn, cx: &mut MutableAppContext| {
-        Copilot::sign_in(client.http_client(), cx)
-    });
+    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
+    cx.set_global(copilot);
+}
+
+enum CopilotServer {
+    Downloading,
+    Error(String),
+    Started {
+        server: Arc<LanguageServer>,
+        status: SignInStatus,
+    },
+}
+
+enum SignInStatus {
+    Authorized,
+    Unauthorized,
+    SignedOut,
 }
 
-#[derive(Debug)]
 struct Copilot {
-    copilot_server: PathBuf,
+    server: CopilotServer,
+}
+
+impl Entity for Copilot {
+    type Event = ();
 }
 
 impl Copilot {
-    fn sign_in(http: Arc<dyn HttpClient>, cx: &mut MutableAppContext) {
-        let maybe_copilot = cx.default_global::<Option<Arc<Copilot>>>().clone();
-
-        cx.spawn(|mut cx| async move {
-            // Lazily download / initialize copilot LSP
-            let copilot = if let Some(copilot) = maybe_copilot {
-                copilot
-            } else {
-                let copilot_server = get_lsp_binary(http).await?; // TODO: Make this error user visible
-                let new_copilot = Arc::new(Copilot { copilot_server });
-                cx.update({
-                    let new_copilot = new_copilot.clone();
-                    move |cx| cx.set_global(Some(new_copilot.clone()))
-                });
-                new_copilot
-            };
+    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
+        if cx.has_global::<ModelHandle<Self>>() {
+            Some(cx.global::<ModelHandle<Self>>().clone())
+        } else {
+            None
+        }
+    }
 
-            dbg!(copilot);
+    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
+        let copilot = Self {
+            server: CopilotServer::Downloading,
+        };
+        cx.spawn(|this, mut cx| async move {
+            let start_language_server = async {
+                let server_path = get_lsp_binary(http).await?;
+                let server =
+                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
+                let server = server.initialize(Default::default()).await?;
+                let status = server
+                    .request::<request::CheckStatus>(request::CheckStatusParams {
+                        local_checks_only: false,
+                    })
+                    .await?;
+                let status = match status.status.as_str() {
+                    "OK" | "MaybeOk" => SignInStatus::Authorized,
+                    "NotAuthorized" => SignInStatus::Unauthorized,
+                    _ => SignInStatus::SignedOut,
+                };
+                anyhow::Ok((server, status))
+            };
 
-            Ok(())
+            let server = start_language_server.await;
+            this.update(&mut cx, |this, cx| {
+                cx.notify();
+                match server {
+                    Ok((server, status)) => {
+                        this.server = CopilotServer::Started { server, status };
+                        Ok(())
+                    }
+                    Err(error) => {
+                        this.server = CopilotServer::Error(error.to_string());
+                        Err(error)
+                    }
+                }
+            })
         })
-        .detach();
+        .detach_and_log_err(cx);
+        copilot
     }
 }
 
 async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
     ///Check for the latest copilot language server and download it if we haven't already
     async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
-        let release = latest_github_release("zed-industries/copilotserver", http.clone()).await?;
+        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
         let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
         let asset = release
             .assets
@@ -58,6 +108,7 @@ async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
             .find(|asset| asset.name == asset_name)
             .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
 
+        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
         let destination_path =
             paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
 

crates/copilot/src/request.rs 🔗

@@ -0,0 +1,22 @@
+use serde::{Deserialize, Serialize};
+
+pub enum CheckStatus {}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CheckStatusParams {
+    pub local_checks_only: bool,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CheckStatusResult {
+    pub status: String,
+    pub user: Option<String>,
+}
+
+impl lsp::request::Request for CheckStatus {
+    type Params = CheckStatusParams;
+    type Result = CheckStatusResult;
+    const METHOD: &'static str = "checkStatus";
+}