@@ -1,16 +1,17 @@
mod request;
mod sign_in;
-use anyhow::{anyhow, bail, Context, Result};
+use anyhow::{anyhow, Context, Result};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use client::Client;
+use collections::HashMap;
use futures::{future::Shared, Future, FutureExt, TryFutureExt};
use gpui::{
actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
Task,
};
-use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
+use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
use log::{debug, error};
use lsp::LanguageServer;
use node_runtime::NodeRuntime;
@@ -92,6 +93,7 @@ enum CopilotServer {
Started {
server: Arc<LanguageServer>,
status: SignInStatus,
+ subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
},
}
@@ -275,6 +277,7 @@ impl Copilot {
this.server = CopilotServer::Started {
server,
status: SignInStatus::SignedOut,
+ subscriptions_by_buffer_id: Default::default(),
};
this.update_sign_in_status(status, cx);
}
@@ -288,7 +291,7 @@ impl Copilot {
}
fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
- if let CopilotServer::Started { server, status } = &mut self.server {
+ if let CopilotServer::Started { server, status, .. } = &mut self.server {
let task = match status {
SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
Task::ready(Ok(())).shared()
@@ -373,7 +376,7 @@ impl Copilot {
}
fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
- if let CopilotServer::Started { server, status } = &mut self.server {
+ if let CopilotServer::Started { server, status, .. } = &mut self.server {
*status = SignInStatus::SignedOut;
cx.notify();
@@ -410,43 +413,20 @@ impl Copilot {
cx.foreground().spawn(start_task)
}
- pub fn completion<T>(
- &self,
+ pub fn completions<T>(
+ &mut self,
buffer: &ModelHandle<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
- ) -> Task<Result<Option<Completion>>>
+ ) -> Task<Result<Vec<Completion>>>
where
T: ToPointUtf16,
{
- let server = match self.authorized_server() {
- Ok(server) => server,
- Err(error) => return Task::ready(Err(error)),
- };
-
- let buffer = buffer.read(cx);
-
- if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
- return Task::ready(Err(anyhow!("Copilot only works locally")));
- }
-
- let buffer = buffer.snapshot();
- let request = server.request::<request::GetCompletions>(
- build_completion_params(&buffer, position, cx).unwrap(),
- );
- cx.background().spawn(async move {
- let result = request.await?;
- let completion = result
- .completions
- .into_iter()
- .next()
- .map(|completion| completion_from_lsp(completion, &buffer));
- anyhow::Ok(completion)
- })
+ self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
}
pub fn completions_cycling<T>(
- &self,
+ &mut self,
buffer: &ModelHandle<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
@@ -454,27 +434,138 @@ impl Copilot {
where
T: ToPointUtf16,
{
- let server = match self.authorized_server() {
- Ok(server) => server,
- Err(error) => return Task::ready(Err(error)),
- };
+ self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
+ }
- let buffer = buffer.read(cx);
+ fn request_completions<R, T>(
+ &mut self,
+ buffer: &ModelHandle<Buffer>,
+ position: T,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<Completion>>>
+ where
+ R: lsp::request::Request<
+ Params = request::GetCompletionsParams,
+ Result = request::GetCompletionsResult,
+ >,
+ T: ToPointUtf16,
+ {
+ let buffer_id = buffer.id();
+ let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
+ let snapshot = buffer.read(cx).snapshot();
+ let server = match &mut self.server {
+ CopilotServer::Starting { .. } => {
+ return Task::ready(Err(anyhow!("copilot is still starting")))
+ }
+ CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
+ CopilotServer::Error(error) => {
+ return Task::ready(Err(anyhow!(
+ "copilot was not started because of an error: {}",
+ error
+ )))
+ }
+ CopilotServer::Started {
+ server,
+ status,
+ subscriptions_by_buffer_id,
+ } => {
+ if matches!(status, SignInStatus::Authorized { .. }) {
+ subscriptions_by_buffer_id
+ .entry(buffer_id)
+ .or_insert_with(|| {
+ server
+ .notify::<lsp::notification::DidOpenTextDocument>(
+ lsp::DidOpenTextDocumentParams {
+ text_document: lsp::TextDocumentItem {
+ uri: uri.clone(),
+ language_id: id_for_language(
+ buffer.read(cx).language(),
+ ),
+ version: 0,
+ text: snapshot.text(),
+ },
+ },
+ )
+ .log_err();
+
+ let uri = uri.clone();
+ cx.observe_release(buffer, move |this, _, _| {
+ if let CopilotServer::Started {
+ server,
+ subscriptions_by_buffer_id,
+ ..
+ } = &mut this.server
+ {
+ server
+ .notify::<lsp::notification::DidCloseTextDocument>(
+ lsp::DidCloseTextDocumentParams {
+ text_document: lsp::TextDocumentIdentifier::new(
+ uri.clone(),
+ ),
+ },
+ )
+ .log_err();
+ subscriptions_by_buffer_id.remove(&buffer_id);
+ }
+ })
+ });
- if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
- return Task::ready(Err(anyhow!("Copilot only works locally")));
+ server.clone()
+ } else {
+ return Task::ready(Err(anyhow!("must sign in before using copilot")));
+ }
+ }
+ };
+
+ let settings = cx.global::<Settings>();
+ let position = position.to_point_utf16(&snapshot);
+ let language = snapshot.language_at(position);
+ let language_name = language.map(|language| language.name());
+ let language_name = language_name.as_deref();
+
+ let path;
+ let relative_path;
+ if let Some(file) = snapshot.file() {
+ if let Some(file) = file.as_local() {
+ path = file.abs_path(cx);
+ } else {
+ path = file.full_path(cx);
+ }
+ relative_path = file.path().to_path_buf();
+ } else {
+ path = PathBuf::new();
+ relative_path = PathBuf::new();
}
- let buffer = buffer.snapshot();
- let request = server.request::<request::GetCompletionsCycling>(
- build_completion_params(&buffer, position, cx).unwrap(),
- );
+ let params = request::GetCompletionsParams {
+ doc: request::GetCompletionsDocument {
+ source: snapshot.text(),
+ tab_size: settings.tab_size(language_name).into(),
+ indent_size: 1,
+ insert_spaces: !settings.hard_tabs(language_name),
+ uri,
+ path: path.to_string_lossy().into(),
+ relative_path: relative_path.to_string_lossy().into(),
+ language_id: id_for_language(language),
+ position: point_to_lsp(position),
+ version: 0,
+ },
+ };
cx.background().spawn(async move {
- let result = request.await?;
+ let result = server.request::<R>(params).await?;
let completions = result
.completions
.into_iter()
- .map(|completion| completion_from_lsp(completion, &buffer))
+ .map(|completion| {
+ let start = snapshot
+ .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
+ let end =
+ snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
+ Completion {
+ range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
+ text: completion.text,
+ }
+ })
.collect();
anyhow::Ok(completions)
})
@@ -516,85 +607,14 @@ impl Copilot {
cx.notify();
}
}
-
- fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
- match &self.server {
- CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
- CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
- CopilotServer::Error(error) => Err(anyhow!(
- "copilot was not started because of an error: {}",
- error
- )),
- CopilotServer::Started { server, status } => {
- if matches!(status, SignInStatus::Authorized { .. }) {
- Ok(server.clone())
- } else {
- Err(anyhow!("must sign in before using copilot"))
- }
- }
- }
- }
}
-fn build_completion_params<T>(
- buffer: &BufferSnapshot,
- position: T,
- cx: &AppContext,
-) -> anyhow::Result<request::GetCompletionsParams>
-where
- T: ToPointUtf16,
-{
- let position = position.to_point_utf16(&buffer);
- let language_name = buffer.language_at(position).map(|language| language.name());
- let language_name = language_name.as_deref();
-
- let path;
- let relative_path;
- if let Some(file) = buffer.file() {
- if let Some(file) = file.as_local() {
- path = file.abs_path(cx);
- } else {
- path = file.full_path(cx);
- }
- relative_path = file.path().to_path_buf();
- } else {
- path = PathBuf::from("/untitled");
- relative_path = PathBuf::from("untitled");
- }
-
- let settings = cx.global::<Settings>();
- let language_id = match language_name {
+fn id_for_language(language: Option<&Arc<Language>>) -> String {
+ let language_name = language.map(|language| language.name());
+ match language_name.as_deref() {
Some("Plain Text") => "plaintext".to_string(),
Some(language_name) => language_name.to_lowercase(),
None => "plaintext".to_string(),
- };
-
- let Ok(uri) = lsp::Url::from_file_path(&path) else {
- bail!("Failed convert file path")
- };
-
- Ok(request::GetCompletionsParams {
- doc: request::GetCompletionsDocument {
- source: buffer.text(),
- tab_size: settings.tab_size(language_name).into(),
- indent_size: 1,
- insert_spaces: !settings.hard_tabs(language_name),
- uri,
- path: path.to_string_lossy().into(),
- relative_path: relative_path.to_string_lossy().into(),
- language_id,
- position: point_to_lsp(position),
- version: 0,
- },
- })
-}
-
-fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
- let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
- let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
- Completion {
- range: buffer.anchor_before(start)..buffer.anchor_after(end),
- text: completion.text,
}
}