ep: Use non-chat completions for /predict/raw (#46633)

Agus Zubiaga , Max Brunsfeld , and Oleksiy Syvokon created

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs     | 33 +++++++++++
crates/edit_prediction/src/edit_prediction.rs       |  8 +-
crates/edit_prediction/src/edit_prediction_tests.rs | 45 ++++----------
crates/edit_prediction/src/zeta2.rs                 | 24 ++-----
4 files changed, 60 insertions(+), 50 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -1,6 +1,7 @@
 use chrono::Duration;
 use serde::{Deserialize, Serialize};
 use std::{
+    borrow::Cow,
     fmt::{Display, Write as _},
     ops::{Add, Range, Sub},
     path::Path,
@@ -214,6 +215,38 @@ impl Sub for Line {
     }
 }
 
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionRequest {
+    pub model: String,
+    pub prompt: String,
+    pub max_tokens: Option<u32>,
+    pub temperature: Option<f32>,
+    pub stop: Vec<Cow<'static, str>>,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionResponse {
+    pub id: String,
+    pub object: String,
+    pub created: u64,
+    pub model: String,
+    pub choices: Vec<RawCompletionChoice>,
+    pub usage: RawCompletionUsage,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionChoice {
+    pub text: String,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct RawCompletionUsage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1,7 +1,9 @@
 use anyhow::Result;
 use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{
+    self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
+};
 use cloud_llm_client::{
     EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
     MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
@@ -1884,13 +1886,13 @@ impl EditPredictionStore {
     }
 
     async fn send_raw_llm_request(
-        request: open_ai::Request,
+        request: RawCompletionRequest,
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: Version,
         #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
         #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
-    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
+    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
         let url = client
             .http_client()
             .build_zed_llm_url("/predict_edits/raw", &[])?;

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -6,6 +6,9 @@ use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
 use cloud_llm_client::{
     EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
     RejectEditPredictionsBody,
+    predict_edits_v3::{
+        RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
+    },
 };
 use futures::{
     AsyncReadExt, StreamExt,
@@ -18,7 +21,6 @@ use gpui::{
 use indoc::indoc;
 use language::Point;
 use lsp::LanguageServerId;
-use open_ai::Usage;
 use parking_lot::Mutex;
 use pretty_assertions::{assert_eq, assert_matches};
 use project::{FakeFs, Project};
@@ -1325,13 +1327,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
 // }
 
 // Generate a model response that would apply the given diff to the active file.
-fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
-    let prompt = match &request.messages[0] {
-        open_ai::RequestMessage::User {
-            content: open_ai::MessageContent::Plain(content),
-        } => content,
-        _ => panic!("unexpected request {request:?}"),
-    };
+fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
+    let prompt = &request.prompt;
 
     let open = "<editable_region>\n";
     let close = "</editable_region>";
@@ -1342,20 +1339,16 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re
     let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
     let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
 
-    open_ai::Response {
+    RawCompletionResponse {
         id: Uuid::new_v4().to_string(),
-        object: "response".into(),
+        object: "text_completion".into(),
         created: 0,
         model: "model".into(),
-        choices: vec![open_ai::Choice {
-            index: 0,
-            message: open_ai::RequestMessage::Assistant {
-                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
-                tool_calls: vec![],
-            },
+        choices: vec![RawCompletionChoice {
+            text: new_excerpt,
             finish_reason: None,
         }],
-        usage: Usage {
+        usage: RawCompletionUsage {
             prompt_tokens: 0,
             completion_tokens: 0,
             total_tokens: 0,
@@ -1363,23 +1356,13 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re
     }
 }
 
-fn prompt_from_request(request: &open_ai::Request) -> &str {
-    assert_eq!(request.messages.len(), 1);
-    let open_ai::RequestMessage::User {
-        content: open_ai::MessageContent::Plain(content),
-        ..
-    } = &request.messages[0]
-    else {
-        panic!(
-            "Request does not have single user message of type Plain. {:#?}",
-            request
-        );
-    };
-    content
+fn prompt_from_request(request: &RawCompletionRequest) -> &str {
+    &request.prompt
 }
 
 struct RequestChannels {
-    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
+    predict:
+        mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
     reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
 }
 

crates/edit_prediction/src/zeta2.rs 🔗

@@ -1,6 +1,5 @@
 #[cfg(feature = "cli-support")]
 use crate::EvalCacheEntryKind;
-use crate::open_ai_response::text_from_response;
 use crate::prediction::EditPredictionResult;
 use crate::{
     CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
@@ -8,6 +7,7 @@ use crate::{
     EditPredictionStore,
 };
 use anyhow::{Result, anyhow};
+use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 use gpui::{App, Task, prelude::*};
 use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
@@ -75,20 +75,12 @@ pub fn request_prediction_with_zeta2(
                     .ok();
             }
 
-            let request = open_ai::Request {
+            let request = RawCompletionRequest {
                 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
-                messages: vec![open_ai::RequestMessage::User {
-                    content: open_ai::MessageContent::Plain(prompt),
-                }],
-                stream: false,
-                max_completion_tokens: None,
-                stop: Default::default(),
-                temperature: Default::default(),
-                tool_choice: None,
-                parallel_tool_calls: None,
-                tools: vec![],
-                prompt_cache_key: None,
-                reasoning_effort: None,
+                prompt,
+                temperature: None,
+                stop: vec![],
+                max_tokens: None,
             };
 
             log::trace!("Sending edit prediction request");
@@ -108,9 +100,9 @@ pub fn request_prediction_with_zeta2(
 
             log::trace!("Got edit prediction response");
 
-            let (res, usage) = response?;
+            let (mut res, usage) = response?;
             let request_id = EditPredictionId(res.id.clone().into());
-            let Some(mut output_text) = text_from_response(res) else {
+            let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
                 return Ok((Some((request_id, None)), usage));
             };