Detailed changes
@@ -5309,7 +5309,6 @@ dependencies = [
"workspace",
"zed_actions",
"zeta",
- "zeta2",
]
[[package]]
@@ -21316,7 +21315,6 @@ dependencies = [
"zed_actions",
"zed_env_vars",
"zeta",
- "zeta2",
"zeta2_tools",
"zlog",
"zlog_settings",
@@ -21636,48 +21634,52 @@ dependencies = [
"ai_onboarding",
"anyhow",
"arrayvec",
- "call",
+ "brotli",
+ "buffer_diff",
"client",
"clock",
"cloud_api_types",
"cloud_llm_client",
+ "cloud_zeta2_prompt",
"collections",
"command_palette_hooks",
"copilot",
"ctor",
"db",
"edit_prediction",
+ "edit_prediction_context",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
- "http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"log",
+ "lsp",
+ "markdown",
"menu",
+ "open_ai",
"parking_lot",
"postage",
+ "pretty_assertions",
"project",
"rand 0.9.2",
"regex",
"release_channel",
- "reqwest_client",
- "rpc",
"semver",
"serde",
"serde_json",
"settings",
+ "smol",
+ "strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
"theme",
"thiserror 2.0.17",
- "tree-sitter-go",
- "tree-sitter-rust",
"ui",
"util",
"uuid",
@@ -21687,53 +21689,11 @@ dependencies = [
"zlog",
]
-[[package]]
-name = "zeta2"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "arrayvec",
- "brotli",
- "chrono",
- "client",
- "clock",
- "cloud_llm_client",
- "cloud_zeta2_prompt",
- "collections",
- "edit_prediction",
- "edit_prediction_context",
- "feature_flags",
- "futures 0.3.31",
- "gpui",
- "indoc",
- "language",
- "language_model",
- "log",
- "lsp",
- "open_ai",
- "pretty_assertions",
- "project",
- "release_channel",
- "semver",
- "serde",
- "serde_json",
- "settings",
- "smol",
- "strsim",
- "thiserror 2.0.17",
- "util",
- "uuid",
- "workspace",
- "worktree",
- "zlog",
-]
-
[[package]]
name = "zeta2_tools"
version = "0.1.0"
dependencies = [
"anyhow",
- "chrono",
"clap",
"client",
"cloud_llm_client",
@@ -21746,9 +21706,7 @@ dependencies = [
"gpui",
"indoc",
"language",
- "log",
"multi_buffer",
- "ordered-float 2.10.1",
"pretty_assertions",
"project",
"serde",
@@ -21760,7 +21718,7 @@ dependencies = [
"ui_input",
"util",
"workspace",
- "zeta2",
+ "zeta",
"zlog",
]
@@ -21810,7 +21768,6 @@ dependencies = [
"util",
"watch",
"zeta",
- "zeta2",
"zlog",
]
@@ -201,7 +201,6 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/zeta",
- "crates/zeta2",
"crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -433,7 +432,6 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
zeta = { path = "crates/zeta" }
-zeta2 = { path = "crates/zeta2" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
@@ -1218,23 +1218,23 @@
}
},
{
- "context": "RateCompletionModal",
+ "context": "RatePredictionsModal",
"use_key_equivalents": true,
"bindings": {
- "cmd-shift-enter": "zeta::ThumbsUpActiveCompletion",
- "cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion",
+ "cmd-shift-enter": "zeta::ThumbsUpActivePrediction",
+ "cmd-shift-backspace": "zeta::ThumbsDownActivePrediction",
"shift-down": "zeta::NextEdit",
"shift-up": "zeta::PreviousEdit",
- "right": "zeta::PreviewCompletion"
+ "right": "zeta::PreviewPrediction"
}
},
{
- "context": "RateCompletionModal > Editor",
+ "context": "RatePredictionsModal > Editor",
"use_key_equivalents": true,
"bindings": {
- "escape": "zeta::FocusCompletions",
- "cmd-shift-enter": "zeta::ThumbsUpActiveCompletion",
- "cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion"
+ "escape": "zeta::FocusPredictions",
+ "cmd-shift-enter": "zeta::ThumbsUpActivePrediction",
+ "cmd-shift-backspace": "zeta::ThumbsDownActivePrediction"
}
},
{
@@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Write as _},
ops::{Add, Range, Sub},
- path::{Path, PathBuf},
+ path::Path,
sync::Arc,
};
use strum::EnumIter;
@@ -17,7 +17,7 @@ pub struct PlanContextRetrievalRequest {
pub excerpt_path: Arc<Path>,
pub excerpt_line_range: Range<Line>,
pub cursor_file_max_row: Line,
- pub events: Vec<Event>,
+ pub events: Vec<Arc<Event>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -36,7 +36,7 @@ pub struct PredictEditsRequest {
pub signatures: Vec<Signature>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub referenced_declarations: Vec<ReferencedDeclaration>,
- pub events: Vec<Event>,
+ pub events: Vec<Arc<Event>>,
#[serde(default)]
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
@@ -120,10 +120,11 @@ impl std::fmt::Display for PromptFormat {
#[serde(tag = "event")]
pub enum Event {
BufferChange {
- path: Option<PathBuf>,
- old_path: Option<PathBuf>,
+ path: Arc<Path>,
+ old_path: Arc<Path>,
diff: String,
predicted: bool,
+ in_open_source_repo: bool,
},
}
@@ -135,23 +136,21 @@ impl Display for Event {
old_path,
diff,
predicted,
+ ..
} => {
- let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
- let old_path = old_path.as_deref().unwrap_or(new_path);
-
if *predicted {
write!(
f,
"// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
DiffPathFmt(old_path),
- DiffPathFmt(new_path)
+ DiffPathFmt(path)
)
} else {
write!(
f,
"--- a/{}\n+++ b/{}\n{diff}",
DiffPathFmt(old_path),
- DiffPathFmt(new_path)
+ DiffPathFmt(path)
)
}
}
@@ -300,10 +299,11 @@ mod tests {
#[test]
fn test_event_display() {
let ev = Event::BufferChange {
- path: None,
- old_path: None,
+ path: Path::new("untitled").into(),
+ old_path: Path::new("untitled").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
+ in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -317,10 +317,11 @@ mod tests {
);
let ev = Event::BufferChange {
- path: Some(PathBuf::from("foo/bar.txt")),
- old_path: Some(PathBuf::from("foo/bar.txt")),
+ path: Path::new("foo/bar.txt").into(),
+ old_path: Path::new("foo/bar.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
+ in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -334,10 +335,11 @@ mod tests {
);
let ev = Event::BufferChange {
- path: Some(PathBuf::from("abc.txt")),
- old_path: Some(PathBuf::from("123.txt")),
+ path: Path::new("abc.txt").into(),
+ old_path: Path::new("123.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
+ in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -351,10 +353,11 @@ mod tests {
);
let ev = Event::BufferChange {
- path: Some(PathBuf::from("abc.txt")),
- old_path: Some(PathBuf::from("123.txt")),
+ path: Path::new("abc.txt").into(),
+ old_path: Path::new("123.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: true,
+ in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -432,7 +432,7 @@ pub fn write_excerpts<'a>(
}
}
-pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
+pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
if events.is_empty() {
return;
};
@@ -910,7 +910,7 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle
}
struct PromptData {
- events: Vec<Event>,
+ events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
included_files: Vec<IncludedFile>,
@@ -35,7 +35,6 @@ ui.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
-zeta2.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -21,7 +21,9 @@ use language::{
use project::DisableAiSettings;
use regex::Regex;
use settings::{
- EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file,
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
+ update_settings_file,
};
use std::{
sync::{Arc, LazyLock},
@@ -38,7 +40,7 @@ use workspace::{
};
use zed_actions::OpenBrowser;
use zeta::RateCompletions;
-use zeta2::SweepFeatureFlag;
+use zeta::{SweepFeatureFlag, Zeta2FeatureFlag};
actions!(
edit_prediction,
@@ -300,10 +302,7 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
- provider @ (EditPredictionProvider::Experimental(
- EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
- )
- | EditPredictionProvider::Zed) => {
+ provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let enabled = self.editor_enabled.unwrap_or(true);
let is_sweep = matches!(
@@ -430,9 +429,7 @@ impl Render for EditPredictionButton {
div().child(popover_menu.into_any_element())
}
- EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
- div().hidden()
- }
+ EditPredictionProvider::None => div().hidden(),
}
}
}
@@ -497,6 +494,12 @@ impl EditPredictionButton {
));
}
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ providers.push(EditPredictionProvider::Experimental(
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
+ ));
+ }
+
providers
}
@@ -554,7 +557,7 @@ impl EditPredictionButton {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
- let has_api_token = zeta2::Zeta::try_global(cx)
+ let has_api_token = zeta::Zeta::try_global(cx)
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
let entry = ContextMenuEntry::new("Sweep")
@@ -571,6 +574,11 @@ impl EditPredictionButton {
menu.item(entry)
}
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => menu.entry("Zeta2", None, move |_, cx| {
+ set_completion_provider(fs.clone(), cx, provider);
+ }),
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
continue;
}
@@ -13,6 +13,7 @@ use crate::{
},
task_context::RunnableRange,
text_diff::text_diff,
+ unified_diff,
};
pub use crate::{
Grammar, Language, LanguageRegistry,
@@ -745,6 +746,33 @@ pub struct EditPreview {
}
impl EditPreview {
+ pub fn as_unified_diff(&self, edits: &[(Range<Anchor>, impl AsRef<str>)]) -> Option<String> {
+ let (first, _) = edits.first()?;
+ let (last, _) = edits.last()?;
+
+ let start = first.start.to_point(&self.old_snapshot);
+ let old_end = last.end.to_point(&self.old_snapshot);
+ let new_end = last
+ .end
+ .bias_right(&self.old_snapshot)
+ .to_point(&self.applied_edits_snapshot);
+
+ let start = Point::new(start.row.saturating_sub(3), 0);
+ let old_end = Point::new(old_end.row + 3, 0).min(self.old_snapshot.max_point());
+ let new_end = Point::new(new_end.row + 3, 0).min(self.applied_edits_snapshot.max_point());
+
+ Some(unified_diff(
+ &self
+ .old_snapshot
+ .text_for_range(start..old_end)
+ .collect::<String>(),
+ &self
+ .applied_edits_snapshot
+ .text_for_range(start..new_end)
+ .collect::<String>(),
+ ))
+ }
+
pub fn highlight_edits(
&self,
current_snapshot: &BufferSnapshot,
@@ -758,6 +786,8 @@ impl EditPreview {
let mut highlighted_text = HighlightedTextBuilder::default();
+ let visible_range_in_preview_snapshot =
+ visible_range_in_preview_snapshot.to_offset(&self.applied_edits_snapshot);
let mut offset_in_preview_snapshot = visible_range_in_preview_snapshot.start;
let insertion_highlight_style = HighlightStyle {
@@ -825,7 +855,19 @@ impl EditPreview {
highlighted_text.build()
}
- fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<usize>> {
+ pub fn build_result_buffer(&self, cx: &mut App) -> Entity<Buffer> {
+ cx.new(|cx| {
+ let mut buffer = Buffer::local_normalized(
+ self.applied_edits_snapshot.as_rope().clone(),
+ self.applied_edits_snapshot.line_ending(),
+ cx,
+ );
+ buffer.set_language(self.syntax_snapshot.root_language(), cx);
+ buffer
+ })
+ }
+
+ pub fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<Point>> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
@@ -842,7 +884,7 @@ impl EditPreview {
let range = Point::new(start.row, 0)
..Point::new(end.row, self.applied_edits_snapshot.line_len(end.row));
- Some(range.to_offset(&self.applied_edits_snapshot))
+ Some(range)
}
}
@@ -279,6 +279,13 @@ impl SyntaxSnapshot {
self.layers.is_empty()
}
+ pub fn root_language(&self) -> Option<Arc<Language>> {
+ match &self.layers.first()?.content {
+ SyntaxLayerContent::Parsed { language, .. } => Some(language.clone()),
+ SyntaxLayerContent::Pending { .. } => None,
+ }
+ }
+
pub fn update_count(&self) -> usize {
self.update_count
}
@@ -78,6 +78,7 @@ pub enum EditPredictionProvider {
}
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
+pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -101,17 +102,25 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
Content::Supermaven => EditPredictionProvider::Supermaven,
Content::Zed => EditPredictionProvider::Zed,
Content::Codestral => EditPredictionProvider::Codestral,
+ Content::Experimental(name)
+ if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME =>
+ {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ )
+ }
+ Content::Experimental(name)
+ if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
+ {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
+ )
+ }
Content::Experimental(name) => {
- if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME {
- EditPredictionProvider::Experimental(
- EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
- )
- } else {
- return Err(D::Error::custom(format!(
- "Unknown experimental edit prediction provider: {}",
- name
- )));
- }
+ return Err(D::Error::custom(format!(
+ "Unknown experimental edit prediction provider: {}",
+ name
+ )));
}
})
}
@@ -161,7 +161,6 @@ workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
zeta.workspace = true
-zeta2.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
chrono.workspace = true
@@ -7,13 +7,14 @@ use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
-use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
+use settings::{
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
+};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
use ui::Window;
-use zeta::ZetaEditPredictionProvider;
-use zeta2::SweepFeatureFlag;
-use zeta2::Zeta2FeatureFlag;
+use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
@@ -100,9 +101,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
- if let Some(zeta) = zeta::Zeta::global(cx) {
- zeta.update(cx, |zeta, _| zeta.clear_history());
- } else if let Some(zeta) = zeta2::Zeta::try_global(cx) {
+ if let Some(zeta) = zeta::Zeta::try_global(cx) {
zeta.update(cx, |zeta, _| zeta.clear_history());
}
}
@@ -204,86 +203,41 @@ fn assign_edit_prediction_provider(
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
- let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
-
- if let Some(project) = editor.project() {
- let mut worktree = None;
- if let Some(buffer) = &singleton_buffer
- && let Some(file) = buffer.read(cx).file()
- {
- let id = file.worktree_id(cx);
- worktree = project.read(cx).worktree_for_id(id, cx);
- }
-
- if let EditPredictionProvider::Experimental(name) = value
- && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
- && cx.has_flag::<SweepFeatureFlag>()
- {
- let provider = cx.new(|cx| {
- zeta2::ZetaEditPredictionProvider::new(
- project.clone(),
- &client,
- &user_store,
- cx,
- )
- });
-
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- {
- zeta2.update(cx, |zeta, cx| {
- zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
- zeta.register_buffer(buffer, project, cx);
- });
- }
-
- editor.set_edit_prediction_provider(Some(provider), window, cx);
- } else if user_store.read(cx).current_user().is_some() {
- if cx.has_flag::<Zeta2FeatureFlag>() {
- let zeta = zeta2::Zeta::global(client, &user_store, cx);
- let provider = cx.new(|cx| {
- zeta2::ZetaEditPredictionProvider::new(
- project.clone(),
- &client,
- &user_store,
- cx,
- )
- });
-
- // TODO [zeta2] handle multibuffers
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
+ let zeta = zeta::Zeta::global(client, &user_store, cx);
+
+ if let Some(project) = editor.project()
+ && let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ {
+ let has_model = zeta.update(cx, |zeta, cx| {
+ let model = if let EditPredictionProvider::Experimental(name) = value {
+ if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<SweepFeatureFlag>()
+ {
+ zeta::ZetaEditPredictionModel::Sweep
+ } else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<Zeta2FeatureFlag>()
{
- zeta.update(cx, |zeta, cx| {
- zeta.set_edit_prediction_model(
- zeta2::ZetaEditPredictionModel::ZedCloud,
- );
- zeta.register_buffer(buffer, project, cx);
- });
+ zeta::ZetaEditPredictionModel::Zeta2
+ } else {
+ return false;
}
-
- editor.set_edit_prediction_provider(Some(provider), window, cx);
+ } else if user_store.read(cx).current_user().is_some() {
+ zeta::ZetaEditPredictionModel::Zeta1
} else {
- let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
+ return false;
+ };
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- {
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(buffer, project, cx);
- });
- }
+ zeta.set_edit_prediction_model(model);
+ zeta.register_buffer(buffer, project, cx);
+ true
+ });
- let provider = cx.new(|cx| {
- zeta::ZetaEditPredictionProvider::new(
- zeta,
- project.clone(),
- singleton_buffer,
- cx,
- )
- });
- editor.set_edit_prediction_provider(Some(provider), window, cx);
- }
+ if has_model {
+ let provider = cx.new(|cx| {
+ ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx)
+ });
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}
@@ -4,81 +4,80 @@ version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
-exclude = ["fixtures"]
[lints]
workspace = true
[lib]
path = "src/zeta.rs"
-doctest = false
[features]
-test-support = []
+eval-support = []
[dependencies]
ai_onboarding.workspace = true
anyhow.workspace = true
arrayvec.workspace = true
+brotli.workspace = true
+buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
+copilot.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
-copilot.workspace = true
db.workspace = true
edit_prediction.workspace = true
+edit_prediction_context.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
-http_client.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
+lsp.workspace = true
+markdown.workspace = true
menu.workspace = true
+open_ai.workspace = true
+pretty_assertions.workspace = true
postage.workspace = true
project.workspace = true
rand.workspace = true
-regex.workspace = true
release_channel.workspace = true
+regex.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
+smol.workspace = true
+strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
theme.workspace = true
thiserror.workspace = true
-ui.workspace = true
util.workspace = true
+ui.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
[dev-dependencies]
-call = { workspace = true, features = ["test-support"] }
-client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
cloud_api_types.workspace = true
-collections = { workspace = true, features = ["test-support"] }
+cloud_llm_client = { workspace = true, features = ["test-support"] }
ctor.workspace = true
-editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
-http_client = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
+language_model = { workspace = true, features = ["test-support"] }
+lsp.workspace = true
parking_lot.workspace = true
-reqwest_client = { workspace = true, features = ["test-support"] }
-rpc = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
-theme = { workspace = true, features = ["test-support"] }
-tree-sitter-go.workspace = true
-tree-sitter-rust.workspace = true
-workspace = { workspace = true, features = ["test-support"] }
-worktree = { workspace = true, features = ["test-support"] }
zlog.workspace = true
@@ -1,173 +0,0 @@
-use std::cmp;
-
-use crate::EditPrediction;
-use gpui::{
- AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText,
- TextLayout, TextStyle, point, prelude::*, quad, size,
-};
-use language::OffsetRangeExt;
-use settings::Settings;
-use theme::ThemeSettings;
-use ui::prelude::*;
-
-pub struct CompletionDiffElement {
- element: AnyElement,
- text_layout: TextLayout,
- cursor_offset: usize,
-}
-
-impl CompletionDiffElement {
- pub fn new(completion: &EditPrediction, cx: &App) -> Self {
- let mut diff = completion
- .snapshot
- .text_for_range(completion.excerpt_range.clone())
- .collect::<String>();
-
- let mut cursor_offset_in_diff = None;
- let mut delta = 0;
- let mut diff_highlights = Vec::new();
- for (old_range, new_text) in completion.edits.iter() {
- let old_range = old_range.to_offset(&completion.snapshot);
-
- if cursor_offset_in_diff.is_none() && completion.cursor_offset <= old_range.end {
- cursor_offset_in_diff =
- Some(completion.cursor_offset - completion.excerpt_range.start + delta);
- }
-
- let old_start_in_diff = old_range.start - completion.excerpt_range.start + delta;
- let old_end_in_diff = old_range.end - completion.excerpt_range.start + delta;
- if old_start_in_diff < old_end_in_diff {
- diff_highlights.push((
- old_start_in_diff..old_end_in_diff,
- HighlightStyle {
- background_color: Some(cx.theme().status().deleted_background),
- strikethrough: Some(gpui::StrikethroughStyle {
- thickness: px(1.),
- color: Some(cx.theme().colors().text_muted),
- }),
- ..Default::default()
- },
- ));
- }
-
- if !new_text.is_empty() {
- diff.insert_str(old_end_in_diff, new_text);
- diff_highlights.push((
- old_end_in_diff..old_end_in_diff + new_text.len(),
- HighlightStyle {
- background_color: Some(cx.theme().status().created_background),
- ..Default::default()
- },
- ));
- delta += new_text.len();
- }
- }
-
- let cursor_offset_in_diff = cursor_offset_in_diff
- .unwrap_or_else(|| completion.cursor_offset - completion.excerpt_range.start + delta);
-
- let settings = ThemeSettings::get_global(cx).clone();
- let text_style = TextStyle {
- color: cx.theme().colors().editor_foreground,
- font_size: settings.buffer_font_size(cx).into(),
- font_family: settings.buffer_font.family,
- font_features: settings.buffer_font.features,
- font_fallbacks: settings.buffer_font.fallbacks,
- line_height: relative(settings.buffer_line_height.value()),
- font_weight: settings.buffer_font.weight,
- font_style: settings.buffer_font.style,
- ..Default::default()
- };
- let element = StyledText::new(diff).with_default_highlights(&text_style, diff_highlights);
- let text_layout = element.layout().clone();
-
- CompletionDiffElement {
- element: element.into_any_element(),
- text_layout,
- cursor_offset: cursor_offset_in_diff,
- }
- }
-}
-
-impl IntoElement for CompletionDiffElement {
- type Element = Self;
-
- fn into_element(self) -> Self {
- self
- }
-}
-
-impl Element for CompletionDiffElement {
- type RequestLayoutState = ();
- type PrepaintState = ();
-
- fn id(&self) -> Option<ElementId> {
- None
- }
-
- fn source_location(&self) -> Option<&'static core::panic::Location<'static>> {
- None
- }
-
- fn request_layout(
- &mut self,
- _id: Option<&gpui::GlobalElementId>,
- _inspector_id: Option<&gpui::InspectorElementId>,
- window: &mut Window,
- cx: &mut App,
- ) -> (gpui::LayoutId, Self::RequestLayoutState) {
- (self.element.request_layout(window, cx), ())
- }
-
- fn prepaint(
- &mut self,
- _id: Option<&gpui::GlobalElementId>,
- _inspector_id: Option<&gpui::InspectorElementId>,
- _bounds: gpui::Bounds<Pixels>,
- _request_layout: &mut Self::RequestLayoutState,
- window: &mut Window,
- cx: &mut App,
- ) -> Self::PrepaintState {
- self.element.prepaint(window, cx);
- }
-
- fn paint(
- &mut self,
- _id: Option<&gpui::GlobalElementId>,
- _inspector_id: Option<&gpui::InspectorElementId>,
- _bounds: gpui::Bounds<Pixels>,
- _request_layout: &mut Self::RequestLayoutState,
- _prepaint: &mut Self::PrepaintState,
- window: &mut Window,
- cx: &mut App,
- ) {
- if let Some(position) = self.text_layout.position_for_index(self.cursor_offset) {
- let bounds = self.text_layout.bounds();
- let line_height = self.text_layout.line_height();
- let line_width = self
- .text_layout
- .line_layout_for_index(self.cursor_offset)
- .map_or(bounds.size.width, |layout| layout.width());
- window.paint_quad(quad(
- Bounds::new(
- point(bounds.origin.x, position.y),
- size(cmp::max(bounds.size.width, line_width), line_height),
- ),
- Corners::default(),
- cx.theme().colors().editor_active_line_background,
- Edges::default(),
- Hsla::transparent_black(),
- BorderStyle::default(),
- ));
- self.element.paint(window, cx);
- window.paint_quad(quad(
- Bounds::new(position, size(px(2.), line_height)),
- Corners::default(),
- cx.theme().players().local().cursor,
- Edges::default(),
- Hsla::transparent_black(),
- BorderStyle::default(),
- ));
- }
- }
-}
@@ -1,110 +0,0 @@
-use std::any::{Any, TypeId};
-
-use command_palette_hooks::CommandPaletteFilter;
-use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
-use gpui::actions;
-use language::language_settings::EditPredictionProvider;
-use project::DisableAiSettings;
-use settings::{Settings, SettingsStore, update_settings_file};
-use ui::App;
-use workspace::Workspace;
-
-use crate::{RateCompletionModal, onboarding_modal::ZedPredictModal};
-
-actions!(
- edit_prediction,
- [
- /// Resets the edit prediction onboarding state.
- ResetOnboarding,
- /// Opens the rate completions modal.
- RateCompletions
- ]
-);
-
-pub fn init(cx: &mut App) {
- feature_gate_predict_edits_actions(cx);
-
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
- if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
- RateCompletionModal::toggle(workspace, window, cx);
- }
- });
-
- workspace.register_action(
- move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
- ZedPredictModal::toggle(
- workspace,
- workspace.user_store().clone(),
- workspace.client().clone(),
- window,
- cx,
- )
- },
- );
-
- workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
- update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
- settings
- .project
- .all_languages
- .features
- .get_or_insert_default()
- .edit_prediction_provider = Some(EditPredictionProvider::None)
- });
- });
- })
- .detach();
-}
-
-fn feature_gate_predict_edits_actions(cx: &mut App) {
- let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
- let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
- let zeta_all_action_types = [
- TypeId::of::<RateCompletions>(),
- TypeId::of::<ResetOnboarding>(),
- zed_actions::OpenZedPredictOnboarding.type_id(),
- TypeId::of::<crate::ClearHistory>(),
- TypeId::of::<crate::ThumbsUpActiveCompletion>(),
- TypeId::of::<crate::ThumbsDownActiveCompletion>(),
- TypeId::of::<crate::NextEdit>(),
- TypeId::of::<crate::PreviousEdit>(),
- ];
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- filter.hide_action_types(&reset_onboarding_action_types);
- filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
- });
-
- cx.observe_global::<SettingsStore>(move |cx| {
- let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
- let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- if is_ai_disabled {
- filter.hide_action_types(&zeta_all_action_types);
- } else if has_feature_flag {
- filter.show_action_types(&rate_completion_action_types);
- } else {
- filter.hide_action_types(&rate_completion_action_types);
- }
- });
- })
- .detach();
-
- cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
- if !DisableAiSettings::get_global(cx).disable_ai {
- if is_enabled {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.show_action_types(&rate_completion_action_types);
- });
- } else {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- });
- }
- }
- })
- .detach();
-}
@@ -1,6 +1,6 @@
use std::sync::Arc;
-use crate::{ZedPredictUpsell, onboarding_event};
+use crate::ZedPredictUpsell;
use ai_onboarding::EditPredictionOnboarding;
use client::{Client, UserStore};
use db::kvp::Dismissable;
@@ -14,6 +14,16 @@ use settings::update_settings_file;
use ui::{Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
+#[macro_export]
+macro_rules! onboarding_event {
+ ($name:expr) => {
+ telemetry::event!($name, source = "Edit Prediction Onboarding");
+ };
+ ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
+ telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+);
+ };
+}
+
/// Introduces user to Zed's Edit Prediction feature
pub struct ZedPredictModal {
onboarding: Entity<EditPredictionOnboarding>,
@@ -1,9 +0,0 @@
-#[macro_export]
-macro_rules! onboarding_event {
- ($name:expr) => {
- telemetry::event!($name, source = "Edit Prediction Onboarding");
- };
- ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
- telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+);
- };
-}
@@ -1,7 +1,13 @@
-use std::{ops::Range, sync::Arc};
+use std::{
+ ops::Range,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
+use serde::Serialize;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -26,6 +32,17 @@ pub struct EditPrediction {
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
pub buffer: Entity<Buffer>,
+ pub buffer_snapshotted_at: Instant,
+ pub response_received_at: Instant,
+ pub inputs: EditPredictionInputs,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct EditPredictionInputs {
+ pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
+ pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
+ pub cursor_path: Arc<Path>,
}
impl EditPrediction {
@@ -33,14 +50,17 @@ impl EditPrediction {
id: EditPredictionId,
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
- edits: Vec<(Range<Anchor>, Arc<str>)>,
+ edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+ buffer_snapshotted_at: Instant,
+ response_received_at: Instant,
+ inputs: EditPredictionInputs,
cx: &mut AsyncApp,
) -> Option<Self> {
let (edits, snapshot, edit_preview_task) = edited_buffer
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
- interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
+ interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})
@@ -53,7 +73,10 @@ impl EditPrediction {
edits,
snapshot,
edit_preview,
+ inputs,
buffer: edited_buffer.clone(),
+ buffer_snapshotted_at,
+ response_received_at,
})
}
@@ -67,6 +90,10 @@ impl EditPrediction {
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
self.snapshot.remote_id() == buffer.remote_id()
}
+
+ pub fn latency(&self) -> Duration {
+ self.response_received_at - self.buffer_snapshotted_at
+ }
}
impl std::fmt::Debug for EditPrediction {
@@ -147,6 +174,17 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
+ inputs: EditPredictionInputs {
+ events: vec![],
+ included_files: vec![],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ line: cloud_llm_client::predict_edits_v3::Line(0),
+ column: 0,
+ },
+ cursor_path: Path::new("path.txt").into(),
+ },
+ buffer_snapshotted_at: Instant::now(),
+ response_received_at: Instant::now(),
};
cx.update(|cx| {
@@ -131,8 +131,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
fn discard(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.discard_current_prediction(&self.project);
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.discard_current_prediction(&self.project, cx);
+ });
+ }
+
+ fn did_show(&mut self, cx: &mut Context<Self>) {
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.did_show_current_prediction(&self.project, cx);
});
}
@@ -162,8 +168,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.discard_current_prediction(&self.project);
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.discard_current_prediction(&self.project, cx);
});
return None;
};
@@ -1,8 +1,18 @@
-use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta};
-use editor::Editor;
-use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*};
-use language::language_settings;
+use crate::{EditPrediction, EditPredictionRating, Zeta};
+use buffer_diff::{BufferDiff, BufferDiffSnapshot};
+use cloud_zeta2_prompt::write_codeblock;
+use editor::{Editor, ExcerptRange, MultiBuffer};
+use gpui::{
+ App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
+ Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
+};
+use language::{LanguageRegistry, Point, language_settings};
+use markdown::{Markdown, MarkdownStyle};
+use settings::Settings as _;
+use std::fmt::Write;
+use std::sync::Arc;
use std::time::Duration;
+use theme::ThemeSettings;
use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*};
use workspace::{ModalView, Workspace};
@@ -10,41 +20,44 @@ actions!(
zeta,
[
/// Rates the active completion with a thumbs up.
- ThumbsUpActiveCompletion,
+ ThumbsUpActivePrediction,
/// Rates the active completion with a thumbs down.
- ThumbsDownActiveCompletion,
+ ThumbsDownActivePrediction,
/// Navigates to the next edit in the completion history.
NextEdit,
/// Navigates to the previous edit in the completion history.
PreviousEdit,
/// Focuses on the completions list.
- FocusCompletions,
+ FocusPredictions,
/// Previews the selected completion.
- PreviewCompletion,
+ PreviewPrediction,
]
);
-pub struct RateCompletionModal {
+pub struct RatePredictionsModal {
zeta: Entity<Zeta>,
- active_completion: Option<ActiveCompletion>,
+ language_registry: Arc<LanguageRegistry>,
+ active_prediction: Option<ActivePrediction>,
selected_index: usize,
+ diff_editor: Entity<Editor>,
focus_handle: FocusHandle,
_subscription: gpui::Subscription,
- current_view: RateCompletionView,
+ current_view: RatePredictionView,
}
-struct ActiveCompletion {
- completion: EditPrediction,
+struct ActivePrediction {
+ prediction: EditPrediction,
feedback_editor: Entity<Editor>,
+ formatted_inputs: Entity<Markdown>,
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
-enum RateCompletionView {
+enum RatePredictionView {
SuggestedEdits,
RawInput,
}
-impl RateCompletionView {
+impl RatePredictionView {
pub fn name(&self) -> &'static str {
match self {
Self::SuggestedEdits => "Suggested Edits",
@@ -53,25 +66,42 @@ impl RateCompletionView {
}
}
-impl RateCompletionModal {
+impl RatePredictionsModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
- if let Some(zeta) = Zeta::global(cx) {
- workspace.toggle_modal(window, cx, |_window, cx| RateCompletionModal::new(zeta, cx));
+ if let Some(zeta) = Zeta::try_global(cx) {
+ let language_registry = workspace.app_state().languages.clone();
+ workspace.toggle_modal(window, cx, |window, cx| {
+ RatePredictionsModal::new(zeta, language_registry, window, cx)
+ });
- telemetry::event!("Rate Completion Modal Open", source = "Edit Prediction");
+ telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction");
}
}
- pub fn new(zeta: Entity<Zeta>, cx: &mut Context<Self>) -> Self {
+ pub fn new(
+ zeta: Entity<Zeta>,
+ language_registry: Arc<LanguageRegistry>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
let subscription = cx.observe(&zeta, |_, _, cx| cx.notify());
Self {
zeta,
+ language_registry,
selected_index: 0,
focus_handle: cx.focus_handle(),
- active_completion: None,
+ active_prediction: None,
_subscription: subscription,
- current_view: RateCompletionView::SuggestedEdits,
+ diff_editor: cx.new(|cx| {
+ let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
+ let mut editor = Editor::for_multibuffer(multibuffer, None, window, cx);
+ editor.disable_inline_diagnostics();
+ editor.set_expand_all_diff_hunks(cx);
+ editor.set_show_git_diff_gutter(false, cx);
+ editor
+ }),
+ current_view: RatePredictionView::SuggestedEdits,
}
}
@@ -83,7 +113,7 @@ impl RateCompletionModal {
self.selected_index += 1;
self.selected_index = usize::min(
self.selected_index,
- self.zeta.read(cx).shown_completions().count(),
+ self.zeta.read(cx).shown_predictions().count(),
);
cx.notify();
}
@@ -102,7 +132,7 @@ impl RateCompletionModal {
let next_index = self
.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.skip(self.selected_index)
.enumerate()
.skip(1) // Skip straight to the next item
@@ -122,7 +152,7 @@ impl RateCompletionModal {
let prev_index = self
.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.rev()
.skip((completions_len - 1) - self.selected_index)
.enumerate()
@@ -149,14 +179,14 @@ impl RateCompletionModal {
pub fn thumbs_up_active(
&mut self,
- _: &ThumbsUpActiveCompletion,
+ _: &ThumbsUpActivePrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.zeta.update(cx, |zeta, cx| {
- if let Some(active) = &self.active_completion {
- zeta.rate_completion(
- &active.completion,
+ if let Some(active) = &self.active_prediction {
+ zeta.rate_prediction(
+ &active.prediction,
EditPredictionRating::Positive,
active.feedback_editor.read(cx).text(cx),
cx,
@@ -165,9 +195,9 @@ impl RateCompletionModal {
});
let current_completion = self
- .active_completion
+ .active_prediction
.as_ref()
- .map(|completion| completion.completion.clone());
+ .map(|completion| completion.prediction.clone());
self.select_completion(current_completion, false, window, cx);
self.select_next_edit(&Default::default(), window, cx);
self.confirm(&Default::default(), window, cx);
@@ -177,18 +207,18 @@ impl RateCompletionModal {
pub fn thumbs_down_active(
&mut self,
- _: &ThumbsDownActiveCompletion,
+ _: &ThumbsDownActivePrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
- if let Some(active) = &self.active_completion {
+ if let Some(active) = &self.active_prediction {
if active.feedback_editor.read(cx).text(cx).is_empty() {
return;
}
self.zeta.update(cx, |zeta, cx| {
- zeta.rate_completion(
- &active.completion,
+ zeta.rate_prediction(
+ &active.prediction,
EditPredictionRating::Negative,
active.feedback_editor.read(cx).text(cx),
cx,
@@ -197,9 +227,9 @@ impl RateCompletionModal {
}
let current_completion = self
- .active_completion
+ .active_prediction
.as_ref()
- .map(|completion| completion.completion.clone());
+ .map(|completion| completion.prediction.clone());
self.select_completion(current_completion, false, window, cx);
self.select_next_edit(&Default::default(), window, cx);
self.confirm(&Default::default(), window, cx);
@@ -209,7 +239,7 @@ impl RateCompletionModal {
fn focus_completions(
&mut self,
- _: &FocusCompletions,
+ _: &FocusPredictions,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -219,14 +249,14 @@ impl RateCompletionModal {
fn preview_completion(
&mut self,
- _: &PreviewCompletion,
+ _: &PreviewPrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
let completion = self
.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.skip(self.selected_index)
.take(1)
.next()
@@ -239,7 +269,7 @@ impl RateCompletionModal {
let completion = self
.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.skip(self.selected_index)
.take(1)
.next()
@@ -250,54 +280,145 @@ impl RateCompletionModal {
pub fn select_completion(
&mut self,
- completion: Option<EditPrediction>,
+ prediction: Option<EditPrediction>,
focus: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
// Avoid resetting completion rating if it's already selected.
- if let Some(completion) = completion.as_ref() {
+ if let Some(prediction) = prediction {
self.selected_index = self
.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.enumerate()
- .find(|(_, completion_b)| completion.id == completion_b.id)
+ .find(|(_, completion_b)| prediction.id == completion_b.id)
.map(|(ix, _)| ix)
.unwrap_or(self.selected_index);
cx.notify();
- if let Some(prev_completion) = self.active_completion.as_ref()
- && completion.id == prev_completion.completion.id
+ if let Some(prev_prediction) = self.active_prediction.as_ref()
+ && prediction.id == prev_prediction.prediction.id
{
if focus {
- window.focus(&prev_completion.feedback_editor.focus_handle(cx));
+ window.focus(&prev_prediction.feedback_editor.focus_handle(cx));
}
return;
}
+
+ self.diff_editor.update(cx, |editor, cx| {
+ let new_buffer = prediction.edit_preview.build_result_buffer(cx);
+ let new_buffer_snapshot = new_buffer.read(cx).snapshot();
+ let old_buffer_snapshot = prediction.snapshot.clone();
+ let new_buffer_id = new_buffer_snapshot.remote_id();
+
+ let range = prediction
+ .edit_preview
+ .compute_visible_range(&prediction.edits)
+ .unwrap_or(Point::zero()..Point::zero());
+ let start = Point::new(range.start.row.saturating_sub(5), 0);
+ let end = Point::new(range.end.row + 5, 0).min(new_buffer_snapshot.max_point());
+
+ let diff = cx.new::<BufferDiff>(|cx| {
+ let diff_snapshot = BufferDiffSnapshot::new_with_base_buffer(
+ new_buffer_snapshot.text.clone(),
+ Some(old_buffer_snapshot.text().into()),
+ old_buffer_snapshot.clone(),
+ cx,
+ );
+ let diff = BufferDiff::new(&new_buffer_snapshot, cx);
+ cx.spawn(async move |diff, cx| {
+ let diff_snapshot = diff_snapshot.await;
+ diff.update(cx, |diff, cx| {
+ diff.set_snapshot(diff_snapshot, &new_buffer_snapshot.text, cx);
+ })
+ })
+ .detach();
+ diff
+ });
+
+ editor.disable_header_for_buffer(new_buffer_id, cx);
+ editor.buffer().update(cx, |multibuffer, cx| {
+ multibuffer.clear(cx);
+ multibuffer.push_excerpts(
+ new_buffer,
+ vec![ExcerptRange {
+ context: start..end,
+ primary: start..end,
+ }],
+ cx,
+ );
+ multibuffer.add_diff(diff, cx);
+ });
+ });
+
+ let mut formatted_inputs = String::new();
+
+ write!(&mut formatted_inputs, "## Events\n\n").unwrap();
+
+ for event in &prediction.inputs.events {
+ write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
+ }
+
+ write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
+
+ for included_file in &prediction.inputs.included_files {
+ let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
+
+ write!(
+ &mut formatted_inputs,
+ "### {}\n\n",
+ included_file.path.display()
+ )
+ .unwrap();
+
+ write_codeblock(
+ &included_file.path,
+ &included_file.excerpts,
+ if included_file.path == prediction.inputs.cursor_path {
+ cursor_insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut formatted_inputs,
+ );
+ }
+
+ self.active_prediction = Some(ActivePrediction {
+ prediction,
+ feedback_editor: cx.new(|cx| {
+ let mut editor = Editor::multi_line(window, cx);
+ editor.disable_scrollbars_and_minimap(window, cx);
+ editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_git_diff_gutter(false, cx);
+ editor.set_show_code_actions(false, cx);
+ editor.set_show_runnables(false, cx);
+ editor.set_show_breakpoints(false, cx);
+ editor.set_show_wrap_guides(false, cx);
+ editor.set_show_indent_guides(false, cx);
+ editor.set_show_edit_predictions(Some(false), window, cx);
+ editor.set_placeholder_text("Add your feedbackโฆ", window, cx);
+ if focus {
+ cx.focus_self(window);
+ }
+ editor
+ }),
+ formatted_inputs: cx.new(|cx| {
+ Markdown::new(
+ formatted_inputs.into(),
+ Some(self.language_registry.clone()),
+ None,
+ cx,
+ )
+ }),
+ });
+ } else {
+ self.active_prediction = None;
}
- self.active_completion = completion.map(|completion| ActiveCompletion {
- completion,
- feedback_editor: cx.new(|cx| {
- let mut editor = Editor::multi_line(window, cx);
- editor.disable_scrollbars_and_minimap(window, cx);
- editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_git_diff_gutter(false, cx);
- editor.set_show_code_actions(false, cx);
- editor.set_show_runnables(false, cx);
- editor.set_show_breakpoints(false, cx);
- editor.set_show_wrap_guides(false, cx);
- editor.set_show_indent_guides(false, cx);
- editor.set_show_edit_predictions(Some(false), window, cx);
- editor.set_placeholder_text("Add your feedbackโฆ", window, cx);
- if focus {
- cx.focus_self(window);
- }
- editor
- }),
- });
cx.notify();
}
@@ -312,33 +433,31 @@ impl RateCompletionModal {
.child(
Button::new(
ElementId::Name("suggested-edits".into()),
- RateCompletionView::SuggestedEdits.name(),
+ RatePredictionView::SuggestedEdits.name(),
)
.label_size(LabelSize::Small)
.on_click(cx.listener(move |this, _, _window, cx| {
- this.current_view = RateCompletionView::SuggestedEdits;
+ this.current_view = RatePredictionView::SuggestedEdits;
cx.notify();
}))
- .toggle_state(self.current_view == RateCompletionView::SuggestedEdits),
+ .toggle_state(self.current_view == RatePredictionView::SuggestedEdits),
)
.child(
Button::new(
ElementId::Name("raw-input".into()),
- RateCompletionView::RawInput.name(),
+ RatePredictionView::RawInput.name(),
)
.label_size(LabelSize::Small)
.on_click(cx.listener(move |this, _, _window, cx| {
- this.current_view = RateCompletionView::RawInput;
+ this.current_view = RatePredictionView::RawInput;
cx.notify();
}))
- .toggle_state(self.current_view == RateCompletionView::RawInput),
+ .toggle_state(self.current_view == RatePredictionView::RawInput),
)
}
fn render_suggested_edits(&self, cx: &mut Context<Self>) -> Option<gpui::Stateful<Div>> {
- let active_completion = self.active_completion.as_ref()?;
let bg_color = cx.theme().colors().editor_background;
-
Some(
div()
.id("diff")
@@ -347,14 +466,18 @@ impl RateCompletionModal {
.bg(bg_color)
.overflow_scroll()
.whitespace_nowrap()
- .child(CompletionDiffElement::new(
- &active_completion.completion,
- cx,
- )),
+ .child(self.diff_editor.clone()),
)
}
- fn render_raw_input(&self, cx: &mut Context<Self>) -> Option<gpui::Stateful<Div>> {
+ fn render_raw_input(
+ &self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<gpui::Stateful<Div>> {
+ let theme_settings = ThemeSettings::get_global(cx);
+ let buffer_font_size = theme_settings.buffer_font_size(cx);
+
Some(
v_flex()
.size_full()
@@ -368,30 +491,81 @@ impl RateCompletionModal {
.size_full()
.bg(cx.theme().colors().editor_background)
.overflow_scroll()
- .child(if let Some(active_completion) = &self.active_completion {
- format!(
- "{}\n{}",
- active_completion.completion.input_events,
- active_completion.completion.input_excerpt
+ .child(if let Some(active_prediction) = &self.active_prediction {
+ markdown::MarkdownElement::new(
+ active_prediction.formatted_inputs.clone(),
+ MarkdownStyle {
+ base_text_style: window.text_style(),
+ syntax: cx.theme().syntax().clone(),
+ code_block: StyleRefinement {
+ text: Some(TextStyleRefinement {
+ font_family: Some(
+ theme_settings.buffer_font.family.clone(),
+ ),
+ font_size: Some(buffer_font_size.into()),
+ ..Default::default()
+ }),
+ padding: EdgesRefinement {
+ top: Some(DefiniteLength::Absolute(
+ AbsoluteLength::Pixels(px(8.)),
+ )),
+ left: Some(DefiniteLength::Absolute(
+ AbsoluteLength::Pixels(px(8.)),
+ )),
+ right: Some(DefiniteLength::Absolute(
+ AbsoluteLength::Pixels(px(8.)),
+ )),
+ bottom: Some(DefiniteLength::Absolute(
+ AbsoluteLength::Pixels(px(8.)),
+ )),
+ },
+ margin: EdgesRefinement {
+ top: Some(Length::Definite(px(8.).into())),
+ left: Some(Length::Definite(px(0.).into())),
+ right: Some(Length::Definite(px(0.).into())),
+ bottom: Some(Length::Definite(px(12.).into())),
+ },
+ border_style: Some(BorderStyle::Solid),
+ border_widths: EdgesRefinement {
+ top: Some(AbsoluteLength::Pixels(px(1.))),
+ left: Some(AbsoluteLength::Pixels(px(1.))),
+ right: Some(AbsoluteLength::Pixels(px(1.))),
+ bottom: Some(AbsoluteLength::Pixels(px(1.))),
+ },
+ border_color: Some(cx.theme().colors().border_variant),
+ background: Some(
+ cx.theme().colors().editor_background.into(),
+ ),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
)
+ .into_any_element()
} else {
- "No active completion".to_string()
+ div()
+ .child("No active completion".to_string())
+ .into_any_element()
}),
)
.id("raw-input-view"),
)
}
- fn render_active_completion(&mut self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
- let active_completion = self.active_completion.as_ref()?;
- let completion_id = active_completion.completion.id;
+ fn render_active_completion(
+ &mut self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<impl IntoElement> {
+ let active_prediction = self.active_prediction.as_ref()?;
+ let completion_id = active_prediction.prediction.id.clone();
let focus_handle = &self.focus_handle(cx);
let border_color = cx.theme().colors().border;
let bg_color = cx.theme().colors().editor_background;
- let rated = self.zeta.read(cx).is_completion_rated(completion_id);
- let feedback_empty = active_completion
+ let rated = self.zeta.read(cx).is_prediction_rated(&completion_id);
+ let feedback_empty = active_prediction
.feedback_editor
.read(cx)
.text(cx)
@@ -412,10 +586,10 @@ impl RateCompletionModal {
.child(self.render_view_nav(cx))
.when_some(
match self.current_view {
- RateCompletionView::SuggestedEdits => {
+ RatePredictionView::SuggestedEdits => {
self.render_suggested_edits(cx)
}
- RateCompletionView::RawInput => self.render_raw_input(cx),
+ RatePredictionView::RawInput => self.render_raw_input(window, cx),
},
|this, element| this.child(element),
),
@@ -450,7 +624,7 @@ impl RateCompletionModal {
.h_40()
.pt_1()
.bg(bg_color)
- .child(active_completion.feedback_editor.clone()),
+ .child(active_prediction.feedback_editor.clone()),
)
})
.child(
@@ -472,7 +646,7 @@ impl RateCompletionModal {
)
.child(Label::new("Rated completion.").color(Color::Muted)),
)
- } else if active_completion.completion.edits.is_empty() {
+ } else if active_prediction.prediction.edits.is_empty() {
Some(
label_container
.child(
@@ -489,7 +663,7 @@ impl RateCompletionModal {
h_flex()
.gap_1()
.child(
- Button::new("bad", "Bad Completion")
+ Button::new("bad", "Bad Prediction")
.icon(IconName::ThumbsDown)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
@@ -500,14 +674,14 @@ impl RateCompletionModal {
))
})
.key_binding(KeyBinding::for_action_in(
- &ThumbsDownActiveCompletion,
+ &ThumbsDownActivePrediction,
focus_handle,
cx,
))
.on_click(cx.listener(move |this, _, window, cx| {
- if this.active_completion.is_some() {
+ if this.active_prediction.is_some() {
this.thumbs_down_active(
- &ThumbsDownActiveCompletion,
+ &ThumbsDownActivePrediction,
window,
cx,
);
@@ -515,20 +689,20 @@ impl RateCompletionModal {
})),
)
.child(
- Button::new("good", "Good Completion")
+ Button::new("good", "Good Prediction")
.icon(IconName::ThumbsUp)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(rated)
.key_binding(KeyBinding::for_action_in(
- &ThumbsUpActiveCompletion,
+ &ThumbsUpActivePrediction,
focus_handle,
cx,
))
.on_click(cx.listener(move |this, _, window, cx| {
- if this.active_completion.is_some() {
+ if this.active_prediction.is_some() {
this.thumbs_up_active(
- &ThumbsUpActiveCompletion,
+ &ThumbsUpActivePrediction,
window,
cx,
);
@@ -543,34 +717,32 @@ impl RateCompletionModal {
fn render_shown_completions(&self, cx: &Context<Self>) -> impl Iterator<Item = ListItem> {
self.zeta
.read(cx)
- .shown_completions()
+ .shown_predictions()
.cloned()
.enumerate()
.map(|(index, completion)| {
let selected = self
- .active_completion
+ .active_prediction
.as_ref()
- .is_some_and(|selected| selected.completion.id == completion.id);
- let rated = self.zeta.read(cx).is_completion_rated(completion.id);
+ .is_some_and(|selected| selected.prediction.id == completion.id);
+ let rated = self.zeta.read(cx).is_prediction_rated(&completion.id);
let (icon_name, icon_color, tooltip_text) =
match (rated, completion.edits.is_empty()) {
- (true, _) => (IconName::Check, Color::Success, "Rated Completion"),
+ (true, _) => (IconName::Check, Color::Success, "Rated Prediction"),
(false, true) => (IconName::File, Color::Muted, "No Edits Produced"),
(false, false) => (IconName::FileDiff, Color::Accent, "Edits Available"),
};
- let file_name = completion
- .path
- .file_name()
- .map(|f| f.to_string_lossy().into_owned())
- .unwrap_or("untitled".to_string());
- let file_path = completion
- .path
- .parent()
- .map(|p| p.to_string_lossy().into_owned());
-
- ListItem::new(completion.id)
+ let file = completion.buffer.read(cx).file();
+ let file_name = file
+ .as_ref()
+ .map_or(SharedString::new_static("untitled"), |file| {
+ file.file_name(cx).to_string().into()
+ });
+ let file_path = file.map(|file| file.path().as_unix_str().to_string());
+
+ ListItem::new(completion.id.clone())
.inset(true)
.spacing(ListItemSpacing::Sparse)
.focused(index == self.selected_index)
@@ -615,12 +787,12 @@ impl RateCompletionModal {
}
}
-impl Render for RateCompletionModal {
+impl Render for RatePredictionsModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let border_color = cx.theme().colors().border;
h_flex()
- .key_context("RateCompletionModal")
+ .key_context("RatePredictionModal")
.track_focus(&self.focus_handle)
.on_action(cx.listener(Self::dismiss))
.on_action(cx.listener(Self::confirm))
@@ -688,20 +860,20 @@ impl Render for RateCompletionModal {
),
),
)
- .children(self.render_active_completion(cx))
+ .children(self.render_active_completion(window, cx))
.on_mouse_down_out(cx.listener(|_, _, _, cx| cx.emit(DismissEvent)))
}
}
-impl EventEmitter<DismissEvent> for RateCompletionModal {}
+impl EventEmitter<DismissEvent> for RatePredictionsModal {}
-impl Focusable for RateCompletionModal {
+impl Focusable for RatePredictionsModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
-impl ModalView for RateCompletionModal {}
+impl ModalView for RatePredictionsModal {}
fn format_time_ago(elapsed: Duration) -> String {
let seconds = elapsed.as_secs();
@@ -2,7 +2,6 @@ use std::fmt;
use std::{path::Path, sync::Arc};
use serde::{Deserialize, Serialize};
-use util::rel_path::RelPath;
#[derive(Debug, Clone, Serialize)]
pub struct AutocompleteRequest {
@@ -91,34 +90,24 @@ pub struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
-pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
+pub(crate) fn write_event(
+ event: &cloud_llm_client::predict_edits_v3::Event,
+ f: &mut impl fmt::Write,
+) -> fmt::Result {
match event {
- crate::Event::BufferChange {
- old_snapshot,
- new_snapshot,
+ cloud_llm_client::predict_edits_v3::Event::BufferChange {
+ old_path,
+ path,
+ diff,
..
} => {
- let old_path = old_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- let new_path = new_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- if old_path != new_path {
+ if old_path != path {
// TODO confirm how to do this for sweep
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
}
- let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
if !diff.is_empty() {
- write!(
- f,
- "File: {}:\n{}\n",
- new_path.display(util::paths::PathStyle::Posix),
- diff
- )?
+ write!(f, "File: {}:\n{}\n", path.display(), diff)?
}
fmt::Result::Ok(())
@@ -1,130 +1,178 @@
-mod completion_diff_element;
-mod init;
-mod input_excerpt;
-mod license_detection;
-mod onboarding_modal;
-mod onboarding_telemetry;
-mod rate_completion_modal;
-
-pub(crate) use completion_diff_element::*;
-use db::kvp::{Dismissable, KEY_VALUE_STORE};
-use db::smol::stream::StreamExt as _;
-use edit_prediction::DataCollectionState;
-use futures::channel::mpsc;
-pub use init::*;
-use license_detection::LicenseDetectionWatcher;
-pub use rate_completion_modal::*;
-
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Context as _, Result, anyhow, bail};
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
+use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
- ZED_VERSION_HEADER_NAME,
+ RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
};
-use collections::{HashMap, HashSet, VecDeque};
-use futures::AsyncReadExt;
+use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
+use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
+use collections::{HashMap, HashSet};
+use command_palette_hooks::CommandPaletteFilter;
+use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use edit_prediction_context::{
+ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
+ EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
+ SyntaxIndex, SyntaxIndexState,
+};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
+use futures::channel::{mpsc, oneshot};
+use futures::{AsyncReadExt as _, StreamExt as _};
use gpui::{
- App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SharedString, Subscription,
- Task, actions,
+ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+ http_client::{self, AsyncBody, Method},
+ prelude::*,
};
-use http_client::{AsyncBody, HttpClient, Method, Request, Response};
-use input_excerpt::excerpt_for_cursor_position;
use language::{
- Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
+ Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
};
+use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use project::{Project, ProjectPath};
+use lsp::DiagnosticSeverity;
+use open_ai::FunctionDefinition;
+use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
-use settings::WorktreeId;
-use std::collections::hash_map;
-use std::mem;
-use std::str::FromStr;
-use std::{
- cmp,
- fmt::Write,
- future::Future,
- ops::Range,
- path::Path,
- rc::Rc,
- sync::Arc,
- time::{Duration, Instant},
-};
+use serde::de::DeserializeOwned;
+use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
+use std::any::{Any as _, TypeId};
+use std::collections::{VecDeque, hash_map};
use telemetry_events::EditPredictionRating;
+use workspace::Workspace;
+
+use std::fmt::Write as _;
+use std::ops::Range;
+use std::path::Path;
+use std::rc::Rc;
+use std::str::FromStr as _;
+use std::sync::{Arc, LazyLock};
+use std::time::{Duration, Instant};
+use std::{env, mem};
use thiserror::Error;
-use util::ResultExt;
-use util::rel_path::RelPath;
-use uuid::Uuid;
+use util::rel_path::RelPathBuf;
+use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
-use worktree::Worktree;
-
-const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
-const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
-const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
-const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
-const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
-const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
-const MAX_CONTEXT_TOKENS: usize = 150;
-const MAX_REWRITE_TOKENS: usize = 350;
-const MAX_EVENT_TOKENS: usize = 500;
+pub mod assemble_excerpts;
+mod license_detection;
+mod onboarding_modal;
+mod prediction;
+mod provider;
+mod rate_prediction_modal;
+pub mod retrieval_search;
+mod sweep_ai;
+pub mod udiff;
+mod xml_edits;
+pub mod zeta1;
-/// Maximum number of events to track.
-const MAX_EVENT_COUNT: usize = 16;
+#[cfg(test)]
+mod zeta_tests;
+
+use crate::assemble_excerpts::assemble_excerpts;
+use crate::license_detection::LicenseDetectionWatcher;
+use crate::onboarding_modal::ZedPredictModal;
+pub use crate::prediction::EditPrediction;
+pub use crate::prediction::EditPredictionId;
+pub use crate::prediction::EditPredictionInputs;
+use crate::rate_prediction_modal::{
+ NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
+ ThumbsUpActivePrediction,
+};
+use crate::zeta1::request_prediction_with_zeta1;
+pub use provider::ZetaEditPredictionProvider;
actions!(
edit_prediction,
[
+ /// Resets the edit prediction onboarding state.
+ ResetOnboarding,
+ /// Opens the rate completions modal.
+ RateCompletions,
/// Clears the edit prediction history.
- ClearHistory
+ ClearHistory,
]
);
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(Uuid);
+/// Maximum number of events to track.
+const EVENT_COUNT_MAX: usize = 6;
+const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
+const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
-impl From<EditPredictionId> for gpui::ElementId {
- fn from(value: EditPredictionId) -> Self {
- gpui::ElementId::Uuid(value.0)
- }
-}
+pub struct SweepFeatureFlag;
-impl std::fmt::Display for EditPredictionId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
+impl FeatureFlag for SweepFeatureFlag {
+ const NAME: &str = "sweep-ai";
}
+pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
+ max_bytes: 512,
+ min_bytes: 128,
+ target_before_cursor_over_total_bytes: 0.5,
+};
-struct ZedPredictUpsell;
+pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
+ ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
-impl Dismissable for ZedPredictUpsell {
- const KEY: &'static str = "dismissed-edit-predict-upsell";
+pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
+ excerpt: DEFAULT_EXCERPT_OPTIONS,
+};
- fn dismissed() -> bool {
- // To make this backwards compatible with older versions of Zed, we
- // check if the user has seen the previous Edit Prediction Onboarding
- // before, by checking the data collection choice which was written to
- // the database once the user clicked on "Accept and Enable"
- if KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .is_some_and(|s| s.is_some())
- {
- return true;
+pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
+ EditPredictionContextOptions {
+ use_imports: true,
+ max_retrieved_declarations: 0,
+ excerpt: DEFAULT_EXCERPT_OPTIONS,
+ score: EditPredictionScoreOptions {
+ omit_excerpt_overlaps: true,
+ },
+ };
+
+pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
+ context: DEFAULT_CONTEXT_OPTIONS,
+ max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
+ max_diagnostic_bytes: 2048,
+ prompt_format: PromptFormat::DEFAULT,
+ file_indexing_parallelism: 1,
+ buffer_change_grouping_interval: Duration::from_secs(1),
+};
+
+static USE_OLLAMA: LazyLock<bool> =
+ LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
+static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
+ "qwen3-coder:30b".to_string()
+ } else {
+ "yqvev8r3".to_string()
+ })
+});
+static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ match env::var("ZED_ZETA2_MODEL").as_deref() {
+ Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
+ Ok(model) => model,
+ Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
+ Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
+ }
+ .to_string()
+});
+static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
+ env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
+ if *USE_OLLAMA {
+ Some("http://localhost:11434/v1/chat/completions".into())
+ } else {
+ None
}
+ })
+});
- KEY_VALUE_STORE
- .read_kvp(Self::KEY)
- .log_err()
- .is_some_and(|s| s.is_some())
- }
-}
+pub struct Zeta2FeatureFlag;
-pub fn should_show_upsell_modal() -> bool {
- !ZedPredictUpsell::dismissed()
+impl FeatureFlag for Zeta2FeatureFlag {
+ const NAME: &'static str = "zeta2";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
}
#[derive(Clone)]
@@ -132,108 +180,291 @@ struct ZetaGlobal(Entity<Zeta>);
impl Global for ZetaGlobal {}
-#[derive(Clone)]
-pub struct EditPrediction {
- id: EditPredictionId,
- path: Arc<Path>,
- excerpt_range: Range<usize>,
- cursor_offset: usize,
- edits: Arc<[(Range<Anchor>, Arc<str>)]>,
- snapshot: BufferSnapshot,
- edit_preview: EditPreview,
- input_outline: Arc<str>,
- input_events: Arc<str>,
- input_excerpt: Arc<str>,
- output_excerpt: Arc<str>,
- buffer_snapshotted_at: Instant,
- response_received_at: Instant,
+pub struct Zeta {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_token: LlmApiToken,
+ _llm_token_subscription: Subscription,
+ projects: HashMap<EntityId, ZetaProject>,
+ options: ZetaOptions,
+ update_required: bool,
+ debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
+ #[cfg(feature = "eval-support")]
+ eval_cache: Option<Arc<dyn EvalCache>>,
+ edit_prediction_model: ZetaEditPredictionModel,
+ sweep_api_token: Option<String>,
+ sweep_ai_debug_info: Arc<str>,
+ data_collection_choice: DataCollectionChoice,
+ rejected_predictions: Vec<EditPredictionRejection>,
+ reject_predictions_tx: mpsc::UnboundedSender<()>,
+ reject_predictions_debounce_task: Option<Task<()>>,
+ shown_predictions: VecDeque<EditPrediction>,
+ rated_predictions: HashSet<EditPredictionId>,
}
-impl EditPrediction {
- fn latency(&self) -> Duration {
- self.response_received_at
- .duration_since(self.buffer_snapshotted_at)
- }
+#[derive(Default, PartialEq, Eq)]
+pub enum ZetaEditPredictionModel {
+ #[default]
+ Zeta1,
+ Zeta2,
+ Sweep,
+}
- fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
- }
+#[derive(Debug, Clone, PartialEq)]
+pub struct ZetaOptions {
+ pub context: ContextMode,
+ pub max_prompt_bytes: usize,
+ pub max_diagnostic_bytes: usize,
+ pub prompt_format: predict_edits_v3::PromptFormat,
+ pub file_indexing_parallelism: usize,
+ pub buffer_change_grouping_interval: Duration,
}
-impl std::fmt::Debug for EditPrediction {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("EditPrediction")
- .field("id", &self.id)
- .field("path", &self.path)
- .field("edits", &self.edits)
- .finish_non_exhaustive()
+#[derive(Debug, Clone, PartialEq)]
+pub enum ContextMode {
+ Agentic(AgenticContextOptions),
+ Syntax(EditPredictionContextOptions),
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct AgenticContextOptions {
+ pub excerpt: EditPredictionExcerptOptions,
+}
+
+impl ContextMode {
+ pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
+ match self {
+ ContextMode::Agentic(options) => &options.excerpt,
+ ContextMode::Syntax(options) => &options.excerpt,
+ }
}
}
-pub struct Zeta {
- projects: HashMap<EntityId, ZetaProject>,
- client: Arc<Client>,
- shown_completions: VecDeque<EditPrediction>,
- rated_completions: HashSet<EditPredictionId>,
- data_collection_choice: DataCollectionChoice,
- discarded_completions: Vec<EditPredictionRejection>,
- llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
- /// Whether an update to a newer version of Zed is required to continue using Zeta.
- update_required: bool,
- user_store: Entity<UserStore>,
- license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
- discard_completions_debounce_task: Option<Task<()>>,
- discard_completions_tx: mpsc::UnboundedSender<()>,
+#[derive(Debug)]
+pub enum ZetaDebugInfo {
+ ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
+ SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
+ SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
+ ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
+ EditPredictionRequested(ZetaEditPredictionDebugInfo),
+}
+
+#[derive(Debug)]
+pub struct ZetaContextRetrievalStartedDebugInfo {
+ pub project: Entity<Project>,
+ pub timestamp: Instant,
+ pub search_prompt: String,
+}
+
+#[derive(Debug)]
+pub struct ZetaContextRetrievalDebugInfo {
+ pub project: Entity<Project>,
+ pub timestamp: Instant,
+}
+
+#[derive(Debug)]
+pub struct ZetaEditPredictionDebugInfo {
+ pub inputs: EditPredictionInputs,
+ pub retrieval_time: Duration,
+ pub buffer: WeakEntity<Buffer>,
+ pub position: language::Anchor,
+ pub local_prompt: Result<String, String>,
+ pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+}
+
+#[derive(Debug)]
+pub struct ZetaSearchQueryDebugInfo {
+ pub project: Entity<Project>,
+ pub timestamp: Instant,
+ pub search_queries: Vec<SearchToolQuery>,
}
+pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
+
struct ZetaProject {
- events: VecDeque<Event>,
+ syntax_index: Option<Entity<SyntaxIndex>>,
+ events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ last_event: Option<LastEvent>,
+ recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+ current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ last_prediction_refresh: Option<(EntityId, Instant)>,
+ context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
+ refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
+ refresh_context_debounce_task: Option<Task<Option<()>>>,
+ refresh_context_timestamp: Option<Instant>,
+ license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ _subscription: gpui::Subscription,
}
-impl Zeta {
- pub fn global(cx: &mut App) -> Option<Entity<Self>> {
- cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
+impl ZetaProject {
+ pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+ self.events
+ .iter()
+ .cloned()
+ .chain(
+ self.last_event
+ .as_ref()
+ .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
+ )
+ .collect()
}
+}
- pub fn register(
- worktree: Option<Entity<Worktree>>,
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- cx: &mut App,
- ) -> Entity<Self> {
- let this = Self::global(cx).unwrap_or_else(|| {
- let entity = cx.new(|cx| Self::new(client, user_store, cx));
- cx.set_global(ZetaGlobal(entity.clone()));
- entity
- });
+#[derive(Debug, Clone)]
+struct CurrentEditPrediction {
+ pub requested_by: PredictionRequestedBy,
+ pub prediction: EditPrediction,
+ pub was_shown: bool,
+}
- this.update(cx, move |this, cx| {
- if let Some(worktree) = worktree {
- let worktree_id = worktree.read(cx).id();
- this.license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
- }
- });
+impl CurrentEditPrediction {
+ fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
+ let Some(new_edits) = self
+ .prediction
+ .interpolate(&self.prediction.buffer.read(cx))
+ else {
+ return false;
+ };
+
+ if self.prediction.buffer != old_prediction.prediction.buffer {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction
+ .prediction
+ .interpolate(&old_prediction.prediction.buffer.read(cx))
+ else {
+ return true;
+ };
- this
+ let requested_by_buffer_id = self.requested_by.buffer_id();
+
+ // This reduces the occurrence of UI thrash from replacing edits
+ //
+ // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
+ if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
+ && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
+ && old_edits.len() == 1
+ && new_edits.len() == 1
+ {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
+ } else {
+ true
+ }
}
+}
- pub fn clear_history(&mut self) {
- for zeta_project in self.projects.values_mut() {
- zeta_project.events.clear();
+#[derive(Debug, Clone)]
+enum PredictionRequestedBy {
+ DiagnosticsUpdate,
+ Buffer(EntityId),
+}
+
+impl PredictionRequestedBy {
+ pub fn buffer_id(&self) -> Option<EntityId> {
+ match self {
+ PredictionRequestedBy::DiagnosticsUpdate => None,
+ PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
}
}
+}
- pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- self.user_store.read(cx).edit_prediction_usage()
+struct PendingPrediction {
+ id: usize,
+ task: Task<Option<EditPredictionId>>,
+}
+
+/// A prediction from the perspective of a buffer.
+#[derive(Debug)]
+enum BufferEditPrediction<'a> {
+ Local { prediction: &'a EditPrediction },
+ Jump { prediction: &'a EditPrediction },
+}
+
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+struct LastEvent {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ end_edit_anchor: Option<Anchor>,
+}
+
+impl LastEvent {
+ pub fn finalize(
+ &self,
+ license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ cx: &App,
+ ) -> Option<Arc<predict_edits_v3::Event>> {
+ let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
+ let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
+
+ let file = self.new_snapshot.file();
+ let old_file = self.old_snapshot.file();
+
+ let in_open_source_repo = [file, old_file].iter().all(|file| {
+ file.is_some_and(|file| {
+ license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .is_some_and(|watcher| watcher.is_project_open_source())
+ })
+ });
+
+ let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
+
+ if path == old_path && diff.is_empty() {
+ None
+ } else {
+ Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ old_path,
+ path,
+ diff,
+ in_open_source_repo,
+ // TODO: Actually detect if this edit was predicted or not
+ predicted: false,
+ }))
+ }
+ }
+}
+
+fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
+ if let Some(file) = snapshot.file() {
+ file.full_path(cx).into()
+ } else {
+ Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
+ }
+}
+
+impl Zeta {
+ pub fn try_global(cx: &App) -> Option<Entity<Self>> {
+ cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
+ }
+
+ pub fn global(
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Entity<Self> {
+ cx.try_global::<ZetaGlobal>()
+ .map(|global| global.0.clone())
+ .unwrap_or_else(|| {
+ let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
+ cx.set_global(ZetaGlobal(zeta.clone()));
+ zeta
+ })
}
- fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choice();
+
let (reject_tx, mut reject_rx) = mpsc::unbounded();
cx.spawn(async move |this, cx| {
while let Some(()) = reject_rx.next().await {
@@ -248,12 +479,8 @@ impl Zeta {
Self {
projects: HashMap::default(),
client,
- shown_completions: VecDeque::new(),
- rated_completions: HashSet::default(),
- discarded_completions: Vec::new(),
- discard_completions_debounce_task: None,
- discard_completions_tx: reject_tx,
- data_collection_choice,
+ user_store,
+ options: DEFAULT_OPTIONS,
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
@@ -268,64 +495,85 @@ impl Zeta {
},
),
update_required: false,
- license_detection_watchers: HashMap::default(),
- user_store,
+ debug_tx: None,
+ #[cfg(feature = "eval-support")]
+ eval_cache: None,
+ edit_prediction_model: ZetaEditPredictionModel::Zeta2,
+ sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
+ .context("No SWEEP_AI_TOKEN environment variable set")
+ .log_err(),
+ data_collection_choice,
+ sweep_ai_debug_info: sweep_ai::debug_info(cx),
+ rejected_predictions: Vec::new(),
+ reject_predictions_debounce_task: None,
+ reject_predictions_tx: reject_tx,
+ rated_predictions: Default::default(),
+ shown_predictions: Default::default(),
}
}
- fn get_or_init_zeta_project(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &mut ZetaProject {
- let project_id = project.entity_id();
- match self.projects.entry(project_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- cx.observe_release(project, move |this, _, _cx| {
- this.projects.remove(&project_id);
- })
- .detach();
- entry.insert(ZetaProject {
- events: VecDeque::with_capacity(MAX_EVENT_COUNT),
- registered_buffers: HashMap::default(),
- })
- }
- }
+ pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
+ self.edit_prediction_model = model;
}
- fn push_event(zeta_project: &mut ZetaProject, event: Event) {
- let events = &mut zeta_project.events;
+ pub fn has_sweep_api_token(&self) -> bool {
+ self.sweep_api_token.is_some()
+ }
- if let Some(Event::BufferChange {
- new_snapshot: last_new_snapshot,
- timestamp: last_timestamp,
- ..
- }) = events.back_mut()
- {
- // Coalesce edits for the same buffer when they happen one after the other.
- let Event::BufferChange {
- old_snapshot,
- new_snapshot,
- timestamp,
- } = &event;
-
- if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
- && old_snapshot.remote_id() == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version
- {
- *last_new_snapshot = new_snapshot.clone();
- *last_timestamp = *timestamp;
- return;
- }
+ #[cfg(feature = "eval-support")]
+ pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
+ self.eval_cache = Some(cache);
+ }
+
+ pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ self.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
+ }
+
+ pub fn options(&self) -> &ZetaOptions {
+ &self.options
+ }
+
+ pub fn set_options(&mut self, options: ZetaOptions) {
+ self.options = options;
+ }
+
+ pub fn clear_history(&mut self) {
+ for zeta_project in self.projects.values_mut() {
+ zeta_project.events.clear();
}
+ }
+
+ pub fn context_for_project(
+ &self,
+ project: &Entity<Project>,
+ ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
+ self.projects
+ .get(&project.entity_id())
+ .and_then(|project| {
+ Some(
+ project
+ .context
+ .as_ref()?
+ .iter()
+ .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
+ )
+ })
+ .into_iter()
+ .flatten()
+ }
- if events.len() >= MAX_EVENT_COUNT {
- // These are halved instead of popping to improve prompt caching.
- events.drain(..MAX_EVENT_COUNT / 2);
+ pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
+ self.user_store.read(cx).edit_prediction_usage()
+ } else {
+ None
}
+ }
- events.push_back(event);
+ pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ self.get_or_init_zeta_project(project, cx);
}
pub fn register_buffer(
@@ -338,6 +586,69 @@ impl Zeta {
Self::register_buffer_impl(zeta_project, buffer, project, cx);
}
+ fn get_or_init_zeta_project(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &mut ZetaProject {
+ self.projects
+ .entry(project.entity_id())
+ .or_insert_with(|| ZetaProject {
+ syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
+ Some(cx.new(|cx| {
+ SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
+ }))
+ } else {
+ None
+ },
+ events: VecDeque::new(),
+ last_event: None,
+ recent_paths: VecDeque::new(),
+ registered_buffers: HashMap::default(),
+ current_prediction: None,
+ pending_predictions: ArrayVec::new(),
+ next_pending_prediction_id: 0,
+ last_prediction_refresh: None,
+ context: None,
+ refresh_context_task: None,
+ refresh_context_debounce_task: None,
+ refresh_context_timestamp: None,
+ license_detection_watchers: HashMap::default(),
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
+ })
+ }
+
+ fn handle_project_event(
+ &mut self,
+ project: Entity<Project>,
+ event: &project::Event,
+ cx: &mut Context<Self>,
+ ) {
+ // TODO [zeta2] init with recent paths
+ match event {
+ project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+ let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+ if let Some(path) = path {
+ if let Some(ix) = zeta_project
+ .recent_paths
+ .iter()
+ .position(|probe| probe == &path)
+ {
+ zeta_project.recent_paths.remove(ix);
+ }
+ zeta_project.recent_paths.push_front(path);
+ }
+ }
+ project::Event::DiagnosticsUpdated { .. } => {
+ self.refresh_prediction_from_diagnostics(project, cx);
+ }
+ _ => (),
+ }
+ }
+
fn register_buffer_impl<'a>(
zeta_project: &'a mut ZetaProject,
buffer: &Entity<Buffer>,
@@ -345,6 +656,28 @@ impl Zeta {
cx: &mut Context<Self>,
) -> &'a mut RegisteredBuffer {
let buffer_id = buffer.entity_id();
+
+ if let Some(file) = buffer.read(cx).file() {
+ let worktree_id = file.worktree_id(cx);
+ if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
+ zeta_project
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| {
+ let project_entity_id = project.entity_id();
+ cx.observe_release(&worktree, move |this, _worktree, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ zeta_project.license_detection_watchers.remove(&worktree_id);
+ })
+ .detach();
+ Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
+ });
+ }
+ }
+
match zeta_project.registered_buffers.entry(buffer_id) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
@@ -376,2037 +709,2755 @@ impl Zeta {
}
}
- fn request_completion_impl<F, R>(
+ fn report_changes_for_buffer(
&mut self,
- project: &Entity<Project>,
buffer: &Entity<Buffer>,
- cursor: language::Anchor,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
- perform_predict_edits: F,
- ) -> Task<Result<Option<EditPrediction>>>
- where
- F: FnOnce(PerformPredictEditsParams) -> R + 'static,
- R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
- + Send
- + 'static,
- {
- let buffer = buffer.clone();
- let buffer_snapshotted_at = Instant::now();
- let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
- let zeta = cx.entity();
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
-
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- let mut events = Vec::with_capacity(zeta_project.events.len());
- events.extend(zeta_project.events.iter().cloned());
- let events = Arc::new(events);
-
- let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = self.can_collect_file(file, cx);
- let git_info = if can_collect_file {
- git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
- } else {
- None
- };
- (git_info, can_collect_file)
- } else {
- (None, false)
- };
-
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|f| Arc::from(f.full_path(cx).as_path()))
- .unwrap_or_else(|| Arc::from(Path::new("untitled")));
- let full_path_str = full_path.to_string_lossy().into_owned();
- let cursor_point = cursor.to_point(&snapshot);
- let cursor_offset = cursor_point.to_offset(&snapshot);
- let prompt_for_events = {
- let events = events.clone();
- move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
- };
- let gather_task = gather_context(
- full_path_str,
- &snapshot,
- cursor_point,
- prompt_for_events,
- cx,
- );
-
- cx.spawn(async move |this, cx| {
- let GatherContextOutput {
- mut body,
- editable_range,
- included_events_count,
- } = gather_task.await?;
- let done_gathering_context_at = Instant::now();
-
- let included_events = &events[events.len() - included_events_count..events.len()];
- body.can_collect_data = can_collect_file
- && this
- .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
- .unwrap_or(false);
- if body.can_collect_data {
- body.git_info = git_info;
- }
-
- log::debug!(
- "Events:\n{}\nExcerpt:\n{:?}",
- body.input_events,
- body.input_excerpt
- );
-
- let input_outline = body.outline.clone().unwrap_or_default();
- let input_events = body.input_events.clone();
- let input_excerpt = body.input_excerpt.clone();
-
- let response = perform_predict_edits(PerformPredictEditsParams {
- client,
- llm_token,
- app_version,
- body,
- })
- .await;
- let (response, usage) = match response {
- Ok(response) => response,
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- zeta.update(cx, |zeta, _cx| {
- zeta.update_required = true;
- });
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button(
- "Update Zed",
- "https://zed.dev/releases",
- )
- })
- },
- );
- })
- .ok();
- }
+ ) {
+ let project_state = self.get_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
- return Err(err);
- }
- };
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version == registered_buffer.snapshot.version {
+ return;
+ }
- let received_response_at = Instant::now();
- log::debug!("completion response: {}", &response.output_excerpt);
+ let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ let end_edit_anchor = new_snapshot
+ .anchored_edits_since::<Point>(&old_snapshot.version)
+ .last()
+ .map(|(_, range)| range.end);
+ let events = &mut project_state.events;
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
+ if let Some(LastEvent {
+ new_snapshot: last_new_snapshot,
+ end_edit_anchor: last_end_edit_anchor,
+ ..
+ }) = project_state.last_event.as_mut()
+ {
+ let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
+ == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version;
+
+ let should_coalesce = is_next_snapshot_of_same_buffer
+ && end_edit_anchor
+ .as_ref()
+ .zip(last_end_edit_anchor.as_ref())
+ .is_some_and(|(a, b)| {
+ let a = a.to_point(&new_snapshot);
+ let b = b.to_point(&new_snapshot);
+ a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
});
- })
- .ok();
+
+ if should_coalesce {
+ *last_end_edit_anchor = end_edit_anchor;
+ *last_new_snapshot = new_snapshot;
+ return;
}
+ }
- let edit_prediction = Self::process_completion_response(
- response,
- buffer,
- &snapshot,
- editable_range,
- cursor_offset,
- full_path,
- input_outline,
- input_events,
- input_excerpt,
- buffer_snapshotted_at,
- cx,
- )
- .await;
+ if events.len() + 1 >= EVENT_COUNT_MAX {
+ events.pop_front();
+ }
- let finished_at = Instant::now();
-
- // record latency for ~1% of requests
- if rand::random::<u8>() <= 2 {
- telemetry::event!(
- "Edit Prediction Request",
- context_latency = done_gathering_context_at
- .duration_since(buffer_snapshotted_at)
- .as_millis(),
- request_latency = received_response_at
- .duration_since(done_gathering_context_at)
- .as_millis(),
- process_latency = finished_at.duration_since(received_response_at).as_millis()
- );
- }
+ if let Some(event) = project_state.last_event.take() {
+ events.extend(event.finalize(&project_state.license_detection_watchers, cx));
+ }
- edit_prediction
- })
+ project_state.last_event = Some(LastEvent {
+ old_snapshot,
+ new_snapshot,
+ end_edit_anchor,
+ });
}
- #[cfg(any(test, feature = "test-support"))]
- pub fn fake_completion(
- &mut self,
- project: &Entity<Project>,
+ fn current_prediction_for_buffer(
+ &self,
buffer: &Entity<Buffer>,
- position: language::Anchor,
- response: PredictEditsResponse,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- self.request_completion_impl(project, buffer, position, cx, |_params| {
- std::future::ready(Ok((response, None)))
- })
- }
-
- pub fn request_completion(
- &mut self,
project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
- }
-
- pub fn perform_predict_edits(
- params: PerformPredictEditsParams,
- ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
- async move {
- let PerformPredictEditsParams {
- client,
- llm_token,
- app_version,
- body,
- ..
- } = params;
-
- let http_client = client.http_client();
- let mut token = llm_token.acquire(&client).await?;
- let mut did_retry = false;
-
- loop {
- let request_builder = http_client::Request::builder().method(Method::POST);
- let request_builder =
- if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
- request_builder.uri(predict_edits_url)
- } else {
- request_builder.uri(
- http_client
- .build_zed_llm_url("/predict_edits/v2", &[])?
- .as_ref(),
- )
- };
- let request = request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- .body(serde_json::to_string(&body)?.into())?;
+ cx: &App,
+ ) -> Option<BufferEditPrediction<'_>> {
+ let project_state = self.projects.get(&project.entity_id())?;
- let mut response = http_client.send(request).await?;
+ let CurrentEditPrediction {
+ requested_by,
+ prediction,
+ ..
+ } = project_state.current_prediction.as_ref()?;
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- {
- anyhow::ensure!(
- app_version >= minimum_required_version,
- ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }
- );
+ if prediction.targets_buffer(buffer.read(cx)) {
+ Some(BufferEditPrediction::Local { prediction })
+ } else {
+ let show_jump = match requested_by {
+ PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
+ requested_by_buffer_id == &buffer.entity_id()
}
+ PredictionRequestedBy::DiagnosticsUpdate => true,
+ };
- if response.status().is_success() {
- let usage = EditPredictionUsage::from_headers(response.headers()).ok();
-
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Ok((serde_json::from_str(&body)?, usage));
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_token.refresh(&client).await?;
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "error predicting edits.\nStatus: {:?}\nBody: {}",
- response.status(),
- body
- );
- }
+ if show_jump {
+ Some(BufferEditPrediction::Jump { prediction })
+ } else {
+ None
}
}
}
- fn accept_edit_prediction(
- &mut self,
- request_id: EditPredictionId,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
+ fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ match self.edit_prediction_model {
+ ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
+ ZetaEditPredictionModel::Sweep => return,
+ }
+
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ let Some(prediction) = project_state.current_prediction.take() else {
+ return;
+ };
+ let request_id = prediction.prediction.id.to_string();
+ for pending_prediction in mem::take(&mut project_state.pending_predictions) {
+ self.cancel_pending_prediction(pending_prediction, cx);
+ }
+
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
cx.spawn(async move |this, cx| {
- let http_client = client.http_client();
- let mut response = llm_token_retry(&llm_token, &client, |token| {
- let request_builder = http_client::Request::builder().method(Method::POST);
- let request_builder =
- if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
- request_builder.uri(accept_prediction_url)
- } else {
- request_builder.uri(
- http_client
- .build_zed_llm_url("/predict_edits/accept", &[])?
- .as_ref(),
- )
- };
- Ok(request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- .body(
- serde_json::to_string(&AcceptEditPredictionBody {
- request_id: request_id.0.to_string(),
- })?
- .into(),
- )?)
- })
- .await?;
-
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- && app_version < minimum_required_version
- {
- return Err(anyhow!(ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }));
- }
-
- if response.status().is_success() {
- if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })?;
- }
-
- Ok(())
+ let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
+ http_client::Url::parse(&predict_edits_url)?
} else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- Err(anyhow!(
- "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
- response.status(),
- body
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ };
+
+ let response = cx
+ .background_spawn(Self::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(
+ serde_json::to_string(&AcceptEditPredictionBody {
+ request_id: request_id.clone(),
+ })?
+ .into(),
+ );
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
))
- }
+ .await;
+
+ Self::handle_api_response(&this, response, cx)?;
+ anyhow::Ok(())
})
+ .detach_and_log_err(cx);
}
fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ match self.edit_prediction_model {
+ ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
+ ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
+ }
+
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
- let last_rejection = self.discarded_completions.last().cloned();
- let body = serde_json::to_string(&RejectEditPredictionsBody {
- rejections: self.discarded_completions.clone(),
- })
- .ok();
-
+ let last_rejection = self.rejected_predictions.last().cloned();
let Some(last_rejection) = last_rejection else {
return Task::ready(anyhow::Ok(()));
};
+ let body = serde_json::to_string(&RejectEditPredictionsBody {
+ rejections: self.rejected_predictions.clone(),
+ })
+ .ok();
+
cx.spawn(async move |this, cx| {
- let http_client = client.http_client();
- let mut response = llm_token_retry(&llm_token, &client, |token| {
- let request_builder = http_client::Request::builder().method(Method::POST);
- let request_builder = request_builder.uri(
- http_client
- .build_zed_llm_url("/predict_edits/reject", &[])?
- .as_ref(),
- );
- Ok(request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- .body(
- body.as_ref()
- .context("failed to serialize body")?
- .clone()
- .into(),
- )?)
+ let url = client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/reject", &[])?;
+
+ cx.background_spawn(Self::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(body.clone().into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ ))
+ .await
+ .context("Failed to reject edit predictions")?;
+
+ this.update(cx, |this, _| {
+ if let Some(ix) = this
+ .rejected_predictions
+ .iter()
+ .position(|rejection| rejection.request_id == last_rejection.request_id)
+ {
+ this.rejected_predictions.drain(..ix + 1);
+ }
})
- .await?;
+ })
+ }
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- && app_version < minimum_required_version
- {
- return Err(anyhow!(ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }));
+ fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.pending_predictions.clear();
+ if let Some(prediction) = project_state.current_prediction.take() {
+ self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
}
+ };
+ }
- if response.status().is_success() {
- this.update(cx, |this, _| {
- if let Some(ix) = this
- .discarded_completions
- .iter()
- .position(|rejection| rejection.request_id == last_rejection.request_id)
- {
- this.discarded_completions.drain(..ix + 1);
+ fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ if let Some(current_prediction) = project_state.current_prediction.as_mut() {
+ if !current_prediction.was_shown {
+ current_prediction.was_shown = true;
+ self.shown_predictions
+ .push_front(current_prediction.prediction.clone());
+ if self.shown_predictions.len() > 50 {
+ let completion = self.shown_predictions.pop_back().unwrap();
+ self.rated_predictions.remove(&completion.id);
}
- })
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- Err(anyhow!(
- "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
- response.status(),
- body
- ))
+ }
+ }
+ }
+ }
+
+ fn discard_prediction(
+ &mut self,
+ prediction_id: EditPredictionId,
+ was_shown: bool,
+ cx: &mut Context<Self>,
+ ) {
+ self.rejected_predictions.push(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ was_shown,
+ });
+
+ let reached_request_limit =
+ self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
+ let reject_tx = self.reject_predictions_tx.clone();
+ self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
+ const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
+ if !reached_request_limit {
+ cx.background_executor()
+ .timer(DISCARD_COMPLETIONS_DEBOUNCE)
+ .await;
}
+ reject_tx.unbounded_send(()).log_err();
+ }));
+ }
+
+ fn cancel_pending_prediction(
+ &self,
+ pending_prediction: PendingPrediction,
+ cx: &mut Context<Self>,
+ ) {
+ cx.spawn(async move |this, cx| {
+ let Some(prediction_id) = pending_prediction.task.await else {
+ return;
+ };
+
+ this.update(cx, |this, cx| {
+ this.discard_prediction(prediction_id, false, cx);
+ })
+ .ok();
})
+ .detach()
+ }
+
+ fn is_refreshing(&self, project: &Entity<Project>) -> bool {
+ self.projects
+ .get(&project.entity_id())
+ .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
}
- fn process_completion_response(
- prediction_response: PredictEditsResponse,
+ pub fn refresh_prediction_from_buffer(
+ &mut self,
+ project: Entity<Project>,
buffer: Entity<Buffer>,
- snapshot: &BufferSnapshot,
- editable_range: Range<usize>,
- cursor_offset: usize,
- path: Arc<Path>,
- input_outline: String,
- input_events: String,
- input_excerpt: String,
- buffer_snapshotted_at: Instant,
- cx: &AsyncApp,
- ) -> Task<Result<Option<EditPrediction>>> {
- let snapshot = snapshot.clone();
- let request_id = prediction_response.request_id;
- let output_excerpt = prediction_response.output_excerpt;
- cx.spawn(async move |cx| {
- let output_excerpt: Arc<str> = output_excerpt.into();
-
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
- .background_spawn({
- let output_excerpt = output_excerpt.clone();
- let editable_range = editable_range.clone();
- let snapshot = snapshot.clone();
- async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
+ let Some(request_task) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(&project, &buffer, position, cx)
})
- .await?
- .into();
-
- let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
- let edits = edits.clone();
- move |buffer, cx| {
- let new_snapshot = buffer.snapshot();
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
- edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
- .into();
- Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
- }
- })?
+ .log_err()
else {
- return anyhow::Ok(None);
+ return Task::ready(anyhow::Ok(None));
};
- let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
-
- let edit_preview = edit_preview.await;
-
- Ok(Some(EditPrediction {
- id: EditPredictionId(request_id),
- path,
- excerpt_range: editable_range,
- cursor_offset,
- edits,
- edit_preview,
- snapshot,
- input_outline: input_outline.into(),
- input_events: input_events.into(),
- input_excerpt: input_excerpt.into(),
- output_excerpt,
- buffer_snapshotted_at,
- response_received_at: Instant::now(),
- }))
+ let project = project.clone();
+ cx.spawn(async move |cx| {
+ if let Some(prediction) = request_task.await? {
+ let id = prediction.id.clone();
+ this.update(cx, |this, cx| {
+ let project_state = this
+ .projects
+ .get_mut(&project.entity_id())
+ .context("Project not found")?;
+
+ let new_prediction = CurrentEditPrediction {
+ requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
+ prediction: prediction,
+ was_shown: false,
+ };
+
+ if project_state
+ .current_prediction
+ .as_ref()
+ .is_none_or(|old_prediction| {
+ new_prediction.should_replace_prediction(&old_prediction, cx)
+ })
+ {
+ project_state.current_prediction = Some(new_prediction);
+ cx.notify();
+ }
+ anyhow::Ok(())
+ })??;
+ Ok(Some(id))
+ } else {
+ Ok(None)
+ }
+ })
})
}
- fn parse_edits(
- output_excerpt: Arc<str>,
- editable_range: Range<usize>,
- snapshot: &BufferSnapshot,
- ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
- let content = output_excerpt.replace(CURSOR_MARKER, "");
-
- let start_markers = content
- .match_indices(EDITABLE_REGION_START_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- start_markers.len() == 1,
- "expected exactly one start marker, found {}",
- start_markers.len()
- );
+ pub fn refresh_prediction_from_diagnostics(
+ &mut self,
+ project: Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
- let end_markers = content
- .match_indices(EDITABLE_REGION_END_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- end_markers.len() == 1,
- "expected exactly one end marker, found {}",
- end_markers.len()
- );
-
- let sof_markers = content
- .match_indices(START_OF_FILE_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- sof_markers.len() <= 1,
- "expected at most one start-of-file marker, found {}",
- sof_markers.len()
- );
+ // Prefer predictions from buffer
+ if zeta_project.current_prediction.is_some() {
+ return;
+ };
- let codefence_start = start_markers[0].0;
- let content = &content[codefence_start..];
+ self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
+ let Some(open_buffer_task) = project
+ .update(cx, |project, cx| {
+ project
+ .active_entry()
+ .and_then(|entry| project.path_for_entry(entry, cx))
+ .map(|path| project.open_buffer(path, cx))
+ })
+ .log_err()
+ .flatten()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
- let newline_ix = content.find('\n').context("could not find newline")?;
- let content = &content[newline_ix + 1..];
+ cx.spawn(async move |cx| {
+ let active_buffer = open_buffer_task.await?;
+ let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ Default::default(),
+ Default::default(),
+ &project,
+ cx,
+ )
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
- let codefence_end = content
- .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
- .context("could not find end marker")?;
- let new_text = &content[..codefence_end];
+ let Some(prediction) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(&project, &jump_buffer, jump_position, cx)
+ })?
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
- let old_text = snapshot
- .text_for_range(editable_range.clone())
- .collect::<String>();
+ let id = prediction.id.clone();
+ this.update(cx, |this, cx| {
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ zeta_project.current_prediction.get_or_insert_with(|| {
+ cx.notify();
+ CurrentEditPrediction {
+ requested_by: PredictionRequestedBy::DiagnosticsUpdate,
+ prediction,
+ was_shown: false,
+ }
+ });
+ }
+ })?;
- Ok(Self::compute_edits(
- old_text,
- new_text,
- editable_range.start,
- snapshot,
- ))
+ anyhow::Ok(Some(id))
+ })
+ });
}
- pub fn compute_edits(
- old_text: String,
- new_text: &str,
- offset: usize,
- snapshot: &BufferSnapshot,
- ) -> Vec<(Range<Anchor>, Arc<str>)> {
- text_diff(&old_text, new_text)
- .into_iter()
- .map(|(mut old_range, new_text)| {
- old_range.start += offset;
- old_range.end += offset;
+ #[cfg(not(test))]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+ #[cfg(test)]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
- let prefix_len = common_prefix(
- snapshot.chars_for_range(old_range.clone()),
- new_text.chars(),
- );
- old_range.start += prefix_len;
+ fn queue_prediction_refresh(
+ &mut self,
+ project: Entity<Project>,
+ throttle_entity: EntityId,
+ cx: &mut Context<Self>,
+ do_refresh: impl FnOnce(
+ WeakEntity<Self>,
+ &mut AsyncApp,
+ ) -> Task<Result<Option<EditPredictionId>>>
+ + 'static,
+ ) {
+ let zeta_project = self.get_or_init_zeta_project(&project, cx);
+ let pending_prediction_id = zeta_project.next_pending_prediction_id;
+ zeta_project.next_pending_prediction_id += 1;
+ let last_request = zeta_project.last_prediction_refresh;
- let suffix_len = common_prefix(
- snapshot.reversed_chars_for_range(old_range.clone()),
- new_text[prefix_len..].chars().rev(),
- );
- old_range.end = old_range.end.saturating_sub(suffix_len);
+ // TODO report cancelled requests like in zeta1
+ let task = cx.spawn(async move |this, cx| {
+ if let Some((last_entity, last_timestamp)) = last_request
+ && throttle_entity == last_entity
+ && let Some(timeout) =
+ (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
- let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
- let range = if old_range.is_empty() {
- let anchor = snapshot.anchor_after(old_range.start);
- anchor..anchor
- } else {
- snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
- };
- (range, new_text)
+ let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
+
+ // When a prediction completes, remove it from the pending list, and cancel
+ // any pending predictions that were enqueued before it.
+ this.update(cx, |this, cx| {
+ let zeta_project = this.get_or_init_zeta_project(&project, cx);
+ let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
+ for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
+ if pending_prediction.id == pending_prediction_id {
+ pending_predictions.remove(ix);
+ for pending_prediction in pending_predictions.drain(0..ix) {
+ this.cancel_pending_prediction(pending_prediction, cx)
+ }
+ break;
+ }
+ }
+ this.get_or_init_zeta_project(&project, cx)
+ .pending_predictions = pending_predictions;
+ cx.notify();
})
- .collect()
- }
+ .ok();
- pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
- self.rated_completions.contains(&completion_id)
- }
+ edit_prediction_id
+ });
- pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
- self.shown_completions.push_front(completion.clone());
- if self.shown_completions.len() > 50 {
- let completion = self.shown_completions.pop_back().unwrap();
- self.rated_completions.remove(&completion.id);
+ if zeta_project.pending_predictions.len() <= 1 {
+ zeta_project.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ } else if zeta_project.pending_predictions.len() == 2 {
+ let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
+ zeta_project.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ self.cancel_pending_prediction(pending_prediction, cx);
}
- cx.notify();
}
- pub fn rate_completion(
+ pub fn request_prediction(
&mut self,
- completion: &EditPrediction,
- rating: EditPredictionRating,
- feedback: String,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
cx: &mut Context<Self>,
- ) {
- self.rated_completions.insert(completion.id);
- telemetry::event!(
- "Edit Prediction Rated",
- rating,
- input_events = completion.input_events,
- input_excerpt = completion.input_excerpt,
- input_outline = completion.input_outline,
- output_excerpt = completion.output_excerpt,
- feedback
- );
- self.client.telemetry().flush_events().detach();
- cx.notify();
- }
-
- pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
- self.shown_completions.iter()
- }
-
- pub fn shown_completions_len(&self) -> usize {
- self.shown_completions.len()
+ ) -> Task<Result<Option<EditPrediction>>> {
+ match self.edit_prediction_model {
+ ZetaEditPredictionModel::Zeta1 => {
+ request_prediction_with_zeta1(self, project, active_buffer, position, cx)
+ }
+ ZetaEditPredictionModel::Zeta2 => {
+ self.request_prediction_with_zeta2(project, active_buffer, position, cx)
+ }
+ ZetaEditPredictionModel::Sweep => {
+ self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
+ }
+ }
}
- fn report_changes_for_buffer(
+ fn request_prediction_with_sweep(
&mut self,
- buffer: &Entity<Buffer>,
project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ allow_jump: bool,
cx: &mut Context<Self>,
- ) -> BufferSnapshot {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ ) -> Task<Result<Option<EditPrediction>>> {
+ let snapshot = active_buffer.read(cx).snapshot();
+ let debug_info = self.sweep_ai_debug_info.clone();
+ let Some(api_token) = self.sweep_api_token.clone() else {
+ return Task::ready(Ok(None));
+ };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let project_file = project::File::from_dyn(snapshot.file());
+ let repo_name = project_file
+ .map(|file| file.worktree.read(cx).root_name_str())
+ .unwrap_or("untitled")
+ .into();
+ let offset = position.to_offset(&snapshot);
+
+ let project_state = self.get_or_init_zeta_project(project, cx);
+ let events = project_state.events(cx);
+ let has_events = !events.is_empty();
+ let recent_buffers = project_state.recent_paths.iter().cloned();
+ let http_client = cx.http_client();
+
+ let recent_buffer_snapshots = recent_buffers
+ .filter_map(|project_path| {
+ let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+ if active_buffer == &buffer {
+ None
+ } else {
+ Some(buffer.read(cx).snapshot())
+ }
+ })
+ .take(3)
+ .collect::<Vec<_>>();
- let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version != registered_buffer.snapshot.version {
- let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- Self::push_event(
- zeta_project,
- Event::BufferChange {
- old_snapshot,
- new_snapshot: new_snapshot.clone(),
- timestamp: Instant::now(),
- },
- );
- }
+ const DIAGNOSTIC_LINES_RANGE: u32 = 20;
- new_snapshot
- }
+ let cursor_point = position.to_point(&snapshot);
+ let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+ let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+ let diagnostic_search_range =
+ Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
+ let buffer_snapshotted_at = Instant::now();
- fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
- self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
- }
+ let result = cx.background_spawn({
+ let snapshot = snapshot.clone();
+ let diagnostic_search_range = diagnostic_search_range.clone();
+ async move {
+ let text = snapshot.text();
- fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
- if !self.data_collection_choice.is_enabled() {
- return false;
- }
- let mut last_checked_file = None;
- for event in events {
- match event {
- Event::BufferChange {
- old_snapshot,
- new_snapshot,
- ..
- } => {
- if let Some(old_file) = old_snapshot.file()
- && let Some(new_file) = new_snapshot.file()
- {
- if let Some(last_checked_file) = last_checked_file
- && Arc::ptr_eq(last_checked_file, old_file)
- && Arc::ptr_eq(last_checked_file, new_file)
- {
- continue;
- }
- if !self.can_collect_file(old_file, cx) {
- return false;
- }
- if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
- {
- return false;
+ let mut recent_changes = String::new();
+ for event in &events {
+ sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
+ }
+
+ let mut file_chunks = recent_buffer_snapshots
+ .into_iter()
+ .map(|snapshot| {
+ let end_point = Point::new(30, 0).min(snapshot.max_point());
+ sweep_ai::FileChunk {
+ content: snapshot.text_for_range(Point::zero()..end_point).collect(),
+ file_path: snapshot
+ .file()
+ .map(|f| f.path().as_unix_str())
+ .unwrap_or("untitled")
+ .to_string(),
+ start_line: 0,
+ end_line: end_point.row as usize,
+ timestamp: snapshot.file().and_then(|file| {
+ Some(
+ file.disk_state()
+ .mtime()?
+ .to_seconds_and_nanos_for_persistence()?
+ .0,
+ )
+ }),
}
- last_checked_file = Some(new_file);
- } else {
- return false;
- }
+ })
+ .collect::<Vec<_>>();
+
+ let diagnostic_entries =
+ snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let mut diagnostic_content = String::new();
+ let mut diagnostic_count = 0;
+
+ for entry in diagnostic_entries {
+ let start_point: Point = entry.range.start;
+
+ let severity = match entry.diagnostic.severity {
+ DiagnosticSeverity::ERROR => "error",
+ DiagnosticSeverity::WARNING => "warning",
+ DiagnosticSeverity::INFORMATION => "info",
+ DiagnosticSeverity::HINT => "hint",
+ _ => continue,
+ };
+
+ diagnostic_count += 1;
+
+ writeln!(
+ &mut diagnostic_content,
+ "{} at line {}: {}",
+ severity,
+ start_point.row + 1,
+ entry.diagnostic.message
+ )?;
+ }
+
+ if !diagnostic_content.is_empty() {
+ file_chunks.push(sweep_ai::FileChunk {
+ file_path: format!("Diagnostics for {}", full_path.display()),
+ start_line: 0,
+ end_line: diagnostic_count,
+ content: diagnostic_content,
+ timestamp: None,
+ });
+ }
+
+ let request_body = sweep_ai::AutocompleteRequest {
+ debug_info,
+ repo_name,
+ file_path: full_path.clone(),
+ file_contents: text.clone(),
+ original_file_contents: text,
+ cursor_position: offset,
+ recent_changes: recent_changes.clone(),
+ changes_above_cursor: true,
+ multiple_suggestions: false,
+ branch: None,
+ file_chunks,
+ retrieval_chunks: vec![],
+ recent_user_actions: vec![],
+ // TODO
+ privacy_mode_enabled: false,
+ };
+
+ let mut buf: Vec<u8> = Vec::new();
+ let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+ serde_json::to_writer(writer, &request_body)?;
+ let body: AsyncBody = buf.into();
+
+ let inputs = EditPredictionInputs {
+ events,
+ included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ path: full_path.clone(),
+ max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+ excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+ start_line: cloud_llm_client::predict_edits_v3::Line(0),
+ text: request_body.file_contents.into(),
+ }],
+ }],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ column: cursor_point.column,
+ line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+ },
+ cursor_path: full_path.clone(),
+ };
+
+ const SWEEP_API_URL: &str =
+ "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+ let request = http_client::Request::builder()
+ .uri(SWEEP_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .header("Content-Encoding", "br")
+ .method(Method::POST)
+ .body(body)?;
+
+ let mut response = http_client.send(request).await?;
+
+ let mut body: Vec<u8> = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ let response_received_at = Instant::now();
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
+
+ let old_text = snapshot
+ .text_for_range(response.start_index..response.end_index)
+ .collect::<String>();
+ let edits = language::text_diff(&old_text, &response.completion)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(response.start_index + range.start)
+ ..snapshot.anchor_before(response.start_index + range.end),
+ text,
+ )
+ })
+ .collect::<Vec<_>>();
+
+ anyhow::Ok((
+ response.autocomplete_id,
+ edits,
+ snapshot,
+ response_received_at,
+ inputs,
+ ))
+ }
+ });
+
+ let buffer = active_buffer.clone();
+ let project = project.clone();
+ let active_buffer = active_buffer.clone();
+
+ cx.spawn(async move |this, cx| {
+ let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+
+ if edits.is_empty() {
+ if has_events
+ && allow_jump
+ && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ diagnostic_search_range,
+ cursor_point,
+ &project,
+ cx,
+ )
+ .await?
+ {
+ return this
+ .update(cx, |this, cx| {
+ this.request_prediction_with_sweep(
+ &project,
+ &jump_buffer,
+ jump_position,
+ false,
+ cx,
+ )
+ })?
+ .await;
}
+
+ return anyhow::Ok(None);
}
- }
- true
- }
- fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
- if !file.is_local() || file.is_private() {
- return false;
- }
- self.license_detection_watchers
- .get(&file.worktree_id(cx))
- .is_some_and(|watcher| watcher.is_project_open_source())
+ anyhow::Ok(
+ EditPrediction::new(
+ EditPredictionId(id.into()),
+ &buffer,
+ &old_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ response_received_at,
+ inputs,
+ cx,
+ )
+ .await,
+ )
+ })
}
- fn load_data_collection_choice() -> DataCollectionChoice {
- let choice = KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .flatten();
+ async fn next_diagnostic_location(
+ active_buffer: Entity<Buffer>,
+ active_buffer_snapshot: &BufferSnapshot,
+ active_buffer_diagnostic_search_range: Range<Point>,
+ active_buffer_cursor_point: Point,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
+ // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
+ let mut jump_location = active_buffer_snapshot
+ .diagnostic_groups(None)
+ .into_iter()
+ .filter_map(|(_, group)| {
+ let range = &group.entries[group.primary_ix]
+ .range
+ .to_point(&active_buffer_snapshot);
+ if range.overlaps(&active_buffer_diagnostic_search_range) {
+ None
+ } else {
+ Some(range.start)
+ }
+ })
+ .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
+ .map(|position| {
+ (
+ active_buffer.clone(),
+ active_buffer_snapshot.anchor_before(position),
+ )
+ });
- match choice.as_deref() {
- Some("true") => DataCollectionChoice::Enabled,
- Some("false") => DataCollectionChoice::Disabled,
- Some(_) => {
- log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
- DataCollectionChoice::NotAnswered
+ if jump_location.is_none() {
+ let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
+ let file = buffer.file()?;
+
+ Some(ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ })
+ })?;
+
+ let buffer_task = project.update(cx, |project, cx| {
+ let (path, _, _) = project
+ .diagnostic_summaries(false, cx)
+ .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
+ .max_by_key(|(path, _, _)| {
+ // find the buffer with errors that shares most parent directories
+ path.path
+ .components()
+ .zip(
+ active_buffer_path
+ .as_ref()
+ .map(|p| p.path.components())
+ .unwrap_or_default(),
+ )
+ .take_while(|(a, b)| a == b)
+ .count()
+ })?;
+
+ Some(project.open_buffer(path, cx))
+ })?;
+
+ if let Some(buffer_task) = buffer_task {
+ let closest_buffer = buffer_task.await?;
+
+ jump_location = closest_buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .buffer_diagnostics(None)
+ .into_iter()
+ .min_by_key(|entry| entry.diagnostic.severity)
+ .map(|entry| entry.range.start)
+ })?
+ .map(|position| (closest_buffer, position));
}
- None => DataCollectionChoice::NotAnswered,
}
- }
- fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
- self.data_collection_choice = self.data_collection_choice.toggle();
- let new_choice = self.data_collection_choice;
- db::write_and_log(cx, move || {
- KEY_VALUE_STORE.write_kvp(
- ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
- new_choice.is_enabled().to_string(),
- )
- });
+ anyhow::Ok(jump_location)
}
- fn discard_completion(
+ fn request_prediction_with_zeta2(
&mut self,
- completion_id: EditPredictionId,
- was_shown: bool,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
cx: &mut Context<Self>,
- ) {
- self.discarded_completions.push(EditPredictionRejection {
- request_id: completion_id.to_string(),
- was_shown,
- });
+ ) -> Task<Result<Option<EditPrediction>>> {
+ let project_state = self.projects.get(&project.entity_id());
- let reached_request_limit =
- self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
- let discard_completions_tx = self.discard_completions_tx.clone();
- self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
- const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
- if !reached_request_limit {
- cx.background_executor()
- .timer(DISCARD_COMPLETIONS_DEBOUNCE)
- .await;
- }
- discard_completions_tx.unbounded_send(()).log_err();
- }));
- }
-}
+ let index_state = project_state.and_then(|state| {
+ state
+ .syntax_index
+ .as_ref()
+ .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
+ });
+ let options = self.options.clone();
+ let active_snapshot = active_buffer.read(cx).snapshot();
+ let buffer_snapshotted_at = Instant::now();
+ let Some(excerpt_path) = active_snapshot
+ .file()
+ .map(|path| -> Arc<Path> { path.full_path(cx).into() })
+ else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let worktree_snapshots = project
+ .read(cx)
+ .worktrees(cx)
+ .map(|worktree| worktree.read(cx).snapshot())
+ .collect::<Vec<_>>();
+ let debug_tx = self.debug_tx.clone();
-pub struct PerformPredictEditsParams {
- pub client: Arc<Client>,
- pub llm_token: LlmApiToken,
- pub app_version: Version,
- pub body: PredictEditsBody,
-}
+ let events = project_state
+ .map(|state| state.events(cx))
+ .unwrap_or_default();
-#[derive(Error, Debug)]
-#[error(
- "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
-)]
-pub struct ZedUpdateRequiredError {
- minimum_version: Version,
-}
+ let diagnostics = active_snapshot.diagnostic_sets().clone();
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
- a.zip(b)
- .take_while(|(a, b)| a == b)
- .map(|(a, _)| a.len_utf8())
- .sum()
-}
+ let file = active_buffer.read(cx).file();
+ let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
+ let mut path = f.worktree.read(cx).absolutize(&f.path);
+ if path.pop() { Some(path) } else { None }
+ });
-fn git_info_for_file(
- project: &Entity<Project>,
- project_path: &ProjectPath,
- cx: &App,
-) -> Option<PredictEditsGitInfo> {
- let git_store = project.read(cx).git_store().read(cx);
- if let Some((repository, _repo_path)) =
- git_store.repository_and_path_for_project_path(project_path, cx)
- {
- let repository = repository.read(cx);
- let head_sha = repository
- .head_commit
+ // TODO data collection
+ let can_collect_data = file
.as_ref()
- .map(|head_commit| head_commit.sha.to_string());
- let remote_origin_url = repository.remote_origin_url.clone();
- let remote_upstream_url = repository.remote_upstream_url.clone();
- if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
- return None;
- }
- Some(PredictEditsGitInfo {
- head_sha,
- remote_origin_url,
- remote_upstream_url,
- })
- } else {
- None
- }
-}
+ .map_or(false, |file| self.can_collect_file(project, file, cx));
+
+ let empty_context_files = HashMap::default();
+ let context_files = project_state
+ .and_then(|project_state| project_state.context.as_ref())
+ .unwrap_or(&empty_context_files);
+
+ #[cfg(feature = "eval-support")]
+ let parsed_fut = futures::future::join_all(
+ context_files
+ .keys()
+ .map(|buffer| buffer.read(cx).parsing_idle()),
+ );
-pub struct GatherContextOutput {
- pub body: PredictEditsBody,
- pub editable_range: Range<usize>,
- pub included_events_count: usize,
-}
+ let mut included_files = context_files
+ .iter()
+ .filter_map(|(buffer_entity, ranges)| {
+ let buffer = buffer_entity.read(cx);
+ Some((
+ buffer_entity.clone(),
+ buffer.snapshot(),
+ buffer.file()?.full_path(cx).into(),
+ ranges.clone(),
+ ))
+ })
+ .collect::<Vec<_>>();
-pub fn gather_context(
- full_path_str: String,
- snapshot: &BufferSnapshot,
- cursor_point: language::Point,
- prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
- cx: &App,
-) -> Task<Result<GatherContextOutput>> {
- cx.background_spawn({
- let snapshot = snapshot.clone();
- async move {
- let input_excerpt = excerpt_for_cursor_position(
- cursor_point,
- &full_path_str,
- &snapshot,
- MAX_REWRITE_TOKENS,
- MAX_CONTEXT_TOKENS,
- );
- let (input_events, included_events_count) = prompt_for_events();
- let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
-
- let body = PredictEditsBody {
- input_events,
- input_excerpt: input_excerpt.prompt,
- can_collect_data: false,
- diagnostic_groups: None,
- git_info: None,
- outline: None,
- speculated_output: None,
- };
+ included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
+ (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
+ });
- Ok(GatherContextOutput {
- body,
- editable_range,
- included_events_count,
- })
- }
- })
-}
+ #[cfg(feature = "eval-support")]
+ let eval_cache = self.eval_cache.clone();
-fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
- let mut result = String::new();
- for (ix, event) in events.iter().rev().enumerate() {
- let event_string = event.to_prompt();
- let event_tokens = guess_token_count(event_string.len());
- if event_tokens > remaining_tokens {
- return (result, ix);
- }
+ let request_task = cx.background_spawn({
+ let active_buffer = active_buffer.clone();
+ async move {
+ #[cfg(feature = "eval-support")]
+ parsed_fut.await;
- if !result.is_empty() {
- result.insert_str(0, "\n\n");
- }
- result.insert_str(0, &event_string);
- remaining_tokens -= event_tokens;
- }
- return (result, events.len());
-}
+ let index_state = if let Some(index_state) = index_state {
+ Some(index_state.lock_owned().await)
+ } else {
+ None
+ };
-struct RegisteredBuffer {
- snapshot: BufferSnapshot,
- _subscriptions: [gpui::Subscription; 2],
-}
+ let cursor_offset = position.to_offset(&active_snapshot);
+ let cursor_point = cursor_offset.to_point(&active_snapshot);
-#[derive(Clone)]
-pub enum Event {
- BufferChange {
- old_snapshot: BufferSnapshot,
- new_snapshot: BufferSnapshot,
- timestamp: Instant,
- },
-}
+ let before_retrieval = Instant::now();
-impl Event {
- fn to_prompt(&self) -> String {
- match self {
- Event::BufferChange {
- old_snapshot,
- new_snapshot,
- ..
- } => {
- let mut prompt = String::new();
-
- let old_path = old_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- let new_path = new_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- if old_path != new_path {
- writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
+ let (diagnostic_groups, diagnostic_groups_truncated) =
+ Self::gather_nearby_diagnostics(
+ cursor_offset,
+ &diagnostics,
+ &active_snapshot,
+ options.max_diagnostic_bytes,
+ );
+
+ let cloud_request = match options.context {
+ ContextMode::Agentic(context_options) => {
+ let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &active_snapshot,
+ &context_options.excerpt,
+ index_state.as_deref(),
+ ) else {
+ return Ok((None, None));
+ };
+
+ let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
+ ..active_snapshot.anchor_before(excerpt.range.end);
+
+ if let Some(buffer_ix) =
+ included_files.iter().position(|(_, snapshot, _, _)| {
+ snapshot.remote_id() == active_snapshot.remote_id()
+ })
+ {
+ let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
+ ranges.push(excerpt_anchor_range);
+ retrieval_search::merge_anchor_ranges(ranges, buffer);
+ let last_ix = included_files.len() - 1;
+ included_files.swap(buffer_ix, last_ix);
+ } else {
+ included_files.push((
+ active_buffer.clone(),
+ active_snapshot.clone(),
+ excerpt_path.clone(),
+ vec![excerpt_anchor_range],
+ ));
+ }
+
+ let included_files = included_files
+ .iter()
+ .map(|(_, snapshot, path, ranges)| {
+ let ranges = ranges
+ .iter()
+ .map(|range| {
+ let point_range = range.to_point(&snapshot);
+ Line(point_range.start.row)..Line(point_range.end.row)
+ })
+ .collect::<Vec<_>>();
+ let excerpts = assemble_excerpts(&snapshot, ranges);
+ predict_edits_v3::IncludedFile {
+ path: path.clone(),
+ max_row: Line(snapshot.max_point().row),
+ excerpts,
+ }
+ })
+ .collect::<Vec<_>>();
+
+ predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: String::new(),
+ excerpt_line_range: Line(0)..Line(0),
+ excerpt_range: 0..0,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(cursor_point.row),
+ column: cursor_point.column,
+ },
+ included_files,
+ referenced_declarations: vec![],
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ debug_info: debug_tx.is_some(),
+ prompt_max_bytes: Some(options.max_prompt_bytes),
+ prompt_format: options.prompt_format,
+ // TODO [zeta2]
+ signatures: vec![],
+ excerpt_parent: None,
+ git_info: None,
+ }
+ }
+ ContextMode::Syntax(context_options) => {
+ let Some(context) = EditPredictionContext::gather_context(
+ cursor_point,
+ &active_snapshot,
+ parent_abs_path.as_deref(),
+ &context_options,
+ index_state.as_deref(),
+ ) else {
+ return Ok((None, None));
+ };
+
+ make_syntax_context_cloud_request(
+ excerpt_path,
+ context,
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ None,
+ debug_tx.is_some(),
+ &worktree_snapshots,
+ index_state.as_deref(),
+ Some(options.max_prompt_bytes),
+ options.prompt_format,
+ )
+ }
+ };
+
+ let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
+
+ let inputs = EditPredictionInputs {
+ included_files: cloud_request.included_files,
+ events: cloud_request.events,
+ cursor_point: cloud_request.cursor_point,
+ cursor_path: cloud_request.excerpt_path,
+ };
+
+ let retrieval_time = Instant::now() - before_retrieval;
+
+ let debug_response_tx = if let Some(debug_tx) = &debug_tx {
+ let (response_tx, response_rx) = oneshot::channel();
+
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
+ ZetaEditPredictionDebugInfo {
+ inputs: inputs.clone(),
+ retrieval_time,
+ buffer: active_buffer.downgrade(),
+ local_prompt: match prompt_result.as_ref() {
+ Ok((prompt, _)) => Ok(prompt.clone()),
+ Err(err) => Err(err.to_string()),
+ },
+ position,
+ response_rx,
+ },
+ ))
+ .ok();
+ Some(response_tx)
+ } else {
+ None
+ };
+
+ if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((Err("Request skipped".to_string()), Duration::ZERO))
+ .ok();
+ }
+ anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
- let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
- if !diff.is_empty() {
- write!(
- prompt,
- "User edited {:?}:\n```diff\n{}\n```",
- new_path, diff
- )
- .unwrap();
+ let (prompt, _) = prompt_result?;
+ let generation_params =
+ cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
+ let request = open_ai::Request {
+ model: EDIT_PREDICTIONS_MODEL_ID.clone(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: generation_params.stop.unwrap_or_default(),
+ temperature: generation_params.temperature.unwrap_or(0.7),
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ log::trace!("Sending edit prediction request");
+
+ let before_request = Instant::now();
+ let response = Self::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Prediction,
+ )
+ .await;
+ let received_response_at = Instant::now();
+ let request_time = received_response_at - before_request;
+
+ log::trace!("Got edit prediction response");
+
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((
+ response
+ .as_ref()
+ .map_err(|err| err.to_string())
+ .map(|response| response.0.clone()),
+ request_time,
+ ))
+ .ok();
}
- prompt
+ let (res, usage) = response?;
+ let request_id = EditPredictionId(res.id.clone().into());
+ let Some(mut output_text) = text_from_response(res) else {
+ return Ok((None, usage));
+ };
+
+ if output_text.contains(CURSOR_MARKER) {
+ log::trace!("Stripping out {CURSOR_MARKER} from response");
+ output_text = output_text.replace(CURSOR_MARKER, "");
+ }
+
+ let get_buffer_from_context = |path: &Path| {
+ included_files
+ .iter()
+ .find_map(|(_, buffer, probe_path, ranges)| {
+ if probe_path.as_ref() == path {
+ Some((buffer, ranges.as_slice()))
+ } else {
+ None
+ }
+ })
+ };
+
+ let (edited_buffer_snapshot, edits) = match options.prompt_format {
+ PromptFormat::NumLinesUniDiff => {
+ // TODO: Implement parsing of multi-file diffs
+ crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
+ }
+ PromptFormat::Minimal
+ | PromptFormat::MinimalQwen
+ | PromptFormat::SeedCoder1120 => {
+ if output_text.contains("--- a/\n+++ b/\nNo edits") {
+ let edits = vec![];
+ (&active_snapshot, edits)
+ } else {
+ crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
+ }
+ }
+ PromptFormat::OldTextNewText => {
+ crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
+ .await?
+ }
+ _ => {
+ bail!("unsupported prompt format {}", options.prompt_format)
+ }
+ };
+
+ let edited_buffer = included_files
+ .iter()
+ .find_map(|(buffer, snapshot, _, _)| {
+ if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
+ Some(buffer.clone())
+ } else {
+ None
+ }
+ })
+ .context("Failed to find buffer in included_buffers")?;
+
+ anyhow::Ok((
+ Some((
+ request_id,
+ inputs,
+ edited_buffer,
+ edited_buffer_snapshot.clone(),
+ edits,
+ received_response_at,
+ )),
+ usage,
+ ))
}
- }
- }
-}
+ });
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
- buffer_id: EntityId,
- completion: EditPrediction,
- was_shown: bool,
- was_accepted: bool,
-}
+ cx.spawn({
+ async move |this, cx| {
+ let Some((
+ id,
+ inputs,
+ edited_buffer,
+ edited_buffer_snapshot,
+ edits,
+ received_response_at,
+ )) = Self::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
-impl CurrentEditPrediction {
- fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
- if self.buffer_id != old_completion.buffer_id {
- return true;
- }
+ // TODO telemetry: duration, etc
+ Ok(EditPrediction::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await)
+ }
+ })
+ }
- let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
- return true;
- };
- let Some(new_edits) = self.completion.interpolate(snapshot) else {
- return false;
+ async fn send_raw_llm_request(
+ request: open_ai::Request,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+ #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
+ ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
+ let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/raw", &[])?
};
- if old_edits.len() == 1 && new_edits.len() == 1 {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text.as_ref())
- } else {
- true
- }
- }
-}
+ #[cfg(feature = "eval-support")]
+ let cache_key = if let Some(cache) = eval_cache {
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
-struct PendingCompletion {
- id: usize,
- task: Task<()>,
-}
+ let mut hasher = FxHasher::default();
+ url.hash(&mut hasher);
+ let request_str = serde_json::to_string_pretty(&request)?;
+ request_str.hash(&mut hasher);
+ let hash = hasher.finish();
-#[derive(Debug, Clone, Copy)]
-pub enum DataCollectionChoice {
- NotAnswered,
- Enabled,
- Disabled,
-}
+ let key = (eval_cache_kind, hash);
+ if let Some(response_str) = cache.read(key) {
+ return Ok((serde_json::from_str(&response_str)?, None));
+ }
-impl DataCollectionChoice {
- pub fn is_enabled(self) -> bool {
- match self {
- Self::Enabled => true,
- Self::NotAnswered | Self::Disabled => false,
- }
- }
+ Some((cache, request_str, key))
+ } else {
+ None
+ };
- pub fn is_answered(self) -> bool {
- match self {
- Self::Enabled | Self::Disabled => true,
- Self::NotAnswered => false,
+ let (response, usage) = Self::send_api_request(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&request)?.into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ #[cfg(feature = "eval-support")]
+ if let Some((cache, request, key)) = cache_key {
+ cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
}
+
+ Ok((response, usage))
}
- #[must_use]
- pub fn toggle(&self) -> DataCollectionChoice {
- match self {
- Self::Enabled => Self::Disabled,
- Self::Disabled => Self::Enabled,
- Self::NotAnswered => Self::Enabled,
+ fn handle_api_response<T>(
+ this: &WeakEntity<Self>,
+ response: Result<(T, Option<EditPredictionUsage>)>,
+ cx: &mut gpui::AsyncApp,
+ ) -> Result<T> {
+ match response {
+ Ok((data, usage)) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+ Ok(data)
+ }
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ })
+ .ok();
+ }
+ Err(err)
+ }
}
}
-}
-impl From<bool> for DataCollectionChoice {
- fn from(value: bool) -> Self {
- match value {
- true => DataCollectionChoice::Enabled,
- false => DataCollectionChoice::Disabled,
+ async fn send_api_request<Res>(
+ build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ ) -> Result<(Res, Option<EditPredictionUsage>)>
+ where
+ Res: DeserializeOwned,
+ {
+ let http_client = client.http_client();
+ let mut token = llm_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ loop {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+
+ let request = build(
+ request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
+ )?;
+
+ let mut response = http_client.send(request).await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
+ {
+ anyhow::ensure!(
+ app_version >= minimum_required_version,
+ ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }
+ );
+ }
+
+ if response.status().is_success() {
+ let usage = EditPredictionUsage::from_headers(response.headers()).ok();
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ return Ok((serde_json::from_slice(&body)?, usage));
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_token.refresh(&client).await?;
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ body
+ );
+ }
}
}
-}
-async fn llm_token_retry(
- llm_token: &LlmApiToken,
- client: &Arc<Client>,
- build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
-) -> Result<Response<AsyncBody>> {
- let mut did_retry = false;
- let http_client = client.http_client();
- let mut token = llm_token.acquire(client).await?;
- loop {
- let request = build_request(token.clone())?;
- let response = http_client.send(request).await?;
-
- if !did_retry
- && !response.status().is_success()
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_token.refresh(client).await?;
- continue;
+ pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
+ pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
+
+ // Refresh the related excerpts when the user just beguns editing after
+ // an idle period, and after they pause editing.
+ fn refresh_context_if_needed(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
+ return;
}
- return Ok(response);
- }
-}
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
-pub struct ZetaEditPredictionProvider {
- zeta: Entity<Zeta>,
- singleton_buffer: Option<Entity<Buffer>>,
- pending_completions: ArrayVec<PendingCompletion, 2>,
- canceled_completions: HashMap<usize, Task<()>>,
- next_pending_completion_id: usize,
- current_completion: Option<CurrentEditPrediction>,
- last_request_timestamp: Instant,
- project: Entity<Project>,
-}
+ let now = Instant::now();
+ let was_idle = zeta_project
+ .refresh_context_timestamp
+ .map_or(true, |timestamp| {
+ now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
+ });
+ zeta_project.refresh_context_timestamp = Some(now);
+ zeta_project.refresh_context_debounce_task = Some(cx.spawn({
+ let buffer = buffer.clone();
+ let project = project.clone();
+ async move |this, cx| {
+ if was_idle {
+ log::debug!("refetching edit prediction context after idle");
+ } else {
+ cx.background_executor()
+ .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
+ .await;
+ log::debug!("refetching edit prediction context after pause");
+ }
+ this.update(cx, |this, cx| {
+ let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
-impl ZetaEditPredictionProvider {
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ zeta_project.refresh_context_task = Some(task.log_err());
+ };
+ })
+ .ok()
+ }
+ }));
+ }
- pub fn new(
- zeta: Entity<Zeta>,
+ // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
+ // and avoid spawning more than one concurrent task.
+ pub fn refresh_context(
+ &mut self,
project: Entity<Project>,
- singleton_buffer: Option<Entity<Buffer>>,
+ buffer: Entity<language::Buffer>,
+ cursor_position: language::Anchor,
cx: &mut Context<Self>,
- ) -> Self {
- cx.on_release(|this, cx| {
- this.take_current_edit_prediction(cx);
- })
- .detach();
+ ) -> Task<Result<()>> {
+ let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
+ return Task::ready(anyhow::Ok(()));
+ };
- Self {
- zeta,
- singleton_buffer,
- pending_completions: ArrayVec::new(),
- canceled_completions: HashMap::default(),
- next_pending_completion_id: 0,
- current_completion: None,
- last_request_timestamp: Instant::now(),
- project,
+ let ContextMode::Agentic(options) = &self.options().context else {
+ return Task::ready(anyhow::Ok(()));
+ };
+
+ let snapshot = buffer.read(cx).snapshot();
+ let cursor_point = cursor_position.to_point(&snapshot);
+ let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &snapshot,
+ &options.excerpt,
+ None,
+ ) else {
+ return Task::ready(Ok(()));
+ };
+
+ let app_version = AppVersion::global(cx);
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let debug_tx = self.debug_tx.clone();
+ let current_file_path: Arc<Path> = snapshot
+ .file()
+ .map(|f| f.full_path(cx).into())
+ .unwrap_or_else(|| Path::new("untitled").into());
+
+ let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
+ predict_edits_v3::PlanContextRetrievalRequest {
+ excerpt: cursor_excerpt.text(&snapshot).body,
+ excerpt_path: current_file_path,
+ excerpt_line_range: cursor_excerpt.line_range,
+ cursor_file_max_row: Line(snapshot.max_point().row),
+ events: zeta_project.events(cx),
+ },
+ ) {
+ Ok(prompt) => prompt,
+ Err(err) => {
+ return Task::ready(Err(err));
+ }
+ };
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
+ ZetaContextRetrievalStartedDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ search_prompt: prompt.clone(),
+ },
+ ))
+ .ok();
}
- }
- fn take_current_edit_prediction(&mut self, cx: &mut App) {
- if let Some(completion) = self.current_completion.take() {
- if !completion.was_accepted {
- self.zeta.update(cx, |zeta, cx| {
- zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
- });
+ pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
+ let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
+ language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
+ );
+
+ let description = schema
+ .get("description")
+ .and_then(|description| description.as_str())
+ .unwrap()
+ .to_string();
+
+ (schema.into(), description)
+ });
+
+ let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
+
+ let request = open_ai::Request {
+ model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: Default::default(),
+ temperature: 0.7,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![open_ai::ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
+ description: Some(tool_description),
+ parameters: Some(tool_schema),
+ },
+ }],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ #[cfg(feature = "eval-support")]
+ let eval_cache = self.eval_cache.clone();
+
+ cx.spawn(async move |this, cx| {
+ log::trace!("Sending search planning request");
+ let response = Self::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "eval-support")]
+ eval_cache.clone(),
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Context,
+ )
+ .await;
+ let mut response = Self::handle_api_response(&this, response, cx)?;
+ log::trace!("Got search planning response");
+
+ let choice = response
+ .choices
+ .pop()
+ .context("No choices in retrieval response")?;
+ let open_ai::RequestMessage::Assistant {
+ content: _,
+ tool_calls,
+ } = choice.message
+ else {
+ anyhow::bail!("Retrieval response didn't include an assistant message");
+ };
+
+ let mut queries: Vec<SearchToolQuery> = Vec::new();
+ for tool_call in tool_calls {
+ let open_ai::ToolCallContent::Function { function } = tool_call.content;
+ if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
+ log::warn!(
+ "Context retrieval response tried to call an unknown tool: {}",
+ function.name
+ );
+
+ continue;
+ }
+
+ let input: SearchToolInput = serde_json::from_str(&function.arguments)
+ .with_context(|| format!("invalid search json {}", &function.arguments))?;
+ queries.extend(input.queries);
}
- }
- }
-}
-impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
- fn name() -> &'static str {
- "zed-predict"
- }
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
+ ZetaSearchQueryDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ search_queries: queries.clone(),
+ },
+ ))
+ .ok();
+ }
- fn display_name() -> &'static str {
- "Zed's Edit Predictions"
- }
+ log::trace!("Running retrieval search: {queries:#?}");
- fn show_completions_in_menu() -> bool {
- true
- }
+ let related_excerpts_result = retrieval_search::run_retrieval_searches(
+ queries,
+ project.clone(),
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ cx,
+ )
+ .await;
- fn show_tab_accept_marker() -> bool {
- true
- }
+ log::trace!("Search queries executed");
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
+ ZetaContextRetrievalDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ },
+ ))
+ .ok();
+ }
- fn data_collection_state(&self, cx: &App) -> DataCollectionState {
- if let Some(buffer) = &self.singleton_buffer
- && let Some(file) = buffer.read(cx).file()
- {
- let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
- if self.zeta.read(cx).data_collection_choice.is_enabled() {
- DataCollectionState::Enabled {
- is_project_open_source,
+ this.update(cx, |this, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+ return Ok(());
+ };
+ zeta_project.refresh_context_task.take();
+ if let Some(debug_tx) = &this.debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
+ ZetaContextRetrievalDebugInfo {
+ project,
+ timestamp: Instant::now(),
+ },
+ ))
+ .ok();
}
- } else {
- DataCollectionState::Disabled {
- is_project_open_source,
+ match related_excerpts_result {
+ Ok(excerpts) => {
+ zeta_project.context = Some(excerpts);
+ Ok(())
+ }
+ Err(error) => Err(error),
}
- }
- } else {
- return DataCollectionState::Disabled {
- is_project_open_source: false,
- };
- }
- }
-
- fn toggle_data_collection(&mut self, cx: &mut App) {
- self.zeta
- .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
- }
-
- fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- self.zeta.read(cx).usage(cx)
- }
-
- fn is_enabled(
- &self,
- _buffer: &Entity<Buffer>,
- _cursor_position: language::Anchor,
- _cx: &App,
- ) -> bool {
- true
- }
- fn is_refreshing(&self, _cx: &App) -> bool {
- !self.pending_completions.is_empty()
+ })?
+ })
}
- fn refresh(
+ pub fn set_context(
&mut self,
- buffer: Entity<Buffer>,
- position: language::Anchor,
- _debounce: bool,
- cx: &mut Context<Self>,
+ project: Entity<Project>,
+ context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
) {
- if self.zeta.read(cx).update_required {
- return;
+ if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
+ zeta_project.context = Some(context);
}
+ }
- if self
- .zeta
- .read(cx)
- .user_store
- .read_with(cx, |user_store, _cx| {
- user_store.account_too_young() || user_store.has_overdue_invoices()
- })
- {
- return;
+ fn gather_nearby_diagnostics(
+ cursor_offset: usize,
+ diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
+ snapshot: &BufferSnapshot,
+ max_diagnostics_bytes: usize,
+ ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
+ // TODO: Could make this more efficient
+ let mut diagnostic_groups = Vec::new();
+ for (language_server_id, diagnostics) in diagnostic_sets {
+ let mut groups = Vec::new();
+ diagnostics.groups(*language_server_id, &mut groups, &snapshot);
+ diagnostic_groups.extend(
+ groups
+ .into_iter()
+ .map(|(_, group)| group.resolve::<usize>(&snapshot)),
+ );
}
- if let Some(current_completion) = self.current_completion.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if current_completion
- .completion
- .interpolate(&snapshot)
- .is_some()
- {
- return;
+ // sort by proximity to cursor
+ diagnostic_groups.sort_by_key(|group| {
+ let range = &group.entries[group.primary_ix].range;
+ if range.start >= cursor_offset {
+ range.start - cursor_offset
+ } else if cursor_offset >= range.end {
+ cursor_offset - range.end
+ } else {
+ (cursor_offset - range.start).min(range.end - cursor_offset)
+ }
+ });
+
+ let mut results = Vec::new();
+ let mut diagnostic_groups_truncated = false;
+ let mut diagnostics_byte_count = 0;
+ for group in diagnostic_groups {
+ let raw_value = serde_json::value::to_raw_value(&group).unwrap();
+ diagnostics_byte_count += raw_value.get().len();
+ if diagnostics_byte_count > max_diagnostics_bytes {
+ diagnostic_groups_truncated = true;
+ break;
}
+ results.push(predict_edits_v3::DiagnosticGroup(raw_value));
}
- let pending_completion_id = self.next_pending_completion_id;
- self.next_pending_completion_id += 1;
- let last_request_timestamp = self.last_request_timestamp;
+ (results, diagnostic_groups_truncated)
+ }
- let project = self.project.clone();
- let task = cx.spawn(async move |this, cx| {
- if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
- .checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
+ // TODO: Dedupe with similar code in request_prediction?
+ pub fn cloud_request_for_zeta_cli(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
+ let project_state = self.projects.get(&project.entity_id());
+
+ let index_state = project_state.and_then(|state| {
+ state
+ .syntax_index
+ .as_ref()
+ .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
+ });
+ let options = self.options.clone();
+ let snapshot = buffer.read(cx).snapshot();
+ let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+ let worktree_snapshots = project
+ .read(cx)
+ .worktrees(cx)
+ .map(|worktree| worktree.read(cx).snapshot())
+ .collect::<Vec<_>>();
- let completion_request = this.update(cx, |this, cx| {
- this.last_request_timestamp = Instant::now();
- this.zeta.update(cx, |zeta, cx| {
- zeta.request_completion(&project, &buffer, position, cx)
- })
- });
+ let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
+ let mut path = f.worktree.read(cx).absolutize(&f.path);
+ if path.pop() { Some(path) } else { None }
+ });
- let completion = match completion_request {
- Ok(completion_request) => {
- let completion_request = completion_request.await;
- completion_request.map(|c| {
- c.map(|completion| CurrentEditPrediction {
- buffer_id: buffer.entity_id(),
- completion,
- was_shown: false,
- was_accepted: false,
- })
- })
- }
- Err(error) => Err(error),
+ cx.background_spawn(async move {
+ let index_state = if let Some(index_state) = index_state {
+ Some(index_state.lock_owned().await)
+ } else {
+ None
};
- let discarded = this
- .update(cx, |this, cx| {
- if this
- .pending_completions
- .first()
- .is_some_and(|completion| completion.id == pending_completion_id)
- {
- this.pending_completions.remove(0);
- } else {
- if let Some(discarded) = this.pending_completions.drain(..).next() {
- this.canceled_completions
- .insert(discarded.id, discarded.task);
- }
- }
-
- let canceled = this.canceled_completions.remove(&pending_completion_id);
+ let cursor_point = position.to_point(&snapshot);
- if canceled.is_some()
- && let Ok(Some(new_completion)) = &completion
- {
- this.zeta.update(cx, |zeta, cx| {
- zeta.discard_completion(new_completion.completion.id, false, cx);
- });
- return true;
+ let debug_info = true;
+ EditPredictionContext::gather_context(
+ cursor_point,
+ &snapshot,
+ parent_abs_path.as_deref(),
+ match &options.context {
+ ContextMode::Agentic(_) => {
+ // TODO
+ panic!("Llm mode not supported in zeta cli yet");
}
-
- cx.notify();
- false
- })
- .ok()
- .unwrap_or(true);
-
- if discarded {
- return;
- }
-
- let Some(new_completion) = completion
- .context("edit prediction failed")
- .log_err()
- .flatten()
- else {
- return;
- };
-
- this.update(cx, |this, cx| {
- if let Some(old_completion) = this.current_completion.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if new_completion.should_replace_completion(old_completion, &snapshot) {
- this.zeta.update(cx, |zeta, cx| {
- zeta.completion_shown(&new_completion.completion, cx);
- });
- this.take_current_edit_prediction(cx);
- this.current_completion = Some(new_completion);
+ ContextMode::Syntax(edit_prediction_context_options) => {
+ edit_prediction_context_options
}
- } else {
- this.zeta.update(cx, |zeta, cx| {
- zeta.completion_shown(&new_completion.completion, cx);
- });
- this.current_completion = Some(new_completion);
- }
-
- cx.notify();
+ },
+ index_state.as_deref(),
+ )
+ .context("Failed to select excerpt")
+ .map(|context| {
+ make_syntax_context_cloud_request(
+ excerpt_path.into(),
+ context,
+ // TODO pass everything
+ Vec::new(),
+ false,
+ Vec::new(),
+ false,
+ None,
+ debug_info,
+ &worktree_snapshots,
+ index_state.as_deref(),
+ Some(options.max_prompt_bytes),
+ options.prompt_format,
+ )
})
- .ok();
- });
-
- // We always maintain at most two pending completions. When we already
- // have two, we replace the newest one.
- if self.pending_completions.len() <= 1 {
- self.pending_completions.push(PendingCompletion {
- id: pending_completion_id,
- task,
- });
- } else if self.pending_completions.len() == 2 {
- if let Some(discarded) = self.pending_completions.pop() {
- self.canceled_completions
- .insert(discarded.id, discarded.task);
- }
- self.pending_completions.push(PendingCompletion {
- id: pending_completion_id,
- task,
- });
- }
+ })
}
- fn cycle(
+ pub fn wait_for_initial_indexing(
&mut self,
- _buffer: Entity<Buffer>,
- _cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
- _cx: &mut Context<Self>,
- ) {
- // Right now we don't support cycling.
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ if let Some(syntax_index) = &zeta_project.syntax_index {
+ syntax_index.read(cx).wait_for_initial_file_indexing(cx)
+ } else {
+ Task::ready(Ok(()))
+ }
}
- fn accept(&mut self, cx: &mut Context<Self>) {
- let completion = self.current_completion.as_mut();
- if let Some(completion) = completion {
- completion.was_accepted = true;
- self.zeta
- .update(cx, |zeta, cx| {
- zeta.accept_edit_prediction(completion.completion.id, cx)
- })
- .detach();
+ fn is_file_open_source(
+ &self,
+ project: &Entity<Project>,
+ file: &Arc<dyn File>,
+ cx: &App,
+ ) -> bool {
+ if !file.is_local() || file.is_private() {
+ return false;
}
- self.pending_completions.clear();
+ let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
+ return false;
+ };
+ zeta_project
+ .license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .as_ref()
+ .is_some_and(|watcher| watcher.is_project_open_source())
}
- fn discard(&mut self, cx: &mut Context<Self>) {
- self.pending_completions.clear();
- self.take_current_edit_prediction(cx);
+ fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
+ self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
- fn did_show(&mut self, _cx: &mut Context<Self>) {
- if let Some(current_completion) = self.current_completion.as_mut() {
- current_completion.was_shown = true;
+ fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+ if !self.data_collection_choice.is_enabled() {
+ return false;
}
+ events.iter().all(|event| {
+ matches!(
+ event.as_ref(),
+ Event::BufferChange {
+ in_open_source_repo: true,
+ ..
+ }
+ )
+ })
}
- fn suggest(
- &mut self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
- let CurrentEditPrediction {
- buffer_id,
- completion,
- ..
- } = self.current_completion.as_mut()?;
-
- // Invalidate previous completion if it was generated for a different buffer.
- if *buffer_id != buffer.entity_id() {
- self.take_current_edit_prediction(cx);
- return None;
- }
-
- let buffer = buffer.read(cx);
- let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
- self.take_current_edit_prediction(cx);
- return None;
- };
-
- let cursor_row = cursor_position.to_point(buffer).row;
- let (closest_edit_ix, (closest_edit_range, _)) =
- edits.iter().enumerate().min_by_key(|(_, (range, _))| {
- let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
- let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
- cmp::min(distance_from_start, distance_from_end)
- })?;
+ fn load_data_collection_choice() -> DataCollectionChoice {
+ let choice = KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .flatten();
- let mut edit_start_ix = closest_edit_ix;
- for (range, _) in edits[..edit_start_ix].iter().rev() {
- let distance_from_closest_edit =
- closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_start_ix -= 1;
- } else {
- break;
+ match choice.as_deref() {
+ Some("true") => DataCollectionChoice::Enabled,
+ Some("false") => DataCollectionChoice::Disabled,
+ Some(_) => {
+ log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
+ DataCollectionChoice::NotAnswered
}
+ None => DataCollectionChoice::NotAnswered,
}
+ }
- let mut edit_end_ix = closest_edit_ix + 1;
- for (range, _) in &edits[edit_end_ix..] {
- let distance_from_closest_edit =
- range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_end_ix += 1;
- } else {
- break;
- }
- }
+ pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
+ self.shown_predictions.iter()
+ }
- Some(edit_prediction::EditPrediction::Local {
- id: Some(completion.id.to_string().into()),
- edits: edits[edit_start_ix..edit_end_ix].to_vec(),
- edit_preview: Some(completion.edit_preview.clone()),
- })
+ pub fn shown_completions_len(&self) -> usize {
+ self.shown_predictions.len()
}
-}
-/// Typical number of string bytes per token for the purposes of limiting model input. This is
-/// intentionally low to err on the side of underestimating limits.
-const BYTES_PER_TOKEN_GUESS: usize = 3;
+ pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
+ self.rated_predictions.contains(id)
+ }
-fn guess_token_count(bytes: usize) -> usize {
- bytes / BYTES_PER_TOKEN_GUESS
+ pub fn rate_prediction(
+ &mut self,
+ prediction: &EditPrediction,
+ rating: EditPredictionRating,
+ feedback: String,
+ cx: &mut Context<Self>,
+ ) {
+ self.rated_predictions.insert(prediction.id.clone());
+ telemetry::event!(
+ "Edit Prediction Rated",
+ rating,
+ inputs = prediction.inputs,
+ output = prediction.edit_preview.as_unified_diff(&prediction.edits),
+ feedback
+ );
+ self.client.telemetry().flush_events().detach();
+ cx.notify();
+ }
}
-#[cfg(test)]
-mod tests {
- use client::test::FakeServer;
- use clock::{FakeSystemClock, ReplicaId};
- use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
- use gpui::TestAppContext;
- use http_client::FakeHttpClient;
- use indoc::indoc;
- use language::Point;
- use parking_lot::Mutex;
- use serde_json::json;
- use settings::SettingsStore;
- use util::{path, rel_path::rel_path};
-
- use super::*;
-
- const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
+pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
+ let choice = res.choices.pop()?;
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return None;
+ }
- #[gpui::test]
- async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
- let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
- to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
- });
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return None;
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return None;
+ }
+ };
+ Some(output_text)
+}
- let edit_preview = cx
- .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
- .await;
+#[derive(Error, Debug)]
+#[error(
+ "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
+)]
+pub struct ZedUpdateRequiredError {
+ minimum_version: Version,
+}
- let completion = EditPrediction {
- edits,
- edit_preview,
- path: Path::new("").into(),
- snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
- id: EditPredictionId(Uuid::new_v4()),
- excerpt_range: 0..0,
- cursor_offset: 0,
- input_outline: "".into(),
- input_events: "".into(),
- input_excerpt: "".into(),
- output_excerpt: "".into(),
- buffer_snapshotted_at: Instant::now(),
- response_received_at: Instant::now(),
+fn make_syntax_context_cloud_request(
+ excerpt_path: Arc<Path>,
+ context: EditPredictionContext,
+ events: Vec<Arc<predict_edits_v3::Event>>,
+ can_collect_data: bool,
+ diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+ diagnostic_groups_truncated: bool,
+ git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
+ debug_info: bool,
+ worktrees: &Vec<worktree::Snapshot>,
+ index_state: Option<&SyntaxIndexState>,
+ prompt_max_bytes: Option<usize>,
+ prompt_format: PromptFormat,
+) -> predict_edits_v3::PredictEditsRequest {
+ let mut signatures = Vec::new();
+ let mut declaration_to_signature_index = HashMap::default();
+ let mut referenced_declarations = Vec::new();
+
+ for snippet in context.declarations {
+ let project_entry_id = snippet.declaration.project_entry_id();
+ let Some(path) = worktrees.iter().find_map(|worktree| {
+ worktree.entry_for_id(project_entry_id).map(|entry| {
+ let mut full_path = RelPathBuf::new();
+ full_path.push(worktree.root_name());
+ full_path.push(&entry.path);
+ full_path
+ })
+ }) else {
+ continue;
};
- cx.update(|cx| {
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..2, "REM".into()), (6..8, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.undo(cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
+ let parent_index = index_state.and_then(|index_state| {
+ snippet.declaration.parent().and_then(|parent| {
+ add_signature(
+ parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index_state,
+ )
+ })
+ });
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(3..3, "EM".into()), (7..9, "".into())]
- );
+ let (text, text_is_truncated) = snippet.declaration.item_text();
+ referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+ path: path.as_std_path().into(),
+ text: text.into(),
+ range: snippet.declaration.item_line_range(),
+ text_is_truncated,
+ signature_range: snippet.declaration.signature_range_in_item_text(),
+ parent_index,
+ signature_score: snippet.score(DeclarationStyle::Signature),
+ declaration_score: snippet.score(DeclarationStyle::Declaration),
+ score_components: snippet.components,
+ });
+ }
- buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
+ let excerpt_parent = index_state.and_then(|index_state| {
+ context
+ .excerpt
+ .parent_declarations
+ .last()
+ .and_then(|(parent, _)| {
+ add_signature(
+ *parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index_state,
+ )
+ })
+ });
+
+ predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: context.excerpt_text.body,
+ excerpt_line_range: context.excerpt.line_range,
+ excerpt_range: context.excerpt.range,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(context.cursor_point.row),
+ column: context.cursor_point.column,
+ },
+ referenced_declarations,
+ included_files: vec![],
+ signatures,
+ excerpt_parent,
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ git_info,
+ debug_info,
+ prompt_max_bytes,
+ prompt_format,
+ }
+}
- buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(9..11, "".into())]
- );
+fn add_signature(
+ declaration_id: DeclarationId,
+ declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
+ signatures: &mut Vec<Signature>,
+ index: &SyntaxIndexState,
+) -> Option<usize> {
+ if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
+ return Some(*signature_index);
+ }
+ let Some(parent_declaration) = index.declaration(declaration_id) else {
+ log::error!("bug: missing parent declaration");
+ return None;
+ };
+ let parent_index = parent_declaration.parent().and_then(|parent| {
+ add_signature(parent, declaration_to_signature_index, signatures, index)
+ });
+ let (text, text_is_truncated) = parent_declaration.signature_text();
+ let signature_index = signatures.len();
+ signatures.push(Signature {
+ text: text.into(),
+ text_is_truncated,
+ parent_index,
+ range: parent_declaration.signature_line_range(),
+ });
+ declaration_to_signature_index.insert(declaration_id, signature_index);
+ Some(signature_index)
+}
- buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
+#[cfg(feature = "eval-support")]
+pub type EvalCacheKey = (EvalCacheEntryKind, u64);
- buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into())]
- );
+#[cfg(feature = "eval-support")]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum EvalCacheEntryKind {
+ Context,
+ Search,
+ Prediction,
+}
- buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
- })
+#[cfg(feature = "eval-support")]
+impl std::fmt::Display for EvalCacheEntryKind {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ EvalCacheEntryKind::Search => write!(f, "search"),
+ EvalCacheEntryKind::Context => write!(f, "context"),
+ EvalCacheEntryKind::Prediction => write!(f, "prediction"),
+ }
}
+}
- #[gpui::test]
- async fn test_clean_up_diff(cx: &mut TestAppContext) {
- init_test(cx);
-
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word.len()..word.len();
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
+#[cfg(feature = "eval-support")]
+pub trait EvalCache: Send + Sync {
+ fn read(&self, key: EvalCacheKey) -> Option<String>;
+ fn write(&self, key: EvalCacheKey, input: &str, value: &str);
+}
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
- "},
- );
+#[derive(Debug, Clone, Copy)]
+pub enum DataCollectionChoice {
+ NotAnswered,
+ Enabled,
+ Disabled,
+}
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let story = \"the quick\"
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
+impl DataCollectionChoice {
+ pub fn is_enabled(self) -> bool {
+ match self {
+ Self::Enabled => true,
+ Self::NotAnswered | Self::Disabled => false,
+ }
+ }
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
- "},
- );
+ pub fn is_answered(self) -> bool {
+ match self {
+ Self::Enabled | Self::Disabled => true,
+ Self::NotAnswered => false,
+ }
}
- #[gpui::test]
- async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let buffer_content = "lorem\n";
- let completion_response = indoc! {"
- ```animals.js
- <|start_of_file|>
- <|editable_region_start|>
- lorem
- ipsum
- <|editable_region_end|>
- ```"};
+ #[must_use]
+ pub fn toggle(&self) -> DataCollectionChoice {
+ match self {
+ Self::Enabled => Self::Disabled,
+ Self::Disabled => Self::Enabled,
+ Self::NotAnswered => Self::Enabled,
+ }
+ }
+}
- assert_eq!(
- apply_edit_prediction(buffer_content, completion_response, cx).await,
- "lorem\nipsum"
- );
+impl From<bool> for DataCollectionChoice {
+ fn from(value: bool) -> Self {
+ match value {
+ true => DataCollectionChoice::Enabled,
+ false => DataCollectionChoice::Disabled,
+ }
}
+}
- #[gpui::test]
- async fn test_can_collect_data(cx: &mut TestAppContext) {
- init_test(cx);
+struct ZedPredictUpsell;
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
- .await;
+impl Dismissable for ZedPredictUpsell {
+ const KEY: &'static str = "dismissed-edit-predict-upsell";
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/project/src/main.rs"), cx)
- })
- .await
- .unwrap();
+ fn dismissed() -> bool {
+ // To make this backwards compatible with older versions of Zed, we
+ // check if the user has seen the previous Edit Prediction Onboarding
+ // before, by checking the data collection choice which was written to
+ // the database once the user clicked on "Accept and Enable"
+ if KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ {
+ return true;
+ }
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
+ KEY_VALUE_STORE
+ .read_kvp(Self::KEY)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ }
+}
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
+pub fn should_show_upsell_modal() -> bool {
+ !ZedPredictUpsell::dismissed()
+}
+
+pub fn init(cx: &mut App) {
+ feature_gate_predict_edits_actions(cx);
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Disabled
+ cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
+ if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
+ RatePredictionsModal::toggle(workspace, window, cx);
+ }
});
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
+ workspace.register_action(
+ move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
+ ZedPredictModal::toggle(
+ workspace,
+ workspace.user_store().clone(),
+ workspace.client().clone(),
+ window,
+ cx,
+ )
+ },
);
- }
- #[gpui::test]
- async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
-
- let buffer = cx.new(|_cx| {
- Buffer::remote(
- language::BufferId::new(1).unwrap(),
- ReplicaId::new(1),
- language::Capability::ReadWrite,
- "fn main() {\n println!(\"Hello\");\n}",
- )
+ workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
+ update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
+ settings
+ .project
+ .all_languages
+ .features
+ .get_or_insert_default()
+ .edit_prediction_provider = Some(EditPredictionProvider::None)
+ });
});
+ })
+ .detach();
+}
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
+fn feature_gate_predict_edits_actions(cx: &mut App) {
+ let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
+ let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
+ let zeta_all_action_types = [
+ TypeId::of::<RateCompletions>(),
+ TypeId::of::<ResetOnboarding>(),
+ zed_actions::OpenZedPredictOnboarding.type_id(),
+ TypeId::of::<ClearHistory>(),
+ TypeId::of::<ThumbsUpActivePrediction>(),
+ TypeId::of::<ThumbsDownActivePrediction>(),
+ TypeId::of::<NextEdit>(),
+ TypeId::of::<PreviousEdit>(),
+ ];
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ filter.hide_action_types(&reset_onboarding_action_types);
+ filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
+ });
+
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
+ let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ if is_ai_disabled {
+ filter.hide_action_types(&zeta_all_action_types);
+ } else if has_feature_flag {
+ filter.show_action_types(&rate_completion_action_types);
+ } else {
+ filter.hide_action_types(&rate_completion_action_types);
+ }
});
+ })
+ .detach();
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
- }
+ cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
+ if !DisableAiSettings::get_global(cx).disable_ai {
+ if is_enabled {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.show_action_types(&rate_completion_action_types);
+ });
+ } else {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ });
+ }
+ }
+ })
+ .detach();
+}
- #[gpui::test]
- async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
- init_test(cx);
+#[cfg(test)]
+mod tests {
+ use std::{path::Path, sync::Arc};
+
+ use client::UserStore;
+ use clock::FakeSystemClock;
+ use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
+ use futures::{
+ AsyncReadExt, StreamExt,
+ channel::{mpsc, oneshot},
+ };
+ use gpui::{
+ Entity, TestAppContext,
+ http_client::{FakeHttpClient, Response},
+ prelude::*,
+ };
+ use indoc::indoc;
+ use language::OffsetRangeExt as _;
+ use open_ai::Usage;
+ use pretty_assertions::{assert_eq, assert_matches};
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+ use uuid::Uuid;
- let fs = project::FakeFs::new(cx.executor());
+ use crate::{BufferEditPrediction, Zeta};
+
+ #[gpui::test]
+ async fn test_current_state(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- path!("/project"),
+ "/root",
json!({
- "LICENSE": BSD_0_TXT,
- ".env": "SECRET_KEY=secret"
+ "1.txt": "Hello!\nHow\nBye\n",
+ "2.txt": "Hola!\nComo\nAdios\n"
}),
)
.await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_project(&project, cx);
+ });
+
+ let buffer1 = project
.update(cx, |project, cx| {
- project.open_local_buffer("/project/.env", cx)
+ let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
})
.await
.unwrap();
+ let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot1.anchor_before(language::Point::new(1, 3));
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
- }
+ // Prediction for current file
- #[gpui::test]
- async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
- init_test(cx);
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+
+ respond_tx
+ .send(model_response(indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
- let buffer = cx.new(|cx| Buffer::local("", cx));
+ cx.run_until_parked();
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
- }
+ // Context refresh
+ let refresh_task = zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+ respond_tx
+ .send(open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![open_ai::ToolCall {
+ id: "search".into(),
+ content: open_ai::ToolCallContent::Function {
+ function: open_ai::FunctionContent {
+ name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
+ .to_string(),
+ arguments: serde_json::to_string(&SearchToolInput {
+ queries: Box::new([SearchToolQuery {
+ glob: "root/2.txt".to_string(),
+ syntax_node: vec![],
+ content: Some(".".into()),
+ }]),
+ })
+ .unwrap(),
+ },
+ },
+ }],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ })
+ .unwrap();
+ refresh_task.await.unwrap();
- #[gpui::test]
- async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
- init_test(cx);
+ zeta.update(cx, |zeta, cx| {
+ zeta.discard_current_prediction(&project, cx);
+ });
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
- .await;
+ // Prediction for another file
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#}))
+ .unwrap();
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(
+ prediction,
+ BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
+ );
+ });
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
+ let buffer2 = project
.update(cx, |project, cx| {
- project.open_local_buffer("/project/main.rs", cx)
+ let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
})
.await
.unwrap();
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer2, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
}
#[gpui::test]
- async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
+ async fn test_simple_request(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- path!("/open_source_worktree"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
)
.await;
- fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
- .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
- let project = Project::test(
- fs.clone(),
- [
- path!("/open_source_worktree").as_ref(),
- path!("/closed_source_worktree").as_ref(),
- ],
- cx,
- )
- .await;
let buffer = project
.update(cx, |project, cx| {
- project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
})
.await
.unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, position, cx)
});
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- let closed_source_file = project
- .update(cx, |project, cx| {
- let worktree2 = project
- .worktree_for_root_name("closed_source_worktree", cx)
- .unwrap();
- worktree2.update(cx, |worktree2, cx| {
- worktree2.load_file(rel_path("main.rs"), cx)
- })
- })
- .await
- .unwrap()
- .file;
+ let (_, respond_tx) = req_rx.next().await.unwrap();
+
+ // TODO Put back when we have a structured request again
+ // assert_eq!(
+ // request.excerpt_path.as_ref(),
+ // Path::new(path!("root/foo.md"))
+ // );
+ // assert_eq!(
+ // request.cursor_point,
+ // Point {
+ // line: Line(1),
+ // column: 3
+ // }
+ // );
+
+ respond_tx
+ .send(model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
- buffer.update(cx, |buffer, cx| {
- buffer.file_updated(closed_source_file, cx);
- });
+ let prediction = prediction_task.await.unwrap().unwrap();
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(prediction.edits.len(), 1);
assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
);
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
#[gpui::test]
- async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
+ async fn test_request_events(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- path!("/worktree1"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
+ "/root",
+ json!({
+ "foo.md": "Hello!\n\nBye\n"
+ }),
)
.await;
- fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
- .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
- let project = Project::test(
- fs.clone(),
- [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
- cx,
- )
- .await;
let buffer = project
.update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree1/main.rs"), cx)
- })
- .await
- .unwrap();
- let private_buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree2/file.rs"), cx)
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
})
.await
.unwrap();
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(&buffer, &project, cx);
});
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(7..7, "How")], None, cx);
+ });
- // this has a side effect of registering the buffer to watch for edits
- run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
- private_buffer.update(cx, |private_buffer, cx| {
- private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, position, cx)
});
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
+ let (request, respond_tx) = req_rx.next().await.unwrap();
+
+ let prompt = prompt_from_request(&request);
+ assert!(
+ prompt.contains(indoc! {"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ -1,3 +1,3 @@
+ Hello!
+ -
+ +How
+ Bye
+ "}),
+ "{prompt}"
);
- // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
- // included
- buffer.update(cx, |buffer, cx| {
- buffer.edit(
- [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
- None,
- cx,
- );
- });
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#}))
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap();
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(prediction.edits.len(), 1);
assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
);
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
+ }
+
+ // Skipped until we start including diagnostics in prompt
+ // #[gpui::test]
+ // async fn test_request_diagnostics(cx: &mut TestAppContext) {
+ // let (zeta, mut req_rx) = init_test(cx);
+ // let fs = FakeFs::new(cx.executor());
+ // fs.insert_tree(
+ // "/root",
+ // json!({
+ // "foo.md": "Hello!\nBye"
+ // }),
+ // )
+ // .await;
+ // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+ // let diagnostic = lsp::Diagnostic {
+ // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ // severity: Some(lsp::DiagnosticSeverity::ERROR),
+ // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+ // ..Default::default()
+ // };
+
+ // project.update(cx, |project, cx| {
+ // project.lsp_store().update(cx, |lsp_store, cx| {
+ // // Create some diagnostics
+ // lsp_store
+ // .update_diagnostics(
+ // LanguageServerId(0),
+ // lsp::PublishDiagnosticsParams {
+ // uri: path_to_buffer_uri.clone(),
+ // diagnostics: vec![diagnostic],
+ // version: None,
+ // },
+ // None,
+ // language::DiagnosticSourceKind::Pushed,
+ // &[],
+ // cx,
+ // )
+ // .unwrap();
+ // });
+ // });
+
+ // let buffer = project
+ // .update(cx, |project, cx| {
+ // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ // project.open_buffer(path, cx)
+ // })
+ // .await
+ // .unwrap();
+
+ // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ // let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+ // let _prediction_task = zeta.update(cx, |zeta, cx| {
+ // zeta.request_prediction(&project, &buffer, position, cx)
+ // });
+
+ // let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+ // assert_eq!(request.diagnostic_groups.len(), 1);
+ // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+ // .unwrap();
+ // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+ // assert_eq!(
+ // value,
+ // json!({
+ // "entries": [{
+ // "range": {
+ // "start": 8,
+ // "end": 10
+ // },
+ // "diagnostic": {
+ // "source": null,
+ // "code": null,
+ // "code_description": null,
+ // "severity": 1,
+ // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+ // "markdown": null,
+ // "group_id": 0,
+ // "is_primary": true,
+ // "is_disk_based": false,
+ // "is_unnecessary": false,
+ // "source_kind": "Pushed",
+ // "data": null,
+ // "underline": true
+ // }
+ // }],
+ // "primary_ix": 0
+ // })
+ // );
+ // }
+
+ fn model_response(text: &str) -> open_ai::Response {
+ open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ tool_calls: vec![],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ }
}
- fn init_test(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
- }
-
- async fn apply_edit_prediction(
- buffer_content: &str,
- completion_response: &str,
- cx: &mut TestAppContext,
- ) -> String {
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
- let (zeta, _, response) = make_test_zeta(&project, cx).await;
- *response.lock() = completion_response.to_string();
- let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
- });
- buffer.read_with(cx, |buffer, _| buffer.text())
- }
-
- async fn run_edit_prediction(
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- zeta: &Entity<Zeta>,
- cx: &mut TestAppContext,
- ) -> EditPrediction {
- let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
- zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
- cx.background_executor.run_until_parked();
- let completion_task = zeta.update(cx, |zeta, cx| {
- zeta.request_completion(&project, buffer, cursor, cx)
- });
- completion_task.await.unwrap().unwrap()
+ fn prompt_from_request(request: &open_ai::Request) -> &str {
+ assert_eq!(request.messages.len(), 1);
+ let open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ ..
+ } = &request.messages[0]
+ else {
+ panic!(
+ "Request does not have single user message of type Plain. {:#?}",
+ request
+ );
+ };
+ content
}
- async fn make_test_zeta(
- project: &Entity<Project>,
+ fn init_test(
cx: &mut TestAppContext,
) -> (
Entity<Zeta>,
- Arc<Mutex<Option<PredictEditsBody>>>,
- Arc<Mutex<String>>,
+ mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
) {
- let default_response = indoc! {"
- ```main.rs
- <|start_of_file|>
- <|editable_region_start|>
- hello world
- <|editable_region_end|>
- ```"
- };
- let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
- let completion_response: Arc<Mutex<String>> =
- Arc::new(Mutex::new(default_response.to_string()));
- let http_client = FakeHttpClient::create({
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- move |req| {
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- async move {
- match (req.method(), req.uri().path()) {
- (&Method::POST, "/client/llm_tokens") => {
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&CreateLlmTokenResponse {
- token: LlmToken("the-llm-token".to_string()),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- }
- (&Method::POST, "/predict_edits/v2") => {
- let mut request_body = String::new();
- req.into_body().read_to_string(&mut request_body).await?;
- *captured_request.lock() =
- Some(serde_json::from_str(&request_body).unwrap());
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: Uuid::new_v4().to_string(),
- output_excerpt: completion_response.lock().clone(),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- }
- _ => Ok(http_client::Response::builder()
- .status(404)
- .body("Not Found".into())
- .unwrap()),
+ cx.update(move |cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ zlog::init_test();
+
+ let (req_tx, req_rx) = mpsc::unbounded();
+
+ let http_client = FakeHttpClient::create({
+ move |req| {
+ let uri = req.uri().path().to_string();
+ let mut body = req.into_body();
+ let req_tx = req_tx.clone();
+ async move {
+ let resp = match uri.as_str() {
+ "/client/llm_tokens" => serde_json::to_string(&json!({
+ "token": "test"
+ }))
+ .unwrap(),
+ "/predict_edits/raw" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await?).unwrap()
+ }
+ _ => {
+ panic!("Unexpected path: {}", uri)
+ }
+ };
+
+ Ok(Response::builder().body(resp.into()).unwrap())
}
}
- }
- });
-
- let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
- cx.update(|cx| {
- RefreshLlmTokenListener::register(client.clone(), cx);
- });
- let _server = FakeServer::for_client(42, &client, cx).await;
-
- let zeta = cx.new(|cx| {
- let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
-
- let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
- for worktree in worktrees {
- let worktree_id = worktree.read(cx).id();
- zeta.license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
- }
-
- zeta
- });
+ });
- (zeta, captured_request, completion_response)
- }
+ let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+ client.cloud_client().set_credentials(1, "test".into());
- fn to_completion_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
- buffer: &Entity<Buffer>,
- cx: &App,
- ) -> Vec<(Range<Anchor>, Arc<str>)> {
- let buffer = buffer.read(cx);
- iterator
- .into_iter()
- .map(|(range, text)| {
- (
- buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
- text,
- )
- })
- .collect()
- }
+ language_model::init(client.clone(), cx);
- fn from_completion_edits(
- editor_edits: &[(Range<Anchor>, Arc<str>)],
- buffer: &Entity<Buffer>,
- cx: &App,
- ) -> Vec<(Range<usize>, Arc<str>)> {
- let buffer = buffer.read(cx);
- editor_edits
- .iter()
- .map(|(range, text)| {
- (
- range.start.to_offset(buffer)..range.end.to_offset(buffer),
- text.clone(),
- )
- })
- .collect()
- }
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let zeta = Zeta::global(&client, &user_store, cx);
- #[ctor::ctor]
- fn init_logger() {
- zlog::init_test();
+ (zeta, req_rx)
+ })
}
}
@@ -0,0 +1,500 @@
+mod input_excerpt;
+
+use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
+
+use crate::{
+ EditPredictionId, ZedUpdateRequiredError, Zeta,
+ prediction::{EditPrediction, EditPredictionInputs},
+};
+use anyhow::{Context as _, Result};
+use cloud_llm_client::{
+ PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
+};
+use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
+use input_excerpt::excerpt_for_cursor_position;
+use language::{
+ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
+};
+use project::{Project, ProjectPath};
+use release_channel::AppVersion;
+use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+
+const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
+const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
+const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
+const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
+
+pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
+pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
+pub(crate) const MAX_EVENT_TOKENS: usize = 500;
+
+pub(crate) fn request_prediction_with_zeta1(
+ zeta: &mut Zeta,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Zeta>,
+) -> Task<Result<Option<EditPrediction>>> {
+ let buffer = buffer.clone();
+ let buffer_snapshotted_at = Instant::now();
+ let snapshot = buffer.read(cx).snapshot();
+ let client = zeta.client.clone();
+ let llm_token = zeta.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+
+ let zeta_project = zeta.get_or_init_zeta_project(project, cx);
+ let events = Arc::new(zeta_project.events(cx));
+
+ let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
+ let can_collect_file = zeta.can_collect_file(project, file, cx);
+ let git_info = if can_collect_file {
+ git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ } else {
+ None
+ };
+ (git_info, can_collect_file)
+ } else {
+ (None, false)
+ };
+
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|f| Arc::from(f.full_path(cx).as_path()))
+ .unwrap_or_else(|| Arc::from(Path::new("untitled")));
+ let full_path_str = full_path.to_string_lossy().into_owned();
+ let cursor_point = position.to_point(&snapshot);
+ let prompt_for_events = {
+ let events = events.clone();
+ move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
+ };
+ let gather_task = gather_context(
+ full_path_str,
+ &snapshot,
+ cursor_point,
+ prompt_for_events,
+ cx,
+ );
+
+ cx.spawn(async move |this, cx| {
+ let GatherContextOutput {
+ mut body,
+ context_range,
+ editable_range,
+ included_events_count,
+ } = gather_task.await?;
+ let done_gathering_context_at = Instant::now();
+
+ let included_events = &events[events.len() - included_events_count..events.len()];
+ body.can_collect_data = can_collect_file
+ && this
+ .read_with(cx, |this, _| this.can_collect_events(included_events))
+ .unwrap_or(false);
+ if body.can_collect_data {
+ body.git_info = git_info;
+ }
+
+ log::debug!(
+ "Events:\n{}\nExcerpt:\n{:?}",
+ body.input_events,
+ body.input_excerpt
+ );
+
+ let http_client = client.http_client();
+
+ let response = Zeta::send_api_request::<PredictEditsResponse>(
+ |request| {
+ let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
+ predict_edits_url
+ } else {
+ http_client
+ .build_zed_llm_url("/predict_edits/v2", &[])?
+ .as_str()
+ .into()
+ };
+ Ok(request
+ .uri(uri)
+ .body(serde_json::to_string(&body)?.into())?)
+ },
+ client,
+ llm_token,
+ app_version,
+ )
+ .await;
+
+ let inputs = EditPredictionInputs {
+ events: included_events.into(),
+ included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ path: full_path.clone(),
+ max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+ excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+ start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
+ text: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ }],
+ }],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ column: cursor_point.column,
+ line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+ },
+ cursor_path: full_path,
+ };
+
+ // let response = perform_predict_edits(PerformPredictEditsParams {
+ // client,
+ // llm_token,
+ // app_version,
+ // body,
+ // })
+ // .await;
+
+ let (response, usage) = match response {
+ Ok(response) => response,
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |zeta, _cx| {
+ zeta.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ })
+ .ok();
+ }
+
+ return Err(err);
+ }
+ };
+
+ let received_response_at = Instant::now();
+ log::debug!("completion response: {}", &response.output_excerpt);
+
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+
+ let edit_prediction = process_completion_response(
+ response,
+ buffer,
+ &snapshot,
+ editable_range,
+ inputs,
+ buffer_snapshotted_at,
+ received_response_at,
+ cx,
+ )
+ .await;
+
+ let finished_at = Instant::now();
+
+ // record latency for ~1% of requests
+ if rand::random::<u8>() <= 2 {
+ telemetry::event!(
+ "Edit Prediction Request",
+ context_latency = done_gathering_context_at
+ .duration_since(buffer_snapshotted_at)
+ .as_millis(),
+ request_latency = received_response_at
+ .duration_since(done_gathering_context_at)
+ .as_millis(),
+ process_latency = finished_at.duration_since(received_response_at).as_millis()
+ );
+ }
+
+ edit_prediction
+ })
+}
+
+fn process_completion_response(
+ prediction_response: PredictEditsResponse,
+ buffer: Entity<Buffer>,
+ snapshot: &BufferSnapshot,
+ editable_range: Range<usize>,
+ inputs: EditPredictionInputs,
+ buffer_snapshotted_at: Instant,
+ received_response_at: Instant,
+ cx: &AsyncApp,
+) -> Task<Result<Option<EditPrediction>>> {
+ let snapshot = snapshot.clone();
+ let request_id = prediction_response.request_id;
+ let output_excerpt = prediction_response.output_excerpt;
+ cx.spawn(async move |cx| {
+ let output_excerpt: Arc<str> = output_excerpt.into();
+
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
+ .background_spawn({
+ let output_excerpt = output_excerpt.clone();
+ let editable_range = editable_range.clone();
+ let snapshot = snapshot.clone();
+ async move { parse_edits(output_excerpt, editable_range, &snapshot) }
+ })
+ .await?
+ .into();
+
+ Ok(EditPrediction::new(
+ EditPredictionId(request_id.into()),
+ &buffer,
+ &snapshot,
+ edits,
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await)
+ })
+}
+
+fn parse_edits(
+ output_excerpt: Arc<str>,
+ editable_range: Range<usize>,
+ snapshot: &BufferSnapshot,
+) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
+ let content = output_excerpt.replace(CURSOR_MARKER, "");
+
+ let start_markers = content
+ .match_indices(EDITABLE_REGION_START_MARKER)
+ .collect::<Vec<_>>();
+ anyhow::ensure!(
+ start_markers.len() == 1,
+ "expected exactly one start marker, found {}",
+ start_markers.len()
+ );
+
+ let end_markers = content
+ .match_indices(EDITABLE_REGION_END_MARKER)
+ .collect::<Vec<_>>();
+ anyhow::ensure!(
+ end_markers.len() == 1,
+ "expected exactly one end marker, found {}",
+ end_markers.len()
+ );
+
+ let sof_markers = content
+ .match_indices(START_OF_FILE_MARKER)
+ .collect::<Vec<_>>();
+ anyhow::ensure!(
+ sof_markers.len() <= 1,
+ "expected at most one start-of-file marker, found {}",
+ sof_markers.len()
+ );
+
+ let codefence_start = start_markers[0].0;
+ let content = &content[codefence_start..];
+
+ let newline_ix = content.find('\n').context("could not find newline")?;
+ let content = &content[newline_ix + 1..];
+
+ let codefence_end = content
+ .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
+ .context("could not find end marker")?;
+ let new_text = &content[..codefence_end];
+
+ let old_text = snapshot
+ .text_for_range(editable_range.clone())
+ .collect::<String>();
+
+ Ok(compute_edits(
+ old_text,
+ new_text,
+ editable_range.start,
+ snapshot,
+ ))
+}
+
+pub fn compute_edits(
+ old_text: String,
+ new_text: &str,
+ offset: usize,
+ snapshot: &BufferSnapshot,
+) -> Vec<(Range<Anchor>, Arc<str>)> {
+ text_diff(&old_text, new_text)
+ .into_iter()
+ .map(|(mut old_range, new_text)| {
+ old_range.start += offset;
+ old_range.end += offset;
+
+ let prefix_len = common_prefix(
+ snapshot.chars_for_range(old_range.clone()),
+ new_text.chars(),
+ );
+ old_range.start += prefix_len;
+
+ let suffix_len = common_prefix(
+ snapshot.reversed_chars_for_range(old_range.clone()),
+ new_text[prefix_len..].chars().rev(),
+ );
+ old_range.end = old_range.end.saturating_sub(suffix_len);
+
+ let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
+ let range = if old_range.is_empty() {
+ let anchor = snapshot.anchor_after(old_range.start);
+ anchor..anchor
+ } else {
+ snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
+ };
+ (range, new_text)
+ })
+ .collect()
+}
+
+fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
+ a.zip(b)
+ .take_while(|(a, b)| a == b)
+ .map(|(a, _)| a.len_utf8())
+ .sum()
+}
+
+fn git_info_for_file(
+ project: &Entity<Project>,
+ project_path: &ProjectPath,
+ cx: &App,
+) -> Option<PredictEditsGitInfo> {
+ let git_store = project.read(cx).git_store().read(cx);
+ if let Some((repository, _repo_path)) =
+ git_store.repository_and_path_for_project_path(project_path, cx)
+ {
+ let repository = repository.read(cx);
+ let head_sha = repository
+ .head_commit
+ .as_ref()
+ .map(|head_commit| head_commit.sha.to_string());
+ let remote_origin_url = repository.remote_origin_url.clone();
+ let remote_upstream_url = repository.remote_upstream_url.clone();
+ if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
+ return None;
+ }
+ Some(PredictEditsGitInfo {
+ head_sha,
+ remote_origin_url,
+ remote_upstream_url,
+ })
+ } else {
+ None
+ }
+}
+
+pub struct GatherContextOutput {
+ pub body: PredictEditsBody,
+ pub context_range: Range<Point>,
+ pub editable_range: Range<usize>,
+ pub included_events_count: usize,
+}
+
+pub fn gather_context(
+ full_path_str: String,
+ snapshot: &BufferSnapshot,
+ cursor_point: language::Point,
+ prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
+ cx: &App,
+) -> Task<Result<GatherContextOutput>> {
+ cx.background_spawn({
+ let snapshot = snapshot.clone();
+ async move {
+ let input_excerpt = excerpt_for_cursor_position(
+ cursor_point,
+ &full_path_str,
+ &snapshot,
+ MAX_REWRITE_TOKENS,
+ MAX_CONTEXT_TOKENS,
+ );
+ let (input_events, included_events_count) = prompt_for_events();
+ let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+
+ let body = PredictEditsBody {
+ input_events,
+ input_excerpt: input_excerpt.prompt,
+ can_collect_data: false,
+ diagnostic_groups: None,
+ git_info: None,
+ outline: None,
+ speculated_output: None,
+ };
+
+ Ok(GatherContextOutput {
+ body,
+ context_range: input_excerpt.context_range,
+ editable_range,
+ included_events_count,
+ })
+ }
+ })
+}
+
+fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
+ let mut result = String::new();
+ for (ix, event) in events.iter().rev().enumerate() {
+ let event_string = format_event(event.as_ref());
+ let event_tokens = guess_token_count(event_string.len());
+ if event_tokens > remaining_tokens {
+ return (result, ix);
+ }
+
+ if !result.is_empty() {
+ result.insert_str(0, "\n\n");
+ }
+ result.insert_str(0, &event_string);
+ remaining_tokens -= event_tokens;
+ }
+ return (result, events.len());
+}
+
+pub fn format_event(event: &Event) -> String {
+ match event {
+ Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ ..
+ } => {
+ let mut prompt = String::new();
+
+ if old_path != path {
+ writeln!(
+ prompt,
+ "User renamed {} to {}\n",
+ old_path.display(),
+ path.display()
+ )
+ .unwrap();
+ }
+
+ if !diff.is_empty() {
+ write!(
+ prompt,
+ "User edited {}:\n```diff\n{}\n```",
+ path.display(),
+ diff
+ )
+ .unwrap();
+ }
+
+ prompt
+ }
+ }
+}
+
+/// Typical number of string bytes per token for the purposes of limiting model input. This is
+/// intentionally low to err on the side of underestimating limits.
+pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
+
+fn guess_token_count(bytes: usize) -> usize {
+ bytes / BYTES_PER_TOKEN_GUESS
+}
@@ -1,4 +1,4 @@
-use crate::{
+use super::{
CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
guess_token_count,
};
@@ -7,6 +7,7 @@ use std::{fmt::Write, ops::Range};
#[derive(Debug)]
pub struct InputExcerpt {
+ pub context_range: Range<Point>,
pub editable_range: Range<Point>,
pub prompt: String,
}
@@ -63,6 +64,7 @@ pub fn excerpt_for_cursor_position(
write!(prompt, "\n```").unwrap();
InputExcerpt {
+ context_range,
editable_range,
prompt,
}
@@ -124,7 +126,7 @@ mod tests {
use super::*;
use gpui::{App, AppContext};
use indoc::indoc;
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
+ use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use std::sync::Arc;
#[gpui::test]
@@ -0,0 +1,671 @@
+use client::test::FakeServer;
+use clock::{FakeSystemClock, ReplicaId};
+use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
+use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
+use gpui::TestAppContext;
+use http_client::FakeHttpClient;
+use indoc::indoc;
+use language::Point;
+use parking_lot::Mutex;
+use serde_json::json;
+use settings::SettingsStore;
+use util::{path, rel_path::rel_path};
+
+use crate::zeta1::MAX_EVENT_TOKENS;
+
+use super::*;
+
+const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
+
+#[gpui::test]
+async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
+ to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
+ });
+
+ let edit_preview = cx
+ .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
+ .await;
+
+ let completion = EditPrediction {
+ edits,
+ edit_preview,
+ buffer: buffer.clone(),
+ snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ id: EditPredictionId("the-id".into()),
+ inputs: EditPredictionInputs {
+ events: Default::default(),
+ included_files: Default::default(),
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ line: Line(0),
+ column: 0,
+ },
+ cursor_path: Path::new("").into(),
+ },
+ buffer_snapshotted_at: Instant::now(),
+ response_received_at: Instant::now(),
+ };
+
+ cx.update(|cx| {
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..2, "REM".into()), (6..8, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(3..3, "EM".into()), (7..9, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
+ assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ })
+}
+
+#[gpui::test]
+async fn test_clean_up_diff(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word.len()..word.len();
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+ "},
+ );
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let story = \"the quick\"
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+ "},
+ );
+}
+
+#[gpui::test]
+async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let buffer_content = "lorem\n";
+ let completion_response = indoc! {"
+ ```animals.js
+ <|start_of_file|>
+ <|editable_region_start|>
+ lorem
+ ipsum
+ <|editable_region_end|>
+ ```"};
+
+ assert_eq!(
+ apply_edit_prediction(buffer_content, completion_response, cx).await,
+ "lorem\nipsum"
+ );
+}
+
+#[gpui::test]
+async fn test_can_collect_data(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Disabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+
+ let buffer = cx.new(|_cx| {
+ Buffer::remote(
+ language::BufferId::new(1).unwrap(),
+ ReplicaId::new(1),
+ language::Capability::ReadWrite,
+ "fn main() {\n println!(\"Hello\");\n}",
+ )
+ });
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "LICENSE": BSD_0_TXT,
+ ".env": "SECRET_KEY=secret"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/.env", cx)
+ })
+ .await
+ .unwrap();
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+ let buffer = cx.new(|cx| Buffer::local("", cx));
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/main.rs", cx)
+ })
+ .await
+ .unwrap();
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/open_source_worktree"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [
+ path!("/open_source_worktree").as_ref(),
+ path!("/closed_source_worktree").as_ref(),
+ ],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ let closed_source_file = project
+ .update(cx, |project, cx| {
+ let worktree2 = project
+ .worktree_for_root_name("closed_source_worktree", cx)
+ .unwrap();
+ worktree2.update(cx, |worktree2, cx| {
+ worktree2.load_file(rel_path("main.rs"), cx)
+ })
+ })
+ .await
+ .unwrap()
+ .file;
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.file_updated(closed_source_file, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/worktree1"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree1/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let private_buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree2/file.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ // this has a side effect of registering the buffer to watch for edits
+ run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ private_buffer.update(cx, |private_buffer, cx| {
+ private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
+ // included
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(
+ [(
+ 0..0,
+ " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
+ )],
+ None,
+ cx,
+ );
+ });
+
+ run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+}
+
+fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+}
+
+async fn apply_edit_prediction(
+ buffer_content: &str,
+ completion_response: &str,
+ cx: &mut TestAppContext,
+) -> String {
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
+ let (zeta, _, response) = make_test_zeta(&project, cx).await;
+ *response.lock() = completion_response.to_string();
+ let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
+ });
+ buffer.read_with(cx, |buffer, _| buffer.text())
+}
+
+async fn run_edit_prediction(
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ zeta: &Entity<Zeta>,
+ cx: &mut TestAppContext,
+) -> EditPrediction {
+ let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
+ zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
+ cx.background_executor.run_until_parked();
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, buffer, cursor, cx)
+ });
+ prediction_task.await.unwrap().unwrap()
+}
+
+async fn make_test_zeta(
+ project: &Entity<Project>,
+ cx: &mut TestAppContext,
+) -> (
+ Entity<Zeta>,
+ Arc<Mutex<Option<PredictEditsBody>>>,
+ Arc<Mutex<String>>,
+) {
+ let default_response = indoc! {"
+ ```main.rs
+ <|start_of_file|>
+ <|editable_region_start|>
+ hello world
+ <|editable_region_end|>
+ ```"
+ };
+ let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
+ let completion_response: Arc<Mutex<String>> =
+ Arc::new(Mutex::new(default_response.to_string()));
+ let http_client = FakeHttpClient::create({
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ let mut next_request_id = 0;
+ move |req| {
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ async move {
+ match (req.method(), req.uri().path()) {
+ (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&CreateLlmTokenResponse {
+ token: LlmToken("the-llm-token".to_string()),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/predict_edits/v2") => {
+ let mut request_body = String::new();
+ req.into_body().read_to_string(&mut request_body).await?;
+ *captured_request.lock() =
+ Some(serde_json::from_str(&request_body).unwrap());
+ next_request_id += 1;
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsResponse {
+ request_id: format!("request-{next_request_id}"),
+ output_excerpt: completion_response.lock().clone(),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap())
+ }
+ _ => Ok(http_client::Response::builder()
+ .status(404)
+ .body("Not Found".into())
+ .unwrap()),
+ }
+ }
+ }
+ });
+
+ let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ cx.update(|cx| {
+ RefreshLlmTokenListener::register(client.clone(), cx);
+ });
+ let _server = FakeServer::for_client(42, &client, cx).await;
+
+ let zeta = cx.new(|cx| {
+ let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
+ zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
+
+ let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
+ for worktree in worktrees {
+ let worktree_id = worktree.read(cx).id();
+ zeta.get_or_init_zeta_project(project, cx)
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
+ }
+
+ zeta
+ });
+
+ (zeta, captured_request, completion_response)
+}
+
+fn to_completion_edits(
+ iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<Anchor>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ iterator
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
+ text,
+ )
+ })
+ .collect()
+}
+
+fn from_completion_edits(
+ editor_edits: &[(Range<Anchor>, Arc<str>)],
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<usize>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ editor_edits
+ .iter()
+ .map(|(range, text)| {
+ (
+ range.start.to_offset(buffer)..range.end.to_offset(buffer),
+ text.clone(),
+ )
+ })
+ .collect()
+}
+
+#[ctor::ctor]
+fn init_logger() {
+ zlog::init_test();
+}
@@ -1,61 +0,0 @@
-[package]
-name = "zeta2"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/zeta2.rs"
-
-[features]
-eval-support = []
-
-[dependencies]
-anyhow.workspace = true
-arrayvec.workspace = true
-brotli.workspace = true
-chrono.workspace = true
-client.workspace = true
-cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
-collections.workspace = true
-edit_prediction.workspace = true
-edit_prediction_context.workspace = true
-feature_flags.workspace = true
-futures.workspace = true
-gpui.workspace = true
-indoc.workspace = true
-language.workspace = true
-language_model.workspace = true
-log.workspace = true
-lsp.workspace = true
-open_ai.workspace = true
-pretty_assertions.workspace = true
-project.workspace = true
-release_channel.workspace = true
-semver.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-smol.workspace = true
-strsim.workspace = true
-thiserror.workspace = true
-util.workspace = true
-uuid.workspace = true
-workspace.workspace = true
-worktree.workspace = true
-
-[dev-dependencies]
-clock = { workspace = true, features = ["test-support"] }
-cloud_llm_client = { workspace = true, features = ["test-support"] }
-gpui = { workspace = true, features = ["test-support"] }
-lsp.workspace = true
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-language_model = { workspace = true, features = ["test-support"] }
-project = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,2968 +0,0 @@
-use anyhow::{Context as _, Result, anyhow, bail};
-use arrayvec::ArrayVec;
-use chrono::TimeDelta;
-use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
-use cloud_llm_client::{
- AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- ZED_VERSION_HEADER_NAME,
-};
-use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
-use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
-use collections::HashMap;
-use edit_prediction_context::{
- DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
- EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
- SyntaxIndex, SyntaxIndexState,
-};
-use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
-use futures::AsyncReadExt as _;
-use futures::channel::{mpsc, oneshot};
-use gpui::http_client::{AsyncBody, Method};
-use gpui::{
- App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity,
- http_client, prelude::*,
-};
-use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
-use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use lsp::DiagnosticSeverity;
-use open_ai::FunctionDefinition;
-use project::{Project, ProjectPath};
-use release_channel::AppVersion;
-use semver::Version;
-use serde::de::DeserializeOwned;
-use std::collections::{VecDeque, hash_map};
-
-use std::fmt::Write;
-use std::ops::Range;
-use std::path::Path;
-use std::str::FromStr;
-use std::sync::{Arc, LazyLock};
-use std::time::{Duration, Instant};
-use std::{env, mem};
-use thiserror::Error;
-use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
-use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
-
-pub mod assemble_excerpts;
-mod prediction;
-mod provider;
-pub mod retrieval_search;
-mod sweep_ai;
-pub mod udiff;
-mod xml_edits;
-
-use crate::assemble_excerpts::assemble_excerpts;
-pub use crate::prediction::EditPrediction;
-pub use crate::prediction::EditPredictionId;
-pub use provider::ZetaEditPredictionProvider;
-
-/// Maximum number of events to track.
-const EVENT_COUNT_MAX_SWEEP: usize = 6;
-const EVENT_COUNT_MAX_ZETA: usize = 16;
-const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
-
-pub struct SweepFeatureFlag;
-
-impl FeatureFlag for SweepFeatureFlag {
- const NAME: &str = "sweep-ai";
-}
-pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
- max_bytes: 512,
- min_bytes: 128,
- target_before_cursor_over_total_bytes: 0.5,
-};
-
-pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
- ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
-
-pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
- excerpt: DEFAULT_EXCERPT_OPTIONS,
-};
-
-pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
- EditPredictionContextOptions {
- use_imports: true,
- max_retrieved_declarations: 0,
- excerpt: DEFAULT_EXCERPT_OPTIONS,
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
- };
-
-pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
- context: DEFAULT_CONTEXT_OPTIONS,
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
- max_diagnostic_bytes: 2048,
- prompt_format: PromptFormat::DEFAULT,
- file_indexing_parallelism: 1,
- buffer_change_grouping_interval: Duration::from_secs(1),
-};
-
-static USE_OLLAMA: LazyLock<bool> =
- LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
-static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
- "qwen3-coder:30b".to_string()
- } else {
- "yqvev8r3".to_string()
- })
-});
-static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- match env::var("ZED_ZETA2_MODEL").as_deref() {
- Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
- Ok(model) => model,
- Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
- Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
- }
- .to_string()
-});
-static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
- env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
- if *USE_OLLAMA {
- Some("http://localhost:11434/v1/chat/completions".into())
- } else {
- None
- }
- })
-});
-
-pub struct Zeta2FeatureFlag;
-
-impl FeatureFlag for Zeta2FeatureFlag {
- const NAME: &'static str = "zeta2";
-
- fn enabled_for_staff() -> bool {
- false
- }
-}
-
-#[derive(Clone)]
-struct ZetaGlobal(Entity<Zeta>);
-
-impl Global for ZetaGlobal {}
-
-pub struct Zeta {
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
- projects: HashMap<EntityId, ZetaProject>,
- options: ZetaOptions,
- update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
- #[cfg(feature = "eval-support")]
- eval_cache: Option<Arc<dyn EvalCache>>,
- edit_prediction_model: ZetaEditPredictionModel,
- sweep_api_token: Option<String>,
- sweep_ai_debug_info: Arc<str>,
-}
-
-#[derive(PartialEq, Eq)]
-pub enum ZetaEditPredictionModel {
- ZedCloud,
- Sweep,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct ZetaOptions {
- pub context: ContextMode,
- pub max_prompt_bytes: usize,
- pub max_diagnostic_bytes: usize,
- pub prompt_format: predict_edits_v3::PromptFormat,
- pub file_indexing_parallelism: usize,
- pub buffer_change_grouping_interval: Duration,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub enum ContextMode {
- Agentic(AgenticContextOptions),
- Syntax(EditPredictionContextOptions),
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct AgenticContextOptions {
- pub excerpt: EditPredictionExcerptOptions,
-}
-
-impl ContextMode {
- pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
- match self {
- ContextMode::Agentic(options) => &options.excerpt,
- ContextMode::Syntax(options) => &options.excerpt,
- }
- }
-}
-
-#[derive(Debug)]
-pub enum ZetaDebugInfo {
- ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
- SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
- SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
- ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
- EditPredictionRequested(ZetaEditPredictionDebugInfo),
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalStartedDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
- pub search_prompt: String,
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
-}
-
-#[derive(Debug)]
-pub struct ZetaEditPredictionDebugInfo {
- pub request: predict_edits_v3::PredictEditsRequest,
- pub retrieval_time: TimeDelta,
- pub buffer: WeakEntity<Buffer>,
- pub position: language::Anchor,
- pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
-}
-
-#[derive(Debug)]
-pub struct ZetaSearchQueryDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
- pub search_queries: Vec<SearchToolQuery>,
-}
-
-pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
-
-struct ZetaProject {
- syntax_index: Option<Entity<SyntaxIndex>>,
- events: VecDeque<Event>,
- recent_paths: VecDeque<ProjectPath>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
- current_prediction: Option<CurrentEditPrediction>,
- next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
- last_prediction_refresh: Option<(EntityId, Instant)>,
- context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
- refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
- refresh_context_debounce_task: Option<Task<Option<()>>>,
- refresh_context_timestamp: Option<Instant>,
- _subscription: gpui::Subscription,
-}
-
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
- pub requested_by: PredictionRequestedBy,
- pub prediction: EditPrediction,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
- let Some(new_edits) = self
- .prediction
- .interpolate(&self.prediction.buffer.read(cx))
- else {
- return false;
- };
-
- if self.prediction.buffer != old_prediction.prediction.buffer {
- return true;
- }
-
- let Some(old_edits) = old_prediction
- .prediction
- .interpolate(&old_prediction.prediction.buffer.read(cx))
- else {
- return true;
- };
-
- let requested_by_buffer_id = self.requested_by.buffer_id();
-
- // This reduces the occurrence of UI thrash from replacing edits
- //
- // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
- if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
- && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
- && old_edits.len() == 1
- && new_edits.len() == 1
- {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text.as_ref())
- } else {
- true
- }
- }
-}
-
-#[derive(Debug, Clone)]
-enum PredictionRequestedBy {
- DiagnosticsUpdate,
- Buffer(EntityId),
-}
-
-impl PredictionRequestedBy {
- pub fn buffer_id(&self) -> Option<EntityId> {
- match self {
- PredictionRequestedBy::DiagnosticsUpdate => None,
- PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
- }
- }
-}
-
-struct PendingPrediction {
- id: usize,
- _task: Task<()>,
-}
-
-/// A prediction from the perspective of a buffer.
-#[derive(Debug)]
-enum BufferEditPrediction<'a> {
- Local { prediction: &'a EditPrediction },
- Jump { prediction: &'a EditPrediction },
-}
-
-struct RegisteredBuffer {
- snapshot: BufferSnapshot,
- _subscriptions: [gpui::Subscription; 2],
-}
-
-#[derive(Clone)]
-pub enum Event {
- BufferChange {
- old_snapshot: BufferSnapshot,
- new_snapshot: BufferSnapshot,
- end_edit_anchor: Option<Anchor>,
- timestamp: Instant,
- },
-}
-
-impl Event {
- pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
- match self {
- Event::BufferChange {
- old_snapshot,
- new_snapshot,
- ..
- } => {
- let path = new_snapshot.file().map(|f| f.full_path(cx));
-
- let old_path = old_snapshot.file().and_then(|f| {
- let old_path = f.full_path(cx);
- if Some(&old_path) != path.as_ref() {
- Some(old_path)
- } else {
- None
- }
- });
-
- // TODO [zeta2] move to bg?
- let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
-
- if path == old_path && diff.is_empty() {
- None
- } else {
- Some(predict_edits_v3::Event::BufferChange {
- old_path,
- path,
- diff,
- //todo: Actually detect if this edit was predicted or not
- predicted: false,
- })
- }
- }
- }
- }
-
- pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
- match self {
- Event::BufferChange { new_snapshot, .. } => new_snapshot
- .file()
- .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
- }
- }
-}
-
-impl Zeta {
- pub fn try_global(cx: &App) -> Option<Entity<Self>> {
- cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
- }
-
- pub fn global(
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- cx: &mut App,
- ) -> Entity<Self> {
- cx.try_global::<ZetaGlobal>()
- .map(|global| global.0.clone())
- .unwrap_or_else(|| {
- let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
- cx.set_global(ZetaGlobal(zeta.clone()));
- zeta
- })
- }
-
- pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
-
- Self {
- projects: HashMap::default(),
- client,
- user_store,
- options: DEFAULT_OPTIONS,
- llm_token: LlmApiToken::default(),
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_token = this.llm_token.clone();
- cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
- update_required: false,
- debug_tx: None,
- #[cfg(feature = "eval-support")]
- eval_cache: None,
- edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
- sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
- .context("No SWEEP_AI_TOKEN environment variable set")
- .log_err(),
- sweep_ai_debug_info: sweep_ai::debug_info(cx),
- }
- }
-
- pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
- self.edit_prediction_model = model;
- }
-
- pub fn has_sweep_api_token(&self) -> bool {
- self.sweep_api_token.is_some()
- }
-
- #[cfg(feature = "eval-support")]
- pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
- self.eval_cache = Some(cache);
- }
-
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
- pub fn options(&self) -> &ZetaOptions {
- &self.options
- }
-
- pub fn set_options(&mut self, options: ZetaOptions) {
- self.options = options;
- }
-
- pub fn clear_history(&mut self) {
- for zeta_project in self.projects.values_mut() {
- zeta_project.events.clear();
- }
- }
-
- pub fn history_for_project(
- &self,
- project: &Entity<Project>,
- ) -> impl DoubleEndedIterator<Item = &Event> {
- self.projects
- .get(&project.entity_id())
- .map(|project| project.events.iter())
- .into_iter()
- .flatten()
- }
-
- pub fn context_for_project(
- &self,
- project: &Entity<Project>,
- ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
- self.projects
- .get(&project.entity_id())
- .and_then(|project| {
- Some(
- project
- .context
- .as_ref()?
- .iter()
- .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
- )
- })
- .into_iter()
- .flatten()
- }
-
- pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
- self.user_store.read(cx).edit_prediction_usage()
- } else {
- None
- }
- }
-
- pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- self.get_or_init_zeta_project(project, cx);
- }
-
- pub fn register_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- Self::register_buffer_impl(zeta_project, buffer, project, cx);
- }
-
- fn get_or_init_zeta_project(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &mut ZetaProject {
- self.projects
- .entry(project.entity_id())
- .or_insert_with(|| ZetaProject {
- syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
- Some(cx.new(|cx| {
- SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
- }))
- } else {
- None
- },
- events: VecDeque::new(),
- recent_paths: VecDeque::new(),
- registered_buffers: HashMap::default(),
- current_prediction: None,
- pending_predictions: ArrayVec::new(),
- next_pending_prediction_id: 0,
- last_prediction_refresh: None,
- context: None,
- refresh_context_task: None,
- refresh_context_debounce_task: None,
- refresh_context_timestamp: None,
- _subscription: cx.subscribe(&project, Self::handle_project_event),
- })
- }
-
- fn handle_project_event(
- &mut self,
- project: Entity<Project>,
- event: &project::Event,
- cx: &mut Context<Self>,
- ) {
- // TODO [zeta2] init with recent paths
- match event {
- project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
- let path = project.read(cx).path_for_entry(*active_entry_id, cx);
- if let Some(path) = path {
- if let Some(ix) = zeta_project
- .recent_paths
- .iter()
- .position(|probe| probe == &path)
- {
- zeta_project.recent_paths.remove(ix);
- }
- zeta_project.recent_paths.push_front(path);
- }
- }
- project::Event::DiagnosticsUpdated { .. } => {
- self.refresh_prediction_from_diagnostics(project, cx);
- }
- _ => (),
- }
- }
-
- fn register_buffer_impl<'a>(
- zeta_project: &'a mut ZetaProject,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &'a mut RegisteredBuffer {
- let buffer_id = buffer.entity_id();
- match zeta_project.registered_buffers.entry(buffer_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- let snapshot = buffer.read(cx).snapshot();
- let project_entity_id = project.entity_id();
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, {
- let project = project.downgrade();
- move |this, buffer, event, cx| {
- if let language::BufferEvent::Edited = event
- && let Some(project) = project.upgrade()
- {
- this.report_changes_for_buffer(&buffer, &project, cx);
- }
- }
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- zeta_project.registered_buffers.remove(&buffer_id);
- }),
- ],
- })
- }
- }
- }
-
- fn report_changes_for_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let event_count_max = match self.edit_prediction_model {
- ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
- ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
- };
-
- let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
-
- let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version == registered_buffer.snapshot.version {
- return;
- }
-
- let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- let end_edit_anchor = new_snapshot
- .anchored_edits_since::<Point>(&old_snapshot.version)
- .last()
- .map(|(_, range)| range.end);
- let events = &mut sweep_ai_project.events;
-
- if let Some(Event::BufferChange {
- new_snapshot: last_new_snapshot,
- end_edit_anchor: last_end_edit_anchor,
- ..
- }) = events.back_mut()
- {
- let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
- == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version;
-
- let should_coalesce = is_next_snapshot_of_same_buffer
- && end_edit_anchor
- .as_ref()
- .zip(last_end_edit_anchor.as_ref())
- .is_some_and(|(a, b)| {
- let a = a.to_point(&new_snapshot);
- let b = b.to_point(&new_snapshot);
- a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
- });
-
- if should_coalesce {
- *last_end_edit_anchor = end_edit_anchor;
- *last_new_snapshot = new_snapshot;
- return;
- }
- }
-
- if events.len() >= event_count_max {
- events.pop_front();
- }
-
- events.push_back(Event::BufferChange {
- old_snapshot,
- new_snapshot,
- end_edit_anchor,
- timestamp: Instant::now(),
- });
- }
-
- fn current_prediction_for_buffer(
- &self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &App,
- ) -> Option<BufferEditPrediction<'_>> {
- let project_state = self.projects.get(&project.entity_id())?;
-
- let CurrentEditPrediction {
- requested_by,
- prediction,
- } = project_state.current_prediction.as_ref()?;
-
- if prediction.targets_buffer(buffer.read(cx)) {
- Some(BufferEditPrediction::Local { prediction })
- } else {
- let show_jump = match requested_by {
- PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
- requested_by_buffer_id == &buffer.entity_id()
- }
- PredictionRequestedBy::DiagnosticsUpdate => true,
- };
-
- if show_jump {
- Some(BufferEditPrediction::Jump { prediction })
- } else {
- None
- }
- }
- }
-
- fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
- return;
- }
-
- let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let Some(prediction) = project_state.current_prediction.take() else {
- return;
- };
- let request_id = prediction.prediction.id.to_string();
- project_state.pending_predictions.clear();
-
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- cx.spawn(async move |this, cx| {
- let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/accept", &[])?
- };
-
- let response = cx
- .background_spawn(Self::send_api_request::<()>(
- move |builder| {
- let req = builder.uri(url.as_ref()).body(
- serde_json::to_string(&AcceptEditPredictionBody {
- request_id: request_id.clone(),
- })?
- .into(),
- );
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- ))
- .await;
-
- Self::handle_api_response(&this, response, cx)?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
-
- fn discard_current_prediction(&mut self, project: &Entity<Project>) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- project_state.current_prediction.take();
- project_state.pending_predictions.clear();
- };
- }
-
- fn is_refreshing(&self, project: &Entity<Project>) -> bool {
- self.projects
- .get(&project.entity_id())
- .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
- }
-
- pub fn refresh_prediction_from_buffer(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
- let Some(request_task) = this
- .update(cx, |this, cx| {
- this.request_prediction(&project, &buffer, position, cx)
- })
- .log_err()
- else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let project = project.clone();
- cx.spawn(async move |cx| {
- if let Some(prediction) = request_task.await? {
- this.update(cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project.entity_id())
- .context("Project not found")?;
-
- let new_prediction = CurrentEditPrediction {
- requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
- prediction: prediction,
- };
-
- if project_state
- .current_prediction
- .as_ref()
- .is_none_or(|old_prediction| {
- new_prediction.should_replace_prediction(&old_prediction, cx)
- })
- {
- project_state.current_prediction = Some(new_prediction);
- cx.notify();
- }
- anyhow::Ok(())
- })??;
- }
- Ok(())
- })
- })
- }
-
- pub fn refresh_prediction_from_diagnostics(
- &mut self,
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- // Prefer predictions from buffer
- if zeta_project.current_prediction.is_some() {
- return;
- };
-
- self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
- let Some(open_buffer_task) = project
- .update(cx, |project, cx| {
- project
- .active_entry()
- .and_then(|entry| project.path_for_entry(entry, cx))
- .map(|path| project.open_buffer(path, cx))
- })
- .log_err()
- .flatten()
- else {
- return Task::ready(anyhow::Ok(()));
- };
-
- cx.spawn(async move |cx| {
- let active_buffer = open_buffer_task.await?;
- let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
- &snapshot,
- Default::default(),
- Default::default(),
- &project,
- cx,
- )
- .await?
- else {
- return anyhow::Ok(());
- };
-
- let Some(prediction) = this
- .update(cx, |this, cx| {
- this.request_prediction(&project, &jump_buffer, jump_position, cx)
- })?
- .await?
- else {
- return anyhow::Ok(());
- };
-
- this.update(cx, |this, cx| {
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
- zeta_project.current_prediction.get_or_insert_with(|| {
- cx.notify();
- CurrentEditPrediction {
- requested_by: PredictionRequestedBy::DiagnosticsUpdate,
- prediction,
- }
- });
- }
- })?;
-
- anyhow::Ok(())
- })
- });
- }
-
- #[cfg(not(test))]
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
- #[cfg(test)]
- pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
-
- fn queue_prediction_refresh(
- &mut self,
- project: Entity<Project>,
- throttle_entity: EntityId,
- cx: &mut Context<Self>,
- do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
- ) {
- let zeta_project = self.get_or_init_zeta_project(&project, cx);
- let pending_prediction_id = zeta_project.next_pending_prediction_id;
- zeta_project.next_pending_prediction_id += 1;
- let last_request = zeta_project.last_prediction_refresh;
-
- // TODO report cancelled requests like in zeta1
- let task = cx.spawn(async move |this, cx| {
- if let Some((last_entity, last_timestamp)) = last_request
- && throttle_entity == last_entity
- && let Some(timeout) =
- (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- do_refresh(this.clone(), cx).await.log_err();
-
- this.update(cx, |this, cx| {
- let zeta_project = this.get_or_init_zeta_project(&project, cx);
-
- if zeta_project.pending_predictions[0].id == pending_prediction_id {
- zeta_project.pending_predictions.remove(0);
- } else {
- zeta_project.pending_predictions.clear();
- }
-
- cx.notify();
- })
- .ok();
- });
-
- if zeta_project.pending_predictions.len() <= 1 {
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- } else if zeta_project.pending_predictions.len() == 2 {
- zeta_project.pending_predictions.pop();
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- }
- }
-
- pub fn request_prediction(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::ZedCloud => {
- self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
- }
- ZetaEditPredictionModel::Sweep => {
- self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
- }
- }
- }
-
- fn request_prediction_with_sweep(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- allow_jump: bool,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- let snapshot = active_buffer.read(cx).snapshot();
- let debug_info = self.sweep_ai_debug_info.clone();
- let Some(api_token) = self.sweep_api_token.clone() else {
- return Task::ready(Ok(None));
- };
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|file| file.full_path(cx))
- .unwrap_or_else(|| "untitled".into())
- .into();
-
- let project_file = project::File::from_dyn(snapshot.file());
- let repo_name = project_file
- .map(|file| file.worktree.read(cx).root_name_str())
- .unwrap_or("untitled")
- .into();
- let offset = position.to_offset(&snapshot);
-
- let project_state = self.get_or_init_zeta_project(project, cx);
- let events = project_state.events.clone();
- let has_events = !events.is_empty();
- let recent_buffers = project_state.recent_paths.iter().cloned();
- let http_client = cx.http_client();
-
- let recent_buffer_snapshots = recent_buffers
- .filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
- None
- } else {
- Some(buffer.read(cx).snapshot())
- }
- })
- .take(3)
- .collect::<Vec<_>>();
-
- const DIAGNOSTIC_LINES_RANGE: u32 = 20;
-
- let cursor_point = position.to_point(&snapshot);
- let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
- let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
- let diagnostic_search_range =
- Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
-
- let result = cx.background_spawn({
- let snapshot = snapshot.clone();
- let diagnostic_search_range = diagnostic_search_range.clone();
- async move {
- let text = snapshot.text();
-
- let mut recent_changes = String::new();
- for event in events {
- sweep_ai::write_event(event, &mut recent_changes).unwrap();
- }
-
- let mut file_chunks = recent_buffer_snapshots
- .into_iter()
- .map(|snapshot| {
- let end_point = Point::new(30, 0).min(snapshot.max_point());
- sweep_ai::FileChunk {
- content: snapshot.text_for_range(Point::zero()..end_point).collect(),
- file_path: snapshot
- .file()
- .map(|f| f.path().as_unix_str())
- .unwrap_or("untitled")
- .to_string(),
- start_line: 0,
- end_line: end_point.row as usize,
- timestamp: snapshot.file().and_then(|file| {
- Some(
- file.disk_state()
- .mtime()?
- .to_seconds_and_nanos_for_persistence()?
- .0,
- )
- }),
- }
- })
- .collect::<Vec<_>>();
-
- let diagnostic_entries =
- snapshot.diagnostics_in_range(diagnostic_search_range, false);
- let mut diagnostic_content = String::new();
- let mut diagnostic_count = 0;
-
- for entry in diagnostic_entries {
- let start_point: Point = entry.range.start;
-
- let severity = match entry.diagnostic.severity {
- DiagnosticSeverity::ERROR => "error",
- DiagnosticSeverity::WARNING => "warning",
- DiagnosticSeverity::INFORMATION => "info",
- DiagnosticSeverity::HINT => "hint",
- _ => continue,
- };
-
- diagnostic_count += 1;
-
- writeln!(
- &mut diagnostic_content,
- "{} at line {}: {}",
- severity,
- start_point.row + 1,
- entry.diagnostic.message
- )?;
- }
-
- if !diagnostic_content.is_empty() {
- file_chunks.push(sweep_ai::FileChunk {
- file_path: format!("Diagnostics for {}", full_path.display()),
- start_line: 0,
- end_line: diagnostic_count,
- content: diagnostic_content,
- timestamp: None,
- });
- }
-
- let request_body = sweep_ai::AutocompleteRequest {
- debug_info,
- repo_name,
- file_path: full_path.clone(),
- file_contents: text.clone(),
- original_file_contents: text,
- cursor_position: offset,
- recent_changes: recent_changes.clone(),
- changes_above_cursor: true,
- multiple_suggestions: false,
- branch: None,
- file_chunks,
- retrieval_chunks: vec![],
- recent_user_actions: vec![],
- // TODO
- privacy_mode_enabled: false,
- };
-
- let mut buf: Vec<u8> = Vec::new();
- let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
- serde_json::to_writer(writer, &request_body)?;
- let body: AsyncBody = buf.into();
-
- const SWEEP_API_URL: &str =
- "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
-
- let request = http_client::Request::builder()
- .uri(SWEEP_API_URL)
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_token))
- .header("Connection", "keep-alive")
- .header("Content-Encoding", "br")
- .method(Method::POST)
- .body(body)?;
-
- let mut response = http_client.send(request).await?;
-
- let mut body: Vec<u8> = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
-
- if !response.status().is_success() {
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- String::from_utf8_lossy(&body),
- );
- };
-
- let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
-
- let old_text = snapshot
- .text_for_range(response.start_index..response.end_index)
- .collect::<String>();
- let edits = language::text_diff(&old_text, &response.completion)
- .into_iter()
- .map(|(range, text)| {
- (
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
- text,
- )
- })
- .collect::<Vec<_>>();
-
- anyhow::Ok((response.autocomplete_id, edits, snapshot))
- }
- });
-
- let buffer = active_buffer.clone();
- let project = project.clone();
- let active_buffer = active_buffer.clone();
-
- cx.spawn(async move |this, cx| {
- let (id, edits, old_snapshot) = result.await?;
-
- if edits.is_empty() {
- if has_events
- && allow_jump
- && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
- &snapshot,
- diagnostic_search_range,
- cursor_point,
- &project,
- cx,
- )
- .await?
- {
- return this
- .update(cx, |this, cx| {
- this.request_prediction_with_sweep(
- &project,
- &jump_buffer,
- jump_position,
- false,
- cx,
- )
- })?
- .await;
- }
-
- return anyhow::Ok(None);
- }
-
- let Some((edits, new_snapshot, preview_task)) =
- buffer.read_with(cx, |buffer, cx| {
- let new_snapshot = buffer.snapshot();
-
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
- edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
- .into();
- let preview_task = buffer.preview_edits(edits.clone(), cx);
-
- Some((edits, new_snapshot, preview_task))
- })?
- else {
- return anyhow::Ok(None);
- };
-
- let prediction = EditPrediction {
- id: EditPredictionId(id.into()),
- edits,
- snapshot: new_snapshot,
- edit_preview: preview_task.await,
- buffer,
- };
-
- anyhow::Ok(Some(prediction))
- })
- }
-
- async fn next_diagnostic_location(
- active_buffer: Entity<Buffer>,
- active_buffer_snapshot: &BufferSnapshot,
- active_buffer_diagnostic_search_range: Range<Point>,
- active_buffer_cursor_point: Point,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
- // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
- let mut jump_location = active_buffer_snapshot
- .diagnostic_groups(None)
- .into_iter()
- .filter_map(|(_, group)| {
- let range = &group.entries[group.primary_ix]
- .range
- .to_point(&active_buffer_snapshot);
- if range.overlaps(&active_buffer_diagnostic_search_range) {
- None
- } else {
- Some(range.start)
- }
- })
- .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
- .map(|position| {
- (
- active_buffer.clone(),
- active_buffer_snapshot.anchor_before(position),
- )
- });
-
- if jump_location.is_none() {
- let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
- let file = buffer.file()?;
-
- Some(ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- })
- })?;
-
- let buffer_task = project.update(cx, |project, cx| {
- let (path, _, _) = project
- .diagnostic_summaries(false, cx)
- .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
- .max_by_key(|(path, _, _)| {
- // find the buffer with errors that shares most parent directories
- path.path
- .components()
- .zip(
- active_buffer_path
- .as_ref()
- .map(|p| p.path.components())
- .unwrap_or_default(),
- )
- .take_while(|(a, b)| a == b)
- .count()
- })?;
-
- Some(project.open_buffer(path, cx))
- })?;
-
- if let Some(buffer_task) = buffer_task {
- let closest_buffer = buffer_task.await?;
-
- jump_location = closest_buffer
- .read_with(cx, |buffer, _cx| {
- buffer
- .buffer_diagnostics(None)
- .into_iter()
- .min_by_key(|entry| entry.diagnostic.severity)
- .map(|entry| entry.range.start)
- })?
- .map(|position| (closest_buffer, position));
- }
- }
-
- anyhow::Ok(jump_location)
- }
-
- fn request_prediction_with_zed_cloud(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- let project_state = self.projects.get(&project.entity_id());
-
- let index_state = project_state.and_then(|state| {
- state
- .syntax_index
- .as_ref()
- .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
- });
- let options = self.options.clone();
- let active_snapshot = active_buffer.read(cx).snapshot();
- let Some(excerpt_path) = active_snapshot
- .file()
- .map(|path| -> Arc<Path> { path.full_path(cx).into() })
- else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- let worktree_snapshots = project
- .read(cx)
- .worktrees(cx)
- .map(|worktree| worktree.read(cx).snapshot())
- .collect::<Vec<_>>();
- let debug_tx = self.debug_tx.clone();
-
- let events = project_state
- .map(|state| {
- state
- .events
- .iter()
- .filter_map(|event| event.to_request_event(cx))
- .collect::<Vec<_>>()
- })
- .unwrap_or_default();
-
- let diagnostics = active_snapshot.diagnostic_sets().clone();
-
- let parent_abs_path =
- project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
- });
-
- // TODO data collection
- let can_collect_data = cx.is_staff();
-
- let empty_context_files = HashMap::default();
- let context_files = project_state
- .and_then(|project_state| project_state.context.as_ref())
- .unwrap_or(&empty_context_files);
-
- #[cfg(feature = "eval-support")]
- let parsed_fut = futures::future::join_all(
- context_files
- .keys()
- .map(|buffer| buffer.read(cx).parsing_idle()),
- );
-
- let mut included_files = context_files
- .iter()
- .filter_map(|(buffer_entity, ranges)| {
- let buffer = buffer_entity.read(cx);
- Some((
- buffer_entity.clone(),
- buffer.snapshot(),
- buffer.file()?.full_path(cx).into(),
- ranges.clone(),
- ))
- })
- .collect::<Vec<_>>();
-
- included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
- (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
- });
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
- async move {
- #[cfg(feature = "eval-support")]
- parsed_fut.await;
-
- let index_state = if let Some(index_state) = index_state {
- Some(index_state.lock_owned().await)
- } else {
- None
- };
-
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = chrono::Utc::now();
-
- let (diagnostic_groups, diagnostic_groups_truncated) =
- Self::gather_nearby_diagnostics(
- cursor_offset,
- &diagnostics,
- &active_snapshot,
- options.max_diagnostic_bytes,
- );
-
- let cloud_request = match options.context {
- ContextMode::Agentic(context_options) => {
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &context_options.excerpt,
- index_state.as_deref(),
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
-
- if let Some(buffer_ix) =
- included_files.iter().position(|(_, snapshot, _, _)| {
- snapshot.remote_id() == active_snapshot.remote_id()
- })
- {
- let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
- ranges.push(excerpt_anchor_range);
- retrieval_search::merge_anchor_ranges(ranges, buffer);
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- included_files.push((
- active_buffer.clone(),
- active_snapshot.clone(),
- excerpt_path.clone(),
- vec![excerpt_anchor_range],
- ));
- }
-
- let included_files = included_files
- .iter()
- .map(|(_, snapshot, path, ranges)| {
- let ranges = ranges
- .iter()
- .map(|range| {
- let point_range = range.to_point(&snapshot);
- Line(point_range.start.row)..Line(point_range.end.row)
- })
- .collect::<Vec<_>>();
- let excerpts = assemble_excerpts(&snapshot, ranges);
- predict_edits_v3::IncludedFile {
- path: path.clone(),
- max_row: Line(snapshot.max_point().row),
- excerpts,
- }
- })
- .collect::<Vec<_>>();
-
- predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- included_files,
- referenced_declarations: vec![],
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- // TODO [zeta2]
- signatures: vec![],
- excerpt_parent: None,
- git_info: None,
- }
- }
- ContextMode::Syntax(context_options) => {
- let Some(context) = EditPredictionContext::gather_context(
- cursor_point,
- &active_snapshot,
- parent_abs_path.as_deref(),
- &context_options,
- index_state.as_deref(),
- ) else {
- return Ok((None, None));
- };
-
- make_syntax_context_cloud_request(
- excerpt_path,
- context,
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- None,
- debug_tx.is_some(),
- &worktree_snapshots,
- index_state.as_deref(),
- Some(options.max_prompt_bytes),
- options.prompt_format,
- )
- }
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let retrieval_time = chrono::Utc::now() - before_retrieval;
-
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
-
- debug_tx
- .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
- ZetaEditPredictionDebugInfo {
- request: cloud_request.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok((prompt, _)) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
- position,
- response_rx,
- },
- ))
- .ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), TimeDelta::zero()))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
- }
-
- let (prompt, _) = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
- let request = open_ai::Request {
- model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.unwrap_or(0.7),
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- log::trace!("Sending edit prediction request");
-
- let before_request = chrono::Utc::now();
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache,
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Prediction,
- )
- .await;
- let request_time = chrono::Utc::now() - before_request;
-
- log::trace!("Got edit prediction response");
-
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
- let (res, usage) = response?;
- let request_id = EditPredictionId(res.id.clone().into());
- let Some(mut output_text) = text_from_response(res) else {
- return Ok((None, usage));
- };
-
- if output_text.contains(CURSOR_MARKER) {
- log::trace!("Stripping out {CURSOR_MARKER} from response");
- output_text = output_text.replace(CURSOR_MARKER, "");
- }
-
- let get_buffer_from_context = |path: &Path| {
- included_files
- .iter()
- .find_map(|(_, buffer, probe_path, ranges)| {
- if probe_path.as_ref() == path {
- Some((buffer, ranges.as_slice()))
- } else {
- None
- }
- })
- };
-
- let (edited_buffer_snapshot, edits) = match options.prompt_format {
- PromptFormat::NumLinesUniDiff => {
- // TODO: Implement parsing of multi-file diffs
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- PromptFormat::Minimal
- | PromptFormat::MinimalQwen
- | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
- .await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
-
- let edited_buffer = included_files
- .iter()
- .find_map(|(buffer, snapshot, _, _)| {
- if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
- Some(buffer.clone())
- } else {
- None
- }
- })
- .context("Failed to find buffer in included_buffers")?;
-
- anyhow::Ok((
- Some((
- request_id,
- edited_buffer,
- edited_buffer_snapshot.clone(),
- edits,
- )),
- usage,
- ))
- }
- });
-
- cx.spawn({
- async move |this, cx| {
- let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
- Self::handle_api_response(&this, request_task.await, cx)?
- else {
- return Ok(None);
- };
-
- // TODO telemetry: duration, etc
- Ok(
- EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
- .await,
- )
- }
- })
- }
-
- async fn send_raw_llm_request(
- request: open_ai::Request,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
- #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
- ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
- let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/raw", &[])?
- };
-
- #[cfg(feature = "eval-support")]
- let cache_key = if let Some(cache) = eval_cache {
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- url.hash(&mut hasher);
- let request_str = serde_json::to_string_pretty(&request)?;
- request_str.hash(&mut hasher);
- let hash = hasher.finish();
-
- let key = (eval_cache_kind, hash);
- if let Some(response_str) = cache.read(key) {
- return Ok((serde_json::from_str(&response_str)?, None));
- }
-
- Some((cache, request_str, key))
- } else {
- None
- };
-
- let (response, usage) = Self::send_api_request(
- |builder| {
- let req = builder
- .uri(url.as_ref())
- .body(serde_json::to_string(&request)?.into());
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- )
- .await?;
-
- #[cfg(feature = "eval-support")]
- if let Some((cache, request, key)) = cache_key {
- cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
- }
-
- Ok((response, usage))
- }
-
- fn handle_api_response<T>(
- this: &WeakEntity<Self>,
- response: Result<(T, Option<EditPredictionUsage>)>,
- cx: &mut gpui::AsyncApp,
- ) -> Result<T> {
- match response {
- Ok((data, usage)) => {
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
- Ok(data)
- }
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |this, _cx| {
- this.update_required = true;
- })
- .ok();
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button("Update Zed", "https://zed.dev/releases")
- })
- },
- );
- })
- .ok();
- }
- Err(err)
- }
- }
- }
-
- async fn send_api_request<Res>(
- build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- ) -> Result<(Res, Option<EditPredictionUsage>)>
- where
- Res: DeserializeOwned,
- {
- let http_client = client.http_client();
- let mut token = llm_token.acquire(&client).await?;
- let mut did_retry = false;
-
- loop {
- let request_builder = http_client::Request::builder().method(Method::POST);
-
- let request = build(
- request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
- )?;
-
- let mut response = http_client.send(request).await?;
-
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- {
- anyhow::ensure!(
- app_version >= minimum_required_version,
- ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }
- );
- }
-
- if response.status().is_success() {
- let usage = EditPredictionUsage::from_headers(response.headers()).ok();
-
- let mut body = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
- return Ok((serde_json::from_slice(&body)?, usage));
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_token.refresh(&client).await?;
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- body
- );
- }
- }
- }
-
- pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
- pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
-
- // Refresh the related excerpts when the user just beguns editing after
- // an idle period, and after they pause editing.
- fn refresh_context_if_needed(
- &mut self,
- project: &Entity<Project>,
- buffer: &Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
- return;
- }
-
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let now = Instant::now();
- let was_idle = zeta_project
- .refresh_context_timestamp
- .map_or(true, |timestamp| {
- now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
- });
- zeta_project.refresh_context_timestamp = Some(now);
- zeta_project.refresh_context_debounce_task = Some(cx.spawn({
- let buffer = buffer.clone();
- let project = project.clone();
- async move |this, cx| {
- if was_idle {
- log::debug!("refetching edit prediction context after idle");
- } else {
- cx.background_executor()
- .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
- .await;
- log::debug!("refetching edit prediction context after pause");
- }
- this.update(cx, |this, cx| {
- let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
-
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
- zeta_project.refresh_context_task = Some(task.log_err());
- };
- })
- .ok()
- }
- }));
- }
-
- // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
- // and avoid spawning more than one concurrent task.
- pub fn refresh_context(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let ContextMode::Agentic(options) = &self.options().context else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let snapshot = buffer.read(cx).snapshot();
- let cursor_point = cursor_position.to_point(&snapshot);
- let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &snapshot,
- &options.excerpt,
- None,
- ) else {
- return Task::ready(Ok(()));
- };
-
- let app_version = AppVersion::global(cx);
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let debug_tx = self.debug_tx.clone();
- let current_file_path: Arc<Path> = snapshot
- .file()
- .map(|f| f.full_path(cx).into())
- .unwrap_or_else(|| Path::new("untitled").into());
-
- let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
- predict_edits_v3::PlanContextRetrievalRequest {
- excerpt: cursor_excerpt.text(&snapshot).body,
- excerpt_path: current_file_path,
- excerpt_line_range: cursor_excerpt.line_range,
- cursor_file_max_row: Line(snapshot.max_point().row),
- events: zeta_project
- .events
- .iter()
- .filter_map(|ev| ev.to_request_event(cx))
- .collect(),
- },
- ) {
- Ok(prompt) => prompt,
- Err(err) => {
- return Task::ready(Err(err));
- }
- };
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
- ZetaContextRetrievalStartedDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- search_prompt: prompt.clone(),
- },
- ))
- .ok();
- }
-
- pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
- let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
- language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
- );
-
- let description = schema
- .get("description")
- .and_then(|description| description.as_str())
- .unwrap()
- .to_string();
-
- (schema.into(), description)
- });
-
- let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
-
- let request = open_ai::Request {
- model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: Default::default(),
- temperature: 0.7,
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![open_ai::ToolDefinition::Function {
- function: FunctionDefinition {
- name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
- description: Some(tool_description),
- parameters: Some(tool_schema),
- },
- }],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- cx.spawn(async move |this, cx| {
- log::trace!("Sending search planning request");
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache.clone(),
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Context,
- )
- .await;
- let mut response = Self::handle_api_response(&this, response, cx)?;
- log::trace!("Got search planning response");
-
- let choice = response
- .choices
- .pop()
- .context("No choices in retrieval response")?;
- let open_ai::RequestMessage::Assistant {
- content: _,
- tool_calls,
- } = choice.message
- else {
- anyhow::bail!("Retrieval response didn't include an assistant message");
- };
-
- let mut queries: Vec<SearchToolQuery> = Vec::new();
- for tool_call in tool_calls {
- let open_ai::ToolCallContent::Function { function } = tool_call.content;
- if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
- log::warn!(
- "Context retrieval response tried to call an unknown tool: {}",
- function.name
- );
-
- continue;
- }
-
- let input: SearchToolInput = serde_json::from_str(&function.arguments)
- .with_context(|| format!("invalid search json {}", &function.arguments))?;
- queries.extend(input.queries);
- }
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
- ZetaSearchQueryDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- search_queries: queries.clone(),
- },
- ))
- .ok();
- }
-
- log::trace!("Running retrieval search: {queries:#?}");
-
- let related_excerpts_result = retrieval_search::run_retrieval_searches(
- queries,
- project.clone(),
- #[cfg(feature = "eval-support")]
- eval_cache,
- cx,
- )
- .await;
-
- log::trace!("Search queries executed");
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
- ZetaContextRetrievalDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
-
- this.update(cx, |this, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
- return Ok(());
- };
- zeta_project.refresh_context_task.take();
- if let Some(debug_tx) = &this.debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalDebugInfo {
- project,
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
- match related_excerpts_result {
- Ok(excerpts) => {
- zeta_project.context = Some(excerpts);
- Ok(())
- }
- Err(error) => Err(error),
- }
- })?
- })
- }
-
- pub fn set_context(
- &mut self,
- project: Entity<Project>,
- context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
- ) {
- if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
- zeta_project.context = Some(context);
- }
- }
-
- fn gather_nearby_diagnostics(
- cursor_offset: usize,
- diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
- snapshot: &BufferSnapshot,
- max_diagnostics_bytes: usize,
- ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
- // TODO: Could make this more efficient
- let mut diagnostic_groups = Vec::new();
- for (language_server_id, diagnostics) in diagnostic_sets {
- let mut groups = Vec::new();
- diagnostics.groups(*language_server_id, &mut groups, &snapshot);
- diagnostic_groups.extend(
- groups
- .into_iter()
- .map(|(_, group)| group.resolve::<usize>(&snapshot)),
- );
- }
-
- // sort by proximity to cursor
- diagnostic_groups.sort_by_key(|group| {
- let range = &group.entries[group.primary_ix].range;
- if range.start >= cursor_offset {
- range.start - cursor_offset
- } else if cursor_offset >= range.end {
- cursor_offset - range.end
- } else {
- (cursor_offset - range.start).min(range.end - cursor_offset)
- }
- });
-
- let mut results = Vec::new();
- let mut diagnostic_groups_truncated = false;
- let mut diagnostics_byte_count = 0;
- for group in diagnostic_groups {
- let raw_value = serde_json::value::to_raw_value(&group).unwrap();
- diagnostics_byte_count += raw_value.get().len();
- if diagnostics_byte_count > max_diagnostics_bytes {
- diagnostic_groups_truncated = true;
- break;
- }
- results.push(predict_edits_v3::DiagnosticGroup(raw_value));
- }
-
- (results, diagnostic_groups_truncated)
- }
-
- // TODO: Dedupe with similar code in request_prediction?
- pub fn cloud_request_for_zeta_cli(
- &mut self,
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
- let project_state = self.projects.get(&project.entity_id());
-
- let index_state = project_state.and_then(|state| {
- state
- .syntax_index
- .as_ref()
- .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
- });
- let options = self.options.clone();
- let snapshot = buffer.read(cx).snapshot();
- let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
- let worktree_snapshots = project
- .read(cx)
- .worktrees(cx)
- .map(|worktree| worktree.read(cx).snapshot())
- .collect::<Vec<_>>();
-
- let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
- });
-
- cx.background_spawn(async move {
- let index_state = if let Some(index_state) = index_state {
- Some(index_state.lock_owned().await)
- } else {
- None
- };
-
- let cursor_point = position.to_point(&snapshot);
-
- let debug_info = true;
- EditPredictionContext::gather_context(
- cursor_point,
- &snapshot,
- parent_abs_path.as_deref(),
- match &options.context {
- ContextMode::Agentic(_) => {
- // TODO
- panic!("Llm mode not supported in zeta cli yet");
- }
- ContextMode::Syntax(edit_prediction_context_options) => {
- edit_prediction_context_options
- }
- },
- index_state.as_deref(),
- )
- .context("Failed to select excerpt")
- .map(|context| {
- make_syntax_context_cloud_request(
- excerpt_path.into(),
- context,
- // TODO pass everything
- Vec::new(),
- false,
- Vec::new(),
- false,
- None,
- debug_info,
- &worktree_snapshots,
- index_state.as_deref(),
- Some(options.max_prompt_bytes),
- options.prompt_format,
- )
- })
- })
- }
-
- pub fn wait_for_initial_indexing(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- if let Some(syntax_index) = &zeta_project.syntax_index {
- syntax_index.read(cx).wait_for_initial_file_indexing(cx)
- } else {
- Task::ready(Ok(()))
- }
- }
-}
-
-pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
- let choice = res.choices.pop()?;
- let output_text = match choice.message {
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(content)),
- ..
- } => content,
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Multipart(mut content)),
- ..
- } => {
- if content.is_empty() {
- log::error!("No output from Baseten completion response");
- return None;
- }
-
- match content.remove(0) {
- open_ai::MessagePart::Text { text } => text,
- open_ai::MessagePart::Image { .. } => {
- log::error!("Expected text, got an image");
- return None;
- }
- }
- }
- _ => {
- log::error!("Invalid response message: {:?}", choice.message);
- return None;
- }
- };
- Some(output_text)
-}
-
-#[derive(Error, Debug)]
-#[error(
- "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
-)]
-pub struct ZedUpdateRequiredError {
- minimum_version: Version,
-}
-
-fn make_syntax_context_cloud_request(
- excerpt_path: Arc<Path>,
- context: EditPredictionContext,
- events: Vec<predict_edits_v3::Event>,
- can_collect_data: bool,
- diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
- diagnostic_groups_truncated: bool,
- git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
- debug_info: bool,
- worktrees: &Vec<worktree::Snapshot>,
- index_state: Option<&SyntaxIndexState>,
- prompt_max_bytes: Option<usize>,
- prompt_format: PromptFormat,
-) -> predict_edits_v3::PredictEditsRequest {
- let mut signatures = Vec::new();
- let mut declaration_to_signature_index = HashMap::default();
- let mut referenced_declarations = Vec::new();
-
- for snippet in context.declarations {
- let project_entry_id = snippet.declaration.project_entry_id();
- let Some(path) = worktrees.iter().find_map(|worktree| {
- worktree.entry_for_id(project_entry_id).map(|entry| {
- let mut full_path = RelPathBuf::new();
- full_path.push(worktree.root_name());
- full_path.push(&entry.path);
- full_path
- })
- }) else {
- continue;
- };
-
- let parent_index = index_state.and_then(|index_state| {
- snippet.declaration.parent().and_then(|parent| {
- add_signature(
- parent,
- &mut declaration_to_signature_index,
- &mut signatures,
- index_state,
- )
- })
- });
-
- let (text, text_is_truncated) = snippet.declaration.item_text();
- referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
- path: path.as_std_path().into(),
- text: text.into(),
- range: snippet.declaration.item_line_range(),
- text_is_truncated,
- signature_range: snippet.declaration.signature_range_in_item_text(),
- parent_index,
- signature_score: snippet.score(DeclarationStyle::Signature),
- declaration_score: snippet.score(DeclarationStyle::Declaration),
- score_components: snippet.components,
- });
- }
-
- let excerpt_parent = index_state.and_then(|index_state| {
- context
- .excerpt
- .parent_declarations
- .last()
- .and_then(|(parent, _)| {
- add_signature(
- *parent,
- &mut declaration_to_signature_index,
- &mut signatures,
- index_state,
- )
- })
- });
-
- predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: context.excerpt_text.body,
- excerpt_line_range: context.excerpt.line_range,
- excerpt_range: context.excerpt.range,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(context.cursor_point.row),
- column: context.cursor_point.column,
- },
- referenced_declarations,
- included_files: vec![],
- signatures,
- excerpt_parent,
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- git_info,
- debug_info,
- prompt_max_bytes,
- prompt_format,
- }
-}
-
-fn add_signature(
- declaration_id: DeclarationId,
- declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
- signatures: &mut Vec<Signature>,
- index: &SyntaxIndexState,
-) -> Option<usize> {
- if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
- return Some(*signature_index);
- }
- let Some(parent_declaration) = index.declaration(declaration_id) else {
- log::error!("bug: missing parent declaration");
- return None;
- };
- let parent_index = parent_declaration.parent().and_then(|parent| {
- add_signature(parent, declaration_to_signature_index, signatures, index)
- });
- let (text, text_is_truncated) = parent_declaration.signature_text();
- let signature_index = signatures.len();
- signatures.push(Signature {
- text: text.into(),
- text_is_truncated,
- parent_index,
- range: parent_declaration.signature_line_range(),
- });
- declaration_to_signature_index.insert(declaration_id, signature_index);
- Some(signature_index)
-}
-
-#[cfg(feature = "eval-support")]
-pub type EvalCacheKey = (EvalCacheEntryKind, u64);
-
-#[cfg(feature = "eval-support")]
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub enum EvalCacheEntryKind {
- Context,
- Search,
- Prediction,
-}
-
-#[cfg(feature = "eval-support")]
-impl std::fmt::Display for EvalCacheEntryKind {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- EvalCacheEntryKind::Search => write!(f, "search"),
- EvalCacheEntryKind::Context => write!(f, "context"),
- EvalCacheEntryKind::Prediction => write!(f, "prediction"),
- }
- }
-}
-
-#[cfg(feature = "eval-support")]
-pub trait EvalCache: Send + Sync {
- fn read(&self, key: EvalCacheKey) -> Option<String>;
- fn write(&self, key: EvalCacheKey, input: &str, value: &str);
-}
-
-#[cfg(test)]
-mod tests {
- use std::{path::Path, sync::Arc};
-
- use client::UserStore;
- use clock::FakeSystemClock;
- use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
- use futures::{
- AsyncReadExt, StreamExt,
- channel::{mpsc, oneshot},
- };
- use gpui::{
- Entity, TestAppContext,
- http_client::{FakeHttpClient, Response},
- prelude::*,
- };
- use indoc::indoc;
- use language::OffsetRangeExt as _;
- use open_ai::Usage;
- use pretty_assertions::{assert_eq, assert_matches};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
- use uuid::Uuid;
-
- use crate::{BufferEditPrediction, Zeta};
-
- #[gpui::test]
- async fn test_current_state(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "1.txt": "Hello!\nHow\nBye\n",
- "2.txt": "Hola!\nComo\nAdios\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(&project, cx);
- });
-
- let buffer1 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot1.anchor_before(language::Point::new(1, 3));
-
- // Prediction for current file
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = req_rx.next().await.unwrap();
-
- respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
-
- // Context refresh
- let refresh_task = zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = req_rx.next().await.unwrap();
- respond_tx
- .send(open_ai::Response {
- id: Uuid::new_v4().to_string(),
- object: "response".into(),
- created: 0,
- model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: None,
- tool_calls: vec![open_ai::ToolCall {
- id: "search".into(),
- content: open_ai::ToolCallContent::Function {
- function: open_ai::FunctionContent {
- name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
- .to_string(),
- arguments: serde_json::to_string(&SearchToolInput {
- queries: Box::new([SearchToolQuery {
- glob: "root/2.txt".to_string(),
- syntax_node: vec![],
- content: Some(".".into()),
- }]),
- })
- .unwrap(),
- },
- },
- }],
- },
- finish_reason: None,
- }],
- usage: Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- })
- .unwrap();
- refresh_task.await.unwrap();
-
- zeta.update(cx, |zeta, _cx| {
- zeta.discard_current_prediction(&project);
- });
-
- // Prediction for another file
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = req_rx.next().await.unwrap();
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
- .unwrap();
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(
- prediction,
- BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
- );
- });
-
- let buffer2 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer2, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
- }
-
- #[gpui::test]
- async fn test_simple_request(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, cx)
- });
-
- let (_, respond_tx) = req_rx.next().await.unwrap();
-
- // TODO Put back when we have a structured request again
- // assert_eq!(
- // request.excerpt_path.as_ref(),
- // Path::new(path!("root/foo.md"))
- // );
- // assert_eq!(
- // request.cursor_point,
- // Point {
- // line: Line(1),
- // column: 3
- // }
- // );
-
- respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- #[gpui::test]
- async fn test_request_events(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\n\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(&buffer, &project, cx);
- });
-
- buffer.update(cx, |buffer, cx| {
- buffer.edit(vec![(7..7, "How")], None, cx);
- });
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, cx)
- });
-
- let (request, respond_tx) = req_rx.next().await.unwrap();
-
- let prompt = prompt_from_request(&request);
- assert!(
- prompt.contains(indoc! {"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ -1,3 +1,3 @@
- Hello!
- -
- +How
- Bye
- "}),
- "{prompt}"
- );
-
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- // Skipped until we start including diagnostics in prompt
- // #[gpui::test]
- // async fn test_request_diagnostics(cx: &mut TestAppContext) {
- // let (zeta, mut req_rx) = init_test(cx);
- // let fs = FakeFs::new(cx.executor());
- // fs.insert_tree(
- // "/root",
- // json!({
- // "foo.md": "Hello!\nBye"
- // }),
- // )
- // .await;
- // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
- // let diagnostic = lsp::Diagnostic {
- // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
- // severity: Some(lsp::DiagnosticSeverity::ERROR),
- // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
- // ..Default::default()
- // };
-
- // project.update(cx, |project, cx| {
- // project.lsp_store().update(cx, |lsp_store, cx| {
- // // Create some diagnostics
- // lsp_store
- // .update_diagnostics(
- // LanguageServerId(0),
- // lsp::PublishDiagnosticsParams {
- // uri: path_to_buffer_uri.clone(),
- // diagnostics: vec![diagnostic],
- // version: None,
- // },
- // None,
- // language::DiagnosticSourceKind::Pushed,
- // &[],
- // cx,
- // )
- // .unwrap();
- // });
- // });
-
- // let buffer = project
- // .update(cx, |project, cx| {
- // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- // project.open_buffer(path, cx)
- // })
- // .await
- // .unwrap();
-
- // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- // let position = snapshot.anchor_before(language::Point::new(0, 0));
-
- // let _prediction_task = zeta.update(cx, |zeta, cx| {
- // zeta.request_prediction(&project, &buffer, position, cx)
- // });
-
- // let (request, _respond_tx) = req_rx.next().await.unwrap();
-
- // assert_eq!(request.diagnostic_groups.len(), 1);
- // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
- // .unwrap();
- // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
- // assert_eq!(
- // value,
- // json!({
- // "entries": [{
- // "range": {
- // "start": 8,
- // "end": 10
- // },
- // "diagnostic": {
- // "source": null,
- // "code": null,
- // "code_description": null,
- // "severity": 1,
- // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
- // "markdown": null,
- // "group_id": 0,
- // "is_primary": true,
- // "is_disk_based": false,
- // "is_unnecessary": false,
- // "source_kind": "Pushed",
- // "data": null,
- // "underline": true
- // }
- // }],
- // "primary_ix": 0
- // })
- // );
- // }
-
- fn model_response(text: &str) -> open_ai::Response {
- open_ai::Response {
- id: Uuid::new_v4().to_string(),
- object: "response".into(),
- created: 0,
- model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
- tool_calls: vec![],
- },
- finish_reason: None,
- }],
- usage: Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- }
- }
-
- fn prompt_from_request(request: &open_ai::Request) -> &str {
- assert_eq!(request.messages.len(), 1);
- let open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(content),
- ..
- } = &request.messages[0]
- else {
- panic!(
- "Request does not have single user message of type Plain. {:#?}",
- request
- );
- };
- content
- }
-
- fn init_test(
- cx: &mut TestAppContext,
- ) -> (
- Entity<Zeta>,
- mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
- ) {
- cx.update(move |cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- zlog::init_test();
-
- let (req_tx, req_rx) = mpsc::unbounded();
-
- let http_client = FakeHttpClient::create({
- move |req| {
- let uri = req.uri().path().to_string();
- let mut body = req.into_body();
- let req_tx = req_tx.clone();
- async move {
- let resp = match uri.as_str() {
- "/client/llm_tokens" => serde_json::to_string(&json!({
- "token": "test"
- }))
- .unwrap(),
- "/predict_edits/raw" => {
- let mut buf = Vec::new();
- body.read_to_end(&mut buf).await.ok();
- let req = serde_json::from_slice(&buf).unwrap();
-
- let (res_tx, res_rx) = oneshot::channel();
- req_tx.unbounded_send((req, res_tx)).unwrap();
- serde_json::to_string(&res_rx.await?).unwrap()
- }
- _ => {
- panic!("Unexpected path: {}", uri)
- }
- };
-
- Ok(Response::builder().body(resp.into()).unwrap())
- }
- }
- });
-
- let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
- client.cloud_client().set_credentials(1, "test".into());
-
- language_model::init(client.clone(), cx);
-
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = Zeta::global(&client, &user_store, cx);
-
- (zeta, req_rx)
- })
- }
-}
@@ -13,7 +13,6 @@ path = "src/zeta2_tools.rs"
[dependencies]
anyhow.workspace = true
-chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
@@ -24,9 +23,7 @@ feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
-log.workspace = true
multi_buffer.workspace = true
-ordered-float.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -36,7 +33,7 @@ ui.workspace = true
ui_input.workspace = true
util.workspace = true
workspace.workspace = true
-zeta2.workspace = true
+zeta.workspace = true
[dev-dependencies]
clap.workspace = true
@@ -25,7 +25,7 @@ use ui::{
v_flex,
};
use workspace::Item;
-use zeta2::{
+use zeta::{
Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
ZetaSearchQueryDebugInfo,
};
@@ -1,30 +1,26 @@
mod zeta2_context_view;
-use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc};
+use std::{str::FromStr, sync::Arc, time::Duration};
-use chrono::TimeDelta;
use client::{Client, UserStore};
-use cloud_llm_client::predict_edits_v3::{
- DeclarationScoreComponents, PredictEditsRequest, PromptFormat,
-};
+use cloud_llm_client::predict_edits_v3::PromptFormat;
use collections::HashMap;
-use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
+use editor::{Editor, EditorEvent, EditorMode, MultiBuffer};
use feature_flags::FeatureFlagAppExt as _;
use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
use gpui::{
- CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
- WeakEntity, actions, prelude::*,
+ Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
+ prelude::*,
};
-use language::{Buffer, DiskState};
-use ordered_float::OrderedFloat;
-use project::{Project, WorktreeId, telemetry_snapshot::TelemetrySnapshot};
+use language::Buffer;
+use project::{Project, telemetry_snapshot::TelemetrySnapshot};
use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*};
use ui_input::InputField;
-use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
+use util::ResultExt;
use workspace::{Item, SplitDirection, Workspace};
-use zeta2::{
- AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag,
- ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
+use zeta::{
+ AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta,
+ Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
};
use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions};
@@ -99,7 +95,6 @@ pub struct Zeta2Inspector {
cursor_context_ratio_input: Entity<InputField>,
max_prompt_bytes_input: Entity<InputField>,
context_mode: ContextModeState,
- active_view: ActiveView,
zeta: Entity<Zeta>,
_active_editor_subscription: Option<Subscription>,
_update_state_task: Task<()>,
@@ -113,21 +108,14 @@ pub enum ContextModeState {
},
}
-#[derive(PartialEq)]
-enum ActiveView {
- Context,
- Inference,
-}
-
struct LastPrediction {
- context_editor: Entity<Editor>,
prompt_editor: Entity<Editor>,
- retrieval_time: TimeDelta,
- request_time: Option<TimeDelta>,
+ retrieval_time: Duration,
+ request_time: Option<Duration>,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
state: LastPredictionState,
- request: PredictEditsRequest,
+ inputs: EditPredictionInputs,
project_snapshot: Shared<Task<Arc<TelemetrySnapshot>>>,
_task: Option<Task<()>>,
}
@@ -175,7 +163,6 @@ impl Zeta2Inspector {
focus_handle: cx.focus_handle(),
project: project.clone(),
last_prediction: None,
- active_view: ActiveView::Inference,
max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
@@ -305,7 +292,7 @@ impl Zeta2Inspector {
ContextMode::Syntax(context_options) => {
let max_retrieved_declarations = match &this.context_mode {
ContextModeState::Llm => {
- zeta2::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
+ zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
}
ContextModeState::Syntax {
max_retrieved_declarations,
@@ -340,22 +327,10 @@ impl Zeta2Inspector {
fn update_last_prediction(
&mut self,
- prediction: zeta2::ZetaDebugInfo,
+ prediction: zeta::ZetaDebugInfo,
window: &mut Window,
cx: &mut Context<Self>,
) {
- let project = self.project.read(cx);
- let path_style = project.path_style(cx);
- let Some(worktree_id) = project
- .worktrees(cx)
- .next()
- .map(|worktree| worktree.read(cx).id())
- else {
- log::error!("Open a worktree to use edit prediction debug view");
- self.last_prediction.take();
- return;
- };
-
self._update_state_task = cx.spawn_in(window, {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
@@ -364,11 +339,10 @@ impl Zeta2Inspector {
return;
};
for ext in prediction
- .request
- .referenced_declarations
+ .inputs
+ .included_files
.iter()
- .filter_map(|snippet| snippet.path.extension())
- .chain(prediction.request.excerpt_path.extension())
+ .filter_map(|file| file.path.extension())
{
if !languages.contains_key(ext) {
// Most snippets are gonna be the same language,
@@ -391,90 +365,6 @@ impl Zeta2Inspector {
let json_language = language_registry.language_for_name("Json").await.log_err();
this.update_in(cx, |this, window, cx| {
- let context_editor = cx.new(|cx| {
- let mut excerpt_score_components = HashMap::default();
-
- let multibuffer = cx.new(|cx| {
- let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
- let excerpt_file = Arc::new(ExcerptMetadataFile {
- title: RelPath::unix("Cursor Excerpt").unwrap().into(),
- path_style,
- worktree_id,
- });
-
- let excerpt_buffer = cx.new(|cx| {
- let mut buffer =
- Buffer::local(prediction.request.excerpt.clone(), cx);
- if let Some(language) = prediction
- .request
- .excerpt_path
- .extension()
- .and_then(|ext| languages.get(ext))
- {
- buffer.set_language(language.clone(), cx);
- }
- buffer.file_updated(excerpt_file, cx);
- buffer
- });
-
- multibuffer.push_excerpts(
- excerpt_buffer,
- [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
- cx,
- );
-
- let mut declarations =
- prediction.request.referenced_declarations.clone();
- declarations.sort_unstable_by_key(|declaration| {
- Reverse(OrderedFloat(declaration.declaration_score))
- });
-
- for snippet in &declarations {
- let snippet_file = Arc::new(ExcerptMetadataFile {
- title: RelPath::unix(&format!(
- "{} (Score: {})",
- snippet.path.display(),
- snippet.declaration_score
- ))
- .unwrap()
- .into(),
- path_style,
- worktree_id,
- });
-
- let excerpt_buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(snippet.text.clone(), cx);
- buffer.file_updated(snippet_file, cx);
- if let Some(ext) = snippet.path.extension()
- && let Some(language) = languages.get(ext)
- {
- buffer.set_language(language.clone(), cx);
- }
- buffer
- });
-
- let excerpt_ids = multibuffer.push_excerpts(
- excerpt_buffer,
- [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
- cx,
- );
- let excerpt_id = excerpt_ids.first().unwrap();
-
- excerpt_score_components
- .insert(*excerpt_id, snippet.score_components.clone());
- }
-
- multibuffer
- });
-
- let mut editor =
- Editor::new(EditorMode::full(), multibuffer, None, window, cx);
- editor.register_addon(ZetaContextAddon {
- excerpt_score_components,
- });
- editor
- });
-
let ZetaEditPredictionDebugInfo {
response_rx,
position,
@@ -606,7 +496,6 @@ impl Zeta2Inspector {
let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx);
this.last_prediction = Some(LastPrediction {
- context_editor,
prompt_editor: cx.new(|cx| {
let buffer = cx.new(|cx| {
let mut buffer =
@@ -632,7 +521,7 @@ impl Zeta2Inspector {
.foreground_executor()
.spawn(async move { Arc::new(project_snapshot_task.await) })
.shared(),
- request: prediction.request,
+ inputs: prediction.inputs,
_task: Some(task),
});
cx.notify();
@@ -664,9 +553,6 @@ impl Zeta2Inspector {
let Some(last_prediction) = self.last_prediction.as_mut() else {
return;
};
- if !last_prediction.request.can_collect_data {
- return;
- }
let project_snapshot_task = last_prediction.project_snapshot.clone();
@@ -718,7 +604,7 @@ impl Zeta2Inspector {
id = request_id,
kind = kind,
text = text,
- request = last_prediction.request,
+ request = last_prediction.inputs,
project_snapshot = project_snapshot,
);
})
@@ -727,17 +613,6 @@ impl Zeta2Inspector {
.detach();
}
- fn focus_feedback(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- if let Some(last_prediction) = self.last_prediction.as_mut() {
- if let LastPredictionState::Success {
- feedback_editor, ..
- } = &mut last_prediction.state
- {
- feedback_editor.focus_handle(cx).focus(window);
- }
- };
- }
-
fn render_options(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
v_flex()
.gap_2()
@@ -747,11 +622,11 @@ impl Zeta2Inspector {
.justify_between()
.child(
ui::Button::new("reset-options", "Reset")
- .disabled(self.zeta.read(cx).options() == &zeta2::DEFAULT_OPTIONS)
+ .disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Large)
.on_click(cx.listener(|this, _, window, cx| {
- this.set_options_state(&zeta2::DEFAULT_OPTIONS, window, cx);
+ this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx);
})),
),
)
@@ -915,42 +790,6 @@ impl Zeta2Inspector {
)
}
- fn render_tabs(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
- if self.last_prediction.is_none() {
- return None;
- };
-
- Some(
- ui::ToggleButtonGroup::single_row(
- "prediction",
- [
- ui::ToggleButtonSimple::new(
- "Context",
- cx.listener(|this, _, _, cx| {
- this.active_view = ActiveView::Context;
- cx.notify();
- }),
- ),
- ui::ToggleButtonSimple::new(
- "Inference",
- cx.listener(|this, _, window, cx| {
- this.active_view = ActiveView::Inference;
- this.focus_feedback(window, cx);
- cx.notify();
- }),
- ),
- ],
- )
- .style(ui::ToggleButtonGroupStyle::Outlined)
- .selected_index(if self.active_view == ActiveView::Context {
- 0
- } else {
- 1
- })
- .into_any_element(),
- )
- }
-
fn render_stats(&self) -> Option<Div> {
let Some(prediction) = self.last_prediction.as_ref() else {
return None;
@@ -970,15 +809,15 @@ impl Zeta2Inspector {
)
}
- fn render_duration(name: &'static str, time: Option<chrono::TimeDelta>) -> Div {
+ fn render_duration(name: &'static str, time: Option<Duration>) -> Div {
h_flex()
.gap_1()
.child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
.child(match time {
- Some(time) => Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 {
- format!("{} ms", time.num_milliseconds())
+ Some(time) => Label::new(if time.as_micros() >= 1000 {
+ format!("{} ms", time.as_millis())
} else {
- format!("{} ยตs", time.num_microseconds().unwrap_or(0))
+ format!("{} ยตs", time.as_micros())
})
.size(LabelSize::Small),
None => Label::new("...").size(LabelSize::Small),
@@ -1006,144 +845,135 @@ impl Zeta2Inspector {
}
fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
- match &self.active_view {
- ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
- ActiveView::Inference => h_flex()
- .items_start()
- .w_full()
- .flex_1()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().editor_background)
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .h_full()
- .child(
- h_flex()
- .justify_between()
- .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
- .child(match prediction.state {
- LastPredictionState::Requested
- | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
- .bg_color(cx.theme().status().warning_background)
- .label_color(Color::Success),
- LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
- .bg_color(cx.theme().status().success_background)
- .label_color(Color::Success),
- }),
- )
- .child(prediction.prompt_editor.clone()),
- )
- .child(ui::vertical_divider())
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .h_full()
- .child(
+ h_flex()
+ .items_start()
+ .w_full()
+ .flex_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .p_4()
+ .h_full()
+ .child(
+ h_flex()
+ .justify_between()
+ .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
+ .child(match prediction.state {
+ LastPredictionState::Requested
+ | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
+ .bg_color(cx.theme().status().warning_background)
+ .label_color(Color::Success),
+ LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
+ .bg_color(cx.theme().status().success_background)
+ .label_color(Color::Success),
+ }),
+ )
+ .child(prediction.prompt_editor.clone()),
+ )
+ .child(ui::vertical_divider())
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .h_full()
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .p_4()
+ .child(
+ ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
+ )
+ .child(match &prediction.state {
+ LastPredictionState::Success {
+ model_response_editor,
+ ..
+ } => model_response_editor.clone().into_any_element(),
+ LastPredictionState::Requested => v_flex()
+ .gap_2()
+ .child(Label::new("Loading...").buffer_font(cx))
+ .into_any_element(),
+ LastPredictionState::Failed { message } => v_flex()
+ .gap_2()
+ .max_w_96()
+ .child(Label::new(message.clone()).buffer_font(cx))
+ .into_any_element(),
+ }),
+ )
+ .child(ui::divider())
+ .child(
+ if let LastPredictionState::Success {
+ feedback_editor,
+ feedback: feedback_state,
+ ..
+ } = &prediction.state
+ {
v_flex()
- .flex_1()
+ .key_context("Zeta2Feedback")
+ .on_action(cx.listener(Self::handle_rate_positive))
+ .on_action(cx.listener(Self::handle_rate_negative))
.gap_2()
- .p_4()
+ .p_2()
+ .child(feedback_editor.clone())
.child(
- ui::Headline::new("Model Response")
- .size(ui::HeadlineSize::XSmall),
- )
- .child(match &prediction.state {
- LastPredictionState::Success {
- model_response_editor,
- ..
- } => model_response_editor.clone().into_any_element(),
- LastPredictionState::Requested => v_flex()
- .gap_2()
- .child(Label::new("Loading...").buffer_font(cx))
- .into_any_element(),
- LastPredictionState::Failed { message } => v_flex()
- .gap_2()
- .max_w_96()
- .child(Label::new(message.clone()).buffer_font(cx))
- .into_any_element(),
- }),
- )
- .child(ui::divider())
- .child(
- if prediction.request.can_collect_data
- && let LastPredictionState::Success {
- feedback_editor,
- feedback: feedback_state,
- ..
- } = &prediction.state
- {
- v_flex()
- .key_context("Zeta2Feedback")
- .on_action(cx.listener(Self::handle_rate_positive))
- .on_action(cx.listener(Self::handle_rate_negative))
- .gap_2()
- .p_2()
- .child(feedback_editor.clone())
- .child(
- h_flex()
- .justify_end()
- .w_full()
- .child(
- ButtonLike::new("rate-positive")
- .when(
- *feedback_state == Some(Feedback::Positive),
- |this| this.style(ButtonStyle::Filled),
- )
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionPositive,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
- )
- .child(ui::Icon::new(ui::IconName::ThumbsUp))
- .on_click(cx.listener(
- |this, _, window, cx| {
- this.handle_rate_positive(
- &Zeta2RatePredictionPositive,
- window,
- cx,
- );
- },
- )),
- )
- .child(
- ButtonLike::new("rate-negative")
- .when(
- *feedback_state == Some(Feedback::Negative),
- |this| this.style(ButtonStyle::Filled),
+ h_flex()
+ .justify_end()
+ .w_full()
+ .child(
+ ButtonLike::new("rate-positive")
+ .when(
+ *feedback_state == Some(Feedback::Positive),
+ |this| this.style(ButtonStyle::Filled),
+ )
+ .child(
+ KeyBinding::for_action(
+ &Zeta2RatePredictionPositive,
+ cx,
)
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionNegative,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
+ .size(TextSize::Small.rems(cx)),
+ )
+ .child(ui::Icon::new(ui::IconName::ThumbsUp))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.handle_rate_positive(
+ &Zeta2RatePredictionPositive,
+ window,
+ cx,
+ );
+ })),
+ )
+ .child(
+ ButtonLike::new("rate-negative")
+ .when(
+ *feedback_state == Some(Feedback::Negative),
+ |this| this.style(ButtonStyle::Filled),
+ )
+ .child(
+ KeyBinding::for_action(
+ &Zeta2RatePredictionNegative,
+ cx,
)
- .child(ui::Icon::new(ui::IconName::ThumbsDown))
- .on_click(cx.listener(
- |this, _, window, cx| {
- this.handle_rate_negative(
- &Zeta2RatePredictionNegative,
- window,
- cx,
- );
- },
- )),
- ),
- )
- .into_any()
- } else {
- Empty.into_any_element()
- },
- ),
- ),
- }
+ .size(TextSize::Small.rems(cx)),
+ )
+ .child(ui::Icon::new(ui::IconName::ThumbsDown))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.handle_rate_negative(
+ &Zeta2RatePredictionNegative,
+ window,
+ cx,
+ );
+ })),
+ ),
+ )
+ .into_any()
+ } else {
+ Empty.into_any_element()
+ },
+ ),
+ )
}
}
@@ -1178,8 +1008,7 @@ impl Render for Zeta2Inspector {
.h_full()
.justify_between()
.child(self.render_options(window, cx))
- .gap_4()
- .children(self.render_tabs(cx)),
+ .gap_4(),
)
.child(ui::vertical_divider())
.children(self.render_stats()),
@@ -1187,104 +1016,3 @@ impl Render for Zeta2Inspector {
.child(self.render_content(window, cx))
}
}
-
-// Using same approach as commit view
-
-struct ExcerptMetadataFile {
- title: Arc<RelPath>,
- worktree_id: WorktreeId,
- path_style: PathStyle,
-}
-
-impl language::File for ExcerptMetadataFile {
- fn as_local(&self) -> Option<&dyn language::LocalFile> {
- None
- }
-
- fn disk_state(&self) -> DiskState {
- DiskState::New
- }
-
- fn path(&self) -> &Arc<RelPath> {
- &self.title
- }
-
- fn full_path(&self, _: &App) -> PathBuf {
- self.title.as_std_path().to_path_buf()
- }
-
- fn file_name<'a>(&'a self, _: &'a App) -> &'a str {
- self.title.file_name().unwrap()
- }
-
- fn path_style(&self, _: &App) -> PathStyle {
- self.path_style
- }
-
- fn worktree_id(&self, _: &App) -> WorktreeId {
- self.worktree_id
- }
-
- fn to_proto(&self, _: &App) -> language::proto::File {
- unimplemented!()
- }
-
- fn is_private(&self) -> bool {
- false
- }
-}
-
-struct ZetaContextAddon {
- excerpt_score_components: HashMap<editor::ExcerptId, DeclarationScoreComponents>,
-}
-
-impl editor::Addon for ZetaContextAddon {
- fn to_any(&self) -> &dyn std::any::Any {
- self
- }
-
- fn render_buffer_header_controls(
- &self,
- excerpt_info: &multi_buffer::ExcerptInfo,
- _window: &Window,
- _cx: &App,
- ) -> Option<AnyElement> {
- let score_components = self.excerpt_score_components.get(&excerpt_info.id)?.clone();
-
- Some(
- div()
- .id(excerpt_info.id.to_proto() as usize)
- .child(ui::Icon::new(IconName::Info))
- .cursor(CursorStyle::PointingHand)
- .tooltip(move |_, cx| {
- cx.new(|_| ScoreComponentsTooltip::new(&score_components))
- .into()
- })
- .into_any(),
- )
- }
-}
-
-struct ScoreComponentsTooltip {
- text: SharedString,
-}
-
-impl ScoreComponentsTooltip {
- fn new(components: &DeclarationScoreComponents) -> Self {
- Self {
- text: format!("{:#?}", components).into(),
- }
- }
-}
-
-impl Render for ScoreComponentsTooltip {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- div().pl_2().pt_2p5().child(
- div()
- .elevation_2(cx)
- .py_1()
- .px_2()
- .child(ui::Label::new(self.text.clone()).buffer_font(cx)),
- )
- }
-}
@@ -53,8 +53,7 @@ terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
-zeta.workspace = true
-zeta2 = { workspace = true, features = ["eval-support"] }
+zeta = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
[dev-dependencies]
@@ -9,7 +9,7 @@ use collections::HashSet;
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
-use zeta2::{Zeta, udiff::DiffLine};
+use zeta::{Zeta, udiff::DiffLine};
use crate::{
EvaluateArguments, PredictionOptions,
@@ -26,7 +26,7 @@ use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
-use zeta2::udiff::OpenedBuffers;
+use zeta::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
@@ -557,7 +557,7 @@ impl NamedExample {
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
- zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
+ zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
@@ -31,7 +31,7 @@ use serde_json::json;
use std::io::{self};
use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
-use zeta2::ContextMode;
+use zeta::ContextMode;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
@@ -193,13 +193,14 @@ pub struct EvaluateArguments {
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
enum PredictionProvider {
+ Zeta1,
#[default]
Zeta2,
Sweep,
}
-fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
- zeta2::ZetaOptions {
+fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
+ zeta::ZetaOptions {
context: ContextMode::Syntax(EditPredictionContextOptions {
max_retrieved_declarations: args.max_retrieved_definitions,
use_imports: !args.disable_imports_gathering,
@@ -397,7 +398,7 @@ async fn zeta2_syntax_context(
let output = cx
.update(|cx| {
let zeta = cx.new(|cx| {
- zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
+ zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
});
let indexing_done_task = zeta.update(cx, |zeta, cx| {
zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
@@ -435,7 +436,7 @@ async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<zeta::GatherContextOutput> {
+) -> Result<zeta::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
@@ -450,7 +451,7 @@ async fn zeta1_context(
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
- zeta::gather_context(
+ zeta::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
@@ -21,7 +21,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
-use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
+use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
args: PredictArguments,
@@ -47,12 +47,13 @@ pub fn setup_zeta(
cx: &mut AsyncApp,
) -> Result<Entity<Zeta>> {
let zeta =
- cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+ cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
zeta.update(cx, |zeta, _cx| {
let model = match provider {
- PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
- PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
+ PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
};
zeta.set_edit_prediction_model(model);
})?;
@@ -142,25 +143,25 @@ pub async fn perform_predict(
let mut search_queries_executed_at = None;
while let Some(event) = debug_rx.next().await {
match event {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
}
- zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
}
- zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
- zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+ zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
+ zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
@@ -170,9 +171,9 @@ pub async fn perform_predict(
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
- for included_file in request.request.included_files {
+ for included_file in request.inputs.included_files {
let insertions =
- vec![(request.request.cursor_point, CURSOR_MARKER)];
+ vec![(request.inputs.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
@@ -182,7 +183,7 @@ pub async fn perform_predict(
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
- if included_file.path == request.request.excerpt_path {
+ if included_file.path == request.inputs.cursor_path {
&insertions
} else {
&[]
@@ -196,7 +197,7 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = zeta2::text_from_response(response).unwrap_or_default();
+ let response = zeta::text_from_response(response).unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
@@ -267,20 +268,7 @@ pub async fn perform_predict(
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
- .map(|prediction| {
- let old_text = prediction.snapshot.text();
- let new_text = prediction
- .buffer
- .update(cx, |buffer, cx| {
- let branch = buffer.branch(cx);
- branch.update(cx, |branch, cx| {
- branch.edit(prediction.edits.iter().cloned(), None, cx);
- branch.text()
- })
- })
- .unwrap();
- language::unified_diff(&old_text, &new_text)
- })
+ .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits))
.unwrap_or_default();
anyhow::Ok(result)
@@ -32,7 +32,7 @@ use std::{
time::Duration,
};
use util::paths::PathStyle;
-use zeta2::ContextMode;
+use zeta::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@@ -44,7 +44,7 @@ pub async fn retrieval_stats(
only_extension: Option<String>,
file_limit: Option<usize>,
skip_files: Option<usize>,
- options: zeta2::ZetaOptions,
+ options: zeta::ZetaOptions,
cx: &mut AsyncApp,
) -> Result<String> {
let ContextMode::Syntax(context_options) = options.context.clone() else {