1use crate::cursor_excerpt::compute_excerpt_ranges;
2use crate::prediction::EditPredictionResult;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
6};
7use anyhow::{Context as _, Result};
8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use edit_prediction_types::PredictedCursorPosition;
11use futures::AsyncReadExt as _;
12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
15use release_channel::AppVersion;
16use text::{Anchor, Bias};
17
18use std::env;
19use std::ops::Range;
20use std::{path::Path, sync::Arc, time::Instant};
21use zeta_prompt::{
22 CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
23 format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
24 zeta1::{self, EDITABLE_REGION_END_MARKER},
25};
26
27pub fn request_prediction_with_zeta(
28 store: &mut EditPredictionStore,
29 EditPredictionModelInput {
30 buffer,
31 snapshot,
32 position,
33 related_files,
34 events,
35 debug_tx,
36 trigger,
37 project,
38 can_collect_data,
39 is_open_source,
40 ..
41 }: EditPredictionModelInput,
42 preferred_model: Option<EditPredictionModelKind>,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let settings = &all_language_settings(None, cx).edit_predictions;
46 let provider = settings.provider;
47 let custom_server_settings = match provider {
48 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
49 settings::EditPredictionProvider::OpenAiCompatibleApi => {
50 settings.open_ai_compatible_api.clone()
51 }
52 _ => None,
53 };
54
55 let http_client = cx.http_client();
56 let buffer_snapshotted_at = Instant::now();
57 let raw_config = store.zeta2_raw_config().cloned();
58
59 let excerpt_path: Arc<Path> = snapshot
60 .file()
61 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
62 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
63
64 let client = store.client.clone();
65 let llm_token = store.llm_token.clone();
66 let app_version = AppVersion::global(cx);
67
68 let request_task = cx.background_spawn({
69 async move {
70 let zeta_version = raw_config
71 .as_ref()
72 .map(|config| config.format)
73 .unwrap_or(ZetaFormat::default());
74
75 let cursor_offset = position.to_offset(&snapshot);
76 let editable_range_in_excerpt: Range<usize>;
77 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
78 &snapshot,
79 related_files,
80 events,
81 excerpt_path,
82 cursor_offset,
83 zeta_version,
84 preferred_model,
85 is_open_source,
86 can_collect_data,
87 );
88
89 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
90 return Ok((None, None));
91 }
92
93 let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1);
94 let excerpt_ranges = prompt_input
95 .excerpt_ranges
96 .as_ref()
97 .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?;
98
99 if let Some(debug_tx) = &debug_tx {
100 let prompt = if is_zeta1 {
101 zeta1::format_zeta1_from_input(
102 &prompt_input,
103 excerpt_ranges.editable_350.clone(),
104 excerpt_ranges.editable_350_context_150.clone(),
105 )
106 } else {
107 format_zeta_prompt(&prompt_input, zeta_version)
108 };
109 debug_tx
110 .unbounded_send(DebugEvent::EditPredictionStarted(
111 EditPredictionStartedDebugEvent {
112 buffer: buffer.downgrade(),
113 prompt: Some(prompt),
114 position,
115 },
116 ))
117 .ok();
118 }
119
120 log::trace!("Sending edit prediction request");
121
122 let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
123 &custom_server_settings
124 {
125 let max_tokens = custom_settings.max_output_tokens * 4;
126
127 if is_zeta1 {
128 let ranges = excerpt_ranges;
129 let prompt = zeta1::format_zeta1_from_input(
130 &prompt_input,
131 ranges.editable_350.clone(),
132 ranges.editable_350_context_150.clone(),
133 );
134 editable_range_in_excerpt = ranges.editable_350.clone();
135 let stop_tokens = vec![
136 EDITABLE_REGION_END_MARKER.to_string(),
137 format!("{EDITABLE_REGION_END_MARKER}\n"),
138 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
139 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
140 ];
141
142 let (response_text, request_id) = send_custom_server_request(
143 provider,
144 custom_settings,
145 prompt,
146 max_tokens,
147 stop_tokens,
148 &http_client,
149 )
150 .await?;
151
152 let request_id = EditPredictionId(request_id.into());
153 let output_text = zeta1::clean_zeta1_model_output(&response_text);
154
155 (request_id, output_text, None, None)
156 } else {
157 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
158 let prefill = get_prefill(&prompt_input, zeta_version);
159 let prompt = format!("{prompt}{prefill}");
160
161 editable_range_in_excerpt = prompt_input
162 .excerpt_ranges
163 .as_ref()
164 .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0)
165 .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
166
167 let (response_text, request_id) = send_custom_server_request(
168 provider,
169 custom_settings,
170 prompt,
171 max_tokens,
172 vec![],
173 &http_client,
174 )
175 .await?;
176
177 let request_id = EditPredictionId(request_id.into());
178 let output_text = if response_text.is_empty() {
179 None
180 } else {
181 let output = format!("{prefill}{response_text}");
182 Some(clean_zeta2_model_output(&output, zeta_version).to_string())
183 };
184
185 (request_id, output_text, None, None)
186 }
187 } else if let Some(config) = &raw_config {
188 let prompt = format_zeta_prompt(&prompt_input, config.format);
189 let prefill = get_prefill(&prompt_input, config.format);
190 let prompt = format!("{prompt}{prefill}");
191 let request = RawCompletionRequest {
192 model: config.model_id.clone().unwrap_or_default(),
193 prompt,
194 temperature: None,
195 stop: vec![],
196 max_tokens: Some(2048),
197 environment: Some(config.format.to_string().to_lowercase()),
198 };
199
200 editable_range_in_excerpt = prompt_input
201 .excerpt_ranges
202 .as_ref()
203 .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1)
204 .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
205
206 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
207 request,
208 client,
209 None,
210 llm_token,
211 app_version,
212 )
213 .await?;
214
215 let request_id = EditPredictionId(response.id.clone().into());
216 let output_text = response.choices.pop().map(|choice| {
217 let response = &choice.text;
218 let output = format!("{prefill}{response}");
219 clean_zeta2_model_output(&output, config.format).to_string()
220 });
221
222 (request_id, output_text, None, usage)
223 } else {
224 // Use V3 endpoint - server handles model/version selection and suffix stripping
225 let (response, usage) = EditPredictionStore::send_v3_request(
226 prompt_input.clone(),
227 client,
228 llm_token,
229 app_version,
230 trigger,
231 )
232 .await?;
233
234 let request_id = EditPredictionId(response.request_id.into());
235 let output_text = if response.output.is_empty() {
236 None
237 } else {
238 Some(response.output)
239 };
240 editable_range_in_excerpt = response.editable_range;
241 let model_version = response.model_version;
242
243 (request_id, output_text, model_version, usage)
244 };
245
246 let received_response_at = Instant::now();
247
248 log::trace!("Got edit prediction response");
249
250 let Some(mut output_text) = output_text else {
251 return Ok((Some((request_id, None, model_version)), usage));
252 };
253
254 // Client-side cursor marker processing (applies to both raw and v3 responses)
255 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
256 if let Some(offset) = cursor_offset_in_output {
257 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
258 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
259 }
260
261 if let Some(debug_tx) = &debug_tx {
262 debug_tx
263 .unbounded_send(DebugEvent::EditPredictionFinished(
264 EditPredictionFinishedDebugEvent {
265 buffer: buffer.downgrade(),
266 position,
267 model_output: Some(output_text.clone()),
268 },
269 ))
270 .ok();
271 }
272
273 let editable_range_in_buffer = editable_range_in_excerpt.start
274 + full_context_offset_range.start
275 ..editable_range_in_excerpt.end + full_context_offset_range.start;
276
277 let mut old_text = snapshot
278 .text_for_range(editable_range_in_buffer.clone())
279 .collect::<String>();
280
281 if !output_text.is_empty() && !output_text.ends_with('\n') {
282 output_text.push('\n');
283 }
284 if !old_text.is_empty() && !old_text.ends_with('\n') {
285 old_text.push('\n');
286 }
287
288 let (edits, cursor_position) = compute_edits_and_cursor_position(
289 old_text,
290 &output_text,
291 editable_range_in_buffer.start,
292 cursor_offset_in_output,
293 &snapshot,
294 );
295
296 anyhow::Ok((
297 Some((
298 request_id,
299 Some((
300 prompt_input,
301 buffer,
302 snapshot.clone(),
303 edits,
304 cursor_position,
305 received_response_at,
306 editable_range_in_buffer,
307 )),
308 model_version,
309 )),
310 usage,
311 ))
312 }
313 });
314
315 cx.spawn(async move |this, cx| {
316 let Some((id, prediction, model_version)) =
317 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
318 else {
319 return Ok(None);
320 };
321
322 let Some((
323 inputs,
324 edited_buffer,
325 edited_buffer_snapshot,
326 edits,
327 cursor_position,
328 received_response_at,
329 editable_range_in_buffer,
330 )) = prediction
331 else {
332 return Ok(Some(EditPredictionResult {
333 id,
334 prediction: Err(EditPredictionRejectReason::Empty),
335 }));
336 };
337
338 if can_collect_data {
339 this.update(cx, |this, cx| {
340 this.enqueue_settled_prediction(
341 id.clone(),
342 &project,
343 &edited_buffer,
344 &edited_buffer_snapshot,
345 editable_range_in_buffer,
346 cx,
347 );
348 })
349 .ok();
350 }
351
352 Ok(Some(
353 EditPredictionResult::new(
354 id,
355 &edited_buffer,
356 &edited_buffer_snapshot,
357 edits.into(),
358 cursor_position,
359 buffer_snapshotted_at,
360 received_response_at,
361 inputs,
362 model_version,
363 cx,
364 )
365 .await,
366 ))
367 })
368}
369
370pub fn zeta2_prompt_input(
371 snapshot: &language::BufferSnapshot,
372 related_files: Vec<zeta_prompt::RelatedFile>,
373 events: Vec<Arc<zeta_prompt::Event>>,
374 excerpt_path: Arc<Path>,
375 cursor_offset: usize,
376 zeta_format: ZetaFormat,
377 preferred_model: Option<EditPredictionModelKind>,
378 is_open_source: bool,
379 can_collect_data: bool,
380) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
381 let cursor_point = cursor_offset.to_point(snapshot);
382
383 let (full_context, full_context_offset_range, excerpt_ranges) =
384 compute_excerpt_ranges(cursor_point, snapshot);
385
386 let full_context_start_offset = full_context_offset_range.start;
387 let full_context_start_row = full_context.start.row;
388
389 let editable_offset_range = match preferred_model {
390 Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(),
391 _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0,
392 };
393
394 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
395
396 let prompt_input = zeta_prompt::ZetaPromptInput {
397 cursor_path: excerpt_path,
398 cursor_excerpt: snapshot
399 .text_for_range(full_context)
400 .collect::<String>()
401 .into(),
402 editable_range_in_excerpt: editable_offset_range,
403 cursor_offset_in_excerpt,
404 excerpt_start_row: Some(full_context_start_row),
405 events,
406 related_files,
407 excerpt_ranges: Some(excerpt_ranges),
408 preferred_model,
409 in_open_source_repo: is_open_source,
410 can_collect_data,
411 };
412 (full_context_offset_range, prompt_input)
413}
414
415pub(crate) async fn send_custom_server_request(
416 provider: settings::EditPredictionProvider,
417 settings: &OpenAiCompatibleEditPredictionSettings,
418 prompt: String,
419 max_tokens: u32,
420 stop_tokens: Vec<String>,
421 http_client: &Arc<dyn http_client::HttpClient>,
422) -> Result<(String, String)> {
423 match provider {
424 settings::EditPredictionProvider::Ollama => {
425 let response =
426 ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
427 .await?;
428 Ok((response.response, response.created_at))
429 }
430 _ => {
431 let request = RawCompletionRequest {
432 model: settings.model.clone(),
433 prompt,
434 max_tokens: Some(max_tokens),
435 temperature: None,
436 stop: stop_tokens
437 .into_iter()
438 .map(std::borrow::Cow::Owned)
439 .collect(),
440 environment: None,
441 };
442
443 let request_body = serde_json::to_string(&request)?;
444 let http_request = http_client::Request::builder()
445 .method(http_client::Method::POST)
446 .uri(settings.api_url.as_ref())
447 .header("Content-Type", "application/json")
448 .body(http_client::AsyncBody::from(request_body))?;
449
450 let mut response = http_client.send(http_request).await?;
451 let status = response.status();
452
453 if !status.is_success() {
454 let mut body = String::new();
455 response.body_mut().read_to_string(&mut body).await?;
456 anyhow::bail!("custom server error: {} - {}", status, body);
457 }
458
459 let mut body = String::new();
460 response.body_mut().read_to_string(&mut body).await?;
461
462 let parsed: RawCompletionResponse =
463 serde_json::from_str(&body).context("Failed to parse completion response")?;
464 let text = parsed
465 .choices
466 .into_iter()
467 .next()
468 .map(|choice| choice.text)
469 .unwrap_or_default();
470 Ok((text, parsed.id))
471 }
472 }
473}
474
475pub(crate) fn edit_prediction_accepted(
476 store: &EditPredictionStore,
477 current_prediction: CurrentEditPrediction,
478 cx: &App,
479) {
480 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
481 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
482 return;
483 }
484
485 let request_id = current_prediction.prediction.id.to_string();
486 let model_version = current_prediction.prediction.model_version;
487 let require_auth = custom_accept_url.is_none();
488 let client = store.client.clone();
489 let llm_token = store.llm_token.clone();
490 let app_version = AppVersion::global(cx);
491
492 cx.background_spawn(async move {
493 let url = if let Some(accept_edits_url) = custom_accept_url {
494 gpui::http_client::Url::parse(&accept_edits_url)?
495 } else {
496 client
497 .http_client()
498 .build_zed_llm_url("/predict_edits/accept", &[])?
499 };
500
501 let response = EditPredictionStore::send_api_request::<()>(
502 move |builder| {
503 let req = builder.uri(url.as_ref()).body(
504 serde_json::to_string(&AcceptEditPredictionBody {
505 request_id: request_id.clone(),
506 model_version: model_version.clone(),
507 })?
508 .into(),
509 );
510 Ok(req?)
511 },
512 client,
513 llm_token,
514 app_version,
515 require_auth,
516 )
517 .await;
518
519 response?;
520 anyhow::Ok(())
521 })
522 .detach_and_log_err(cx);
523}
524
525pub fn compute_edits(
526 old_text: String,
527 new_text: &str,
528 offset: usize,
529 snapshot: &BufferSnapshot,
530) -> Vec<(Range<Anchor>, Arc<str>)> {
531 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
532}
533
534pub fn compute_edits_and_cursor_position(
535 old_text: String,
536 new_text: &str,
537 offset: usize,
538 cursor_offset_in_new_text: Option<usize>,
539 snapshot: &BufferSnapshot,
540) -> (
541 Vec<(Range<Anchor>, Arc<str>)>,
542 Option<PredictedCursorPosition>,
543) {
544 let diffs = text_diff(&old_text, new_text);
545
546 // Delta represents the cumulative change in byte count from all preceding edits.
547 // new_offset = old_offset + delta, so old_offset = new_offset - delta
548 let mut delta: isize = 0;
549 let mut cursor_position: Option<PredictedCursorPosition> = None;
550 let buffer_len = snapshot.len();
551
552 let edits = diffs
553 .iter()
554 .map(|(raw_old_range, new_text)| {
555 // Compute cursor position if it falls within or before this edit.
556 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
557 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
558 let edit_end_in_new = edit_start_in_new + new_text.len();
559
560 if cursor_offset < edit_start_in_new {
561 let cursor_in_old = (cursor_offset as isize - delta) as usize;
562 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
563 cursor_position = Some(PredictedCursorPosition::at_anchor(
564 snapshot.anchor_after(buffer_offset),
565 ));
566 } else if cursor_offset < edit_end_in_new {
567 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
568 let offset_within_insertion = cursor_offset - edit_start_in_new;
569 cursor_position = Some(PredictedCursorPosition::new(
570 snapshot.anchor_before(buffer_offset),
571 offset_within_insertion,
572 ));
573 }
574
575 delta += new_text.len() as isize - raw_old_range.len() as isize;
576 }
577
578 // Compute the edit with prefix/suffix trimming.
579 let mut old_range = raw_old_range.clone();
580 let old_slice = &old_text[old_range.clone()];
581
582 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
583 let suffix_len = common_prefix(
584 old_slice[prefix_len..].chars().rev(),
585 new_text[prefix_len..].chars().rev(),
586 );
587
588 old_range.start += offset;
589 old_range.end += offset;
590 old_range.start += prefix_len;
591 old_range.end -= suffix_len;
592
593 old_range.start = old_range.start.min(buffer_len);
594 old_range.end = old_range.end.min(buffer_len);
595
596 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
597 let range = if old_range.is_empty() {
598 let anchor = snapshot.anchor_after(old_range.start);
599 anchor..anchor
600 } else {
601 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
602 };
603 (range, new_text)
604 })
605 .collect();
606
607 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
608 let cursor_in_old = (cursor_offset as isize - delta) as usize;
609 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
610 cursor_position = Some(PredictedCursorPosition::at_anchor(
611 snapshot.anchor_after(buffer_offset),
612 ));
613 }
614
615 (edits, cursor_position)
616}
617
618fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
619 a.zip(b)
620 .take_while(|(a, b)| a == b)
621 .map(|(a, _)| a.len_utf8())
622 .sum()
623}