Detailed changes
@@ -3133,13 +3133,13 @@ name = "cloud_llm_client"
version = "0.1.0"
dependencies = [
"anyhow",
- "chrono",
"indoc",
"pretty_assertions",
"serde",
"serde_json",
"strum 0.27.2",
"uuid",
+ "zeta_prompt",
]
[[package]]
@@ -3247,7 +3247,7 @@ name = "codestral"
version = "0.1.0"
dependencies = [
"anyhow",
- "edit_prediction_context",
+ "edit_prediction",
"edit_prediction_types",
"futures 0.3.31",
"gpui",
@@ -5336,7 +5336,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"clock",
- "cloud_llm_client",
"collections",
"env_logger 0.11.8",
"futures 0.3.31",
@@ -16,11 +16,11 @@ path = "src/cloud_llm_client.rs"
[dependencies]
anyhow.workspace = true
-chrono.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
uuid = { workspace = true, features = ["serde"] }
+zeta_prompt.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true
@@ -1,219 +1,5 @@
-use chrono::Duration;
use serde::{Deserialize, Serialize};
-use std::{
- borrow::Cow,
- fmt::{Display, Write as _},
- ops::{Add, Range, Sub},
- path::Path,
- sync::Arc,
-};
-use strum::EnumIter;
-use uuid::Uuid;
-
-use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PlanContextRetrievalRequest {
- pub excerpt: String,
- pub excerpt_path: Arc<Path>,
- pub excerpt_line_range: Range<Line>,
- pub cursor_file_max_row: Line,
- pub events: Vec<Arc<Event>>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsRequest {
- pub excerpt: String,
- pub excerpt_path: Arc<Path>,
- /// Within file
- pub excerpt_range: Range<usize>,
- pub excerpt_line_range: Range<Line>,
- pub cursor_point: Point,
- /// Within `signatures`
- pub excerpt_parent: Option<usize>,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub related_files: Vec<RelatedFile>,
- pub events: Vec<Arc<Event>>,
- #[serde(default)]
- pub can_collect_data: bool,
- /// Info about the git repository state, only present when can_collect_data is true.
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub git_info: Option<PredictEditsGitInfo>,
- // Only available to staff
- #[serde(default)]
- pub debug_info: bool,
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub prompt_max_bytes: Option<usize>,
- #[serde(default)]
- pub prompt_format: PromptFormat,
- #[serde(default)]
- pub trigger: PredictEditsRequestTrigger,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct RelatedFile {
- pub path: Arc<Path>,
- pub max_row: Line,
- pub excerpts: Vec<Excerpt>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Excerpt {
- pub start_line: Line,
- pub text: Arc<str>,
-}
-
-#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
-pub enum PromptFormat {
- /// XML old_tex/new_text
- OldTextNewText,
- /// Prompt format intended for use via edit_prediction_cli
- OnlySnippets,
- /// One-sentence instructions used in fine-tuned models
- Minimal,
- /// One-sentence instructions + FIM-like template
- MinimalQwen,
- /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
- SeedCoder1120,
-}
-
-impl PromptFormat {
- pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
-}
-
-impl Default for PromptFormat {
- fn default() -> Self {
- Self::DEFAULT
- }
-}
-
-impl PromptFormat {
- pub fn iter() -> impl Iterator<Item = Self> {
- <Self as strum::IntoEnumIterator>::iter()
- }
-}
-
-impl std::fmt::Display for PromptFormat {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
- PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
- PromptFormat::Minimal => write!(f, "Minimal"),
- PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
- PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
- }
- }
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
-#[serde(tag = "event")]
-pub enum Event {
- BufferChange {
- path: Arc<Path>,
- old_path: Arc<Path>,
- diff: String,
- predicted: bool,
- in_open_source_repo: bool,
- },
-}
-
-impl Display for Event {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Event::BufferChange {
- path,
- old_path,
- diff,
- predicted,
- ..
- } => {
- if *predicted {
- write!(
- f,
- "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
- DiffPathFmt(old_path),
- DiffPathFmt(path)
- )
- } else {
- write!(
- f,
- "--- a/{}\n+++ b/{}\n{diff}",
- DiffPathFmt(old_path),
- DiffPathFmt(path)
- )
- }
- }
- }
- }
-}
-
-/// always format the Path as a unix path with `/` as the path sep in Diffs
-pub struct DiffPathFmt<'a>(pub &'a Path);
-
-impl<'a> std::fmt::Display for DiffPathFmt<'a> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let mut is_first = true;
- for component in self.0.components() {
- if !is_first {
- f.write_char('/')?;
- } else {
- is_first = false;
- }
- write!(f, "{}", component.as_os_str().display())?;
- }
- Ok(())
- }
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsResponse {
- pub request_id: Uuid,
- pub edits: Vec<Edit>,
- pub debug_info: Option<DebugInfo>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct DebugInfo {
- pub prompt: String,
- pub prompt_planning_time: Duration,
- pub model_response: String,
- pub inference_time: Duration,
- pub parsing_time: Duration,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Edit {
- pub path: Arc<Path>,
- pub range: Range<Line>,
- pub content: String,
-}
-
-#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
-pub struct Point {
- pub line: Line,
- pub column: u32,
-}
-
-#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
-#[serde(transparent)]
-pub struct Line(pub u32);
-
-impl Add for Line {
- type Output = Self;
-
- fn add(self, rhs: Self) -> Self::Output {
- Self(self.0 + rhs.0)
- }
-}
-
-impl Sub for Line {
- type Output = Self;
-
- fn sub(self, rhs: Self) -> Self::Output {
- Self(self.0 - rhs.0)
- }
-}
+use std::borrow::Cow;
#[derive(Debug, Deserialize, Serialize)]
pub struct RawCompletionRequest {
@@ -226,6 +12,22 @@ pub struct RawCompletionRequest {
pub stop: Vec<Cow<'static, str>>,
}
+#[derive(Debug, Serialize, Deserialize)]
+pub struct PredictEditsV3Request {
+ #[serde(flatten)]
+ pub input: zeta_prompt::ZetaPromptInput,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub model: Option<String>,
+ #[serde(default)]
+ pub prompt_version: zeta_prompt::ZetaVersion,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct PredictEditsV3Response {
+ pub request_id: String,
+ pub output: String,
+}
+
#[derive(Debug, Deserialize, Serialize)]
pub struct RawCompletionResponse {
pub id: String,
@@ -248,86 +50,3 @@ pub struct RawCompletionUsage {
pub completion_tokens: u32,
pub total_tokens: u32,
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use indoc::indoc;
- use pretty_assertions::assert_eq;
-
- #[test]
- fn test_event_display() {
- let ev = Event::BufferChange {
- path: Path::new("untitled").into(),
- old_path: Path::new("untitled").into(),
- diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
- predicted: false,
- in_open_source_repo: true,
- };
- assert_eq!(
- ev.to_string(),
- indoc! {"
- --- a/untitled
- +++ b/untitled
- @@ -1,2 +1,2 @@
- -a
- -b
- "}
- );
-
- let ev = Event::BufferChange {
- path: Path::new("foo/bar.txt").into(),
- old_path: Path::new("foo/bar.txt").into(),
- diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
- predicted: false,
- in_open_source_repo: true,
- };
- assert_eq!(
- ev.to_string(),
- indoc! {"
- --- a/foo/bar.txt
- +++ b/foo/bar.txt
- @@ -1,2 +1,2 @@
- -a
- -b
- "}
- );
-
- let ev = Event::BufferChange {
- path: Path::new("abc.txt").into(),
- old_path: Path::new("123.txt").into(),
- diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
- predicted: false,
- in_open_source_repo: true,
- };
- assert_eq!(
- ev.to_string(),
- indoc! {"
- --- a/123.txt
- +++ b/abc.txt
- @@ -1,2 +1,2 @@
- -a
- -b
- "}
- );
-
- let ev = Event::BufferChange {
- path: Path::new("abc.txt").into(),
- old_path: Path::new("123.txt").into(),
- diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
- predicted: true,
- in_open_source_repo: true,
- };
- assert_eq!(
- ev.to_string(),
- indoc! {"
- // User accepted prediction:
- --- a/123.txt
- +++ b/abc.txt
- @@ -1,2 +1,2 @@
- -a
- -b
- "}
- );
- }
-}
@@ -11,7 +11,7 @@ path = "src/codestral.rs"
[dependencies]
anyhow.workspace = true
edit_prediction_types.workspace = true
-edit_prediction_context.workspace = true
+edit_prediction.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
@@ -1,5 +1,5 @@
-use anyhow::{Context as _, Result};
-use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use anyhow::Result;
+use edit_prediction::cursor_excerpt;
use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
use futures::AsyncReadExt;
use gpui::{App, Context, Entity, Task};
@@ -15,16 +15,10 @@ use std::{
sync::Arc,
time::{Duration, Instant},
};
-use text::ToOffset;
+use text::{OffsetRangeExt as _, ToOffset};
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
-const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
- max_bytes: 1050,
- min_bytes: 525,
- target_before_cursor_over_total_bytes: 0.66,
-};
-
/// Represents a completion that has been received and processed from Codestral.
/// This struct maintains the state needed to interpolate the completion as the user types.
#[derive(Clone)]
@@ -235,19 +229,27 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate {
let cursor_offset = cursor_position.to_offset(&snapshot);
let cursor_point = cursor_offset.to_point(&snapshot);
- let excerpt = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &snapshot,
- &EXCERPT_OPTIONS,
- )
- .context("Line containing cursor doesn't fit in excerpt max bytes")?;
- let excerpt_text = excerpt.text(&snapshot);
+ const MAX_CONTEXT_TOKENS: usize = 150;
+ const MAX_REWRITE_TOKENS: usize = 350;
+
+ let (_, context_range) =
+ cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ MAX_REWRITE_TOKENS,
+ MAX_CONTEXT_TOKENS,
+ );
+
+ let context_range = context_range.to_offset(&snapshot);
+ let excerpt_text = snapshot
+ .text_for_range(context_range.clone())
+ .collect::<String>();
let cursor_within_excerpt = cursor_offset
- .saturating_sub(excerpt.range.start)
- .min(excerpt_text.body.len());
- let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
- let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
+ .saturating_sub(context_range.start)
+ .min(excerpt_text.len());
+ let prompt = excerpt_text[..cursor_within_excerpt].to_string();
+ let suffix = excerpt_text[cursor_within_excerpt..].to_string();
let completion_text = match Self::fetch_completion(
http_client,
@@ -2,7 +2,7 @@ use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{
- self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
+ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
};
use cloud_llm_client::{
EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
@@ -12,7 +12,6 @@ use cloud_llm_client::{
use collections::{HashMap, HashSet};
use copilot::Copilot;
use db::kvp::{Dismissable, KEY_VALUE_STORE};
-use edit_prediction_context::EditPredictionExcerptOptions;
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
@@ -39,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
use std::collections::{VecDeque, hash_map};
use text::Edit;
use workspace::Workspace;
+use zeta_prompt::ZetaPromptInput;
use zeta_prompt::ZetaVersion;
use std::ops::Range;
@@ -113,27 +113,8 @@ impl FeatureFlag for MercuryFeatureFlag {
const NAME: &str = "mercury";
}
-pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
- context: EditPredictionExcerptOptions {
- max_bytes: 512,
- min_bytes: 128,
- target_before_cursor_over_total_bytes: 0.5,
- },
- prompt_format: PromptFormat::DEFAULT,
-};
-
-static USE_OLLAMA: LazyLock<bool> =
- LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
-
-static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- match env::var("ZED_ZETA2_MODEL").as_deref() {
- Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
- Ok(model) => model,
- Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
- Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
- }
- .to_string()
-});
+static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
+ LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
pub struct Zeta2FeatureFlag;
@@ -167,10 +148,7 @@ pub struct EditPredictionStore {
_llm_token_subscription: Subscription,
projects: HashMap<EntityId, ProjectState>,
use_context: bool,
- options: ZetaOptions,
update_required: bool,
- #[cfg(feature = "cli-support")]
- eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
pub sweep_ai: SweepAi,
pub mercury: Mercury,
@@ -206,12 +184,6 @@ pub struct EditPredictionModelInput {
pub user_actions: Vec<UserActionRecord>,
}
-#[derive(Debug, Clone, PartialEq)]
-pub struct ZetaOptions {
- pub context: EditPredictionExcerptOptions,
- pub prompt_format: predict_edits_v3::PromptFormat,
-}
-
#[derive(Debug)]
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
@@ -248,8 +220,6 @@ pub struct EditPredictionFinishedDebugEvent {
pub model_output: Option<String>,
}
-pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
-
const USER_ACTION_HISTORY_SIZE: usize = 16;
#[derive(Clone, Debug)]
@@ -641,7 +611,6 @@ impl EditPredictionStore {
projects: HashMap::default(),
client,
user_store,
- options: DEFAULT_OPTIONS,
use_context: false,
llm_token,
_llm_token_subscription: cx.subscribe(
@@ -657,8 +626,6 @@ impl EditPredictionStore {
},
),
update_required: false,
- #[cfg(feature = "cli-support")]
- eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2 {
version: Default::default(),
},
@@ -671,17 +638,7 @@ impl EditPredictionStore {
shown_predictions: Default::default(),
custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
- Err(_) => {
- if *USE_OLLAMA {
- Some(
- Url::parse("http://localhost:11434/v1/chat/completions")
- .unwrap()
- .into(),
- )
- } else {
- None
- }
- }
+ Err(_) => None,
},
};
@@ -718,19 +675,6 @@ impl EditPredictionStore {
self.mercury.api_token.read(cx).has_key()
}
- #[cfg(feature = "cli-support")]
- pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
- self.eval_cache = Some(cache);
- }
-
- pub fn options(&self) -> &ZetaOptions {
- &self.options
- }
-
- pub fn set_options(&mut self, options: ZetaOptions) {
- self.options = options;
- }
-
pub fn set_use_context(&mut self, use_context: bool) {
self.use_context = use_context;
}
@@ -1946,8 +1890,6 @@ impl EditPredictionStore {
custom_url: Option<Arc<Url>>,
llm_token: LlmApiToken,
app_version: Version,
- #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
- #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
let url = if let Some(custom_url) = custom_url {
custom_url.as_ref().clone()
@@ -1957,28 +1899,39 @@ impl EditPredictionStore {
.build_zed_llm_url("/predict_edits/raw", &[])?
};
- #[cfg(feature = "cli-support")]
- let cache_key = if let Some(cache) = eval_cache {
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- url.hash(&mut hasher);
- let request_str = serde_json::to_string_pretty(&request)?;
- request_str.hash(&mut hasher);
- let hash = hasher.finish();
-
- let key = (eval_cache_kind, hash);
- if let Some(response_str) = cache.read(key) {
- return Ok((serde_json::from_str(&response_str)?, None));
- }
+ Self::send_api_request(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&request)?.into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ true,
+ )
+ .await
+ }
- Some((cache, request_str, key))
- } else {
- None
+ pub(crate) async fn send_v3_request(
+ input: ZetaPromptInput,
+ prompt_version: ZetaVersion,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
+ let url = client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/v3", &[])?;
+
+ let request = PredictEditsV3Request {
+ input,
+ model: EDIT_PREDICTIONS_MODEL_ID.clone(),
+ prompt_version,
};
- let (response, usage) = Self::send_api_request(
+ Self::send_api_request(
|builder| {
let req = builder
.uri(url.as_ref())
@@ -1990,14 +1943,7 @@ impl EditPredictionStore {
app_version,
true,
)
- .await?;
-
- #[cfg(feature = "cli-support")]
- if let Some((cache, request, key)) = cache_key {
- cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
- }
-
- Ok((response, usage))
+ .await
}
fn handle_api_response<T>(
@@ -2282,34 +2228,6 @@ pub struct ZedUpdateRequiredError {
minimum_version: Version,
}
-#[cfg(feature = "cli-support")]
-pub type EvalCacheKey = (EvalCacheEntryKind, u64);
-
-#[cfg(feature = "cli-support")]
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub enum EvalCacheEntryKind {
- Context,
- Search,
- Prediction,
-}
-
-#[cfg(feature = "cli-support")]
-impl std::fmt::Display for EvalCacheEntryKind {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- EvalCacheEntryKind::Search => write!(f, "search"),
- EvalCacheEntryKind::Context => write!(f, "context"),
- EvalCacheEntryKind::Prediction => write!(f, "prediction"),
- }
- }
-}
-
-#[cfg(feature = "cli-support")]
-pub trait EvalCache: Send + Sync {
- fn read(&self, key: EvalCacheKey) -> Option<String>;
- fn write(&self, key: EvalCacheKey, input: &str, value: &str);
-}
-
#[derive(Debug, Clone, Copy)]
pub enum DataCollectionChoice {
NotAnswered,
@@ -6,9 +6,7 @@ use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
- predict_edits_v3::{
- RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
- },
+ predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
};
use futures::{
AsyncReadExt, StreamExt,
@@ -72,7 +70,7 @@ async fn test_current_state(cx: &mut TestAppContext) {
respond_tx
.send(model_response(
- request,
+ &request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ -129,7 +127,7 @@ async fn test_current_state(cx: &mut TestAppContext) {
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(
- request,
+ &request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ -213,7 +211,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
respond_tx
.send(model_response(
- request,
+ &request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ -290,7 +288,7 @@ async fn test_request_events(cx: &mut TestAppContext) {
respond_tx
.send(model_response(
- request,
+ &request,
indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ -622,8 +620,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(request, "");
- let id = response.id.clone();
+ let response = model_response(&request, "");
+ let id = response.request_id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
@@ -682,8 +680,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
- let response = model_response(request, SIMPLE_DIFF);
- let id = response.id.clone();
+ let response = model_response(&request, SIMPLE_DIFF);
+ let id = response.request_id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
@@ -747,8 +745,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(request, SIMPLE_DIFF);
- let first_id = first_response.id.clone();
+ let first_response = model_response(&request, SIMPLE_DIFF);
+ let first_id = first_response.request_id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
@@ -770,8 +768,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(request, SIMPLE_DIFF);
- let second_id = second_response.id.clone();
+ let second_response = model_response(&request, SIMPLE_DIFF);
+ let second_id = second_response.request_id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
@@ -829,8 +827,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(request, SIMPLE_DIFF);
- let first_id = first_response.id.clone();
+ let first_response = model_response(&request, SIMPLE_DIFF);
+ let first_id = first_response.request_id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
@@ -854,7 +852,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
let second_response = model_response(
- request,
+ &request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ -865,7 +863,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
Bye
"},
);
- let second_id = second_response.id.clone();
+ let second_id = second_response.request_id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
@@ -935,8 +933,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
cx.run_until_parked();
// second responds first
- let second_response = model_response(request, SIMPLE_DIFF);
- let second_id = second_response.id.clone();
+ let second_response = model_response(&request, SIMPLE_DIFF);
+ let second_id = second_response.request_id.clone();
respond_second.send(second_response).unwrap();
cx.run_until_parked();
@@ -953,8 +951,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
- let first_response = model_response(request1, SIMPLE_DIFF);
- let first_id = first_response.id.clone();
+ let first_response = model_response(&request1, SIMPLE_DIFF);
+ let first_id = first_response.request_id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
@@ -1046,8 +1044,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
let (request3, respond_third) = requests.predict.next().await.unwrap();
- let first_response = model_response(request1, SIMPLE_DIFF);
- let first_id = first_response.id.clone();
+ let first_response = model_response(&request1, SIMPLE_DIFF);
+ let first_id = first_response.request_id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
@@ -1064,8 +1062,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let cancelled_response = model_response(request2, SIMPLE_DIFF);
- let cancelled_id = cancelled_response.id.clone();
+ let cancelled_response = model_response(&request2, SIMPLE_DIFF);
+ let cancelled_id = cancelled_response.request_id.clone();
respond_second.send(cancelled_response).unwrap();
cx.run_until_parked();
@@ -1082,8 +1080,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let third_response = model_response(request3, SIMPLE_DIFF);
- let third_response_id = third_response.id.clone();
+ let third_response = model_response(&request3, SIMPLE_DIFF);
+ let third_response_id = third_response.request_id.clone();
respond_third.send(third_response).unwrap();
cx.run_until_parked();
@@ -1327,50 +1325,26 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// }
// Generate a model response that would apply the given diff to the active file.
-fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
- let prompt = &request.prompt;
-
- let current_marker = "<|fim_middle|>current\n";
- let updated_marker = "<|fim_middle|>updated\n";
- let suffix_marker = "<|fim_suffix|>\n";
- let cursor = "<|user_cursor|>";
-
- let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
- let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
- let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
- // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
- // Strip that out to get just the editable region.
- let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
- &excerpt[..suffix_pos]
- } else {
- &excerpt
- };
- let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
-
- RawCompletionResponse {
- id: Uuid::new_v4().to_string(),
- object: "text_completion".into(),
- created: 0,
- model: "model".into(),
- choices: vec![RawCompletionChoice {
- text: new_excerpt,
- finish_reason: None,
- }],
- usage: RawCompletionUsage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
+fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
+ let excerpt =
+ request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
+ let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+
+ PredictEditsV3Response {
+ request_id: Uuid::new_v4().to_string(),
+ output: new_excerpt,
}
}
-fn prompt_from_request(request: &RawCompletionRequest) -> &str {
- &request.prompt
+fn prompt_from_request(request: &PredictEditsV3Request) -> String {
+ zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
}
struct RequestChannels {
- predict:
- mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
+ predict: mpsc::UnboundedReceiver<(
+ PredictEditsV3Request,
+ oneshot::Sender<PredictEditsV3Response>,
+ )>,
reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
}
@@ -1397,7 +1371,7 @@ fn init_test_with_fake_client(
"token": "test"
}))
.unwrap(),
- "/predict_edits/raw" => {
+ "/predict_edits/v3" => {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.ok();
let req = serde_json::from_slice(&buf).unwrap();
@@ -1677,20 +1651,9 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
// Model returns output WITH a trailing newline, even though the buffer doesn't have one.
// Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
- let response = RawCompletionResponse {
- id: Uuid::new_v4().to_string(),
- object: "text_completion".into(),
- created: 0,
- model: "model".into(),
- choices: vec![RawCompletionChoice {
- text: "hello world\n".to_string(),
- finish_reason: None,
- }],
- usage: RawCompletionUsage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
+ let response = PredictEditsV3Response {
+ request_id: Uuid::new_v4().to_string(),
+ output: "hello world\n".to_string(),
};
respond_tx.send(response).unwrap();
@@ -15,8 +15,8 @@ use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
-const MAX_CONTEXT_TOKENS: usize = 150;
-const MAX_REWRITE_TOKENS: usize = 350;
+const MAX_REWRITE_TOKENS: usize = 150;
+const MAX_CONTEXT_TOKENS: usize = 350;
pub struct Mercury {
pub api_token: Entity<ApiKeyState>,
@@ -1,5 +1,3 @@
-#[cfg(feature = "cli-support")]
-use crate::EvalCacheEntryKind;
use crate::prediction::EditPredictionResult;
use crate::{
CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
@@ -22,8 +20,8 @@ pub const MAX_CONTEXT_TOKENS: usize = 350;
pub fn max_editable_tokens(version: ZetaVersion) -> usize {
match version {
- ZetaVersion::V0112_MiddleAtEnd | ZetaVersion::V0113_Ordered => 150,
- ZetaVersion::V0114_180EditableRegion => 180,
+ ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
+ ZetaVersion::V0114180EditableRegion => 180,
}
}
@@ -42,7 +40,7 @@ pub fn request_prediction_with_zeta2(
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer_snapshotted_at = Instant::now();
- let url = store.custom_predict_edits_url.clone();
+ let custom_url = store.custom_predict_edits_url.clone();
let Some(excerpt_path) = snapshot
.file()
@@ -55,9 +53,6 @@ pub fn request_prediction_with_zeta2(
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
- #[cfg(feature = "cli-support")]
- let eval_cache = store.eval_cache.clone();
-
let request_task = cx.background_spawn({
async move {
let cursor_offset = position.to_offset(&snapshot);
@@ -70,49 +65,68 @@ pub fn request_prediction_with_zeta2(
zeta_version,
);
- let prompt = format_zeta_prompt(&prompt_input, zeta_version);
-
if let Some(debug_tx) = &debug_tx {
+ let prompt = format_zeta_prompt(&prompt_input, zeta_version);
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
- prompt: Some(prompt.clone()),
+ prompt: Some(prompt),
position,
},
))
.ok();
}
- let request = RawCompletionRequest {
- model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- prompt,
- temperature: None,
- stop: vec![],
- max_tokens: Some(2048),
- };
-
log::trace!("Sending edit prediction request");
- let response = EditPredictionStore::send_raw_llm_request(
- request,
- client,
- url,
- llm_token,
- app_version,
- #[cfg(feature = "cli-support")]
- eval_cache,
- #[cfg(feature = "cli-support")]
- EvalCacheEntryKind::Prediction,
- )
- .await;
+ let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
+ // Use raw endpoint with custom URL
+ let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+ let request = RawCompletionRequest {
+ model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
+ prompt,
+ temperature: None,
+ stop: vec![],
+ max_tokens: Some(2048),
+ };
+
+ let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
+ request,
+ client,
+ Some(custom_url),
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(response.id.clone().into());
+ let output_text = response.choices.pop().map(|choice| choice.text);
+ (request_id, output_text, usage)
+ } else {
+ let (response, usage) = EditPredictionStore::send_v3_request(
+ prompt_input.clone(),
+ zeta_version,
+ client,
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(response.request_id.into());
+ let output_text = if response.output.is_empty() {
+ None
+ } else {
+ Some(response.output)
+ };
+ (request_id, output_text, usage)
+ };
+
let received_response_at = Instant::now();
log::trace!("Got edit prediction response");
- let (mut res, usage) = response?;
- let request_id = EditPredictionId(res.id.clone().into());
- let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
+ let Some(mut output_text) = output_text else {
return Ok((Some((request_id, None)), usage));
};
@@ -14,7 +14,6 @@ path = "src/edit_prediction_context.rs"
[dependencies]
anyhow.workspace = true
clock.workspace = true
-cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -20,12 +20,9 @@ use util::{RangeExt as _, ResultExt};
mod assemble_excerpts;
#[cfg(test)]
mod edit_prediction_context_tests;
-mod excerpt;
#[cfg(test)]
mod fake_definition_lsp;
-pub use cloud_llm_client::predict_edits_v3::Line;
-pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
@@ -1,556 +0,0 @@
-use cloud_llm_client::predict_edits_v3::Line;
-use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _};
-use std::ops::Range;
-use tree_sitter::{Node, TreeCursor};
-use util::RangeExt;
-
-// TODO:
-//
-// - Test parent signatures
-//
-// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
-// planning.
-//
-// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
-// paragraph).
-//
-// - Truncation of long lines.
-//
-// - Filter outer syntax layers that don't support edit prediction.
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct EditPredictionExcerptOptions {
- /// Limit for the number of bytes in the window around the cursor.
- pub max_bytes: usize,
- /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
- /// in an excerpt smaller than this, it will fall back on line-based selection.
- pub min_bytes: usize,
- /// Target ratio of bytes before the cursor divided by total bytes in the window.
- pub target_before_cursor_over_total_bytes: f32,
-}
-
-#[derive(Debug, Clone)]
-pub struct EditPredictionExcerpt {
- pub range: Range<usize>,
- pub line_range: Range<Line>,
- pub size: usize,
-}
-
-#[derive(Debug, Clone)]
-pub struct EditPredictionExcerptText {
- pub body: String,
- pub language_id: Option<LanguageId>,
-}
-
-impl EditPredictionExcerpt {
- pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
- let body = buffer
- .text_for_range(self.range.clone())
- .collect::<String>();
- let language_id = buffer.language().map(|l| l.id());
- EditPredictionExcerptText { body, language_id }
- }
-
- /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
- /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
- /// cursor.
- ///
- /// When `index` is provided, the excerpt will include the signatures of parent outline items.
- ///
- /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
- /// expansion.
- ///
- /// Returns `None` if the line around the cursor doesn't fit.
- pub fn select_from_buffer(
- query_point: Point,
- buffer: &BufferSnapshot,
- options: &EditPredictionExcerptOptions,
- ) -> Option<Self> {
- if buffer.len() <= options.max_bytes {
- log::debug!(
- "using entire file for excerpt since source length ({}) <= window max bytes ({})",
- buffer.len(),
- options.max_bytes
- );
- let offset_range = 0..buffer.len();
- let line_range = Line(0)..Line(buffer.max_point().row);
- return Some(EditPredictionExcerpt::new(offset_range, line_range));
- }
-
- let query_offset = query_point.to_offset(buffer);
- let query_line_range = query_point.row..query_point.row + 1;
- let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
- ..Point::new(query_line_range.end, 0).to_offset(buffer);
- if query_range.len() >= options.max_bytes {
- return None;
- }
-
- let excerpt_selector = ExcerptSelector {
- query_offset,
- query_range,
- query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
- buffer,
- options,
- };
-
- if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
- if excerpt.size >= options.min_bytes {
- return Some(excerpt);
- }
- log::debug!(
- "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
- excerpt.size,
- options.min_bytes
- );
- } else {
- log::debug!(
- "couldn't find excerpt via tree-sitter, falling back on line-based selection"
- );
- }
-
- excerpt_selector.select_lines()
- }
-
- fn new(range: Range<usize>, line_range: Range<Line>) -> Self {
- Self {
- size: range.len(),
- range,
- line_range,
- }
- }
-
- fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
- if !new_range.contains_inclusive(&self.range) {
- // this is an issue because parent_signature_ranges may be incorrect
- log::error!("bug: with_expanded_range called with disjoint range");
- }
- Self::new(new_range, new_line_range)
- }
-
- fn parent_signatures_size(&self) -> usize {
- self.size - self.range.len()
- }
-}
-
-struct ExcerptSelector<'a> {
- query_offset: usize,
- query_range: Range<usize>,
- query_line_range: Range<Line>,
- buffer: &'a BufferSnapshot,
- options: &'a EditPredictionExcerptOptions,
-}
-
-impl<'a> ExcerptSelector<'a> {
- /// Finds the largest node that is smaller than the window size and contains `query_range`.
- fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
- let selected_layer_root = self.select_syntax_layer()?;
- let mut cursor = selected_layer_root.walk();
-
- loop {
- let line_start = node_line_start(cursor.node());
- let line_end = node_line_end(cursor.node());
- let line_range = Line(line_start.row)..Line(line_end.row);
- let excerpt_range =
- line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
- if excerpt_range.contains_inclusive(&self.query_range) {
- let excerpt = self.make_excerpt(excerpt_range, line_range);
- if excerpt.size <= self.options.max_bytes {
- return Some(self.expand_to_siblings(&mut cursor, excerpt));
- }
- } else {
- // TODO: Should still be able to handle this case via AST nodes. For example, this
- // can happen if the cursor is between two methods in a large class file.
- return None;
- }
-
- if cursor
- .goto_first_child_for_byte(self.query_range.start)
- .is_none()
- {
- return None;
- }
- }
- }
-
- /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
- fn select_syntax_layer(&self) -> Option<Node<'_>> {
- let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
- let mut largest: Option<Node<'_>> = None;
- for layer in self
- .buffer
- .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
- {
- let layer_range = layer.node().byte_range();
- if !layer_range.contains_inclusive(&self.query_range) {
- continue;
- }
-
- if layer_range.len() > self.options.max_bytes {
- match &smallest_exceeding_max_len {
- None => smallest_exceeding_max_len = Some(layer.node()),
- Some(existing) => {
- if layer_range.len() < existing.byte_range().len() {
- smallest_exceeding_max_len = Some(layer.node());
- }
- }
- }
- } else {
- match &largest {
- None => largest = Some(layer.node()),
- Some(existing) if layer_range.len() > existing.byte_range().len() => {
- largest = Some(layer.node())
- }
- _ => {}
- }
- }
- }
-
- smallest_exceeding_max_len.or(largest)
- }
-
- // motivation for this and `goto_previous_named_sibling` is to avoid including things like
- // trailing unnamed "}" in body nodes
- fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
- while cursor.goto_next_sibling() {
- if cursor.node().is_named() {
- return true;
- }
- }
- false
- }
-
- fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
- while cursor.goto_previous_sibling() {
- if cursor.node().is_named() {
- return true;
- }
- }
- false
- }
-
- fn expand_to_siblings(
- &self,
- cursor: &mut TreeCursor,
- mut excerpt: EditPredictionExcerpt,
- ) -> EditPredictionExcerpt {
- let mut forward_cursor = cursor.clone();
- let backward_cursor = cursor;
- let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
- let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
- loop {
- if backward_done && forward_done {
- break;
- }
-
- let mut forward = None;
- while !forward_done {
- let new_end_point = node_line_end(forward_cursor.node());
- let new_end = new_end_point.to_offset(&self.buffer);
- if new_end > excerpt.range.end {
- let new_excerpt = excerpt.with_expanded_range(
- excerpt.range.start..new_end,
- excerpt.line_range.start..Line(new_end_point.row),
- );
- if new_excerpt.size <= self.options.max_bytes {
- forward = Some(new_excerpt);
- break;
- } else {
- log::debug!("halting forward expansion, as it doesn't fit");
- forward_done = true;
- break;
- }
- }
- forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
- }
-
- let mut backward = None;
- while !backward_done {
- let new_start_point = node_line_start(backward_cursor.node());
- let new_start = new_start_point.to_offset(&self.buffer);
- if new_start < excerpt.range.start {
- let new_excerpt = excerpt.with_expanded_range(
- new_start..excerpt.range.end,
- Line(new_start_point.row)..excerpt.line_range.end,
- );
- if new_excerpt.size <= self.options.max_bytes {
- backward = Some(new_excerpt);
- break;
- } else {
- log::debug!("halting backward expansion, as it doesn't fit");
- backward_done = true;
- break;
- }
- }
- backward_done = !Self::goto_previous_named_sibling(backward_cursor);
- }
-
- let go_forward = match (forward, backward) {
- (Some(forward), Some(backward)) => {
- let go_forward = self.is_better_excerpt(&forward, &backward);
- if go_forward {
- excerpt = forward;
- } else {
- excerpt = backward;
- }
- go_forward
- }
- (Some(forward), None) => {
- log::debug!("expanding forward, since backward expansion has halted");
- excerpt = forward;
- true
- }
- (None, Some(backward)) => {
- log::debug!("expanding backward, since forward expansion has halted");
- excerpt = backward;
- false
- }
- (None, None) => break,
- };
-
- if go_forward {
- forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
- } else {
- backward_done = !Self::goto_previous_named_sibling(backward_cursor);
- }
- }
-
- excerpt
- }
-
- fn select_lines(&self) -> Option<EditPredictionExcerpt> {
- // early return if line containing query_offset is already too large
- let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
- if excerpt.size > self.options.max_bytes {
- log::debug!(
- "excerpt for cursor line is {} bytes, which exceeds the window",
- excerpt.size
- );
- return None;
- }
- let signatures_size = excerpt.parent_signatures_size();
- let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
-
- let before_bytes =
- (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
-
- let start_line = {
- let offset = self.query_offset.saturating_sub(before_bytes);
- let point = offset.to_point(self.buffer);
- Line(point.row + 1)
- };
- let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
- let end_line = {
- let offset = start_offset + bytes_remaining;
- let point = offset.to_point(self.buffer);
- Line(point.row)
- };
- let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
-
- // this could be expanded further since recalculated `signature_size` may be smaller, but
- // skipping that for now for simplicity
- //
- // TODO: could also consider checking if lines immediately before / after fit.
- let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
- if excerpt.size > self.options.max_bytes {
- log::error!(
- "bug: line-based excerpt selection has size {}, \
- which is {} bytes larger than the max size",
- excerpt.size,
- excerpt.size - self.options.max_bytes
- );
- }
- return Some(excerpt);
- }
-
- fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
- EditPredictionExcerpt::new(range, line_range)
- }
-
- /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
- fn is_better_excerpt(
- &self,
- forward: &EditPredictionExcerpt,
- backward: &EditPredictionExcerpt,
- ) -> bool {
- let forward_ratio = self.excerpt_range_ratio(forward);
- let backward_ratio = self.excerpt_range_ratio(backward);
- let forward_delta =
- (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
- let backward_delta =
- (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
- let forward_is_better = forward_delta <= backward_delta;
- if forward_is_better {
- log::debug!(
- "expanding forward since {} is closer than {} to {}",
- forward_ratio,
- backward_ratio,
- self.options.target_before_cursor_over_total_bytes
- );
- } else {
- log::debug!(
- "expanding backward since {} is closer than {} to {}",
- backward_ratio,
- forward_ratio,
- self.options.target_before_cursor_over_total_bytes
- );
- }
- forward_is_better
- }
-
- /// Returns the ratio of bytes before the cursor over bytes within the range.
- fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
- let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
- log::error!("bug: edit prediction cursor offset is not outside the excerpt");
- return 0.0;
- };
- bytes_before_cursor as f32 / excerpt.range.len() as f32
- }
-}
-
-fn node_line_start(node: Node) -> Point {
- Point::new(node.start_position().row as u32, 0)
-}
-
-fn node_line_end(node: Node) -> Point {
- Point::new(node.end_position().row as u32 + 1, 0)
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::{AppContext, TestAppContext};
- use language::Buffer;
- use util::test::{generate_marked_text, marked_text_offsets_by};
-
- fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
- buffer.read_with(cx, |buffer, _| buffer.snapshot())
- }
-
- fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
- let (text, offsets) = marked_text_offsets_by(text, vec!['ห', 'ยซ', 'ยป']);
- (text, offsets[&'ห'][0], offsets[&'ยซ'][0]..offsets[&'ยป'][0])
- }
-
- fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
- let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
-
- let buffer = create_buffer(&text, cx);
- let cursor_point = cursor.to_point(&buffer);
-
- let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
- .expect("Should select an excerpt");
- pretty_assertions::assert_eq!(
- generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
- generate_marked_text(&text, &[expected_excerpt], false)
- );
- assert!(excerpt.size <= options.max_bytes);
- assert!(excerpt.range.contains(&cursor));
- }
-
- #[gpui::test]
- fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
- zlog::init_test();
- let text = r#"
-fn main() {
- let x = 1;
-ยซ let หy = 2;
-ยป let z = 3;
-}"#;
-
- let options = EditPredictionExcerptOptions {
- max_bytes: 20,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.5,
- };
-
- check_example(options, text, cx);
- }
-
- #[gpui::test]
- fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
- zlog::init_test();
- let text = r#"
-fn foo() {}
-
-ยซfn main() {
- let x = 1;
- let หy = 2;
- let z = 3;
-}
-ยป
-fn bar() {}"#;
-
- let options = EditPredictionExcerptOptions {
- max_bytes: 65,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.5,
- };
-
- check_example(options, text, cx);
- }
-
- #[gpui::test]
- fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
- zlog::init_test();
- let text = r#"
-fn main() {
-ยซ let x = 1;
- let หy = 2;
- let z = 3;
-ยป}"#;
-
- let options = EditPredictionExcerptOptions {
- max_bytes: 50,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.5,
- };
-
- check_example(options, text, cx);
- }
-
- #[gpui::test]
- fn test_line_based_selection(cx: &mut TestAppContext) {
- zlog::init_test();
- let text = r#"
-fn main() {
- let x = 1;
-ยซ if true {
- let หy = 2;
- }
- let z = 3;
-ยป}"#;
-
- let options = EditPredictionExcerptOptions {
- max_bytes: 60,
- min_bytes: 45,
- target_before_cursor_over_total_bytes: 0.5,
- };
-
- check_example(options, text, cx);
- }
-
- #[gpui::test]
- fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
- zlog::init_test();
- let text = r#"
- fn main() {
-ยซ let a = 1;
- let b = 2;
- let c = 3;
- let หd = 4;
- let e = 5;
- let f = 6;
-ยป
- let g = 7;
- }"#;
-
- let options = EditPredictionExcerptOptions {
- max_bytes: 120,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.6,
- };
-
- check_example(options, text, cx);
- }
-}
@@ -33,10 +33,10 @@ pub struct ZetaPromptInput {
)]
#[allow(non_camel_case_types)]
pub enum ZetaVersion {
- V0112_MiddleAtEnd,
- V0113_Ordered,
+ V0112MiddleAtEnd,
+ V0113Ordered,
#[default]
- V0114_180EditableRegion,
+ V0114180EditableRegion,
}
impl std::fmt::Display for ZetaVersion {
@@ -134,10 +134,10 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri
write_edit_history_section(&mut prompt, input);
match version {
- ZetaVersion::V0112_MiddleAtEnd => {
+ ZetaVersion::V0112MiddleAtEnd => {
v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
}
- ZetaVersion::V0113_Ordered | ZetaVersion::V0114_180EditableRegion => {
+ ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
}
}