diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 39482839c9d3f2e9af39e50476032ea0472ef7a6..79d5ac8ba4e638509b8da02677d13575747e6fcc 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1888,14 +1888,19 @@ impl EditPredictionStore { async fn send_raw_llm_request( request: RawCompletionRequest, client: Arc, + custom_url: Option>, llm_token: LlmApiToken, app_version: Version, #[cfg(feature = "cli-support")] eval_cache: Option>, #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind, ) -> Result<(RawCompletionResponse, Option)> { - let url = client - .http_client() - .build_zed_llm_url("/predict_edits/raw", &[])?; + let url = if let Some(custom_url) = custom_url { + custom_url.as_ref().clone() + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/raw", &[])? + }; #[cfg(feature = "cli-support")] let cache_key = if let Some(cache) = eval_cache { diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index b8fc7ce37fb5d793963b97dec6deb2855bf47ecc..270a4c6f589f43e9fe80604995b40d4464492faf 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -35,6 +35,7 @@ pub fn request_prediction_with_zeta2( cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); + let url = store.custom_predict_edits_url.clone(); let Some(excerpt_path) = snapshot .file() @@ -88,6 +89,7 @@ pub fn request_prediction_with_zeta2( let response = EditPredictionStore::send_raw_llm_request( request, client, + url, llm_token, app_version, #[cfg(feature = "cli-support")] diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 7a53bea32f68245fe45fe3a9aea16aa240b6655d..d2691246147ad27acac9961ee848997231bff8b4 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -1,7 +1,7 @@ use crate::{ PredictionProvider, PromptFormat, anthropic_client::AnthropicClient, - example::{Example, ExamplePrediction}, + example::{Example, ExamplePrediction, ExamplePrompt}, format_prompt::{TeacherPrompt, run_format_prompt}, headless::EpAppState, load_project::run_load_project, @@ -123,6 +123,13 @@ pub async fn run_prediction( if let Some(prompt) = request.prompt { fs::write(run_dir.join("prediction_prompt.md"), &prompt)?; + if provider == PredictionProvider::Zeta2 { + updated_example.prompt.get_or_insert(ExamplePrompt { + input: prompt, + expected_output: String::new(), + format: PromptFormat::Zeta2, + }); + } } } DebugEvent::EditPredictionFinished(request) => { diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index cb09431098a474086c8ec48c43cb29d264eeb83b..ab8976f1810bb3fa88cdadf06bbac42f18003806 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -109,6 +109,9 @@ fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { prompt.push_str("<|fim_suffix|>\n"); prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } prompt.push_str("<|fim_middle|>current\n"); prompt.push_str( @@ -119,7 +122,6 @@ fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { prompt.push_str( &input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], ); - if !prompt.ends_with('\n') { prompt.push('\n'); }