1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
6};
7use anyhow::{Context as _, Result};
8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use futures::AsyncReadExt as _;
12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
15use release_channel::AppVersion;
16use text::{Anchor, Bias};
17
18use std::env;
19use std::ops::Range;
20use std::{path::Path, sync::Arc, time::Instant};
21use zeta_prompt::{
22 CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
23 format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
24 zeta1::{self, EDITABLE_REGION_END_MARKER},
25};
26
27pub fn request_prediction_with_zeta(
28 store: &mut EditPredictionStore,
29 EditPredictionModelInput {
30 buffer,
31 snapshot,
32 position,
33 related_files,
34 events,
35 debug_tx,
36 trigger,
37 project,
38 ..
39 }: EditPredictionModelInput,
40 preferred_model: Option<EditPredictionModelKind>,
41 cx: &mut Context<EditPredictionStore>,
42) -> Task<Result<Option<EditPredictionResult>>> {
43 let settings = &all_language_settings(None, cx).edit_predictions;
44 let provider = settings.provider;
45 let custom_server_settings = match provider {
46 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
47 settings::EditPredictionProvider::OpenAiCompatibleApi => {
48 settings.open_ai_compatible_api.clone()
49 }
50 _ => None,
51 };
52
53 let http_client = cx.http_client();
54 let buffer_snapshotted_at = Instant::now();
55 let raw_config = store.zeta2_raw_config().cloned();
56
57 let excerpt_path: Arc<Path> = snapshot
58 .file()
59 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
60 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
61
62 let client = store.client.clone();
63 let llm_token = store.llm_token.clone();
64 let app_version = AppVersion::global(cx);
65
66 let is_open_source = snapshot
67 .file()
68 .map_or(false, |file| store.is_file_open_source(&project, file, cx))
69 && events.iter().all(|event| event.in_open_source_repo())
70 && related_files.iter().all(|file| file.in_open_source_repo);
71
72 let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
73
74 let request_task = cx.background_spawn({
75 async move {
76 let zeta_version = raw_config
77 .as_ref()
78 .map(|config| config.format)
79 .unwrap_or(ZetaFormat::default());
80
81 let cursor_offset = position.to_offset(&snapshot);
82 let editable_range_in_excerpt: Range<usize>;
83 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
84 &snapshot,
85 related_files,
86 events,
87 excerpt_path,
88 cursor_offset,
89 zeta_version,
90 preferred_model,
91 is_open_source,
92 can_collect_data,
93 );
94
95 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
96 return Ok((None, None));
97 }
98
99 let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1);
100 let excerpt_ranges = prompt_input
101 .excerpt_ranges
102 .as_ref()
103 .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?;
104
105 if let Some(debug_tx) = &debug_tx {
106 let prompt = if is_zeta1 {
107 zeta1::format_zeta1_from_input(
108 &prompt_input,
109 excerpt_ranges.editable_350.clone(),
110 excerpt_ranges.editable_350_context_150.clone(),
111 )
112 } else {
113 format_zeta_prompt(&prompt_input, zeta_version)
114 };
115 debug_tx
116 .unbounded_send(DebugEvent::EditPredictionStarted(
117 EditPredictionStartedDebugEvent {
118 buffer: buffer.downgrade(),
119 prompt: Some(prompt),
120 position,
121 },
122 ))
123 .ok();
124 }
125
126 log::trace!("Sending edit prediction request");
127
128 let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
129 &custom_server_settings
130 {
131 let max_tokens = custom_settings.max_output_tokens * 4;
132
133 if is_zeta1 {
134 let ranges = excerpt_ranges;
135 let prompt = zeta1::format_zeta1_from_input(
136 &prompt_input,
137 ranges.editable_350.clone(),
138 ranges.editable_350_context_150.clone(),
139 );
140 editable_range_in_excerpt = ranges.editable_350.clone();
141 let stop_tokens = vec![
142 EDITABLE_REGION_END_MARKER.to_string(),
143 format!("{EDITABLE_REGION_END_MARKER}\n"),
144 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
145 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
146 ];
147
148 let (response_text, request_id) = send_custom_server_request(
149 provider,
150 custom_settings,
151 prompt,
152 max_tokens,
153 stop_tokens,
154 &http_client,
155 )
156 .await?;
157
158 let request_id = EditPredictionId(request_id.into());
159 let output_text = zeta1::clean_zeta1_model_output(&response_text);
160
161 (request_id, output_text, None, None)
162 } else {
163 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
164 let prefill = get_prefill(&prompt_input, zeta_version);
165 let prompt = format!("{prompt}{prefill}");
166
167 editable_range_in_excerpt = prompt_input
168 .excerpt_ranges
169 .as_ref()
170 .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0)
171 .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
172
173 let (response_text, request_id) = send_custom_server_request(
174 provider,
175 custom_settings,
176 prompt,
177 max_tokens,
178 vec![],
179 &http_client,
180 )
181 .await?;
182
183 let request_id = EditPredictionId(request_id.into());
184 let output_text = if response_text.is_empty() {
185 None
186 } else {
187 let output = format!("{prefill}{response_text}");
188 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
189 };
190
191 (request_id, output_text, None, None)
192 }
193 } else if let Some(config) = &raw_config {
194 let prompt = format_zeta_prompt(&prompt_input, config.format);
195 let prefill = get_prefill(&prompt_input, config.format);
196 let prompt = format!("{prompt}{prefill}");
197 let request = RawCompletionRequest {
198 model: config.model_id.clone().unwrap_or_default(),
199 prompt,
200 temperature: None,
201 stop: vec![],
202 max_tokens: Some(2048),
203 environment: Some(config.format.to_string().to_lowercase()),
204 };
205
206 editable_range_in_excerpt = prompt_input
207 .excerpt_ranges
208 .as_ref()
209 .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1)
210 .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
211
212 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
213 request,
214 client,
215 None,
216 llm_token,
217 app_version,
218 )
219 .await?;
220
221 let request_id = EditPredictionId(response.id.clone().into());
222 let output_text = response.choices.pop().map(|choice| {
223 let response = &choice.text;
224 let output = format!("{prefill}{response}");
225 clean_zeta2_model_output(&output, config.format).to_string()
226 });
227
228 (request_id, output_text, None, usage)
229 } else {
230 // Use V3 endpoint - server handles model/version selection and suffix stripping
231 let (response, usage) = EditPredictionStore::send_v3_request(
232 prompt_input.clone(),
233 client,
234 llm_token,
235 app_version,
236 trigger,
237 )
238 .await?;
239
240 let request_id = EditPredictionId(response.request_id.into());
241 let output_text = if response.output.is_empty() {
242 None
243 } else {
244 Some(response.output)
245 };
246 editable_range_in_excerpt = response.editable_range;
247 let model_version = response.model_version;
248
249 (request_id, output_text, model_version, usage)
250 };
251
252 let received_response_at = Instant::now();
253
254 log::trace!("Got edit prediction response");
255
256 let Some(mut output_text) = output_text else {
257 return Ok((Some((request_id, None, model_version)), usage));
258 };
259
260 // Client-side cursor marker processing (applies to both raw and v3 responses)
261 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
262 if let Some(offset) = cursor_offset_in_output {
263 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
264 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
265 }
266
267 if let Some(debug_tx) = &debug_tx {
268 debug_tx
269 .unbounded_send(DebugEvent::EditPredictionFinished(
270 EditPredictionFinishedDebugEvent {
271 buffer: buffer.downgrade(),
272 position,
273 model_output: Some(output_text.clone()),
274 },
275 ))
276 .ok();
277 }
278
279 let editable_range_in_buffer = editable_range_in_excerpt.start
280 + full_context_offset_range.start
281 ..editable_range_in_excerpt.end + full_context_offset_range.start;
282
283 let mut old_text = snapshot
284 .text_for_range(editable_range_in_buffer.clone())
285 .collect::<String>();
286
287 if !output_text.is_empty() && !output_text.ends_with('\n') {
288 output_text.push('\n');
289 }
290 if !old_text.is_empty() && !old_text.ends_with('\n') {
291 old_text.push('\n');
292 }
293
294 let (edits, cursor_position) = compute_edits_and_cursor_position(
295 old_text,
296 &output_text,
297 editable_range_in_buffer.start,
298 cursor_offset_in_output,
299 &snapshot,
300 );
301
302 anyhow::Ok((
303 Some((
304 request_id,
305 Some((
306 prompt_input,
307 buffer,
308 snapshot.clone(),
309 edits,
310 cursor_position,
311 received_response_at,
312 full_context_offset_range,
313 editable_range_in_buffer,
314 )),
315 model_version,
316 )),
317 usage,
318 ))
319 }
320 });
321
322 cx.spawn(async move |this, cx| {
323 let Some((id, prediction, model_version)) =
324 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
325 else {
326 return Ok(None);
327 };
328
329 let Some((
330 inputs,
331 edited_buffer,
332 edited_buffer_snapshot,
333 edits,
334 cursor_position,
335 received_response_at,
336 full_context_offset_range,
337 editable_range_in_buffer,
338 )) = prediction
339 else {
340 return Ok(Some(EditPredictionResult {
341 id,
342 prediction: Err(EditPredictionRejectReason::Empty),
343 }));
344 };
345
346 if can_collect_data {
347 cx.spawn({
348 let weak_buffer = edited_buffer.downgrade();
349 let context_anchor_range =
350 edited_buffer_snapshot.anchor_range_around(full_context_offset_range);
351 let editable_anchor_range =
352 edited_buffer_snapshot.anchor_range_around(editable_range_in_buffer);
353 let request_id = id.0.clone();
354 async move |cx| {
355 cx.background_executor()
356 .timer(std::time::Duration::from_secs(30))
357 .await;
358
359 let Some(buffer) = weak_buffer.upgrade() else {
360 return;
361 };
362 let (new_cursor_region, editable_range_in_excerpt) =
363 buffer.read_with(cx, |buffer, _| {
364 let context_start =
365 buffer.offset_for_anchor(&context_anchor_range.start);
366 let editable_range_in_excerpt = (buffer
367 .offset_for_anchor(&editable_anchor_range.start)
368 - context_start)
369 ..(buffer.offset_for_anchor(&editable_anchor_range.end)
370 - context_start);
371 let text = buffer
372 .text_for_range(context_anchor_range)
373 .collect::<String>();
374 (text, editable_range_in_excerpt)
375 });
376 telemetry::event!(
377 "Edit Prediction Snapshot",
378 request_id,
379 new_cursor_region,
380 editable_range_in_excerpt,
381 );
382 }
383 })
384 .detach();
385 }
386
387 Ok(Some(
388 EditPredictionResult::new(
389 id,
390 &edited_buffer,
391 &edited_buffer_snapshot,
392 edits.into(),
393 cursor_position,
394 buffer_snapshotted_at,
395 received_response_at,
396 inputs,
397 model_version,
398 cx,
399 )
400 .await,
401 ))
402 })
403}
404
405pub fn zeta2_prompt_input(
406 snapshot: &language::BufferSnapshot,
407 related_files: Vec<zeta_prompt::RelatedFile>,
408 events: Vec<Arc<zeta_prompt::Event>>,
409 excerpt_path: Arc<Path>,
410 cursor_offset: usize,
411 zeta_format: ZetaFormat,
412 preferred_model: Option<EditPredictionModelKind>,
413 is_open_source: bool,
414 can_collect_data: bool,
415) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
416 let cursor_point = cursor_offset.to_point(snapshot);
417
418 let (full_context, full_context_offset_range, excerpt_ranges) =
419 compute_excerpt_ranges(cursor_point, snapshot);
420
421 let related_files = crate::filter_redundant_excerpts(
422 related_files,
423 excerpt_path.as_ref(),
424 full_context.start.row..full_context.end.row,
425 );
426
427 let full_context_start_offset = full_context_offset_range.start;
428 let full_context_start_row = full_context.start.row;
429
430 let editable_offset_range = match preferred_model {
431 Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(),
432 _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0,
433 };
434
435 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
436
437 let prompt_input = zeta_prompt::ZetaPromptInput {
438 cursor_path: excerpt_path,
439 cursor_excerpt: snapshot
440 .text_for_range(full_context)
441 .collect::<String>()
442 .into(),
443 editable_range_in_excerpt: editable_offset_range,
444 cursor_offset_in_excerpt,
445 excerpt_start_row: Some(full_context_start_row),
446 events,
447 related_files,
448 excerpt_ranges: Some(excerpt_ranges),
449 preferred_model,
450 in_open_source_repo: is_open_source,
451 can_collect_data,
452 };
453 (full_context_offset_range, prompt_input)
454}
455
456pub(crate) async fn send_custom_server_request(
457 provider: settings::EditPredictionProvider,
458 settings: &OpenAiCompatibleEditPredictionSettings,
459 prompt: String,
460 max_tokens: u32,
461 stop_tokens: Vec<String>,
462 http_client: &Arc<dyn http_client::HttpClient>,
463) -> Result<(String, String)> {
464 match provider {
465 settings::EditPredictionProvider::Ollama => {
466 let response =
467 ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
468 .await?;
469 Ok((response.response, response.created_at))
470 }
471 _ => {
472 let request = RawCompletionRequest {
473 model: settings.model.clone(),
474 prompt,
475 max_tokens: Some(max_tokens),
476 temperature: None,
477 stop: stop_tokens
478 .into_iter()
479 .map(std::borrow::Cow::Owned)
480 .collect(),
481 environment: None,
482 };
483
484 let request_body = serde_json::to_string(&request)?;
485 let http_request = http_client::Request::builder()
486 .method(http_client::Method::POST)
487 .uri(settings.api_url.as_ref())
488 .header("Content-Type", "application/json")
489 .body(http_client::AsyncBody::from(request_body))?;
490
491 let mut response = http_client.send(http_request).await?;
492 let status = response.status();
493
494 if !status.is_success() {
495 let mut body = String::new();
496 response.body_mut().read_to_string(&mut body).await?;
497 anyhow::bail!("custom server error: {} - {}", status, body);
498 }
499
500 let mut body = String::new();
501 response.body_mut().read_to_string(&mut body).await?;
502
503 let parsed: RawCompletionResponse =
504 serde_json::from_str(&body).context("Failed to parse completion response")?;
505 let text = parsed
506 .choices
507 .into_iter()
508 .next()
509 .map(|choice| choice.text)
510 .unwrap_or_default();
511 Ok((text, parsed.id))
512 }
513 }
514}
515
516pub(crate) fn edit_prediction_accepted(
517 store: &EditPredictionStore,
518 current_prediction: CurrentEditPrediction,
519 cx: &App,
520) {
521 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
522 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
523 return;
524 }
525
526 let request_id = current_prediction.prediction.id.to_string();
527 let model_version = current_prediction.prediction.model_version;
528 let require_auth = custom_accept_url.is_none();
529 let client = store.client.clone();
530 let llm_token = store.llm_token.clone();
531 let app_version = AppVersion::global(cx);
532
533 cx.background_spawn(async move {
534 let url = if let Some(accept_edits_url) = custom_accept_url {
535 gpui::http_client::Url::parse(&accept_edits_url)?
536 } else {
537 client
538 .http_client()
539 .build_zed_llm_url("/predict_edits/accept", &[])?
540 };
541
542 let response = EditPredictionStore::send_api_request::<()>(
543 move |builder| {
544 let req = builder.uri(url.as_ref()).body(
545 serde_json::to_string(&AcceptEditPredictionBody {
546 request_id: request_id.clone(),
547 model_version: model_version.clone(),
548 })?
549 .into(),
550 );
551 Ok(req?)
552 },
553 client,
554 llm_token,
555 app_version,
556 require_auth,
557 )
558 .await;
559
560 response?;
561 anyhow::Ok(())
562 })
563 .detach_and_log_err(cx);
564}
565
566pub fn compute_edits(
567 old_text: String,
568 new_text: &str,
569 offset: usize,
570 snapshot: &BufferSnapshot,
571) -> Vec<(Range<Anchor>, Arc<str>)> {
572 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
573}
574
575pub fn compute_edits_and_cursor_position(
576 old_text: String,
577 new_text: &str,
578 offset: usize,
579 cursor_offset_in_new_text: Option<usize>,
580 snapshot: &BufferSnapshot,
581) -> (
582 Vec<(Range<Anchor>, Arc<str>)>,
583 Option<PredictedCursorPosition>,
584) {
585 let diffs = text_diff(&old_text, new_text);
586
587 // Delta represents the cumulative change in byte count from all preceding edits.
588 // new_offset = old_offset + delta, so old_offset = new_offset - delta
589 let mut delta: isize = 0;
590 let mut cursor_position: Option<PredictedCursorPosition> = None;
591 let buffer_len = snapshot.len();
592
593 let edits = diffs
594 .iter()
595 .map(|(raw_old_range, new_text)| {
596 // Compute cursor position if it falls within or before this edit.
597 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
598 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
599 let edit_end_in_new = edit_start_in_new + new_text.len();
600
601 if cursor_offset < edit_start_in_new {
602 let cursor_in_old = (cursor_offset as isize - delta) as usize;
603 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
604 cursor_position = Some(PredictedCursorPosition::at_anchor(
605 snapshot.anchor_after(buffer_offset),
606 ));
607 } else if cursor_offset < edit_end_in_new {
608 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
609 let offset_within_insertion = cursor_offset - edit_start_in_new;
610 cursor_position = Some(PredictedCursorPosition::new(
611 snapshot.anchor_before(buffer_offset),
612 offset_within_insertion,
613 ));
614 }
615
616 delta += new_text.len() as isize - raw_old_range.len() as isize;
617 }
618
619 // Compute the edit with prefix/suffix trimming.
620 let mut old_range = raw_old_range.clone();
621 let old_slice = &old_text[old_range.clone()];
622
623 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
624 let suffix_len = common_prefix(
625 old_slice[prefix_len..].chars().rev(),
626 new_text[prefix_len..].chars().rev(),
627 );
628
629 old_range.start += offset;
630 old_range.end += offset;
631 old_range.start += prefix_len;
632 old_range.end -= suffix_len;
633
634 old_range.start = old_range.start.min(buffer_len);
635 old_range.end = old_range.end.min(buffer_len);
636
637 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
638 let range = if old_range.is_empty() {
639 let anchor = snapshot.anchor_after(old_range.start);
640 anchor..anchor
641 } else {
642 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
643 };
644 (range, new_text)
645 })
646 .collect();
647
648 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
649 let cursor_in_old = (cursor_offset as isize - delta) as usize;
650 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
651 cursor_position = Some(PredictedCursorPosition::at_anchor(
652 snapshot.anchor_after(buffer_offset),
653 ));
654 }
655
656 (edits, cursor_position)
657}
658
659fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
660 a.zip(b)
661 .take_while(|(a, b)| a == b)
662 .map(|(a, _)| a.len_utf8())
663 .sum()
664}