Detailed changes
@@ -1,6 +1,7 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{
+ borrow::Cow,
fmt::{Display, Write as _},
ops::{Add, Range, Sub},
path::Path,
@@ -214,6 +215,38 @@ impl Sub for Line {
}
}
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionRequest {
+ pub model: String,
+ pub prompt: String,
+ pub max_tokens: Option<u32>,
+ pub temperature: Option<f32>,
+ pub stop: Vec<Cow<'static, str>>,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionResponse {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec<RawCompletionChoice>,
+ pub usage: RawCompletionUsage,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct RawCompletionChoice {
+ pub text: String,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct RawCompletionUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1,7 +1,9 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{
+ self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
+};
use cloud_llm_client::{
EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
@@ -1884,13 +1886,13 @@ impl EditPredictionStore {
}
async fn send_raw_llm_request(
- request: open_ai::Request,
+ request: RawCompletionRequest,
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
- ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
+ ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
let url = client
.http_client()
.build_zed_llm_url("/predict_edits/raw", &[])?;
@@ -6,6 +6,9 @@ use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
+ predict_edits_v3::{
+ RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
+ },
};
use futures::{
AsyncReadExt, StreamExt,
@@ -18,7 +21,6 @@ use gpui::{
use indoc::indoc;
use language::Point;
use lsp::LanguageServerId;
-use open_ai::Usage;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
@@ -1325,13 +1327,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// }
// Generate a model response that would apply the given diff to the active file.
-fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
- let prompt = match &request.messages[0] {
- open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(content),
- } => content,
- _ => panic!("unexpected request {request:?}"),
- };
+fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
+ let prompt = &request.prompt;
let open = "<editable_region>\n";
let close = "</editable_region>";
@@ -1342,20 +1339,16 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
- open_ai::Response {
+ RawCompletionResponse {
id: Uuid::new_v4().to_string(),
- object: "response".into(),
+ object: "text_completion".into(),
created: 0,
model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(new_excerpt)),
- tool_calls: vec![],
- },
+ choices: vec![RawCompletionChoice {
+ text: new_excerpt,
finish_reason: None,
}],
- usage: Usage {
+ usage: RawCompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
@@ -1363,23 +1356,13 @@ fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Re
}
}
-fn prompt_from_request(request: &open_ai::Request) -> &str {
- assert_eq!(request.messages.len(), 1);
- let open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(content),
- ..
- } = &request.messages[0]
- else {
- panic!(
- "Request does not have single user message of type Plain. {:#?}",
- request
- );
- };
- content
+fn prompt_from_request(request: &RawCompletionRequest) -> &str {
+ &request.prompt
}
struct RequestChannels {
- predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
+ predict:
+ mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
}
@@ -1,6 +1,5 @@
#[cfg(feature = "cli-support")]
use crate::EvalCacheEntryKind;
-use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
@@ -8,6 +7,7 @@ use crate::{
EditPredictionStore,
};
use anyhow::{Result, anyhow};
+use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
use gpui::{App, Task, prelude::*};
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
@@ -75,20 +75,12 @@ pub fn request_prediction_with_zeta2(
.ok();
}
- let request = open_ai::Request {
+ let request = RawCompletionRequest {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: Default::default(),
- temperature: Default::default(),
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![],
- prompt_cache_key: None,
- reasoning_effort: None,
+ prompt,
+ temperature: None,
+ stop: vec![],
+ max_tokens: None,
};
log::trace!("Sending edit prediction request");
@@ -108,9 +100,9 @@ pub fn request_prediction_with_zeta2(
log::trace!("Got edit prediction response");
- let (res, usage) = response?;
+ let (mut res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
- let Some(mut output_text) = text_from_response(res) else {
+ let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
return Ok((Some((request_id, None)), usage));
};