1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
6};
7use anyhow::{Context as _, Result};
8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use futures::AsyncReadExt as _;
12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
15use release_channel::AppVersion;
16use settings::EditPredictionPromptFormat;
17use text::{Anchor, Bias};
18
19use std::env;
20use std::ops::Range;
21use std::{path::Path, sync::Arc, time::Instant};
22use zeta_prompt::{
23 CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
24 prompt_input_contains_special_tokens,
25 zeta1::{self, EDITABLE_REGION_END_MARKER},
26};
27
28pub fn request_prediction_with_zeta(
29 store: &mut EditPredictionStore,
30 EditPredictionModelInput {
31 buffer,
32 snapshot,
33 position,
34 related_files,
35 events,
36 debug_tx,
37 trigger,
38 project,
39 can_collect_data,
40 is_open_source,
41 ..
42 }: EditPredictionModelInput,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let settings = &all_language_settings(None, cx).edit_predictions;
46 let provider = settings.provider;
47 let custom_server_settings = match provider {
48 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
49 settings::EditPredictionProvider::OpenAiCompatibleApi => {
50 settings.open_ai_compatible_api.clone()
51 }
52 _ => None,
53 };
54
55 let http_client = cx.http_client();
56 let buffer_snapshotted_at = Instant::now();
57 let raw_config = store.zeta2_raw_config().cloned();
58 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
59
60 let excerpt_path: Arc<Path> = snapshot
61 .file()
62 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
63 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
64
65 let client = store.client.clone();
66 let llm_token = store.llm_token.clone();
67 let app_version = AppVersion::global(cx);
68
69 let request_task = cx.background_spawn({
70 async move {
71 let zeta_version = raw_config
72 .as_ref()
73 .map(|config| config.format)
74 .unwrap_or(ZetaFormat::default());
75
76 let cursor_offset = position.to_offset(&snapshot);
77 let editable_range_in_excerpt: Range<usize>;
78 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
79 &snapshot,
80 related_files,
81 events,
82 excerpt_path,
83 cursor_offset,
84 preferred_experiment,
85 is_open_source,
86 can_collect_data,
87 );
88
89 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
90 return Ok((None, None));
91 }
92
93 if let Some(debug_tx) = &debug_tx {
94 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
95 debug_tx
96 .unbounded_send(DebugEvent::EditPredictionStarted(
97 EditPredictionStartedDebugEvent {
98 buffer: buffer.downgrade(),
99 prompt: Some(prompt),
100 position,
101 },
102 ))
103 .ok();
104 }
105
106 log::trace!("Sending edit prediction request");
107
108 let (request_id, output_text, model_version, usage) =
109 if let Some(custom_settings) = &custom_server_settings {
110 let max_tokens = custom_settings.max_output_tokens * 4;
111
112 match custom_settings.prompt_format {
113 EditPredictionPromptFormat::Zeta => {
114 let ranges = &prompt_input.excerpt_ranges;
115 let prompt = zeta1::format_zeta1_from_input(
116 &prompt_input,
117 ranges.editable_350.clone(),
118 ranges.editable_350_context_150.clone(),
119 );
120 editable_range_in_excerpt = ranges.editable_350.clone();
121 let stop_tokens = vec![
122 EDITABLE_REGION_END_MARKER.to_string(),
123 format!("{EDITABLE_REGION_END_MARKER}\n"),
124 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
125 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
126 ];
127
128 let (response_text, request_id) = send_custom_server_request(
129 provider,
130 custom_settings,
131 prompt,
132 max_tokens,
133 stop_tokens,
134 &http_client,
135 )
136 .await?;
137
138 let request_id = EditPredictionId(request_id.into());
139 let output_text = zeta1::clean_zeta1_model_output(&response_text);
140
141 (request_id, output_text, None, None)
142 }
143 EditPredictionPromptFormat::Zeta2 => {
144 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
145 let prefill = get_prefill(&prompt_input, zeta_version);
146 let prompt = format!("{prompt}{prefill}");
147
148 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
149 zeta_version,
150 &prompt_input.excerpt_ranges,
151 )
152 .0;
153
154 let (response_text, request_id) = send_custom_server_request(
155 provider,
156 custom_settings,
157 prompt,
158 max_tokens,
159 vec![],
160 &http_client,
161 )
162 .await?;
163
164 let request_id = EditPredictionId(request_id.into());
165 let output_text = if response_text.is_empty() {
166 None
167 } else {
168 let output = format!("{prefill}{response_text}");
169 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
170 };
171
172 (request_id, output_text, None, None)
173 }
174 _ => anyhow::bail!("unsupported prompt format"),
175 }
176 } else if let Some(config) = &raw_config {
177 let prompt = format_zeta_prompt(&prompt_input, config.format);
178 let prefill = get_prefill(&prompt_input, config.format);
179 let prompt = format!("{prompt}{prefill}");
180 let request = RawCompletionRequest {
181 model: config.model_id.clone().unwrap_or_default(),
182 prompt,
183 temperature: None,
184 stop: vec![],
185 max_tokens: Some(2048),
186 environment: Some(config.format.to_string().to_lowercase()),
187 };
188
189 editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
190 config.format,
191 &prompt_input.excerpt_ranges,
192 )
193 .1;
194
195 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
196 request,
197 client,
198 None,
199 llm_token,
200 app_version,
201 )
202 .await?;
203
204 let request_id = EditPredictionId(response.id.clone().into());
205 let output_text = response.choices.pop().map(|choice| {
206 let response = &choice.text;
207 let output = format!("{prefill}{response}");
208 clean_zeta2_model_output(&output, config.format).to_string()
209 });
210
211 (request_id, output_text, None, usage)
212 } else {
213 // Use V3 endpoint - server handles model/version selection and suffix stripping
214 let (response, usage) = EditPredictionStore::send_v3_request(
215 prompt_input.clone(),
216 client,
217 llm_token,
218 app_version,
219 trigger,
220 )
221 .await?;
222
223 let request_id = EditPredictionId(response.request_id.into());
224 let output_text = if response.output.is_empty() {
225 None
226 } else {
227 Some(response.output)
228 };
229 editable_range_in_excerpt = response.editable_range;
230 let model_version = response.model_version;
231
232 (request_id, output_text, model_version, usage)
233 };
234
235 let received_response_at = Instant::now();
236
237 log::trace!("Got edit prediction response");
238
239 let Some(mut output_text) = output_text else {
240 return Ok((Some((request_id, None, model_version)), usage));
241 };
242
243 // Client-side cursor marker processing (applies to both raw and v3 responses)
244 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
245 if let Some(offset) = cursor_offset_in_output {
246 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
247 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
248 }
249
250 if let Some(debug_tx) = &debug_tx {
251 debug_tx
252 .unbounded_send(DebugEvent::EditPredictionFinished(
253 EditPredictionFinishedDebugEvent {
254 buffer: buffer.downgrade(),
255 position,
256 model_output: Some(output_text.clone()),
257 },
258 ))
259 .ok();
260 }
261
262 let editable_range_in_buffer = editable_range_in_excerpt.start
263 + full_context_offset_range.start
264 ..editable_range_in_excerpt.end + full_context_offset_range.start;
265
266 let mut old_text = snapshot
267 .text_for_range(editable_range_in_buffer.clone())
268 .collect::<String>();
269
270 if !output_text.is_empty() && !output_text.ends_with('\n') {
271 output_text.push('\n');
272 }
273 if !old_text.is_empty() && !old_text.ends_with('\n') {
274 old_text.push('\n');
275 }
276
277 let (edits, cursor_position) = compute_edits_and_cursor_position(
278 old_text,
279 &output_text,
280 editable_range_in_buffer.start,
281 cursor_offset_in_output,
282 &snapshot,
283 );
284
285 anyhow::Ok((
286 Some((
287 request_id,
288 Some((
289 prompt_input,
290 buffer,
291 snapshot.clone(),
292 edits,
293 cursor_position,
294 received_response_at,
295 editable_range_in_buffer,
296 )),
297 model_version,
298 )),
299 usage,
300 ))
301 }
302 });
303
304 cx.spawn(async move |this, cx| {
305 let Some((id, prediction, model_version)) =
306 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
307 else {
308 return Ok(None);
309 };
310
311 let Some((
312 inputs,
313 edited_buffer,
314 edited_buffer_snapshot,
315 edits,
316 cursor_position,
317 received_response_at,
318 editable_range_in_buffer,
319 )) = prediction
320 else {
321 return Ok(Some(EditPredictionResult {
322 id,
323 prediction: Err(EditPredictionRejectReason::Empty),
324 }));
325 };
326
327 if can_collect_data {
328 this.update(cx, |this, cx| {
329 this.enqueue_settled_prediction(
330 id.clone(),
331 &project,
332 &edited_buffer,
333 &edited_buffer_snapshot,
334 editable_range_in_buffer,
335 cx,
336 );
337 })
338 .ok();
339 }
340
341 Ok(Some(
342 EditPredictionResult::new(
343 id,
344 &edited_buffer,
345 &edited_buffer_snapshot,
346 edits.into(),
347 cursor_position,
348 buffer_snapshotted_at,
349 received_response_at,
350 inputs,
351 model_version,
352 cx,
353 )
354 .await,
355 ))
356 })
357}
358
359pub fn zeta2_prompt_input(
360 snapshot: &language::BufferSnapshot,
361 related_files: Vec<zeta_prompt::RelatedFile>,
362 events: Vec<Arc<zeta_prompt::Event>>,
363 excerpt_path: Arc<Path>,
364 cursor_offset: usize,
365 preferred_experiment: Option<String>,
366 is_open_source: bool,
367 can_collect_data: bool,
368) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
369 let cursor_point = cursor_offset.to_point(snapshot);
370
371 let (full_context, full_context_offset_range, excerpt_ranges) =
372 compute_excerpt_ranges(cursor_point, snapshot);
373
374 let related_files = crate::filter_redundant_excerpts(
375 related_files,
376 excerpt_path.as_ref(),
377 full_context.start.row..full_context.end.row,
378 );
379
380 let full_context_start_offset = full_context_offset_range.start;
381 let full_context_start_row = full_context.start.row;
382
383 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
384
385 let prompt_input = zeta_prompt::ZetaPromptInput {
386 cursor_path: excerpt_path,
387 cursor_excerpt: snapshot
388 .text_for_range(full_context)
389 .collect::<String>()
390 .into(),
391 cursor_offset_in_excerpt,
392 excerpt_start_row: Some(full_context_start_row),
393 events,
394 related_files,
395 excerpt_ranges,
396 experiment: preferred_experiment,
397 in_open_source_repo: is_open_source,
398 can_collect_data,
399 };
400 (full_context_offset_range, prompt_input)
401}
402
403pub(crate) async fn send_custom_server_request(
404 provider: settings::EditPredictionProvider,
405 settings: &OpenAiCompatibleEditPredictionSettings,
406 prompt: String,
407 max_tokens: u32,
408 stop_tokens: Vec<String>,
409 http_client: &Arc<dyn http_client::HttpClient>,
410) -> Result<(String, String)> {
411 match provider {
412 settings::EditPredictionProvider::Ollama => {
413 let response =
414 ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
415 .await?;
416 Ok((response.response, response.created_at))
417 }
418 _ => {
419 let request = RawCompletionRequest {
420 model: settings.model.clone(),
421 prompt,
422 max_tokens: Some(max_tokens),
423 temperature: None,
424 stop: stop_tokens
425 .into_iter()
426 .map(std::borrow::Cow::Owned)
427 .collect(),
428 environment: None,
429 };
430
431 let request_body = serde_json::to_string(&request)?;
432 let http_request = http_client::Request::builder()
433 .method(http_client::Method::POST)
434 .uri(settings.api_url.as_ref())
435 .header("Content-Type", "application/json")
436 .body(http_client::AsyncBody::from(request_body))?;
437
438 let mut response = http_client.send(http_request).await?;
439 let status = response.status();
440
441 if !status.is_success() {
442 let mut body = String::new();
443 response.body_mut().read_to_string(&mut body).await?;
444 anyhow::bail!("custom server error: {} - {}", status, body);
445 }
446
447 let mut body = String::new();
448 response.body_mut().read_to_string(&mut body).await?;
449
450 let parsed: RawCompletionResponse =
451 serde_json::from_str(&body).context("Failed to parse completion response")?;
452 let text = parsed
453 .choices
454 .into_iter()
455 .next()
456 .map(|choice| choice.text)
457 .unwrap_or_default();
458 Ok((text, parsed.id))
459 }
460 }
461}
462
463pub(crate) fn edit_prediction_accepted(
464 store: &EditPredictionStore,
465 current_prediction: CurrentEditPrediction,
466 cx: &App,
467) {
468 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
469 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
470 return;
471 }
472
473 let request_id = current_prediction.prediction.id.to_string();
474 let model_version = current_prediction.prediction.model_version;
475 let require_auth = custom_accept_url.is_none();
476 let client = store.client.clone();
477 let llm_token = store.llm_token.clone();
478 let app_version = AppVersion::global(cx);
479
480 cx.background_spawn(async move {
481 let url = if let Some(accept_edits_url) = custom_accept_url {
482 gpui::http_client::Url::parse(&accept_edits_url)?
483 } else {
484 client
485 .http_client()
486 .build_zed_llm_url("/predict_edits/accept", &[])?
487 };
488
489 let response = EditPredictionStore::send_api_request::<()>(
490 move |builder| {
491 let req = builder.uri(url.as_ref()).body(
492 serde_json::to_string(&AcceptEditPredictionBody {
493 request_id: request_id.clone(),
494 model_version: model_version.clone(),
495 })?
496 .into(),
497 );
498 Ok(req?)
499 },
500 client,
501 llm_token,
502 app_version,
503 require_auth,
504 )
505 .await;
506
507 response?;
508 anyhow::Ok(())
509 })
510 .detach_and_log_err(cx);
511}
512
513pub fn compute_edits(
514 old_text: String,
515 new_text: &str,
516 offset: usize,
517 snapshot: &BufferSnapshot,
518) -> Vec<(Range<Anchor>, Arc<str>)> {
519 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
520}
521
522pub fn compute_edits_and_cursor_position(
523 old_text: String,
524 new_text: &str,
525 offset: usize,
526 cursor_offset_in_new_text: Option<usize>,
527 snapshot: &BufferSnapshot,
528) -> (
529 Vec<(Range<Anchor>, Arc<str>)>,
530 Option<PredictedCursorPosition>,
531) {
532 let diffs = text_diff(&old_text, new_text);
533
534 // Delta represents the cumulative change in byte count from all preceding edits.
535 // new_offset = old_offset + delta, so old_offset = new_offset - delta
536 let mut delta: isize = 0;
537 let mut cursor_position: Option<PredictedCursorPosition> = None;
538 let buffer_len = snapshot.len();
539
540 let edits = diffs
541 .iter()
542 .map(|(raw_old_range, new_text)| {
543 // Compute cursor position if it falls within or before this edit.
544 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
545 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
546 let edit_end_in_new = edit_start_in_new + new_text.len();
547
548 if cursor_offset < edit_start_in_new {
549 let cursor_in_old = (cursor_offset as isize - delta) as usize;
550 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
551 cursor_position = Some(PredictedCursorPosition::at_anchor(
552 snapshot.anchor_after(buffer_offset),
553 ));
554 } else if cursor_offset < edit_end_in_new {
555 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
556 let offset_within_insertion = cursor_offset - edit_start_in_new;
557 cursor_position = Some(PredictedCursorPosition::new(
558 snapshot.anchor_before(buffer_offset),
559 offset_within_insertion,
560 ));
561 }
562
563 delta += new_text.len() as isize - raw_old_range.len() as isize;
564 }
565
566 // Compute the edit with prefix/suffix trimming.
567 let mut old_range = raw_old_range.clone();
568 let old_slice = &old_text[old_range.clone()];
569
570 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
571 let suffix_len = common_prefix(
572 old_slice[prefix_len..].chars().rev(),
573 new_text[prefix_len..].chars().rev(),
574 );
575
576 old_range.start += offset;
577 old_range.end += offset;
578 old_range.start += prefix_len;
579 old_range.end -= suffix_len;
580
581 old_range.start = old_range.start.min(buffer_len);
582 old_range.end = old_range.end.min(buffer_len);
583
584 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
585 let range = if old_range.is_empty() {
586 let anchor = snapshot.anchor_after(old_range.start);
587 anchor..anchor
588 } else {
589 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
590 };
591 (range, new_text)
592 })
593 .collect();
594
595 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
596 let cursor_in_old = (cursor_offset as isize - delta) as usize;
597 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
598 cursor_position = Some(PredictedCursorPosition::at_anchor(
599 snapshot.anchor_after(buffer_offset),
600 ));
601 }
602
603 (edits, cursor_position)
604}
605
606fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
607 a.zip(b)
608 .take_while(|(a, b)| a == b)
609 .map(|(a, _)| a.len_utf8())
610 .sum()
611}