1#[cfg(feature = "eval-support")]
2use crate::EvalCacheEntryKind;
3use crate::open_ai_response::text_from_response;
4use crate::prediction::EditPredictionResult;
5use crate::{
6 DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
7 EditPredictionRequestedDebugEvent, EditPredictionStore,
8};
9use anyhow::{Result, anyhow, bail};
10use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
11use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
12use cloud_zeta2_prompt::CURSOR_MARKER;
13use edit_prediction_context::{EditPredictionExcerpt, Line};
14use edit_prediction_context::{RelatedExcerpt, RelatedFile};
15use futures::channel::oneshot;
16use gpui::{Entity, Task, prelude::*};
17use language::{Anchor, BufferSnapshot};
18use language::{Buffer, Point, ToOffset as _, ToPoint};
19use project::{Project, ProjectItem as _};
20use release_channel::AppVersion;
21use std::{
22 env,
23 path::Path,
24 sync::Arc,
25 time::{Duration, Instant},
26};
27
28pub fn request_prediction_with_zeta2(
29 store: &mut EditPredictionStore,
30 project: &Entity<Project>,
31 active_buffer: &Entity<Buffer>,
32 active_snapshot: BufferSnapshot,
33 position: Anchor,
34 events: Vec<Arc<Event>>,
35 mut included_files: Vec<RelatedFile>,
36 trigger: PredictEditsRequestTrigger,
37 cx: &mut Context<EditPredictionStore>,
38) -> Task<Result<Option<EditPredictionResult>>> {
39 let options = store.options.clone();
40 let buffer_snapshotted_at = Instant::now();
41
42 let Some((excerpt_path, active_project_path)) = active_snapshot
43 .file()
44 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
45 .zip(active_buffer.read(cx).project_path(cx))
46 else {
47 return Task::ready(Err(anyhow!("No file path for excerpt")));
48 };
49
50 let client = store.client.clone();
51 let llm_token = store.llm_token.clone();
52 let app_version = AppVersion::global(cx);
53 let debug_tx = store.debug_tx.clone();
54
55 let file = active_buffer.read(cx).file();
56
57 let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
58
59 // TODO data collection
60 let can_collect_data = file
61 .as_ref()
62 .map_or(false, |file| store.can_collect_file(project, file, cx));
63
64 #[cfg(feature = "eval-support")]
65 let eval_cache = store.eval_cache.clone();
66
67 let request_task = cx.background_spawn({
68 let active_buffer = active_buffer.clone();
69 async move {
70 let cursor_offset = position.to_offset(&active_snapshot);
71 let cursor_point = cursor_offset.to_point(&active_snapshot);
72
73 let before_retrieval = Instant::now();
74
75 let excerpt_options = options.context;
76
77 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
78 cursor_point,
79 &active_snapshot,
80 &excerpt_options,
81 ) else {
82 return Ok((None, None));
83 };
84
85 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
86 ..active_snapshot.anchor_before(excerpt.range.end);
87 let related_excerpt = RelatedExcerpt {
88 anchor_range: excerpt_anchor_range.clone(),
89 point_range: Point::new(excerpt.line_range.start.0, 0)
90 ..Point::new(excerpt.line_range.end.0, 0),
91 text: active_snapshot.as_rope().slice(excerpt.range),
92 };
93
94 if let Some(buffer_ix) = included_files
95 .iter()
96 .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
97 {
98 let file = &mut included_files[buffer_ix];
99 file.excerpts.push(related_excerpt);
100 file.merge_excerpts();
101 let last_ix = included_files.len() - 1;
102 included_files.swap(buffer_ix, last_ix);
103 } else {
104 let active_file = RelatedFile {
105 path: active_project_path,
106 buffer: active_buffer.downgrade(),
107 excerpts: vec![related_excerpt],
108 max_row: active_snapshot.max_point().row,
109 };
110 included_files.push(active_file);
111 }
112
113 let included_files = included_files
114 .iter()
115 .map(|related_file| predict_edits_v3::RelatedFile {
116 path: Arc::from(related_file.path.path.as_std_path()),
117 max_row: Line(related_file.max_row),
118 excerpts: related_file
119 .excerpts
120 .iter()
121 .map(|excerpt| predict_edits_v3::Excerpt {
122 start_line: Line(excerpt.point_range.start.row),
123 text: excerpt.text.to_string().into(),
124 })
125 .collect(),
126 })
127 .collect::<Vec<_>>();
128
129 let cloud_request = predict_edits_v3::PredictEditsRequest {
130 excerpt_path,
131 excerpt: String::new(),
132 excerpt_line_range: Line(0)..Line(0),
133 excerpt_range: 0..0,
134 cursor_point: predict_edits_v3::Point {
135 line: predict_edits_v3::Line(cursor_point.row),
136 column: cursor_point.column,
137 },
138 related_files: included_files,
139 events,
140 can_collect_data,
141 debug_info: debug_tx.is_some(),
142 prompt_max_bytes: Some(options.max_prompt_bytes),
143 prompt_format: options.prompt_format,
144 excerpt_parent: None,
145 git_info: None,
146 trigger,
147 };
148
149 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
150
151 let inputs = EditPredictionInputs {
152 included_files: cloud_request.related_files,
153 events: cloud_request.events,
154 cursor_point: cloud_request.cursor_point,
155 cursor_path: cloud_request.excerpt_path,
156 };
157
158 let retrieval_time = Instant::now() - before_retrieval;
159
160 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
161 let (response_tx, response_rx) = oneshot::channel();
162
163 debug_tx
164 .unbounded_send(DebugEvent::EditPredictionRequested(
165 EditPredictionRequestedDebugEvent {
166 inputs: inputs.clone(),
167 retrieval_time,
168 buffer: active_buffer.downgrade(),
169 local_prompt: match prompt_result.as_ref() {
170 Ok(prompt) => Ok(prompt.clone()),
171 Err(err) => Err(err.to_string()),
172 },
173 position,
174 response_rx,
175 },
176 ))
177 .ok();
178 Some(response_tx)
179 } else {
180 None
181 };
182
183 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
184 if let Some(debug_response_tx) = debug_response_tx {
185 debug_response_tx
186 .send((Err("Request skipped".to_string()), Duration::ZERO))
187 .ok();
188 }
189 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
190 }
191
192 let prompt = prompt_result?;
193 let generation_params =
194 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
195 let request = open_ai::Request {
196 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
197 messages: vec![open_ai::RequestMessage::User {
198 content: open_ai::MessageContent::Plain(prompt),
199 }],
200 stream: false,
201 max_completion_tokens: None,
202 stop: generation_params.stop.unwrap_or_default(),
203 temperature: generation_params.temperature.or(Some(0.7)),
204 tool_choice: None,
205 parallel_tool_calls: None,
206 tools: vec![],
207 prompt_cache_key: None,
208 reasoning_effort: None,
209 };
210
211 log::trace!("Sending edit prediction request");
212
213 let before_request = Instant::now();
214 let response = EditPredictionStore::send_raw_llm_request(
215 request,
216 client,
217 llm_token,
218 app_version,
219 #[cfg(feature = "eval-support")]
220 eval_cache,
221 #[cfg(feature = "eval-support")]
222 EvalCacheEntryKind::Prediction,
223 )
224 .await;
225 let received_response_at = Instant::now();
226 let request_time = received_response_at - before_request;
227
228 log::trace!("Got edit prediction response");
229
230 if let Some(debug_response_tx) = debug_response_tx {
231 debug_response_tx
232 .send((
233 response
234 .as_ref()
235 .map_err(|err| err.to_string())
236 .map(|response| response.0.clone()),
237 request_time,
238 ))
239 .ok();
240 }
241
242 let (res, usage) = response?;
243 let request_id = EditPredictionId(res.id.clone().into());
244 let Some(mut output_text) = text_from_response(res) else {
245 return Ok((Some((request_id, None)), usage));
246 };
247
248 if output_text.contains(CURSOR_MARKER) {
249 log::trace!("Stripping out {CURSOR_MARKER} from response");
250 output_text = output_text.replace(CURSOR_MARKER, "");
251 }
252
253 let get_buffer_from_context = |path: &Path| {
254 if Some(path) == active_file_full_path.as_deref() {
255 Some((
256 &active_snapshot,
257 std::slice::from_ref(&excerpt_anchor_range),
258 ))
259 } else {
260 None
261 }
262 };
263
264 let (_, edits) = match options.prompt_format {
265 PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
266 if output_text.contains("--- a/\n+++ b/\nNo edits") {
267 let edits = vec![];
268 (&active_snapshot, edits)
269 } else {
270 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
271 }
272 }
273 PromptFormat::OldTextNewText => {
274 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
275 }
276 _ => {
277 bail!("unsupported prompt format {}", options.prompt_format)
278 }
279 };
280
281 anyhow::Ok((
282 Some((
283 request_id,
284 Some((
285 inputs,
286 active_buffer,
287 active_snapshot.clone(),
288 edits,
289 received_response_at,
290 )),
291 )),
292 usage,
293 ))
294 }
295 });
296
297 cx.spawn(async move |this, cx| {
298 let Some((id, prediction)) =
299 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
300 else {
301 return Ok(None);
302 };
303
304 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
305 prediction
306 else {
307 return Ok(Some(EditPredictionResult {
308 id,
309 prediction: Err(EditPredictionRejectReason::Empty),
310 }));
311 };
312
313 Ok(Some(
314 EditPredictionResult::new(
315 id,
316 &edited_buffer,
317 &edited_buffer_snapshot,
318 edits.into(),
319 buffer_snapshotted_at,
320 received_response_at,
321 inputs,
322 cx,
323 )
324 .await,
325 ))
326 })
327}