From 78535890330b684dcea8e5bdbf387f02d2b4f2da Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 12 Jan 2026 14:42:50 -0300 Subject: [PATCH] ep: Use non-chat completions for /predict/raw (#46633) Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld Co-authored-by: Oleksiy Syvokon --- .../cloud_llm_client/src/predict_edits_v3.rs | 33 ++++++++++++++ crates/edit_prediction/src/edit_prediction.rs | 8 ++-- .../src/edit_prediction_tests.rs | 45 ++++++------------- crates/edit_prediction/src/zeta2.rs | 24 ++++------ 4 files changed, 60 insertions(+), 50 deletions(-) diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 9e590dc4cf48a82ecdda8b007c38ab15f3b602be..49300a8b7a42df169e2d450d5d835a0f8aa99776 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,6 +1,7 @@ use chrono::Duration; use serde::{Deserialize, Serialize}; use std::{ + borrow::Cow, fmt::{Display, Write as _}, ops::{Add, Range, Sub}, path::Path, @@ -214,6 +215,38 @@ impl Sub for Line { } } +#[derive(Debug, Deserialize, Serialize)] +pub struct RawCompletionRequest { + pub model: String, + pub prompt: String, + pub max_tokens: Option, + pub temperature: Option, + pub stop: Vec>, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct RawCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: RawCompletionUsage, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct RawCompletionChoice { + pub text: String, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RawCompletionUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 9d301846123c47f7451309525919e85670766ac9..39482839c9d3f2e9af39e50476032ea0472ef7a6 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,7 +1,9 @@ use anyhow::Result; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, PromptFormat}; +use cloud_llm_client::predict_edits_v3::{ + self, PromptFormat, RawCompletionRequest, RawCompletionResponse, +}; use cloud_llm_client::{ EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, @@ -1884,13 +1886,13 @@ impl EditPredictionStore { } async fn send_raw_llm_request( - request: open_ai::Request, + request: RawCompletionRequest, client: Arc, llm_token: LlmApiToken, app_version: Version, #[cfg(feature = "cli-support")] eval_cache: Option>, #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind, - ) -> Result<(open_ai::Response, Option)> { + ) -> Result<(RawCompletionResponse, Option)> { let url = client .http_client() .build_zed_llm_url("/predict_edits/raw", &[])?; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 607cec2460e91ea5f48f34816e0d78ba72b755cd..e463f1b6ee30a48001493390ca9c06615e066ef8 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -6,6 +6,9 @@ use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, RejectEditPredictionsBody, + predict_edits_v3::{ + RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage, + }, }; use futures::{ AsyncReadExt, StreamExt, @@ -18,7 +21,6 @@ use gpui::{ use indoc::indoc; use language::Point; use lsp::LanguageServerId; -use open_ai::Usage; use parking_lot::Mutex; use pretty_assertions::{assert_eq, assert_matches}; use project::{FakeFs, Project}; @@ -1325,13 +1327,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { // } // Generate a model response that would apply the given diff to the active file. -fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response { - let prompt = match &request.messages[0] { - open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(content), - } => content, - _ => panic!("unexpected request {request:?}"), - }; +fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse { + let prompt = &request.prompt; let open = "\n"; let close = ""; @@ -1342,20 +1339,16 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re let excerpt = prompt[start_ix..end_ix].replace(cursor, ""); let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); - open_ai::Response { + RawCompletionResponse { id: Uuid::new_v4().to_string(), - object: "response".into(), + object: "text_completion".into(), created: 0, model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(new_excerpt)), - tool_calls: vec![], - }, + choices: vec![RawCompletionChoice { + text: new_excerpt, finish_reason: None, }], - usage: Usage { + usage: RawCompletionUsage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, @@ -1363,23 +1356,13 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re } } -fn prompt_from_request(request: &open_ai::Request) -> &str { - assert_eq!(request.messages.len(), 1); - let open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(content), - .. - } = &request.messages[0] - else { - panic!( - "Request does not have single user message of type Plain. {:#?}", - request - ); - }; - content +fn prompt_from_request(request: &RawCompletionRequest) -> &str { + &request.prompt } struct RequestChannels { - predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, + predict: + mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender)>, reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, } diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index f332155d7abb4331bcb2bb9f5fbebf437dd779de..035bf663461a22462a1a56aaa5e21383edd30a58 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,6 +1,5 @@ #[cfg(feature = "cli-support")] use crate::EvalCacheEntryKind; -use crate::open_ai_response::text_from_response; use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, @@ -8,6 +7,7 @@ use crate::{ EditPredictionStore, }; use anyhow::{Result, anyhow}; +use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; use gpui::{App, Task, prelude::*}; use language::{OffsetRangeExt as _, ToOffset as _, ToPoint}; @@ -75,20 +75,12 @@ pub fn request_prediction_with_zeta2( .ok(); } - let request = open_ai::Request { + let request = RawCompletionRequest { model: EDIT_PREDICTIONS_MODEL_ID.clone(), - messages: vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt), - }], - stream: false, - max_completion_tokens: None, - stop: Default::default(), - temperature: Default::default(), - tool_choice: None, - parallel_tool_calls: None, - tools: vec![], - prompt_cache_key: None, - reasoning_effort: None, + prompt, + temperature: None, + stop: vec![], + max_tokens: None, }; log::trace!("Sending edit prediction request"); @@ -108,9 +100,9 @@ pub fn request_prediction_with_zeta2( log::trace!("Got edit prediction response"); - let (res, usage) = response?; + let (mut res, usage) = response?; let request_id = EditPredictionId(res.id.clone().into()); - let Some(mut output_text) = text_from_response(res) else { + let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else { return Ok((Some((request_id, None)), usage)); };