Implement Copilot sign in and sign out

Antonio Scandurra created

Change summary

Cargo.lock                    |   1 
crates/copilot/Cargo.toml     |   1 
crates/copilot/src/copilot.rs | 186 +++++++++++++++++++++++++++++-------
crates/copilot/src/request.rs |  81 ++++++++++++++-
4 files changed, 222 insertions(+), 47 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1340,6 +1340,7 @@ dependencies = [
  "client",
  "futures 0.3.25",
  "gpui",
+ "log",
  "lsp",
  "serde",
  "serde_derive",

crates/copilot/Cargo.toml 🔗

@@ -17,6 +17,7 @@ client = { path = "../client" }
 workspace = { path = "../workspace" }
 async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
 anyhow = "1.0"
+log = "0.4"
 serde = { workspace = true }
 serde_derive = { workspace = true }
 smol = "1.2.5"

crates/copilot/src/copilot.rs 🔗

@@ -3,7 +3,7 @@ mod request;
 use anyhow::{anyhow, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use client::Client;
-use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext};
+use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
 use lsp::LanguageServer;
 use smol::{fs, io::BufReader, stream::StreamExt};
 use std::{
@@ -15,11 +15,32 @@ use util::{
     fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 };
 
-actions!(copilot, [SignIn]);
+actions!(copilot, [SignIn, SignOut]);
 
 pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
-    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
+    let (copilot, task) = Copilot::start(client.http_client(), cx);
     cx.set_global(copilot);
+    cx.spawn(|mut cx| async move {
+        task.await?;
+        cx.update(|cx| {
+            cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
+                if let Some(copilot) = Copilot::global(cx) {
+                    copilot
+                        .update(cx, |copilot, cx| copilot.sign_in(cx))
+                        .detach_and_log_err(cx);
+                }
+            });
+            cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
+                if let Some(copilot) = Copilot::global(cx) {
+                    copilot
+                        .update(cx, |copilot, cx| copilot.sign_out(cx))
+                        .detach_and_log_err(cx);
+                }
+            });
+        });
+        anyhow::Ok(())
+    })
+    .detach_and_log_err(cx);
 }
 
 enum CopilotServer {
@@ -31,18 +52,26 @@ enum CopilotServer {
     },
 }
 
+#[derive(Clone, Debug, PartialEq, Eq)]
 enum SignInStatus {
-    Authorized,
-    Unauthorized,
+    Authorized { user: String },
+    Unauthorized { user: String },
     SignedOut,
 }
 
+pub enum Event {
+    PromptUserDeviceFlow {
+        user_code: String,
+        verification_uri: String,
+    },
+}
+
 struct Copilot {
     server: CopilotServer,
 }
 
 impl Entity for Copilot {
-    type Event = ();
+    type Event = Event;
 }
 
 impl Copilot {
@@ -54,46 +83,123 @@ impl Copilot {
         }
     }
 
-    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
-        let copilot = Self {
+    fn start(
+        http: Arc<dyn HttpClient>,
+        cx: &mut MutableAppContext,
+    ) -> (ModelHandle<Self>, Task<Result<()>>) {
+        let this = cx.add_model(|_| Self {
             server: CopilotServer::Downloading,
-        };
-        cx.spawn(|this, mut cx| async move {
-            let start_language_server = async {
-                let server_path = get_lsp_binary(http).await?;
-                let server =
-                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
-                let server = server.initialize(Default::default()).await?;
-                let status = server
-                    .request::<request::CheckStatus>(request::CheckStatusParams {
-                        local_checks_only: false,
-                    })
-                    .await?;
-                let status = match status.status.as_str() {
-                    "OK" | "MaybeOk" => SignInStatus::Authorized,
-                    "NotAuthorized" => SignInStatus::Unauthorized,
-                    _ => SignInStatus::SignedOut,
+        });
+        let task = cx.spawn({
+            let this = this.clone();
+            |mut cx| async move {
+                let start_language_server = async {
+                    let server_path = get_lsp_binary(http).await?;
+                    let server = LanguageServer::new(
+                        0,
+                        &server_path,
+                        &["--stdio"],
+                        Path::new("/"),
+                        cx.clone(),
+                    )?;
+                    let server = server.initialize(Default::default()).await?;
+                    let status = server
+                        .request::<request::CheckStatus>(request::CheckStatusParams {
+                            local_checks_only: false,
+                        })
+                        .await?;
+                    anyhow::Ok((server, status))
                 };
-                anyhow::Ok((server, status))
-            };
 
-            let server = start_language_server.await;
-            this.update(&mut cx, |this, cx| {
-                cx.notify();
-                match server {
-                    Ok((server, status)) => {
-                        this.server = CopilotServer::Started { server, status };
-                        Ok(())
-                    }
-                    Err(error) => {
-                        this.server = CopilotServer::Error(error.to_string());
-                        Err(error)
+                let server = start_language_server.await;
+                this.update(&mut cx, |this, cx| {
+                    cx.notify();
+                    match server {
+                        Ok((server, status)) => {
+                            this.server = CopilotServer::Started {
+                                server,
+                                status: SignInStatus::SignedOut,
+                            };
+                            this.update_sign_in_status(status, cx);
+                            Ok(())
+                        }
+                        Err(error) => {
+                            this.server = CopilotServer::Error(error.to_string());
+                            Err(error)
+                        }
                     }
+                })
+            }
+        });
+        (this, task)
+    }
+
+    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        if let CopilotServer::Started { server, .. } = &self.server {
+            let server = server.clone();
+            cx.spawn(|this, mut cx| async move {
+                let sign_in = server
+                    .request::<request::SignInInitiate>(request::SignInInitiateParams {})
+                    .await?;
+                if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
+                    this.update(&mut cx, |_, cx| {
+                        cx.emit(Event::PromptUserDeviceFlow {
+                            user_code: flow.user_code.clone(),
+                            verification_uri: flow.verification_uri,
+                        });
+                    });
+                    let response = server
+                        .request::<request::SignInConfirm>(request::SignInConfirmParams {
+                            user_code: flow.user_code,
+                        })
+                        .await?;
+                    this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
                 }
+                anyhow::Ok(())
             })
-        })
-        .detach_and_log_err(cx);
-        copilot
+        } else {
+            Task::ready(Err(anyhow!("copilot hasn't started yet")))
+        }
+    }
+
+    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        if let CopilotServer::Started { server, .. } = &self.server {
+            let server = server.clone();
+            cx.spawn(|this, mut cx| async move {
+                server
+                    .request::<request::SignOut>(request::SignOutParams {})
+                    .await?;
+                this.update(&mut cx, |this, cx| {
+                    if let CopilotServer::Started { status, .. } = &mut this.server {
+                        *status = SignInStatus::SignedOut;
+                        cx.notify();
+                    }
+                });
+
+                anyhow::Ok(())
+            })
+        } else {
+            Task::ready(Err(anyhow!("copilot hasn't started yet")))
+        }
+    }
+
+    fn update_sign_in_status(
+        &mut self,
+        lsp_status: request::SignInStatus,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let CopilotServer::Started { status, .. } = &mut self.server {
+            *status = match lsp_status {
+                request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
+                    SignInStatus::Authorized { user }
+                }
+                request::SignInStatus::NotAuthorized { user } => {
+                    SignInStatus::Unauthorized { user }
+                }
+                _ => SignInStatus::SignedOut,
+            };
+            cx.notify();
+        }
     }
 }
 

crates/copilot/src/request.rs 🔗

@@ -8,15 +8,82 @@ pub struct CheckStatusParams {
     pub local_checks_only: bool,
 }
 
+impl lsp::request::Request for CheckStatus {
+    type Params = CheckStatusParams;
+    type Result = SignInStatus;
+    const METHOD: &'static str = "checkStatus";
+}
+
+pub enum SignInInitiate {}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SignInInitiateParams {}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "status")]
+pub enum SignInInitiateResult {
+    AlreadySignedIn { user: String },
+    PromptUserDeviceFlow(PromptUserDeviceFlow),
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-pub struct CheckStatusResult {
-    pub status: String,
-    pub user: Option<String>,
+pub struct PromptUserDeviceFlow {
+    pub user_code: String,
+    pub verification_uri: String,
 }
 
-impl lsp::request::Request for CheckStatus {
-    type Params = CheckStatusParams;
-    type Result = CheckStatusResult;
-    const METHOD: &'static str = "checkStatus";
+impl lsp::request::Request for SignInInitiate {
+    type Params = SignInInitiateParams;
+    type Result = SignInInitiateResult;
+    const METHOD: &'static str = "signInInitiate";
+}
+
+pub enum SignInConfirm {}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SignInConfirmParams {
+    pub user_code: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "status")]
+pub enum SignInStatus {
+    #[serde(rename = "OK")]
+    Ok {
+        user: String,
+    },
+    MaybeOk {
+        user: String,
+    },
+    AlreadySignedIn {
+        user: String,
+    },
+    NotAuthorized {
+        user: String,
+    },
+    NotSignedIn,
+}
+
+impl lsp::request::Request for SignInConfirm {
+    type Params = SignInConfirmParams;
+    type Result = SignInStatus;
+    const METHOD: &'static str = "signInConfirm";
+}
+
+pub enum SignOut {}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SignOutParams {}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SignOutResult {}
+
+impl lsp::request::Request for SignOut {
+    type Params = SignOutParams;
+    type Result = SignOutResult;
+    const METHOD: &'static str = "signOut";
 }