@@ -24,6 +24,8 @@ pub struct PredictEditsRequest {
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub diagnostic_groups: Vec<DiagnosticGroup>,
+ #[serde(skip_serializing_if = "is_default", default)]
+ pub diagnostic_groups_truncated: bool,
/// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>,
@@ -92,10 +94,8 @@ pub struct ScoreComponents {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct DiagnosticGroup {
- pub language_server: String,
- pub diagnostic_group: serde_json::Value,
-}
+#[serde(transparent)]
+pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
@@ -119,3 +119,7 @@ pub struct Edit {
pub range: Range<usize>,
pub content: String,
}
+
+fn is_default<T: Default + PartialEq>(value: &T) -> bool {
+ *value == T::default()
+}
@@ -18,7 +18,9 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
-use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
+use language::{
+ Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
+};
use language::{BufferSnapshot, EditPreview};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
@@ -45,6 +47,11 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction
target_before_cursor_over_total_bytes: 0.5,
};
+pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
+ excerpt: DEFAULT_EXCERPT_OPTIONS,
+ max_diagnostic_bytes: 2048,
+};
+
#[derive(Clone)]
struct ZetaGlobal(Entity<Zeta>);
@@ -56,11 +63,17 @@ pub struct Zeta {
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
projects: HashMap<EntityId, ZetaProject>,
- pub excerpt_options: EditPredictionExcerptOptions,
+ options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
}
+#[derive(Debug, Clone, PartialEq)]
+pub struct ZetaOptions {
+ pub excerpt: EditPredictionExcerptOptions,
+ pub max_diagnostic_bytes: usize,
+}
+
pub struct PredictionDebugInfo {
pub context: EditPredictionContext,
pub retrieval_time: TimeDelta,
@@ -113,7 +126,7 @@ impl Zeta {
projects: HashMap::new(),
client,
user_store,
- excerpt_options: DEFAULT_EXCERPT_OPTIONS,
+ options: DEFAULT_OPTIONS,
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
@@ -138,12 +151,12 @@ impl Zeta {
debug_watch_rx
}
- pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
- &self.excerpt_options
+ pub fn options(&self) -> &ZetaOptions {
+ &self.options
}
- pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
- self.excerpt_options = options;
+ pub fn set_options(&mut self, options: ZetaOptions) {
+ self.options = options;
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -290,7 +303,7 @@ impl Zeta {
.syntax_index
.read_with(cx, |index, _cx| index.state().clone())
});
- let excerpt_options = self.excerpt_options.clone();
+ let options = self.options.clone();
let snapshot = buffer.read(cx).snapshot();
let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
@@ -343,6 +356,8 @@ impl Zeta {
})
.unwrap_or_default();
+ let diagnostics = snapshot.diagnostic_sets().clone();
+
let request_task = cx.background_spawn({
let snapshot = snapshot.clone();
let buffer = buffer.clone();
@@ -353,14 +368,15 @@ impl Zeta {
None
};
- let cursor_point = position.to_point(&snapshot);
+ let cursor_offset = position.to_offset(&snapshot);
+ let cursor_point = cursor_offset.to_point(&snapshot);
let before_retrieval = chrono::Utc::now();
let Some(context) = EditPredictionContext::gather_context(
cursor_point,
&snapshot,
- &excerpt_options,
+ &options.excerpt,
index_state.as_deref(),
) else {
return Ok(None);
@@ -372,13 +388,22 @@ impl Zeta {
None
};
+ let (diagnostic_groups, diagnostic_groups_truncated) =
+ Self::gather_nearby_diagnostics(
+ cursor_offset,
+ &diagnostics,
+ &snapshot,
+ options.max_diagnostic_bytes,
+ );
+
let request = make_cloud_request(
excerpt_path.clone(),
context,
events,
// TODO data collection
false,
- Vec::new(),
+ diagnostic_groups,
+ diagnostic_groups_truncated,
None,
debug_context.is_some(),
&worktree_snapshots,
@@ -575,6 +600,52 @@ impl Zeta {
}
}
+ fn gather_nearby_diagnostics(
+ cursor_offset: usize,
+ diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
+ snapshot: &BufferSnapshot,
+ max_diagnostics_bytes: usize,
+ ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
+ // TODO: Could make this more efficient
+ let mut diagnostic_groups = Vec::new();
+ for (language_server_id, diagnostics) in diagnostic_sets {
+ let mut groups = Vec::new();
+ diagnostics.groups(*language_server_id, &mut groups, &snapshot);
+ diagnostic_groups.extend(
+ groups
+ .into_iter()
+ .map(|(_, group)| group.resolve::<usize>(&snapshot)),
+ );
+ }
+
+ // sort by proximity to cursor
+ diagnostic_groups.sort_by_key(|group| {
+ let range = &group.entries[group.primary_ix].range;
+ if range.start >= cursor_offset {
+ range.start - cursor_offset
+ } else if cursor_offset >= range.end {
+ cursor_offset - range.end
+ } else {
+ (cursor_offset - range.start).min(range.end - cursor_offset)
+ }
+ });
+
+ let mut results = Vec::new();
+ let mut diagnostic_groups_truncated = false;
+ let mut diagnostics_byte_count = 0;
+ for group in diagnostic_groups {
+ let raw_value = serde_json::value::to_raw_value(&group).unwrap();
+ diagnostics_byte_count += raw_value.get().len();
+ if diagnostics_byte_count > max_diagnostics_bytes {
+ diagnostic_groups_truncated = true;
+ break;
+ }
+ results.push(predict_edits_v3::DiagnosticGroup(raw_value));
+ }
+
+ (results, diagnostic_groups_truncated)
+ }
+
// TODO: Dedupe with similar code in request_prediction?
pub fn cloud_request_for_zeta_cli(
&mut self,
@@ -590,7 +661,7 @@ impl Zeta {
.syntax_index
.read_with(cx, |index, _cx| index.state().clone())
});
- let excerpt_options = self.excerpt_options.clone();
+ let options = self.options.clone();
let snapshot = buffer.read(cx).snapshot();
let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
@@ -614,7 +685,7 @@ impl Zeta {
EditPredictionContext::gather_context(
cursor_point,
&snapshot,
- &excerpt_options,
+ &options.excerpt,
index_state.as_deref(),
)
.context("Failed to select excerpt")
@@ -626,6 +697,7 @@ impl Zeta {
Vec::new(),
false,
Vec::new(),
+ false,
None,
debug_info,
&worktree_snapshots,
@@ -985,6 +1057,7 @@ fn make_cloud_request(
events: Vec<predict_edits_v3::Event>,
can_collect_data: bool,
diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+ diagnostic_groups_truncated: bool,
git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
debug_info: bool,
worktrees: &Vec<worktree::Snapshot>,
@@ -1058,6 +1131,8 @@ fn make_cloud_request(
events,
can_collect_data,
diagnostic_groups,
+ diagnostic_groups_truncated,
+
git_info,
debug_info,
}
@@ -1141,7 +1216,6 @@ fn interpolate(
mod tests {
use super::*;
use gpui::TestAppContext;
- use language::ToOffset as _;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -22,7 +22,7 @@ use ui::prelude::*;
use ui_input::SingleLineInput;
use util::ResultExt;
use workspace::{Item, SplitDirection, Workspace};
-use zeta2::Zeta;
+use zeta2::{Zeta, ZetaOptions};
use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
@@ -137,25 +137,28 @@ impl Zeta2Inspector {
_update_state_task: Task::ready(()),
_receive_task: receive_task,
};
- this.set_input_options(&zeta.read(cx).excerpt_options().clone(), window, cx);
+ this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
this
}
fn set_input_options(
&mut self,
- options: &EditPredictionExcerptOptions,
+ options: &ZetaOptions,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.max_bytes_input.update(cx, |input, cx| {
- input.set_text(options.max_bytes.to_string(), window, cx);
+ input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
});
self.min_bytes_input.update(cx, |input, cx| {
- input.set_text(options.min_bytes.to_string(), window, cx);
+ input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
});
self.cursor_context_ratio_input.update(cx, |input, cx| {
input.set_text(
- format!("{:.2}", options.target_before_cursor_over_total_bytes),
+ format!(
+ "{:.2}",
+ options.excerpt.target_before_cursor_over_total_bytes
+ ),
window,
cx,
);
@@ -163,9 +166,8 @@ impl Zeta2Inspector {
cx.notify();
}
- fn set_options(&mut self, options: EditPredictionExcerptOptions, cx: &mut Context<Self>) {
- self.zeta
- .update(cx, |this, _cx| this.set_excerpt_options(options));
+ fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
+ self.zeta.update(cx, |this, _cx| this.set_options(options));
const THROTTLE_TIME: Duration = Duration::from_millis(100);
@@ -233,7 +235,7 @@ impl Zeta2Inspector {
.unwrap_or_default()
}
- let options = EditPredictionExcerptOptions {
+ let excerpt_options = EditPredictionExcerptOptions {
max_bytes: number_input_value(&this.max_bytes_input, cx),
min_bytes: number_input_value(&this.min_bytes_input, cx),
target_before_cursor_over_total_bytes: number_input_value(
@@ -242,7 +244,13 @@ impl Zeta2Inspector {
),
};
- this.set_options(options, cx);
+ this.set_options(
+ ZetaOptions {
+ excerpt: excerpt_options,
+ ..this.zeta.read(cx).options().clone()
+ },
+ cx,
+ );
},
)
.detach();
@@ -525,15 +533,15 @@ impl Render for Zeta2Inspector {
.child(
ui::Button::new("reset-options", "Reset")
.disabled(
- self.zeta.read(cx).excerpt_options()
- == &zeta2::DEFAULT_EXCERPT_OPTIONS,
+ self.zeta.read(cx).options()
+ == &zeta2::DEFAULT_OPTIONS,
)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Large)
.on_click(cx.listener(
|this, _, window, cx| {
this.set_input_options(
- &zeta2::DEFAULT_EXCERPT_OPTIONS,
+ &zeta2::DEFAULT_OPTIONS,
window,
cx,
);
@@ -70,6 +70,8 @@ struct Zeta2Args {
excerpt_min_bytes: usize,
#[arg(long, default_value_t = 0.66)]
target_before_cursor_over_total_bytes: f32,
+ #[arg(long, default_value_t = 1024)]
+ max_diagnostic_bytes: usize,
}
#[derive(Debug, Clone)]
@@ -221,12 +223,15 @@ async fn get_context(
});
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(&buffer, &project, cx);
- zeta.excerpt_options = EditPredictionExcerptOptions {
- max_bytes: zeta2_args.excerpt_max_bytes,
- min_bytes: zeta2_args.excerpt_min_bytes,
- target_before_cursor_over_total_bytes: zeta2_args
- .target_before_cursor_over_total_bytes,
- }
+ zeta.set_options(zeta2::ZetaOptions {
+ excerpt: EditPredictionExcerptOptions {
+ max_bytes: zeta2_args.excerpt_max_bytes,
+ min_bytes: zeta2_args.excerpt_min_bytes,
+ target_before_cursor_over_total_bytes: zeta2_args
+ .target_before_cursor_over_total_bytes,
+ },
+ max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
+ })
});
// TODO: Actually wait for indexing.
let timer = cx.background_executor().timer(Duration::from_secs(5));