1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
6};
7use anyhow::Result;
8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use gpui::{App, AppContext as _, Task, prelude::*};
12use language::language_settings::all_language_settings;
13use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
14use release_channel::AppVersion;
15use settings::EditPredictionPromptFormat;
16use text::{Anchor, Bias};
17
18use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
19use zeta_prompt::{
20 CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
21 output_with_context_for_format, prompt_input_contains_special_tokens,
22 zeta1::{self, EDITABLE_REGION_END_MARKER},
23};
24
25use crate::open_ai_compatible::{
26 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
27};
28
29pub fn request_prediction_with_zeta(
30 store: &mut EditPredictionStore,
31 EditPredictionModelInput {
32 buffer,
33 snapshot,
34 position,
35 related_files,
36 events,
37 debug_tx,
38 trigger,
39 project,
40 can_collect_data,
41 is_open_source,
42 ..
43 }: EditPredictionModelInput,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let settings = &all_language_settings(None, cx).edit_predictions;
47 let provider = settings.provider;
48 let custom_server_settings = match provider {
49 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
50 settings::EditPredictionProvider::OpenAiCompatibleApi => {
51 settings.open_ai_compatible_api.clone()
52 }
53 _ => None,
54 };
55
56 let http_client = cx.http_client();
57 let buffer_snapshotted_at = Instant::now();
58 let raw_config = store.zeta2_raw_config().cloned();
59 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
60 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
61
62 let excerpt_path: Arc<Path> = snapshot
63 .file()
64 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66
67 let repo_url = if can_collect_data {
68 let buffer_id = buffer.read(cx).remote_id();
69 project
70 .read(cx)
71 .git_store()
72 .read(cx)
73 .repository_and_path_for_buffer_id(buffer_id, cx)
74 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
75 } else {
76 None
77 };
78
79 let client = store.client.clone();
80 let llm_token = store.llm_token.clone();
81 let organization_id = store
82 .user_store
83 .read(cx)
84 .current_organization()
85 .map(|organization| organization.id.clone());
86 let app_version = AppVersion::global(cx);
87
88 let request_task = cx.background_spawn({
89 async move {
90 let zeta_version = raw_config
91 .as_ref()
92 .map(|config| config.format)
93 .unwrap_or(ZetaFormat::default());
94
95 let cursor_offset = position.to_offset(&snapshot);
96 let editable_range_in_excerpt: Range<usize>;
97 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
98 &snapshot,
99 related_files,
100 events,
101 excerpt_path,
102 cursor_offset,
103 preferred_experiment,
104 is_open_source,
105 can_collect_data,
106 repo_url,
107 );
108
109 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
110 return Ok((None, None));
111 }
112
113 if let Some(debug_tx) = &debug_tx {
114 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
115 debug_tx
116 .unbounded_send(DebugEvent::EditPredictionStarted(
117 EditPredictionStartedDebugEvent {
118 buffer: buffer.downgrade(),
119 prompt: Some(prompt),
120 position,
121 },
122 ))
123 .ok();
124 }
125
126 log::trace!("Sending edit prediction request");
127
128 let (request_id, output_text, model_version, usage) =
129 if let Some(custom_settings) = &custom_server_settings {
130 let max_tokens = custom_settings.max_output_tokens * 4;
131
132 match custom_settings.prompt_format {
133 EditPredictionPromptFormat::Zeta => {
134 let ranges = &prompt_input.excerpt_ranges;
135 let prompt = zeta1::format_zeta1_from_input(
136 &prompt_input,
137 ranges.editable_350.clone(),
138 ranges.editable_350_context_150.clone(),
139 );
140 editable_range_in_excerpt = ranges.editable_350.clone();
141 let stop_tokens = vec![
142 EDITABLE_REGION_END_MARKER.to_string(),
143 format!("{EDITABLE_REGION_END_MARKER}\n"),
144 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
145 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
146 ];
147
148 let (response_text, request_id) = send_custom_server_request(
149 provider,
150 custom_settings,
151 prompt,
152 max_tokens,
153 stop_tokens,
154 open_ai_compatible_api_key.clone(),
155 &http_client,
156 )
157 .await?;
158
159 let request_id = EditPredictionId(request_id.into());
160 let output_text = zeta1::clean_zeta1_model_output(&response_text);
161
162 (request_id, output_text, None, None)
163 }
164 EditPredictionPromptFormat::Zeta2 => {
165 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
166 let prefill = get_prefill(&prompt_input, zeta_version);
167 let prompt = format!("{prompt}{prefill}");
168
169 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
170 zeta_version,
171 &prompt_input.excerpt_ranges,
172 )
173 .0;
174
175 let (response_text, request_id) = send_custom_server_request(
176 provider,
177 custom_settings,
178 prompt,
179 max_tokens,
180 vec![],
181 open_ai_compatible_api_key.clone(),
182 &http_client,
183 )
184 .await?;
185
186 let request_id = EditPredictionId(request_id.into());
187 let output_text = if response_text.is_empty() {
188 None
189 } else {
190 let output = format!("{prefill}{response_text}");
191 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
192 };
193
194 (request_id, output_text, None, None)
195 }
196 _ => anyhow::bail!("unsupported prompt format"),
197 }
198 } else if let Some(config) = &raw_config {
199 let prompt = format_zeta_prompt(&prompt_input, config.format);
200 let prefill = get_prefill(&prompt_input, config.format);
201 let prompt = format!("{prompt}{prefill}");
202 let environment = config
203 .environment
204 .clone()
205 .or_else(|| Some(config.format.to_string().to_lowercase()));
206 let request = RawCompletionRequest {
207 model: config.model_id.clone().unwrap_or_default(),
208 prompt,
209 temperature: None,
210 stop: vec![],
211 max_tokens: Some(2048),
212 environment,
213 };
214
215 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
216 config.format,
217 &prompt_input.excerpt_ranges,
218 )
219 .1;
220
221 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
222 request,
223 client,
224 None,
225 llm_token,
226 organization_id,
227 app_version,
228 )
229 .await?;
230
231 let request_id = EditPredictionId(response.id.clone().into());
232 let output_text = response.choices.pop().map(|choice| {
233 let response = &choice.text;
234 let output = format!("{prefill}{response}");
235 clean_zeta2_model_output(&output, config.format).to_string()
236 });
237
238 (request_id, output_text, None, usage)
239 } else {
240 // Use V3 endpoint - server handles model/version selection and suffix stripping
241 let (response, usage) = EditPredictionStore::send_v3_request(
242 prompt_input.clone(),
243 client,
244 llm_token,
245 organization_id,
246 app_version,
247 trigger,
248 )
249 .await?;
250
251 let request_id = EditPredictionId(response.request_id.into());
252 let output_text = if response.output.is_empty() {
253 None
254 } else {
255 Some(response.output)
256 };
257 editable_range_in_excerpt = response.editable_range;
258 let model_version = response.model_version;
259
260 (request_id, output_text, model_version, usage)
261 };
262
263 let received_response_at = Instant::now();
264
265 log::trace!("Got edit prediction response");
266
267 let Some(mut output_text) = output_text else {
268 return Ok((Some((request_id, None, model_version)), usage));
269 };
270
271 let editable_range_in_buffer = editable_range_in_excerpt.start
272 + full_context_offset_range.start
273 ..editable_range_in_excerpt.end + full_context_offset_range.start;
274
275 let mut old_text = snapshot
276 .text_for_range(editable_range_in_buffer.clone())
277 .collect::<String>();
278
279 // For the hashline format, the model may return <|set|>/<|insert|>
280 // edit commands instead of a full replacement. Apply them against
281 // the original editable region to produce the full replacement text.
282 // This must happen before cursor marker stripping because the cursor
283 // marker is embedded inside edit command content.
284 if let Some(rewritten_output) =
285 output_with_context_for_format(zeta_version, &old_text, &output_text)?
286 {
287 output_text = rewritten_output;
288 }
289
290 // Client-side cursor marker processing (applies to both raw and v3 responses)
291 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
292 if let Some(offset) = cursor_offset_in_output {
293 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
294 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
295 }
296
297 if let Some(debug_tx) = &debug_tx {
298 debug_tx
299 .unbounded_send(DebugEvent::EditPredictionFinished(
300 EditPredictionFinishedDebugEvent {
301 buffer: buffer.downgrade(),
302 position,
303 model_output: Some(output_text.clone()),
304 },
305 ))
306 .ok();
307 }
308
309 if !output_text.is_empty() && !output_text.ends_with('\n') {
310 output_text.push('\n');
311 }
312 if !old_text.is_empty() && !old_text.ends_with('\n') {
313 old_text.push('\n');
314 }
315
316 let (edits, cursor_position) = compute_edits_and_cursor_position(
317 old_text,
318 &output_text,
319 editable_range_in_buffer.start,
320 cursor_offset_in_output,
321 &snapshot,
322 );
323
324 anyhow::Ok((
325 Some((
326 request_id,
327 Some((
328 prompt_input,
329 buffer,
330 snapshot.clone(),
331 edits,
332 cursor_position,
333 received_response_at,
334 editable_range_in_buffer,
335 )),
336 model_version,
337 )),
338 usage,
339 ))
340 }
341 });
342
343 cx.spawn(async move |this, cx| {
344 let Some((id, prediction, model_version)) =
345 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
346 else {
347 return Ok(None);
348 };
349
350 let Some((
351 inputs,
352 edited_buffer,
353 edited_buffer_snapshot,
354 edits,
355 cursor_position,
356 received_response_at,
357 editable_range_in_buffer,
358 )) = prediction
359 else {
360 return Ok(Some(EditPredictionResult {
361 id,
362 prediction: Err(EditPredictionRejectReason::Empty),
363 }));
364 };
365
366 if can_collect_data {
367 this.update(cx, |this, cx| {
368 this.enqueue_settled_prediction(
369 id.clone(),
370 &project,
371 &edited_buffer,
372 &edited_buffer_snapshot,
373 editable_range_in_buffer,
374 cx,
375 );
376 })
377 .ok();
378 }
379
380 Ok(Some(
381 EditPredictionResult::new(
382 id,
383 &edited_buffer,
384 &edited_buffer_snapshot,
385 edits.into(),
386 cursor_position,
387 buffer_snapshotted_at,
388 received_response_at,
389 inputs,
390 model_version,
391 cx,
392 )
393 .await,
394 ))
395 })
396}
397
398pub fn zeta2_prompt_input(
399 snapshot: &language::BufferSnapshot,
400 related_files: Vec<zeta_prompt::RelatedFile>,
401 events: Vec<Arc<zeta_prompt::Event>>,
402 excerpt_path: Arc<Path>,
403 cursor_offset: usize,
404 preferred_experiment: Option<String>,
405 is_open_source: bool,
406 can_collect_data: bool,
407 repo_url: Option<String>,
408) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
409 let cursor_point = cursor_offset.to_point(snapshot);
410
411 let (full_context, full_context_offset_range, excerpt_ranges) =
412 compute_excerpt_ranges(cursor_point, snapshot);
413
414 let related_files = crate::filter_redundant_excerpts(
415 related_files,
416 excerpt_path.as_ref(),
417 full_context.start.row..full_context.end.row,
418 );
419
420 let full_context_start_offset = full_context_offset_range.start;
421 let full_context_start_row = full_context.start.row;
422
423 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
424
425 let prompt_input = zeta_prompt::ZetaPromptInput {
426 cursor_path: excerpt_path,
427 cursor_excerpt: snapshot
428 .text_for_range(full_context)
429 .collect::<String>()
430 .into(),
431 cursor_offset_in_excerpt,
432 excerpt_start_row: Some(full_context_start_row),
433 events,
434 related_files,
435 excerpt_ranges,
436 experiment: preferred_experiment,
437 in_open_source_repo: is_open_source,
438 can_collect_data,
439 repo_url,
440 };
441 (full_context_offset_range, prompt_input)
442}
443
444pub(crate) fn edit_prediction_accepted(
445 store: &EditPredictionStore,
446 current_prediction: CurrentEditPrediction,
447 cx: &App,
448) {
449 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
450 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
451 return;
452 }
453
454 let request_id = current_prediction.prediction.id.to_string();
455 let model_version = current_prediction.prediction.model_version;
456 let require_auth = custom_accept_url.is_none();
457 let client = store.client.clone();
458 let llm_token = store.llm_token.clone();
459 let organization_id = store
460 .user_store
461 .read(cx)
462 .current_organization()
463 .map(|organization| organization.id.clone());
464 let app_version = AppVersion::global(cx);
465
466 cx.background_spawn(async move {
467 let url = if let Some(accept_edits_url) = custom_accept_url {
468 gpui::http_client::Url::parse(&accept_edits_url)?
469 } else {
470 client
471 .http_client()
472 .build_zed_llm_url("/predict_edits/accept", &[])?
473 };
474
475 let response = EditPredictionStore::send_api_request::<()>(
476 move |builder| {
477 let req = builder.uri(url.as_ref()).body(
478 serde_json::to_string(&AcceptEditPredictionBody {
479 request_id: request_id.clone(),
480 model_version: model_version.clone(),
481 })?
482 .into(),
483 );
484 Ok(req?)
485 },
486 client,
487 llm_token,
488 organization_id,
489 app_version,
490 require_auth,
491 )
492 .await;
493
494 response?;
495 anyhow::Ok(())
496 })
497 .detach_and_log_err(cx);
498}
499
500pub fn compute_edits(
501 old_text: String,
502 new_text: &str,
503 offset: usize,
504 snapshot: &BufferSnapshot,
505) -> Vec<(Range<Anchor>, Arc<str>)> {
506 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
507}
508
509pub fn compute_edits_and_cursor_position(
510 old_text: String,
511 new_text: &str,
512 offset: usize,
513 cursor_offset_in_new_text: Option<usize>,
514 snapshot: &BufferSnapshot,
515) -> (
516 Vec<(Range<Anchor>, Arc<str>)>,
517 Option<PredictedCursorPosition>,
518) {
519 let diffs = text_diff(&old_text, new_text);
520
521 // Delta represents the cumulative change in byte count from all preceding edits.
522 // new_offset = old_offset + delta, so old_offset = new_offset - delta
523 let mut delta: isize = 0;
524 let mut cursor_position: Option<PredictedCursorPosition> = None;
525 let buffer_len = snapshot.len();
526
527 let edits = diffs
528 .iter()
529 .map(|(raw_old_range, new_text)| {
530 // Compute cursor position if it falls within or before this edit.
531 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
532 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
533 let edit_end_in_new = edit_start_in_new + new_text.len();
534
535 if cursor_offset < edit_start_in_new {
536 let cursor_in_old = (cursor_offset as isize - delta) as usize;
537 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
538 cursor_position = Some(PredictedCursorPosition::at_anchor(
539 snapshot.anchor_after(buffer_offset),
540 ));
541 } else if cursor_offset < edit_end_in_new {
542 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
543 let offset_within_insertion = cursor_offset - edit_start_in_new;
544 cursor_position = Some(PredictedCursorPosition::new(
545 snapshot.anchor_before(buffer_offset),
546 offset_within_insertion,
547 ));
548 }
549
550 delta += new_text.len() as isize - raw_old_range.len() as isize;
551 }
552
553 // Compute the edit with prefix/suffix trimming.
554 let mut old_range = raw_old_range.clone();
555 let old_slice = &old_text[old_range.clone()];
556
557 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
558 let suffix_len = common_prefix(
559 old_slice[prefix_len..].chars().rev(),
560 new_text[prefix_len..].chars().rev(),
561 );
562
563 old_range.start += offset;
564 old_range.end += offset;
565 old_range.start += prefix_len;
566 old_range.end -= suffix_len;
567
568 old_range.start = old_range.start.min(buffer_len);
569 old_range.end = old_range.end.min(buffer_len);
570
571 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
572 let range = if old_range.is_empty() {
573 let anchor = snapshot.anchor_after(old_range.start);
574 anchor..anchor
575 } else {
576 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
577 };
578 (range, new_text)
579 })
580 .collect();
581
582 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
583 let cursor_in_old = (cursor_offset as isize - delta) as usize;
584 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
585 cursor_position = Some(PredictedCursorPosition::at_anchor(
586 snapshot.anchor_after(buffer_offset),
587 ));
588 }
589
590 (edits, cursor_position)
591}
592
593fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
594 a.zip(b)
595 .take_while(|(a, b)| a == b)
596 .map(|(a, _)| a.len_utf8())
597 .sum()
598}