1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
6};
7use anyhow::Result;
8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use gpui::{App, AppContext as _, Task, prelude::*};
12use language::language_settings::all_language_settings;
13use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
14use release_channel::AppVersion;
15use settings::EditPredictionPromptFormat;
16use text::{Anchor, Bias};
17
18use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
19use zeta_prompt::{
20 CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
21 output_with_context_for_format, prompt_input_contains_special_tokens,
22 zeta1::{self, EDITABLE_REGION_END_MARKER},
23};
24
25use crate::open_ai_compatible::{
26 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
27};
28
29pub fn request_prediction_with_zeta(
30 store: &mut EditPredictionStore,
31 EditPredictionModelInput {
32 buffer,
33 snapshot,
34 position,
35 related_files,
36 events,
37 debug_tx,
38 trigger,
39 project,
40 can_collect_data,
41 is_open_source,
42 ..
43 }: EditPredictionModelInput,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let settings = &all_language_settings(None, cx).edit_predictions;
47 let provider = settings.provider;
48 let custom_server_settings = match provider {
49 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
50 settings::EditPredictionProvider::OpenAiCompatibleApi => {
51 settings.open_ai_compatible_api.clone()
52 }
53 _ => None,
54 };
55
56 let http_client = cx.http_client();
57 let buffer_snapshotted_at = Instant::now();
58 let raw_config = store.zeta2_raw_config().cloned();
59 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
60 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
61
62 let excerpt_path: Arc<Path> = snapshot
63 .file()
64 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66
67 let client = store.client.clone();
68 let llm_token = store.llm_token.clone();
69 let organization_id = store
70 .user_store
71 .read(cx)
72 .current_organization()
73 .map(|organization| organization.id.clone());
74 let app_version = AppVersion::global(cx);
75
76 let request_task = cx.background_spawn({
77 async move {
78 let zeta_version = raw_config
79 .as_ref()
80 .map(|config| config.format)
81 .unwrap_or(ZetaFormat::default());
82
83 let cursor_offset = position.to_offset(&snapshot);
84 let editable_range_in_excerpt: Range<usize>;
85 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
86 &snapshot,
87 related_files,
88 events,
89 excerpt_path,
90 cursor_offset,
91 preferred_experiment,
92 is_open_source,
93 can_collect_data,
94 );
95
96 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
97 return Ok((None, None));
98 }
99
100 if let Some(debug_tx) = &debug_tx {
101 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
102 debug_tx
103 .unbounded_send(DebugEvent::EditPredictionStarted(
104 EditPredictionStartedDebugEvent {
105 buffer: buffer.downgrade(),
106 prompt: Some(prompt),
107 position,
108 },
109 ))
110 .ok();
111 }
112
113 log::trace!("Sending edit prediction request");
114
115 let (request_id, output_text, model_version, usage) =
116 if let Some(custom_settings) = &custom_server_settings {
117 let max_tokens = custom_settings.max_output_tokens * 4;
118
119 match custom_settings.prompt_format {
120 EditPredictionPromptFormat::Zeta => {
121 let ranges = &prompt_input.excerpt_ranges;
122 let prompt = zeta1::format_zeta1_from_input(
123 &prompt_input,
124 ranges.editable_350.clone(),
125 ranges.editable_350_context_150.clone(),
126 );
127 editable_range_in_excerpt = ranges.editable_350.clone();
128 let stop_tokens = vec![
129 EDITABLE_REGION_END_MARKER.to_string(),
130 format!("{EDITABLE_REGION_END_MARKER}\n"),
131 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
132 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
133 ];
134
135 let (response_text, request_id) = send_custom_server_request(
136 provider,
137 custom_settings,
138 prompt,
139 max_tokens,
140 stop_tokens,
141 open_ai_compatible_api_key.clone(),
142 &http_client,
143 )
144 .await?;
145
146 let request_id = EditPredictionId(request_id.into());
147 let output_text = zeta1::clean_zeta1_model_output(&response_text);
148
149 (request_id, output_text, None, None)
150 }
151 EditPredictionPromptFormat::Zeta2 => {
152 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
153 let prefill = get_prefill(&prompt_input, zeta_version);
154 let prompt = format!("{prompt}{prefill}");
155
156 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
157 zeta_version,
158 &prompt_input.excerpt_ranges,
159 )
160 .0;
161
162 let (response_text, request_id) = send_custom_server_request(
163 provider,
164 custom_settings,
165 prompt,
166 max_tokens,
167 vec![],
168 open_ai_compatible_api_key.clone(),
169 &http_client,
170 )
171 .await?;
172
173 let request_id = EditPredictionId(request_id.into());
174 let output_text = if response_text.is_empty() {
175 None
176 } else {
177 let output = format!("{prefill}{response_text}");
178 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
179 };
180
181 (request_id, output_text, None, None)
182 }
183 _ => anyhow::bail!("unsupported prompt format"),
184 }
185 } else if let Some(config) = &raw_config {
186 let prompt = format_zeta_prompt(&prompt_input, config.format);
187 let prefill = get_prefill(&prompt_input, config.format);
188 let prompt = format!("{prompt}{prefill}");
189 let environment = config
190 .environment
191 .clone()
192 .or_else(|| Some(config.format.to_string().to_lowercase()));
193 let request = RawCompletionRequest {
194 model: config.model_id.clone().unwrap_or_default(),
195 prompt,
196 temperature: None,
197 stop: vec![],
198 max_tokens: Some(2048),
199 environment,
200 };
201
202 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
203 config.format,
204 &prompt_input.excerpt_ranges,
205 )
206 .1;
207
208 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
209 request,
210 client,
211 None,
212 llm_token,
213 organization_id,
214 app_version,
215 )
216 .await?;
217
218 let request_id = EditPredictionId(response.id.clone().into());
219 let output_text = response.choices.pop().map(|choice| {
220 let response = &choice.text;
221 let output = format!("{prefill}{response}");
222 clean_zeta2_model_output(&output, config.format).to_string()
223 });
224
225 (request_id, output_text, None, usage)
226 } else {
227 // Use V3 endpoint - server handles model/version selection and suffix stripping
228 let (response, usage) = EditPredictionStore::send_v3_request(
229 prompt_input.clone(),
230 client,
231 llm_token,
232 organization_id,
233 app_version,
234 trigger,
235 )
236 .await?;
237
238 let request_id = EditPredictionId(response.request_id.into());
239 let output_text = if response.output.is_empty() {
240 None
241 } else {
242 Some(response.output)
243 };
244 editable_range_in_excerpt = response.editable_range;
245 let model_version = response.model_version;
246
247 (request_id, output_text, model_version, usage)
248 };
249
250 let received_response_at = Instant::now();
251
252 log::trace!("Got edit prediction response");
253
254 let Some(mut output_text) = output_text else {
255 return Ok((Some((request_id, None, model_version)), usage));
256 };
257
258 let editable_range_in_buffer = editable_range_in_excerpt.start
259 + full_context_offset_range.start
260 ..editable_range_in_excerpt.end + full_context_offset_range.start;
261
262 let mut old_text = snapshot
263 .text_for_range(editable_range_in_buffer.clone())
264 .collect::<String>();
265
266 // For the hashline format, the model may return <|set|>/<|insert|>
267 // edit commands instead of a full replacement. Apply them against
268 // the original editable region to produce the full replacement text.
269 // This must happen before cursor marker stripping because the cursor
270 // marker is embedded inside edit command content.
271 if let Some(rewritten_output) =
272 output_with_context_for_format(zeta_version, &old_text, &output_text)?
273 {
274 output_text = rewritten_output;
275 }
276
277 // Client-side cursor marker processing (applies to both raw and v3 responses)
278 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
279 if let Some(offset) = cursor_offset_in_output {
280 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
281 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
282 }
283
284 if let Some(debug_tx) = &debug_tx {
285 debug_tx
286 .unbounded_send(DebugEvent::EditPredictionFinished(
287 EditPredictionFinishedDebugEvent {
288 buffer: buffer.downgrade(),
289 position,
290 model_output: Some(output_text.clone()),
291 },
292 ))
293 .ok();
294 }
295
296 if !output_text.is_empty() && !output_text.ends_with('\n') {
297 output_text.push('\n');
298 }
299 if !old_text.is_empty() && !old_text.ends_with('\n') {
300 old_text.push('\n');
301 }
302
303 let (edits, cursor_position) = compute_edits_and_cursor_position(
304 old_text,
305 &output_text,
306 editable_range_in_buffer.start,
307 cursor_offset_in_output,
308 &snapshot,
309 );
310
311 anyhow::Ok((
312 Some((
313 request_id,
314 Some((
315 prompt_input,
316 buffer,
317 snapshot.clone(),
318 edits,
319 cursor_position,
320 received_response_at,
321 editable_range_in_buffer,
322 )),
323 model_version,
324 )),
325 usage,
326 ))
327 }
328 });
329
330 cx.spawn(async move |this, cx| {
331 let Some((id, prediction, model_version)) =
332 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
333 else {
334 return Ok(None);
335 };
336
337 let Some((
338 inputs,
339 edited_buffer,
340 edited_buffer_snapshot,
341 edits,
342 cursor_position,
343 received_response_at,
344 editable_range_in_buffer,
345 )) = prediction
346 else {
347 return Ok(Some(EditPredictionResult {
348 id,
349 prediction: Err(EditPredictionRejectReason::Empty),
350 }));
351 };
352
353 if can_collect_data {
354 this.update(cx, |this, cx| {
355 this.enqueue_settled_prediction(
356 id.clone(),
357 &project,
358 &edited_buffer,
359 &edited_buffer_snapshot,
360 editable_range_in_buffer,
361 cx,
362 );
363 })
364 .ok();
365 }
366
367 Ok(Some(
368 EditPredictionResult::new(
369 id,
370 &edited_buffer,
371 &edited_buffer_snapshot,
372 edits.into(),
373 cursor_position,
374 buffer_snapshotted_at,
375 received_response_at,
376 inputs,
377 model_version,
378 cx,
379 )
380 .await,
381 ))
382 })
383}
384
385pub fn zeta2_prompt_input(
386 snapshot: &language::BufferSnapshot,
387 related_files: Vec<zeta_prompt::RelatedFile>,
388 events: Vec<Arc<zeta_prompt::Event>>,
389 excerpt_path: Arc<Path>,
390 cursor_offset: usize,
391 preferred_experiment: Option<String>,
392 is_open_source: bool,
393 can_collect_data: bool,
394) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
395 let cursor_point = cursor_offset.to_point(snapshot);
396
397 let (full_context, full_context_offset_range, excerpt_ranges) =
398 compute_excerpt_ranges(cursor_point, snapshot);
399
400 let full_context_start_offset = full_context_offset_range.start;
401 let full_context_start_row = full_context.start.row;
402
403 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
404
405 let prompt_input = zeta_prompt::ZetaPromptInput {
406 cursor_path: excerpt_path,
407 cursor_excerpt: snapshot
408 .text_for_range(full_context)
409 .collect::<String>()
410 .into(),
411 cursor_offset_in_excerpt,
412 excerpt_start_row: Some(full_context_start_row),
413 events,
414 related_files,
415 excerpt_ranges,
416 experiment: preferred_experiment,
417 in_open_source_repo: is_open_source,
418 can_collect_data,
419 };
420 (full_context_offset_range, prompt_input)
421}
422
423pub(crate) fn edit_prediction_accepted(
424 store: &EditPredictionStore,
425 current_prediction: CurrentEditPrediction,
426 cx: &App,
427) {
428 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
429 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
430 return;
431 }
432
433 let request_id = current_prediction.prediction.id.to_string();
434 let model_version = current_prediction.prediction.model_version;
435 let require_auth = custom_accept_url.is_none();
436 let client = store.client.clone();
437 let llm_token = store.llm_token.clone();
438 let organization_id = store
439 .user_store
440 .read(cx)
441 .current_organization()
442 .map(|organization| organization.id.clone());
443 let app_version = AppVersion::global(cx);
444
445 cx.background_spawn(async move {
446 let url = if let Some(accept_edits_url) = custom_accept_url {
447 gpui::http_client::Url::parse(&accept_edits_url)?
448 } else {
449 client
450 .http_client()
451 .build_zed_llm_url("/predict_edits/accept", &[])?
452 };
453
454 let response = EditPredictionStore::send_api_request::<()>(
455 move |builder| {
456 let req = builder.uri(url.as_ref()).body(
457 serde_json::to_string(&AcceptEditPredictionBody {
458 request_id: request_id.clone(),
459 model_version: model_version.clone(),
460 })?
461 .into(),
462 );
463 Ok(req?)
464 },
465 client,
466 llm_token,
467 organization_id,
468 app_version,
469 require_auth,
470 )
471 .await;
472
473 response?;
474 anyhow::Ok(())
475 })
476 .detach_and_log_err(cx);
477}
478
479pub fn compute_edits(
480 old_text: String,
481 new_text: &str,
482 offset: usize,
483 snapshot: &BufferSnapshot,
484) -> Vec<(Range<Anchor>, Arc<str>)> {
485 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
486}
487
488pub fn compute_edits_and_cursor_position(
489 old_text: String,
490 new_text: &str,
491 offset: usize,
492 cursor_offset_in_new_text: Option<usize>,
493 snapshot: &BufferSnapshot,
494) -> (
495 Vec<(Range<Anchor>, Arc<str>)>,
496 Option<PredictedCursorPosition>,
497) {
498 let diffs = text_diff(&old_text, new_text);
499
500 // Delta represents the cumulative change in byte count from all preceding edits.
501 // new_offset = old_offset + delta, so old_offset = new_offset - delta
502 let mut delta: isize = 0;
503 let mut cursor_position: Option<PredictedCursorPosition> = None;
504 let buffer_len = snapshot.len();
505
506 let edits = diffs
507 .iter()
508 .map(|(raw_old_range, new_text)| {
509 // Compute cursor position if it falls within or before this edit.
510 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
511 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
512 let edit_end_in_new = edit_start_in_new + new_text.len();
513
514 if cursor_offset < edit_start_in_new {
515 let cursor_in_old = (cursor_offset as isize - delta) as usize;
516 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
517 cursor_position = Some(PredictedCursorPosition::at_anchor(
518 snapshot.anchor_after(buffer_offset),
519 ));
520 } else if cursor_offset < edit_end_in_new {
521 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
522 let offset_within_insertion = cursor_offset - edit_start_in_new;
523 cursor_position = Some(PredictedCursorPosition::new(
524 snapshot.anchor_before(buffer_offset),
525 offset_within_insertion,
526 ));
527 }
528
529 delta += new_text.len() as isize - raw_old_range.len() as isize;
530 }
531
532 // Compute the edit with prefix/suffix trimming.
533 let mut old_range = raw_old_range.clone();
534 let old_slice = &old_text[old_range.clone()];
535
536 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
537 let suffix_len = common_prefix(
538 old_slice[prefix_len..].chars().rev(),
539 new_text[prefix_len..].chars().rev(),
540 );
541
542 old_range.start += offset;
543 old_range.end += offset;
544 old_range.start += prefix_len;
545 old_range.end -= suffix_len;
546
547 old_range.start = old_range.start.min(buffer_len);
548 old_range.end = old_range.end.min(buffer_len);
549
550 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
551 let range = if old_range.is_empty() {
552 let anchor = snapshot.anchor_after(old_range.start);
553 anchor..anchor
554 } else {
555 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
556 };
557 (range, new_text)
558 })
559 .collect();
560
561 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
562 let cursor_in_old = (cursor_offset as isize - delta) as usize;
563 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
564 cursor_position = Some(PredictedCursorPosition::at_anchor(
565 snapshot.anchor_after(buffer_offset),
566 ));
567 }
568
569 (edits, cursor_position)
570}
571
572fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
573 a.zip(b)
574 .take_while(|(a, b)| a == b)
575 .map(|(a, _)| a.len_utf8())
576 .sum()
577}