1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
6};
7use anyhow::Result;
8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use gpui::{App, AppContext as _, Task, prelude::*};
12use language::language_settings::all_language_settings;
13use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
14use release_channel::AppVersion;
15use settings::EditPredictionPromptFormat;
16use text::{Anchor, Bias};
17
18use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
19use zeta_prompt::{
20 CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
21 output_with_context_for_format, prompt_input_contains_special_tokens,
22 zeta1::{self, EDITABLE_REGION_END_MARKER},
23};
24
25use crate::open_ai_compatible::{
26 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
27};
28
29pub fn request_prediction_with_zeta(
30 store: &mut EditPredictionStore,
31 EditPredictionModelInput {
32 buffer,
33 snapshot,
34 position,
35 related_files,
36 events,
37 debug_tx,
38 trigger,
39 project,
40 can_collect_data,
41 is_open_source,
42 ..
43 }: EditPredictionModelInput,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let settings = &all_language_settings(None, cx).edit_predictions;
47 let provider = settings.provider;
48 let custom_server_settings = match provider {
49 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
50 settings::EditPredictionProvider::OpenAiCompatibleApi => {
51 settings.open_ai_compatible_api.clone()
52 }
53 _ => None,
54 };
55
56 let http_client = cx.http_client();
57 let buffer_snapshotted_at = Instant::now();
58 let raw_config = store.zeta2_raw_config().cloned();
59 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
60 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
61
62 let excerpt_path: Arc<Path> = snapshot
63 .file()
64 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66
67 let client = store.client.clone();
68 let llm_token = store.llm_token.clone();
69 let organization_id = store
70 .user_store
71 .read(cx)
72 .current_organization()
73 .map(|organization| organization.id.clone());
74 let app_version = AppVersion::global(cx);
75
76 let request_task = cx.background_spawn({
77 async move {
78 let zeta_version = raw_config
79 .as_ref()
80 .map(|config| config.format)
81 .unwrap_or(ZetaFormat::default());
82
83 let cursor_offset = position.to_offset(&snapshot);
84 let editable_range_in_excerpt: Range<usize>;
85 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
86 &snapshot,
87 related_files,
88 events,
89 excerpt_path,
90 cursor_offset,
91 preferred_experiment,
92 is_open_source,
93 can_collect_data,
94 );
95
96 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
97 return Ok((None, None));
98 }
99
100 if let Some(debug_tx) = &debug_tx {
101 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
102 debug_tx
103 .unbounded_send(DebugEvent::EditPredictionStarted(
104 EditPredictionStartedDebugEvent {
105 buffer: buffer.downgrade(),
106 prompt: Some(prompt),
107 position,
108 },
109 ))
110 .ok();
111 }
112
113 log::trace!("Sending edit prediction request");
114
115 let (request_id, output_text, model_version, usage) =
116 if let Some(custom_settings) = &custom_server_settings {
117 let max_tokens = custom_settings.max_output_tokens * 4;
118
119 match custom_settings.prompt_format {
120 EditPredictionPromptFormat::Zeta => {
121 let ranges = &prompt_input.excerpt_ranges;
122 let prompt = zeta1::format_zeta1_from_input(
123 &prompt_input,
124 ranges.editable_350.clone(),
125 ranges.editable_350_context_150.clone(),
126 );
127 editable_range_in_excerpt = ranges.editable_350.clone();
128 let stop_tokens = vec![
129 EDITABLE_REGION_END_MARKER.to_string(),
130 format!("{EDITABLE_REGION_END_MARKER}\n"),
131 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
132 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
133 ];
134
135 let (response_text, request_id) = send_custom_server_request(
136 provider,
137 custom_settings,
138 prompt,
139 max_tokens,
140 stop_tokens,
141 open_ai_compatible_api_key.clone(),
142 &http_client,
143 )
144 .await?;
145
146 let request_id = EditPredictionId(request_id.into());
147 let output_text = zeta1::clean_zeta1_model_output(&response_text);
148
149 (request_id, output_text, None, None)
150 }
151 EditPredictionPromptFormat::Zeta2 => {
152 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
153 let prefill = get_prefill(&prompt_input, zeta_version);
154 let prompt = format!("{prompt}{prefill}");
155
156 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
157 zeta_version,
158 &prompt_input.excerpt_ranges,
159 )
160 .0;
161
162 let (response_text, request_id) = send_custom_server_request(
163 provider,
164 custom_settings,
165 prompt,
166 max_tokens,
167 vec![],
168 open_ai_compatible_api_key.clone(),
169 &http_client,
170 )
171 .await?;
172
173 let request_id = EditPredictionId(request_id.into());
174 let output_text = if response_text.is_empty() {
175 None
176 } else {
177 let output = format!("{prefill}{response_text}");
178 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
179 };
180
181 (request_id, output_text, None, None)
182 }
183 _ => anyhow::bail!("unsupported prompt format"),
184 }
185 } else if let Some(config) = &raw_config {
186 let prompt = format_zeta_prompt(&prompt_input, config.format);
187 let prefill = get_prefill(&prompt_input, config.format);
188 let prompt = format!("{prompt}{prefill}");
189 let environment = config
190 .environment
191 .clone()
192 .or_else(|| Some(config.format.to_string().to_lowercase()));
193 let request = RawCompletionRequest {
194 model: config.model_id.clone().unwrap_or_default(),
195 prompt,
196 temperature: None,
197 stop: vec![],
198 max_tokens: Some(2048),
199 environment,
200 };
201
202 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
203 config.format,
204 &prompt_input.excerpt_ranges,
205 )
206 .1;
207
208 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
209 request,
210 client,
211 None,
212 llm_token,
213 organization_id,
214 app_version,
215 )
216 .await?;
217
218 let request_id = EditPredictionId(response.id.clone().into());
219 let output_text = response.choices.pop().map(|choice| {
220 let response = &choice.text;
221 let output = format!("{prefill}{response}");
222 clean_zeta2_model_output(&output, config.format).to_string()
223 });
224
225 (request_id, output_text, None, usage)
226 } else {
227 // Use V3 endpoint - server handles model/version selection and suffix stripping
228 let (response, usage) = EditPredictionStore::send_v3_request(
229 prompt_input.clone(),
230 client,
231 llm_token,
232 organization_id,
233 app_version,
234 trigger,
235 )
236 .await?;
237
238 let request_id = EditPredictionId(response.request_id.into());
239 let output_text = if response.output.is_empty() {
240 None
241 } else {
242 Some(response.output)
243 };
244 editable_range_in_excerpt = response.editable_range;
245 let model_version = response.model_version;
246
247 (request_id, output_text, model_version, usage)
248 };
249
250 let received_response_at = Instant::now();
251
252 log::trace!("Got edit prediction response");
253
254 let Some(mut output_text) = output_text else {
255 return Ok((Some((request_id, None, model_version)), usage));
256 };
257
258 let editable_range_in_buffer = editable_range_in_excerpt.start
259 + full_context_offset_range.start
260 ..editable_range_in_excerpt.end + full_context_offset_range.start;
261
262 let mut old_text = snapshot
263 .text_for_range(editable_range_in_buffer.clone())
264 .collect::<String>();
265
266 // For the hashline format, the model may return <|set|>/<|insert|>
267 // edit commands instead of a full replacement. Apply them against
268 // the original editable region to produce the full replacement text.
269 // This must happen before cursor marker stripping because the cursor
270 // marker is embedded inside edit command content.
271 if let Some(rewritten_output) =
272 output_with_context_for_format(zeta_version, &old_text, &output_text)?
273 {
274 output_text = rewritten_output;
275 }
276
277 // Client-side cursor marker processing (applies to both raw and v3 responses)
278 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
279 if let Some(offset) = cursor_offset_in_output {
280 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
281 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
282 }
283
284 if let Some(debug_tx) = &debug_tx {
285 debug_tx
286 .unbounded_send(DebugEvent::EditPredictionFinished(
287 EditPredictionFinishedDebugEvent {
288 buffer: buffer.downgrade(),
289 position,
290 model_output: Some(output_text.clone()),
291 },
292 ))
293 .ok();
294 }
295
296 if !output_text.is_empty() && !output_text.ends_with('\n') {
297 output_text.push('\n');
298 }
299 if !old_text.is_empty() && !old_text.ends_with('\n') {
300 old_text.push('\n');
301 }
302
303 let (edits, cursor_position) = compute_edits_and_cursor_position(
304 old_text,
305 &output_text,
306 editable_range_in_buffer.start,
307 cursor_offset_in_output,
308 &snapshot,
309 );
310
311 anyhow::Ok((
312 Some((
313 request_id,
314 Some((
315 prompt_input,
316 buffer,
317 snapshot.clone(),
318 edits,
319 cursor_position,
320 received_response_at,
321 editable_range_in_buffer,
322 )),
323 model_version,
324 )),
325 usage,
326 ))
327 }
328 });
329
330 cx.spawn(async move |this, cx| {
331 let Some((id, prediction, model_version)) =
332 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
333 else {
334 return Ok(None);
335 };
336
337 let Some((
338 inputs,
339 edited_buffer,
340 edited_buffer_snapshot,
341 edits,
342 cursor_position,
343 received_response_at,
344 editable_range_in_buffer,
345 )) = prediction
346 else {
347 return Ok(Some(EditPredictionResult {
348 id,
349 prediction: Err(EditPredictionRejectReason::Empty),
350 }));
351 };
352
353 if can_collect_data {
354 this.update(cx, |this, cx| {
355 this.enqueue_settled_prediction(
356 id.clone(),
357 &project,
358 &edited_buffer,
359 &edited_buffer_snapshot,
360 editable_range_in_buffer,
361 cx,
362 );
363 })
364 .ok();
365 }
366
367 Ok(Some(
368 EditPredictionResult::new(
369 id,
370 &edited_buffer,
371 &edited_buffer_snapshot,
372 edits.into(),
373 cursor_position,
374 buffer_snapshotted_at,
375 received_response_at,
376 inputs,
377 model_version,
378 cx,
379 )
380 .await,
381 ))
382 })
383}
384
385pub fn zeta2_prompt_input(
386 snapshot: &language::BufferSnapshot,
387 related_files: Vec<zeta_prompt::RelatedFile>,
388 events: Vec<Arc<zeta_prompt::Event>>,
389 excerpt_path: Arc<Path>,
390 cursor_offset: usize,
391 preferred_experiment: Option<String>,
392 is_open_source: bool,
393 can_collect_data: bool,
394) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
395 let cursor_point = cursor_offset.to_point(snapshot);
396
397 let (full_context, full_context_offset_range, excerpt_ranges) =
398 compute_excerpt_ranges(cursor_point, snapshot);
399
400 let related_files = crate::filter_redundant_excerpts(
401 related_files,
402 excerpt_path.as_ref(),
403 full_context.start.row..full_context.end.row,
404 );
405
406 let full_context_start_offset = full_context_offset_range.start;
407 let full_context_start_row = full_context.start.row;
408
409 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
410
411 let prompt_input = zeta_prompt::ZetaPromptInput {
412 cursor_path: excerpt_path,
413 cursor_excerpt: snapshot
414 .text_for_range(full_context)
415 .collect::<String>()
416 .into(),
417 cursor_offset_in_excerpt,
418 excerpt_start_row: Some(full_context_start_row),
419 events,
420 related_files,
421 excerpt_ranges,
422 experiment: preferred_experiment,
423 in_open_source_repo: is_open_source,
424 can_collect_data,
425 };
426 (full_context_offset_range, prompt_input)
427}
428
429pub(crate) fn edit_prediction_accepted(
430 store: &EditPredictionStore,
431 current_prediction: CurrentEditPrediction,
432 cx: &App,
433) {
434 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
435 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
436 return;
437 }
438
439 let request_id = current_prediction.prediction.id.to_string();
440 let model_version = current_prediction.prediction.model_version;
441 let require_auth = custom_accept_url.is_none();
442 let client = store.client.clone();
443 let llm_token = store.llm_token.clone();
444 let organization_id = store
445 .user_store
446 .read(cx)
447 .current_organization()
448 .map(|organization| organization.id.clone());
449 let app_version = AppVersion::global(cx);
450
451 cx.background_spawn(async move {
452 let url = if let Some(accept_edits_url) = custom_accept_url {
453 gpui::http_client::Url::parse(&accept_edits_url)?
454 } else {
455 client
456 .http_client()
457 .build_zed_llm_url("/predict_edits/accept", &[])?
458 };
459
460 let response = EditPredictionStore::send_api_request::<()>(
461 move |builder| {
462 let req = builder.uri(url.as_ref()).body(
463 serde_json::to_string(&AcceptEditPredictionBody {
464 request_id: request_id.clone(),
465 model_version: model_version.clone(),
466 })?
467 .into(),
468 );
469 Ok(req?)
470 },
471 client,
472 llm_token,
473 organization_id,
474 app_version,
475 require_auth,
476 )
477 .await;
478
479 response?;
480 anyhow::Ok(())
481 })
482 .detach_and_log_err(cx);
483}
484
485pub fn compute_edits(
486 old_text: String,
487 new_text: &str,
488 offset: usize,
489 snapshot: &BufferSnapshot,
490) -> Vec<(Range<Anchor>, Arc<str>)> {
491 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
492}
493
494pub fn compute_edits_and_cursor_position(
495 old_text: String,
496 new_text: &str,
497 offset: usize,
498 cursor_offset_in_new_text: Option<usize>,
499 snapshot: &BufferSnapshot,
500) -> (
501 Vec<(Range<Anchor>, Arc<str>)>,
502 Option<PredictedCursorPosition>,
503) {
504 let diffs = text_diff(&old_text, new_text);
505
506 // Delta represents the cumulative change in byte count from all preceding edits.
507 // new_offset = old_offset + delta, so old_offset = new_offset - delta
508 let mut delta: isize = 0;
509 let mut cursor_position: Option<PredictedCursorPosition> = None;
510 let buffer_len = snapshot.len();
511
512 let edits = diffs
513 .iter()
514 .map(|(raw_old_range, new_text)| {
515 // Compute cursor position if it falls within or before this edit.
516 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
517 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
518 let edit_end_in_new = edit_start_in_new + new_text.len();
519
520 if cursor_offset < edit_start_in_new {
521 let cursor_in_old = (cursor_offset as isize - delta) as usize;
522 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
523 cursor_position = Some(PredictedCursorPosition::at_anchor(
524 snapshot.anchor_after(buffer_offset),
525 ));
526 } else if cursor_offset < edit_end_in_new {
527 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
528 let offset_within_insertion = cursor_offset - edit_start_in_new;
529 cursor_position = Some(PredictedCursorPosition::new(
530 snapshot.anchor_before(buffer_offset),
531 offset_within_insertion,
532 ));
533 }
534
535 delta += new_text.len() as isize - raw_old_range.len() as isize;
536 }
537
538 // Compute the edit with prefix/suffix trimming.
539 let mut old_range = raw_old_range.clone();
540 let old_slice = &old_text[old_range.clone()];
541
542 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
543 let suffix_len = common_prefix(
544 old_slice[prefix_len..].chars().rev(),
545 new_text[prefix_len..].chars().rev(),
546 );
547
548 old_range.start += offset;
549 old_range.end += offset;
550 old_range.start += prefix_len;
551 old_range.end -= suffix_len;
552
553 old_range.start = old_range.start.min(buffer_len);
554 old_range.end = old_range.end.min(buffer_len);
555
556 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
557 let range = if old_range.is_empty() {
558 let anchor = snapshot.anchor_after(old_range.start);
559 anchor..anchor
560 } else {
561 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
562 };
563 (range, new_text)
564 })
565 .collect();
566
567 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
568 let cursor_in_old = (cursor_offset as isize - delta) as usize;
569 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
570 cursor_position = Some(PredictedCursorPosition::at_anchor(
571 snapshot.anchor_after(buffer_offset),
572 ));
573 }
574
575 (edits, cursor_position)
576}
577
578fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
579 a.zip(b)
580 .take_while(|(a, b)| a == b)
581 .map(|(a, _)| a.len_utf8())
582 .sum()
583}