1#[cfg(feature = "eval-support")]
2use crate::EvalCacheEntryKind;
3use crate::prediction::EditPredictionResult;
4use crate::{
5 DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
6 EditPredictionRequestedDebugEvent, EditPredictionStore,
7};
8use anyhow::{Result, anyhow, bail};
9use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
10use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
11use cloud_zeta2_prompt::CURSOR_MARKER;
12use edit_prediction_context::{EditPredictionExcerpt, Line};
13use edit_prediction_context::{RelatedExcerpt, RelatedFile};
14use futures::channel::oneshot;
15use gpui::{Entity, Task, prelude::*};
16use language::{Anchor, BufferSnapshot};
17use language::{Buffer, Point, ToOffset as _, ToPoint};
18use project::{Project, ProjectItem as _};
19use release_channel::AppVersion;
20use std::{
21 env,
22 path::Path,
23 sync::Arc,
24 time::{Duration, Instant},
25};
26
27pub fn request_prediction_with_zeta2(
28 store: &mut EditPredictionStore,
29 project: &Entity<Project>,
30 active_buffer: &Entity<Buffer>,
31 active_snapshot: BufferSnapshot,
32 position: Anchor,
33 events: Vec<Arc<Event>>,
34 mut included_files: Vec<RelatedFile>,
35 trigger: PredictEditsRequestTrigger,
36 cx: &mut Context<EditPredictionStore>,
37) -> Task<Result<Option<EditPredictionResult>>> {
38 let options = store.options.clone();
39 let buffer_snapshotted_at = Instant::now();
40
41 let Some((excerpt_path, active_project_path)) = active_snapshot
42 .file()
43 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
44 .zip(active_buffer.read(cx).project_path(cx))
45 else {
46 return Task::ready(Err(anyhow!("No file path for excerpt")));
47 };
48
49 let client = store.client.clone();
50 let llm_token = store.llm_token.clone();
51 let app_version = AppVersion::global(cx);
52 let debug_tx = store.debug_tx.clone();
53
54 let file = active_buffer.read(cx).file();
55
56 let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
57
58 // TODO data collection
59 let can_collect_data = file
60 .as_ref()
61 .map_or(false, |file| store.can_collect_file(project, file, cx));
62
63 #[cfg(feature = "eval-support")]
64 let eval_cache = store.eval_cache.clone();
65
66 let request_task = cx.background_spawn({
67 let active_buffer = active_buffer.clone();
68 async move {
69 let cursor_offset = position.to_offset(&active_snapshot);
70 let cursor_point = cursor_offset.to_point(&active_snapshot);
71
72 let before_retrieval = Instant::now();
73
74 let excerpt_options = options.context;
75
76 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
77 cursor_point,
78 &active_snapshot,
79 &excerpt_options,
80 ) else {
81 return Ok((None, None));
82 };
83
84 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
85 ..active_snapshot.anchor_before(excerpt.range.end);
86 let related_excerpt = RelatedExcerpt {
87 anchor_range: excerpt_anchor_range.clone(),
88 point_range: Point::new(excerpt.line_range.start.0, 0)
89 ..Point::new(excerpt.line_range.end.0, 0),
90 text: active_snapshot.as_rope().slice(excerpt.range),
91 };
92
93 if let Some(buffer_ix) = included_files
94 .iter()
95 .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
96 {
97 let file = &mut included_files[buffer_ix];
98 file.excerpts.push(related_excerpt);
99 file.merge_excerpts();
100 let last_ix = included_files.len() - 1;
101 included_files.swap(buffer_ix, last_ix);
102 } else {
103 let active_file = RelatedFile {
104 path: active_project_path,
105 buffer: active_buffer.downgrade(),
106 excerpts: vec![related_excerpt],
107 max_row: active_snapshot.max_point().row,
108 };
109 included_files.push(active_file);
110 }
111
112 let included_files = included_files
113 .iter()
114 .map(|related_file| predict_edits_v3::RelatedFile {
115 path: Arc::from(related_file.path.path.as_std_path()),
116 max_row: Line(related_file.max_row),
117 excerpts: related_file
118 .excerpts
119 .iter()
120 .map(|excerpt| predict_edits_v3::Excerpt {
121 start_line: Line(excerpt.point_range.start.row),
122 text: excerpt.text.to_string().into(),
123 })
124 .collect(),
125 })
126 .collect::<Vec<_>>();
127
128 let cloud_request = predict_edits_v3::PredictEditsRequest {
129 excerpt_path,
130 excerpt: String::new(),
131 excerpt_line_range: Line(0)..Line(0),
132 excerpt_range: 0..0,
133 cursor_point: predict_edits_v3::Point {
134 line: predict_edits_v3::Line(cursor_point.row),
135 column: cursor_point.column,
136 },
137 related_files: included_files,
138 events,
139 can_collect_data,
140 debug_info: debug_tx.is_some(),
141 prompt_max_bytes: Some(options.max_prompt_bytes),
142 prompt_format: options.prompt_format,
143 excerpt_parent: None,
144 git_info: None,
145 trigger,
146 };
147
148 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
149
150 let inputs = EditPredictionInputs {
151 included_files: cloud_request.related_files,
152 events: cloud_request.events,
153 cursor_point: cloud_request.cursor_point,
154 cursor_path: cloud_request.excerpt_path,
155 };
156
157 let retrieval_time = Instant::now() - before_retrieval;
158
159 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
160 let (response_tx, response_rx) = oneshot::channel();
161
162 debug_tx
163 .unbounded_send(DebugEvent::EditPredictionRequested(
164 EditPredictionRequestedDebugEvent {
165 inputs: inputs.clone(),
166 retrieval_time,
167 buffer: active_buffer.downgrade(),
168 local_prompt: match prompt_result.as_ref() {
169 Ok(prompt) => Ok(prompt.clone()),
170 Err(err) => Err(err.to_string()),
171 },
172 position,
173 response_rx,
174 },
175 ))
176 .ok();
177 Some(response_tx)
178 } else {
179 None
180 };
181
182 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
183 if let Some(debug_response_tx) = debug_response_tx {
184 debug_response_tx
185 .send((Err("Request skipped".to_string()), Duration::ZERO))
186 .ok();
187 }
188 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
189 }
190
191 let prompt = prompt_result?;
192 let generation_params =
193 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
194 let request = open_ai::Request {
195 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
196 messages: vec![open_ai::RequestMessage::User {
197 content: open_ai::MessageContent::Plain(prompt),
198 }],
199 stream: false,
200 max_completion_tokens: None,
201 stop: generation_params.stop.unwrap_or_default(),
202 temperature: generation_params.temperature.unwrap_or(0.7),
203 tool_choice: None,
204 parallel_tool_calls: None,
205 tools: vec![],
206 prompt_cache_key: None,
207 reasoning_effort: None,
208 };
209
210 log::trace!("Sending edit prediction request");
211
212 let before_request = Instant::now();
213 let response = EditPredictionStore::send_raw_llm_request(
214 request,
215 client,
216 llm_token,
217 app_version,
218 #[cfg(feature = "eval-support")]
219 eval_cache,
220 #[cfg(feature = "eval-support")]
221 EvalCacheEntryKind::Prediction,
222 )
223 .await;
224 let received_response_at = Instant::now();
225 let request_time = received_response_at - before_request;
226
227 log::trace!("Got edit prediction response");
228
229 if let Some(debug_response_tx) = debug_response_tx {
230 debug_response_tx
231 .send((
232 response
233 .as_ref()
234 .map_err(|err| err.to_string())
235 .map(|response| response.0.clone()),
236 request_time,
237 ))
238 .ok();
239 }
240
241 let (res, usage) = response?;
242 let request_id = EditPredictionId(res.id.clone().into());
243 let Some(mut output_text) = text_from_response(res) else {
244 return Ok((Some((request_id, None)), usage));
245 };
246
247 if output_text.contains(CURSOR_MARKER) {
248 log::trace!("Stripping out {CURSOR_MARKER} from response");
249 output_text = output_text.replace(CURSOR_MARKER, "");
250 }
251
252 let get_buffer_from_context = |path: &Path| {
253 if Some(path) == active_file_full_path.as_deref() {
254 Some((
255 &active_snapshot,
256 std::slice::from_ref(&excerpt_anchor_range),
257 ))
258 } else {
259 None
260 }
261 };
262
263 let (_, edits) = match options.prompt_format {
264 PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
265 if output_text.contains("--- a/\n+++ b/\nNo edits") {
266 let edits = vec![];
267 (&active_snapshot, edits)
268 } else {
269 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
270 }
271 }
272 PromptFormat::OldTextNewText => {
273 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
274 }
275 _ => {
276 bail!("unsupported prompt format {}", options.prompt_format)
277 }
278 };
279
280 anyhow::Ok((
281 Some((
282 request_id,
283 Some((
284 inputs,
285 active_buffer,
286 active_snapshot.clone(),
287 edits,
288 received_response_at,
289 )),
290 )),
291 usage,
292 ))
293 }
294 });
295
296 cx.spawn(async move |this, cx| {
297 let Some((id, prediction)) =
298 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
299 else {
300 return Ok(None);
301 };
302
303 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
304 prediction
305 else {
306 return Ok(Some(EditPredictionResult {
307 id,
308 prediction: Err(EditPredictionRejectReason::Empty),
309 }));
310 };
311
312 Ok(Some(
313 EditPredictionResult::new(
314 id,
315 &edited_buffer,
316 &edited_buffer_snapshot,
317 edits.into(),
318 buffer_snapshotted_at,
319 received_response_at,
320 inputs,
321 cx,
322 )
323 .await,
324 ))
325 })
326}
327
328pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
329 let choice = res.choices.pop()?;
330 let output_text = match choice.message {
331 open_ai::RequestMessage::Assistant {
332 content: Some(open_ai::MessageContent::Plain(content)),
333 ..
334 } => content,
335 open_ai::RequestMessage::Assistant {
336 content: Some(open_ai::MessageContent::Multipart(mut content)),
337 ..
338 } => {
339 if content.is_empty() {
340 log::error!("No output from Baseten completion response");
341 return None;
342 }
343
344 match content.remove(0) {
345 open_ai::MessagePart::Text { text } => text,
346 open_ai::MessagePart::Image { .. } => {
347 log::error!("Expected text, got an image");
348 return None;
349 }
350 }
351 }
352 _ => {
353 log::error!("Invalid response message: {:?}", choice.message);
354 return None;
355 }
356 };
357 Some(output_text)
358}