Start on copilot completions

Antonio Scandurra created

Change summary

Cargo.lock                    |  1 
crates/copilot/Cargo.toml     |  1 
crates/copilot/src/copilot.rs | 73 +++++++++++++++++++++++++++++++++++-
3 files changed, 72 insertions(+), 3 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1340,6 +1340,7 @@ dependencies = [
  "client",
  "futures 0.3.25",
  "gpui",
+ "language",
  "log",
  "lsp",
  "serde",

crates/copilot/Cargo.toml 🔗

@@ -10,6 +10,7 @@ doctest = false
 
 [dependencies]
 gpui = { path = "../gpui" }
+language = { path = "../language" }
 settings = { path = "../settings" }
 lsp = { path = "../lsp" }
 util = { path = "../util" }

crates/copilot/src/copilot.rs 🔗

@@ -4,6 +4,7 @@ use anyhow::{anyhow, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use client::Client;
 use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
+use language::{Buffer, ToPointUtf16};
 use lsp::LanguageServer;
 use smol::{fs, io::BufReader, stream::StreamExt};
 use std::{
@@ -38,7 +39,7 @@ pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 
 enum CopilotServer {
     Downloading,
-    Error(String),
+    Error(Arc<str>),
     Started {
         server: Arc<LanguageServer>,
         status: SignInStatus,
@@ -59,6 +60,21 @@ pub enum Event {
     },
 }
 
+#[derive(Debug)]
+pub enum Status {
+    Downloading,
+    Error(Arc<str>),
+    SignedOut,
+    Unauthorized,
+    Authorized,
+}
+
+impl Status {
+    fn is_authorized(&self) -> bool {
+        matches!(self, Status::Authorized)
+    }
+}
+
 struct Copilot {
     server: CopilotServer,
 }
@@ -70,7 +86,12 @@ impl Entity for Copilot {
 impl Copilot {
     fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
         if cx.has_global::<ModelHandle<Self>>() {
-            Some(cx.global::<ModelHandle<Self>>().clone())
+            let copilot = cx.global::<ModelHandle<Self>>().clone();
+            if copilot.read(cx).status().is_authorized() {
+                Some(copilot)
+            } else {
+                None
+            }
         } else {
             None
         }
@@ -103,7 +124,7 @@ impl Copilot {
                         this.update_sign_in_status(status, cx);
                     }
                     Err(error) => {
-                        this.server = CopilotServer::Error(error.to_string());
+                        this.server = CopilotServer::Error(error.to_string().into());
                     }
                 }
             })
@@ -163,6 +184,35 @@ impl Copilot {
         }
     }
 
+    pub fn completions<T>(
+        &self,
+        buffer: &ModelHandle<Buffer>,
+        position: T,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>>
+    where
+        T: ToPointUtf16,
+    {
+        let server = match self.authenticated_server() {
+            Ok(server) => server,
+            Err(error) => return Task::ready(Err(error)),
+        };
+
+        cx.spawn(|this, cx| async move { anyhow::Ok(()) })
+    }
+
+    pub fn status(&self) -> Status {
+        match &self.server {
+            CopilotServer::Downloading => Status::Downloading,
+            CopilotServer::Error(error) => Status::Error(error.clone()),
+            CopilotServer::Started { status, .. } => match status {
+                SignInStatus::Authorized { .. } => Status::Authorized,
+                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
+                SignInStatus::SignedOut => Status::SignedOut,
+            },
+        }
+    }
+
     fn update_sign_in_status(
         &mut self,
         lsp_status: request::SignInStatus,
@@ -181,6 +231,23 @@ impl Copilot {
             cx.notify();
         }
     }
+
+    fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
+        match &self.server {
+            CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
+            CopilotServer::Error(error) => Err(anyhow!(
+                "copilot was not started because of an error: {}",
+                error
+            )),
+            CopilotServer::Started { server, status } => {
+                if matches!(status, SignInStatus::Authorized { .. }) {
+                    Ok(server.clone())
+                } else {
+                    Err(anyhow!("must sign in before using copilot"))
+                }
+            }
+        }
+    }
 }
 
 async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {