Cargo.lock 🔗
@@ -4687,6 +4687,7 @@ dependencies = [
"client",
"clock",
"collections",
+ "copilot",
"ctor",
"db",
"env_logger",
Antonio Scandurra created
Cargo.lock | 1
crates/copilot/src/copilot.rs | 338 ++++++++++++++++++++++++++----------
crates/copilot/src/request.rs | 3
crates/project/Cargo.toml | 1
crates/project/src/project.rs | 24 ++
5 files changed, 268 insertions(+), 99 deletions(-)
@@ -4687,6 +4687,7 @@ dependencies = [
"client",
"clock",
"collections",
+ "copilot",
"ctor",
"db",
"env_logger",
@@ -6,8 +6,13 @@ use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use collections::HashMap;
use futures::{future::Shared, Future, FutureExt, TryFutureExt};
-use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
-use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
+use gpui::{
+ actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
+};
+use language::{
+ point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
+ ToPointUtf16,
+};
use log::{debug, error};
use lsp::LanguageServer;
use node_runtime::NodeRuntime;
@@ -105,7 +110,7 @@ enum CopilotServer {
Started {
server: Arc<LanguageServer>,
status: SignInStatus,
- subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
+ registered_buffers: HashMap<usize, RegisteredBuffer>,
},
}
@@ -141,6 +146,66 @@ impl Status {
}
}
+struct RegisteredBuffer {
+ uri: lsp::Url,
+ snapshot: Option<(i32, BufferSnapshot)>,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+impl RegisteredBuffer {
+ fn report_changes(
+ &mut self,
+ buffer: &ModelHandle<Buffer>,
+ server: &LanguageServer,
+ cx: &AppContext,
+ ) -> Result<(i32, BufferSnapshot)> {
+ let buffer = buffer.read(cx);
+ let (version, prev_snapshot) = self
+ .snapshot
+ .as_ref()
+ .ok_or_else(|| anyhow!("expected at least one snapshot"))?;
+ let next_snapshot = buffer.snapshot();
+
+ let content_changes = buffer
+ .edits_since::<(PointUtf16, usize)>(prev_snapshot.version())
+ .map(|edit| {
+ let edit_start = edit.new.start.0;
+ let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
+ let new_text = next_snapshot
+ .text_for_range(edit.new.start.1..edit.new.end.1)
+ .collect();
+ lsp::TextDocumentContentChangeEvent {
+ range: Some(lsp::Range::new(
+ point_to_lsp(edit_start),
+ point_to_lsp(edit_end),
+ )),
+ range_length: None,
+ text: new_text,
+ }
+ })
+ .collect::<Vec<_>>();
+
+ if content_changes.is_empty() {
+ Ok((*version, prev_snapshot.clone()))
+ } else {
+ let next_version = version + 1;
+ self.snapshot = Some((next_version, next_snapshot.clone()));
+
+ server.notify::<lsp::notification::DidChangeTextDocument>(
+ lsp::DidChangeTextDocumentParams {
+ text_document: lsp::VersionedTextDocumentIdentifier::new(
+ self.uri.clone(),
+ next_version,
+ ),
+ content_changes,
+ },
+ )?;
+
+ Ok((next_version, next_snapshot))
+ }
+ }
+}
+
#[derive(Debug, PartialEq, Eq)]
pub struct Completion {
pub range: Range<Anchor>,
@@ -151,6 +216,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>,
server: CopilotServer,
+ buffers: HashMap<usize, WeakModelHandle<Buffer>>,
}
impl Entity for Copilot {
@@ -212,12 +278,14 @@ impl Copilot {
http,
node_runtime,
server: CopilotServer::Starting { task: start_task },
+ buffers: Default::default(),
}
} else {
Self {
http,
node_runtime,
server: CopilotServer::Disabled,
+ buffers: Default::default(),
}
}
}
@@ -233,8 +301,9 @@ impl Copilot {
server: CopilotServer::Started {
server: Arc::new(server),
status: SignInStatus::Authorized,
- subscriptions_by_buffer_id: Default::default(),
+ registered_buffers: Default::default(),
},
+ buffers: Default::default(),
});
(this, fake_server)
}
@@ -297,7 +366,7 @@ impl Copilot {
this.server = CopilotServer::Started {
server,
status: SignInStatus::SignedOut,
- subscriptions_by_buffer_id: Default::default(),
+ registered_buffers: Default::default(),
};
this.update_sign_in_status(status, cx);
}
@@ -396,10 +465,8 @@ impl Copilot {
}
fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
- if let CopilotServer::Started { server, status, .. } = &mut self.server {
- *status = SignInStatus::SignedOut;
- cx.notify();
-
+ self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
+ if let CopilotServer::Started { server, .. } = &self.server {
let server = server.clone();
cx.background().spawn(async move {
server
@@ -433,6 +500,108 @@ impl Copilot {
cx.foreground().spawn(start_task)
}
+ pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
+ let buffer_id = buffer.id();
+ self.buffers.insert(buffer_id, buffer.downgrade());
+
+ if let CopilotServer::Started {
+ server,
+ status,
+ registered_buffers,
+ ..
+ } = &mut self.server
+ {
+ if !matches!(status, SignInStatus::Authorized { .. }) {
+ return;
+ }
+
+ let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
+ registered_buffers.entry(buffer.id()).or_insert_with(|| {
+ let snapshot = buffer.read(cx).snapshot();
+ server
+ .notify::<lsp::notification::DidOpenTextDocument>(
+ lsp::DidOpenTextDocumentParams {
+ text_document: lsp::TextDocumentItem {
+ uri: uri.clone(),
+ language_id: id_for_language(buffer.read(cx).language()),
+ version: 0,
+ text: snapshot.text(),
+ },
+ },
+ )
+ .log_err();
+
+ RegisteredBuffer {
+ uri,
+ snapshot: Some((0, snapshot)),
+ _subscriptions: [
+ cx.subscribe(buffer, |this, buffer, event, cx| {
+ this.handle_buffer_event(buffer, event, cx).log_err();
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ this.buffers.remove(&buffer_id);
+ this.unregister_buffer(buffer_id);
+ }),
+ ],
+ }
+ });
+ }
+ }
+
+ fn handle_buffer_event(
+ &mut self,
+ buffer: ModelHandle<Buffer>,
+ event: &language::Event,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<()> {
+ if let CopilotServer::Started {
+ server,
+ registered_buffers,
+ ..
+ } = &mut self.server
+ {
+ if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
+ match event {
+ language::Event::Edited => {
+ registered_buffer.report_changes(&buffer, server, cx)?;
+ }
+ language::Event::Saved => {
+ server.notify::<lsp::notification::DidSaveTextDocument>(
+ lsp::DidSaveTextDocumentParams {
+ text_document: lsp::TextDocumentIdentifier::new(
+ registered_buffer.uri.clone(),
+ ),
+ text: None,
+ },
+ )?;
+ }
+ _ => {}
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn unregister_buffer(&mut self, buffer_id: usize) {
+ if let CopilotServer::Started {
+ server,
+ registered_buffers,
+ ..
+ } = &mut self.server
+ {
+ if let Some(buffer) = registered_buffers.remove(&buffer_id) {
+ server
+ .notify::<lsp::notification::DidCloseTextDocument>(
+ lsp::DidCloseTextDocumentParams {
+ text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
+ },
+ )
+ .log_err();
+ }
+ }
+ }
+
pub fn completions<T>(
&mut self,
buffer: &ModelHandle<Buffer>,
@@ -464,16 +633,14 @@ impl Copilot {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>>
where
- R: lsp::request::Request<
- Params = request::GetCompletionsParams,
- Result = request::GetCompletionsResult,
- >,
+ R: 'static
+ + lsp::request::Request<
+ Params = request::GetCompletionsParams,
+ Result = request::GetCompletionsResult,
+ >,
T: ToPointUtf16,
{
- let buffer_id = buffer.id();
- let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
- let snapshot = buffer.read(cx).snapshot();
- let server = match &mut self.server {
+ let (server, registered_buffer) = match &mut self.server {
CopilotServer::Starting { .. } => {
return Task::ready(Err(anyhow!("copilot is still starting")))
}
@@ -487,56 +654,28 @@ impl Copilot {
CopilotServer::Started {
server,
status,
- subscriptions_by_buffer_id,
+ registered_buffers,
+ ..
} => {
if matches!(status, SignInStatus::Authorized { .. }) {
- subscriptions_by_buffer_id
- .entry(buffer_id)
- .or_insert_with(|| {
- server
- .notify::<lsp::notification::DidOpenTextDocument>(
- lsp::DidOpenTextDocumentParams {
- text_document: lsp::TextDocumentItem {
- uri: uri.clone(),
- language_id: id_for_language(
- buffer.read(cx).language(),
- ),
- version: 0,
- text: snapshot.text(),
- },
- },
- )
- .log_err();
-
- let uri = uri.clone();
- cx.observe_release(buffer, move |this, _, _| {
- if let CopilotServer::Started {
- server,
- subscriptions_by_buffer_id,
- ..
- } = &mut this.server
- {
- server
- .notify::<lsp::notification::DidCloseTextDocument>(
- lsp::DidCloseTextDocumentParams {
- text_document: lsp::TextDocumentIdentifier::new(
- uri.clone(),
- ),
- },
- )
- .log_err();
- subscriptions_by_buffer_id.remove(&buffer_id);
- }
- })
- });
-
- server.clone()
+ if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
+ (server.clone(), registered_buffer)
+ } else {
+ return Task::ready(Err(anyhow!(
+ "requested completions for an unregistered buffer"
+ )));
+ }
} else {
return Task::ready(Err(anyhow!("must sign in before using copilot")));
}
}
};
+ let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) {
+ Ok((version, snapshot)) => (version, snapshot),
+ Err(error) => return Task::ready(Err(error)),
+ };
+ let uri = registered_buffer.uri.clone();
let settings = cx.global::<Settings>();
let position = position.to_point_utf16(&snapshot);
let language = snapshot.language_at(position);
@@ -544,39 +683,23 @@ impl Copilot {
let language_name = language_name.as_deref();
let tab_size = settings.tab_size(language_name);
let hard_tabs = settings.hard_tabs(language_name);
- let language_id = id_for_language(language);
-
- let path;
- let relative_path;
- if let Some(file) = snapshot.file() {
- if let Some(file) = file.as_local() {
- path = file.abs_path(cx);
- } else {
- path = file.full_path(cx);
- }
- relative_path = file.path().to_path_buf();
- } else {
- path = PathBuf::new();
- relative_path = PathBuf::new();
- }
-
+ let relative_path = snapshot
+ .file()
+ .map(|file| file.path().to_path_buf())
+ .unwrap_or_default();
+ let request = server.request::<R>(request::GetCompletionsParams {
+ doc: request::GetCompletionsDocument {
+ uri,
+ tab_size: tab_size.into(),
+ indent_size: 1,
+ insert_spaces: !hard_tabs,
+ relative_path: relative_path.to_string_lossy().into(),
+ position: point_to_lsp(position),
+ version: version.try_into().unwrap(),
+ },
+ });
cx.background().spawn(async move {
- let result = server
- .request::<R>(request::GetCompletionsParams {
- doc: request::GetCompletionsDocument {
- source: snapshot.text(),
- tab_size: tab_size.into(),
- indent_size: 1,
- insert_spaces: !hard_tabs,
- uri,
- path: path.to_string_lossy().into(),
- relative_path: relative_path.to_string_lossy().into(),
- language_id,
- position: point_to_lsp(position),
- version: 0,
- },
- })
- .await?;
+ let result = request.await?;
let completions = result
.completions
.into_iter()
@@ -616,14 +739,37 @@ impl Copilot {
lsp_status: request::SignInStatus,
cx: &mut ModelContext<Self>,
) {
+ self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
+
if let CopilotServer::Started { status, .. } = &mut self.server {
- *status = match lsp_status {
+ match lsp_status {
request::SignInStatus::Ok { .. }
| request::SignInStatus::MaybeOk { .. }
- | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
- request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
- request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
- };
+ | request::SignInStatus::AlreadySignedIn { .. } => {
+ *status = SignInStatus::Authorized;
+
+ for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
+ if let Some(buffer) = buffer.upgrade(cx) {
+ self.register_buffer(&buffer, cx);
+ }
+ }
+ }
+ request::SignInStatus::NotAuthorized { .. } => {
+ *status = SignInStatus::Unauthorized;
+
+ for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
+ self.unregister_buffer(buffer_id);
+ }
+ }
+ request::SignInStatus::NotSignedIn => {
+ *status = SignInStatus::SignedOut;
+
+ for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
+ self.unregister_buffer(buffer_id);
+ }
+ }
+ }
+
cx.notify();
}
}
@@ -99,14 +99,11 @@ pub struct GetCompletionsParams {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetCompletionsDocument {
- pub source: String,
pub tab_size: u32,
pub indent_size: u32,
pub insert_spaces: bool,
pub uri: lsp::Url,
- pub path: String,
pub relative_path: String,
- pub language_id: String,
pub position: lsp::Position,
pub version: usize,
}
@@ -19,6 +19,7 @@ test-support = [
[dependencies]
text = { path = "../text" }
+copilot = { path = "../copilot" }
client = { path = "../client" }
clock = { path = "../clock" }
collections = { path = "../collections" }
@@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result};
use client::{proto, Client, TypedEnvelope, UserStore};
use clock::ReplicaId;
use collections::{hash_map, BTreeMap, HashMap, HashSet};
+use copilot::Copilot;
use futures::{
channel::mpsc::{self, UnboundedReceiver},
future::{try_join_all, Shared},
@@ -129,6 +130,7 @@ pub struct Project {
_maintain_buffer_languages: Task<()>,
_maintain_workspace_config: Task<()>,
terminals: Terminals,
+ copilot_enabled: bool,
}
enum BufferMessage {
@@ -472,6 +474,7 @@ impl Project {
terminals: Terminals {
local_handles: Vec::new(),
},
+ copilot_enabled: Copilot::global(cx).is_some(),
}
})
}
@@ -559,6 +562,7 @@ impl Project {
terminals: Terminals {
local_handles: Vec::new(),
},
+ copilot_enabled: Copilot::global(cx).is_some(),
};
for worktree in worktrees {
let _ = this.add_worktree(&worktree, cx);
@@ -664,6 +668,15 @@ impl Project {
self.start_language_server(worktree_id, worktree_path, language, cx);
}
+ if !self.copilot_enabled && Copilot::global(cx).is_some() {
+ self.copilot_enabled = true;
+ for buffer in self.opened_buffers.values() {
+ if let Some(buffer) = buffer.upgrade(cx) {
+ self.register_buffer_with_copilot(&buffer, cx);
+ }
+ }
+ }
+
cx.notify();
}
@@ -1616,6 +1629,7 @@ impl Project {
self.detect_language_for_buffer(buffer, cx);
self.register_buffer_with_language_server(buffer, cx);
+ self.register_buffer_with_copilot(buffer, cx);
cx.observe_release(buffer, |this, buffer, cx| {
if let Some(file) = File::from_dyn(buffer.file()) {
if file.is_local() {
@@ -1731,6 +1745,16 @@ impl Project {
});
}
+ fn register_buffer_with_copilot(
+ &self,
+ buffer_handle: &ModelHandle<Buffer>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ if let Some(copilot) = Copilot::global(cx) {
+ copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx));
+ }
+ }
+
async fn send_buffer_messages(
this: WeakModelHandle<Self>,
rx: UnboundedReceiver<BufferMessage>,