1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc};
26use zeta_prompt::{
27 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 prompt_input_contains_special_tokens, stop_tokens_for_format,
29 zeta1::{self, EDITABLE_REGION_END_MARKER},
30};
31
32use crate::open_ai_compatible::{
33 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
34};
35
36pub fn request_prediction_with_zeta(
37 store: &mut EditPredictionStore,
38 EditPredictionModelInput {
39 buffer,
40 snapshot,
41 position,
42 related_files,
43 events,
44 debug_tx,
45 trigger,
46 project,
47 diagnostic_search_range,
48 can_collect_data,
49 is_open_source,
50 ..
51 }: EditPredictionModelInput,
52 capture_data: Option<Vec<StoredEvent>>,
53 cx: &mut Context<EditPredictionStore>,
54) -> Task<Result<Option<EditPredictionResult>>> {
55 let settings = &all_language_settings(None, cx).edit_predictions;
56 let provider = settings.provider;
57 let custom_server_settings = match provider {
58 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
59 settings::EditPredictionProvider::OpenAiCompatibleApi => {
60 settings.open_ai_compatible_api.clone()
61 }
62 _ => None,
63 };
64
65 let http_client = cx.http_client();
66 let request_start = cx.background_executor().now();
67 let raw_config = store.zeta2_raw_config().cloned();
68 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
69 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
70
71 let excerpt_path: Arc<Path> = snapshot
72 .file()
73 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
74 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
75
76 let repo_url = if can_collect_data {
77 let buffer_id = buffer.read(cx).remote_id();
78 project
79 .read(cx)
80 .git_store()
81 .read(cx)
82 .repository_and_path_for_buffer_id(buffer_id, cx)
83 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
84 } else {
85 None
86 };
87
88 let client = store.client.clone();
89 let llm_token = store.llm_token.clone();
90 let organization_id = store
91 .user_store
92 .read(cx)
93 .current_organization()
94 .map(|organization| organization.id.clone());
95 let app_version = AppVersion::global(cx);
96
97 struct Prediction {
98 prompt_input: ZetaPromptInput,
99 buffer: Entity<Buffer>,
100 snapshot: BufferSnapshot,
101 edits: Vec<(Range<Anchor>, Arc<str>)>,
102 cursor_position: Option<PredictedCursorPosition>,
103 editable_range_in_buffer: Range<usize>,
104 model_version: Option<String>,
105 }
106
107 let request_task = cx.background_spawn({
108 async move {
109 let zeta_version = raw_config
110 .as_ref()
111 .map(|config| config.format)
112 .unwrap_or(ZetaFormat::default());
113
114 let cursor_offset = position.to_offset(&snapshot);
115 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
116 &snapshot,
117 related_files,
118 events,
119 diagnostic_search_range,
120 excerpt_path,
121 cursor_offset,
122 preferred_experiment,
123 is_open_source,
124 can_collect_data,
125 repo_url,
126 );
127
128 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
129 return Err(anyhow::anyhow!("prompt contains special tokens"));
130 }
131
132 let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
133
134 if let Some(debug_tx) = &debug_tx {
135 debug_tx
136 .unbounded_send(DebugEvent::EditPredictionStarted(
137 EditPredictionStartedDebugEvent {
138 buffer: buffer.downgrade(),
139 prompt: formatted_prompt.clone(),
140 position,
141 },
142 ))
143 .ok();
144 }
145
146 log::trace!("Sending edit prediction request");
147
148 let Some((request_id, output, model_version, usage)) =
149 (if let Some(custom_settings) = &custom_server_settings {
150 let max_tokens = custom_settings.max_output_tokens * 4;
151
152 Some(match custom_settings.prompt_format {
153 EditPredictionPromptFormat::Zeta => {
154 let ranges = &prompt_input.excerpt_ranges;
155 let editable_range_in_excerpt = ranges.editable_350.clone();
156 let prompt = zeta1::format_zeta1_from_input(
157 &prompt_input,
158 editable_range_in_excerpt.clone(),
159 ranges.editable_350_context_150.clone(),
160 );
161 let stop_tokens = vec![
162 EDITABLE_REGION_END_MARKER.to_string(),
163 format!("{EDITABLE_REGION_END_MARKER}\n"),
164 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
165 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
166 ];
167
168 let (response_text, request_id) = send_custom_server_request(
169 provider,
170 custom_settings,
171 prompt,
172 max_tokens,
173 stop_tokens,
174 open_ai_compatible_api_key.clone(),
175 &http_client,
176 )
177 .await?;
178
179 let request_id = EditPredictionId(request_id.into());
180 let output_text = zeta1::clean_zeta1_model_output(&response_text);
181 let parsed_output = output_text.map(|text| ParsedOutput {
182 new_editable_region: text,
183 range_in_excerpt: editable_range_in_excerpt,
184 });
185
186 (request_id, parsed_output, None, None)
187 }
188 EditPredictionPromptFormat::Zeta2 => {
189 let Some(prompt) = formatted_prompt.clone() else {
190 return Ok((None, None));
191 };
192 let prefill = get_prefill(&prompt_input, zeta_version);
193 let prompt = format!("{prompt}{prefill}");
194
195 let (response_text, request_id) = send_custom_server_request(
196 provider,
197 custom_settings,
198 prompt,
199 max_tokens,
200 stop_tokens_for_format(zeta_version)
201 .iter()
202 .map(|token| token.to_string())
203 .collect(),
204 open_ai_compatible_api_key.clone(),
205 &http_client,
206 )
207 .await?;
208
209 let request_id = EditPredictionId(request_id.into());
210 let output_text = if response_text.is_empty() {
211 None
212 } else {
213 let output = format!("{prefill}{response_text}");
214 Some(parse_zeta2_model_output(
215 &output,
216 zeta_version,
217 &prompt_input,
218 )?)
219 };
220
221 (request_id, output_text, None, None)
222 }
223 _ => anyhow::bail!("unsupported prompt format"),
224 })
225 } else if let Some(config) = &raw_config {
226 let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
227 return Ok((None, None));
228 };
229 let prefill = get_prefill(&prompt_input, config.format);
230 let prompt = format!("{prompt}{prefill}");
231 let environment = config
232 .environment
233 .clone()
234 .or_else(|| Some(config.format.to_string().to_lowercase()));
235 let request = RawCompletionRequest {
236 model: config.model_id.clone().unwrap_or_default(),
237 prompt,
238 temperature: None,
239 stop: stop_tokens_for_format(config.format)
240 .iter()
241 .map(|token| std::borrow::Cow::Borrowed(*token))
242 .collect(),
243 max_tokens: Some(2048),
244 environment,
245 };
246
247 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
248 request,
249 client,
250 None,
251 llm_token,
252 organization_id,
253 app_version,
254 )
255 .await?;
256
257 let request_id = EditPredictionId(response.id.clone().into());
258 let output = if let Some(choice) = response.choices.pop() {
259 let response = &choice.text;
260 let output = format!("{prefill}{response}");
261 Some(parse_zeta2_model_output(
262 &output,
263 config.format,
264 &prompt_input,
265 )?)
266 } else {
267 None
268 };
269
270 Some((request_id, output, None, usage))
271 } else {
272 // Use V3 endpoint - server handles model/version selection and suffix stripping
273 let (response, usage) = EditPredictionStore::send_v3_request(
274 prompt_input.clone(),
275 client,
276 llm_token,
277 organization_id,
278 app_version,
279 trigger,
280 )
281 .await?;
282
283 let request_id = EditPredictionId(response.request_id.into());
284 let output_text = Some(response.output).filter(|s| !s.is_empty());
285 let model_version = response.model_version;
286 let parsed_output = ParsedOutput {
287 new_editable_region: output_text.unwrap_or_default(),
288 range_in_excerpt: response.editable_range,
289 };
290
291 Some((request_id, Some(parsed_output), model_version, usage))
292 })
293 else {
294 return Ok((None, None));
295 };
296
297 log::trace!("Got edit prediction response");
298
299 let Some(ParsedOutput {
300 new_editable_region: mut output_text,
301 range_in_excerpt: editable_range_in_excerpt,
302 }) = output
303 else {
304 return Ok((Some((request_id, None)), None));
305 };
306
307 let editable_range_in_buffer = editable_range_in_excerpt.start
308 + full_context_offset_range.start
309 ..editable_range_in_excerpt.end + full_context_offset_range.start;
310
311 let mut old_text = snapshot
312 .text_for_range(editable_range_in_buffer.clone())
313 .collect::<String>();
314
315 // Client-side cursor marker processing (applies to both raw and v3 responses)
316 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
317 if let Some(offset) = cursor_offset_in_output {
318 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
319 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
320 }
321
322 if let Some(debug_tx) = &debug_tx {
323 debug_tx
324 .unbounded_send(DebugEvent::EditPredictionFinished(
325 EditPredictionFinishedDebugEvent {
326 buffer: buffer.downgrade(),
327 position,
328 model_output: Some(output_text.clone()),
329 },
330 ))
331 .ok();
332 }
333
334 if !output_text.is_empty() && !output_text.ends_with('\n') {
335 output_text.push('\n');
336 }
337 if !old_text.is_empty() && !old_text.ends_with('\n') {
338 old_text.push('\n');
339 }
340
341 let (edits, cursor_position) = compute_edits_and_cursor_position(
342 old_text,
343 &output_text,
344 editable_range_in_buffer.start,
345 cursor_offset_in_output,
346 &snapshot,
347 );
348
349 anyhow::Ok((
350 Some((
351 request_id,
352 Some(Prediction {
353 prompt_input,
354 buffer,
355 snapshot: snapshot.clone(),
356 edits,
357 cursor_position,
358 editable_range_in_buffer,
359 model_version,
360 }),
361 )),
362 usage,
363 ))
364 }
365 });
366
367 cx.spawn(async move |this, cx| {
368 let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
369 return Ok(None);
370 };
371 let request_duration = cx.background_executor().now() - request_start;
372
373 let Some(Prediction {
374 prompt_input: inputs,
375 buffer: edited_buffer,
376 snapshot: edited_buffer_snapshot,
377 edits,
378 cursor_position,
379 editable_range_in_buffer,
380 model_version,
381 }) = prediction
382 else {
383 return Ok(Some(EditPredictionResult {
384 id,
385 e2e_latency: request_duration,
386 prediction: Err(EditPredictionRejectReason::Empty),
387 }));
388 };
389
390 if can_collect_data {
391 let weak_this = this.clone();
392 let id = id.clone();
393 let edited_buffer = edited_buffer.clone();
394 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
395 let example_task = capture_data.and_then(|stored_events| {
396 cx.update(|cx| {
397 crate::capture_example(
398 project.clone(),
399 edited_buffer.clone(),
400 position,
401 stored_events,
402 false,
403 cx,
404 )
405 })
406 });
407 cx.spawn(async move |cx| {
408 let example_spec = if let Some(task) = example_task {
409 task.await.ok()
410 } else {
411 None
412 };
413
414 weak_this
415 .update(cx, |this, cx| {
416 this.enqueue_settled_prediction(
417 id.clone(),
418 &project,
419 &edited_buffer,
420 &edited_buffer_snapshot,
421 editable_range_in_buffer,
422 example_spec,
423 request_duration,
424 cx,
425 );
426 })
427 .ok();
428 })
429 .detach();
430 }
431
432 Ok(Some(
433 EditPredictionResult::new(
434 id,
435 &edited_buffer,
436 &edited_buffer_snapshot,
437 edits.into(),
438 cursor_position,
439 inputs,
440 model_version,
441 request_duration,
442 cx,
443 )
444 .await,
445 ))
446 })
447}
448
449fn handle_api_response<T>(
450 this: &WeakEntity<EditPredictionStore>,
451 response: Result<(T, Option<client::EditPredictionUsage>)>,
452 cx: &mut gpui::AsyncApp,
453) -> Result<T> {
454 match response {
455 Ok((data, usage)) => {
456 if let Some(usage) = usage {
457 this.update(cx, |this, cx| {
458 this.user_store.update(cx, |user_store, cx| {
459 user_store.update_edit_prediction_usage(usage, cx);
460 });
461 })
462 .ok();
463 }
464 Ok(data)
465 }
466 Err(err) => {
467 if err.is::<ZedUpdateRequiredError>() {
468 cx.update(|cx| {
469 this.update(cx, |this, _cx| {
470 this.update_required = true;
471 })
472 .ok();
473
474 let error_message: SharedString = err.to_string().into();
475 show_app_notification(
476 NotificationId::unique::<ZedUpdateRequiredError>(),
477 cx,
478 move |cx| {
479 cx.new(|cx| {
480 ErrorMessagePrompt::new(error_message.clone(), cx)
481 .with_link_button("Update Zed", "https://zed.dev/releases")
482 })
483 },
484 );
485 });
486 }
487 Err(err)
488 }
489 }
490}
491
492pub(crate) fn active_buffer_diagnostics(
493 snapshot: &language::BufferSnapshot,
494 diagnostic_search_range: Range<Point>,
495 additional_context_token_count: usize,
496) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
497 snapshot
498 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
499 .map(|entry| {
500 let severity = match entry.diagnostic.severity {
501 DiagnosticSeverity::ERROR => Some(1),
502 DiagnosticSeverity::WARNING => Some(2),
503 DiagnosticSeverity::INFORMATION => Some(3),
504 DiagnosticSeverity::HINT => Some(4),
505 _ => None,
506 };
507 let diagnostic_point_range = entry.range.clone();
508 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
509 snapshot,
510 diagnostic_point_range.clone(),
511 additional_context_token_count,
512 );
513 let snippet = snapshot
514 .text_for_range(snippet_point_range.clone())
515 .collect::<String>();
516 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
517 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
518 zeta_prompt::ActiveBufferDiagnostic {
519 severity,
520 message: entry.diagnostic.message.clone(),
521 snippet,
522 snippet_buffer_row_range: diagnostic_point_range.start.row
523 ..diagnostic_point_range.end.row,
524 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
525 ..diagnostic_offset_range.end - snippet_start_offset,
526 }
527 })
528 .collect()
529}
530
531pub fn zeta2_prompt_input(
532 snapshot: &language::BufferSnapshot,
533 related_files: Vec<zeta_prompt::RelatedFile>,
534 events: Vec<Arc<zeta_prompt::Event>>,
535 diagnostic_search_range: Range<Point>,
536 excerpt_path: Arc<Path>,
537 cursor_offset: usize,
538 preferred_experiment: Option<String>,
539 is_open_source: bool,
540 can_collect_data: bool,
541 repo_url: Option<String>,
542) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
543 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
544 compute_cursor_excerpt(snapshot, cursor_offset);
545
546 let cursor_excerpt: Arc<str> = snapshot
547 .text_for_range(excerpt_point_range.clone())
548 .collect::<String>()
549 .into();
550 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
551 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
552 &cursor_excerpt,
553 cursor_offset_in_excerpt,
554 &syntax_ranges,
555 );
556
557 let active_buffer_diagnostics =
558 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
559
560 let prompt_input = zeta_prompt::ZetaPromptInput {
561 cursor_path: excerpt_path,
562 cursor_excerpt,
563 cursor_offset_in_excerpt,
564 excerpt_start_row: Some(excerpt_point_range.start.row),
565 events,
566 related_files: Some(related_files),
567 active_buffer_diagnostics,
568 excerpt_ranges,
569 syntax_ranges: Some(syntax_ranges),
570 experiment: preferred_experiment,
571 in_open_source_repo: is_open_source,
572 can_collect_data,
573 repo_url,
574 };
575 (excerpt_offset_range, prompt_input)
576}
577
578pub(crate) fn edit_prediction_accepted(
579 store: &EditPredictionStore,
580 current_prediction: CurrentEditPrediction,
581 cx: &App,
582) {
583 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
584 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
585 return;
586 }
587
588 let request_id = current_prediction.prediction.id.to_string();
589 let model_version = current_prediction.prediction.model_version;
590 let e2e_latency = current_prediction.e2e_latency;
591 let require_auth = custom_accept_url.is_none();
592 let client = store.client.clone();
593 let llm_token = store.llm_token.clone();
594 let organization_id = store
595 .user_store
596 .read(cx)
597 .current_organization()
598 .map(|organization| organization.id.clone());
599 let app_version = AppVersion::global(cx);
600
601 cx.background_spawn(async move {
602 let url = if let Some(accept_edits_url) = custom_accept_url {
603 gpui::http_client::Url::parse(&accept_edits_url)?
604 } else {
605 client
606 .http_client()
607 .build_zed_llm_url("/predict_edits/accept", &[])?
608 };
609
610 let response = EditPredictionStore::send_api_request::<()>(
611 move |builder| {
612 let req = builder.uri(url.as_ref()).body(
613 serde_json::to_string(&AcceptEditPredictionBody {
614 request_id: request_id.clone(),
615 model_version: model_version.clone(),
616 e2e_latency_ms: Some(e2e_latency.as_millis()),
617 })?
618 .into(),
619 );
620 Ok(req?)
621 },
622 client,
623 llm_token,
624 organization_id,
625 app_version,
626 require_auth,
627 )
628 .await;
629
630 response?;
631 anyhow::Ok(())
632 })
633 .detach_and_log_err(cx);
634}
635
636pub fn compute_edits(
637 old_text: String,
638 new_text: &str,
639 offset: usize,
640 snapshot: &BufferSnapshot,
641) -> Vec<(Range<Anchor>, Arc<str>)> {
642 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
643}
644
645pub fn compute_edits_and_cursor_position(
646 old_text: String,
647 new_text: &str,
648 offset: usize,
649 cursor_offset_in_new_text: Option<usize>,
650 snapshot: &BufferSnapshot,
651) -> (
652 Vec<(Range<Anchor>, Arc<str>)>,
653 Option<PredictedCursorPosition>,
654) {
655 let diffs = text_diff(&old_text, new_text);
656
657 // Delta represents the cumulative change in byte count from all preceding edits.
658 // new_offset = old_offset + delta, so old_offset = new_offset - delta
659 let mut delta: isize = 0;
660 let mut cursor_position: Option<PredictedCursorPosition> = None;
661 let buffer_len = snapshot.len();
662
663 let edits = diffs
664 .iter()
665 .map(|(raw_old_range, new_text)| {
666 // Compute cursor position if it falls within or before this edit.
667 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
668 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
669 let edit_end_in_new = edit_start_in_new + new_text.len();
670
671 if cursor_offset < edit_start_in_new {
672 let cursor_in_old = (cursor_offset as isize - delta) as usize;
673 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
674 cursor_position = Some(PredictedCursorPosition::at_anchor(
675 snapshot.anchor_after(buffer_offset),
676 ));
677 } else if cursor_offset < edit_end_in_new {
678 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
679 let offset_within_insertion = cursor_offset - edit_start_in_new;
680 cursor_position = Some(PredictedCursorPosition::new(
681 snapshot.anchor_before(buffer_offset),
682 offset_within_insertion,
683 ));
684 }
685
686 delta += new_text.len() as isize - raw_old_range.len() as isize;
687 }
688
689 // Compute the edit with prefix/suffix trimming.
690 let mut old_range = raw_old_range.clone();
691 let old_slice = &old_text[old_range.clone()];
692
693 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
694 let suffix_len = common_prefix(
695 old_slice[prefix_len..].chars().rev(),
696 new_text[prefix_len..].chars().rev(),
697 );
698
699 old_range.start += offset;
700 old_range.end += offset;
701 old_range.start += prefix_len;
702 old_range.end -= suffix_len;
703
704 old_range.start = old_range.start.min(buffer_len);
705 old_range.end = old_range.end.min(buffer_len);
706
707 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
708 let range = if old_range.is_empty() {
709 let anchor = snapshot.anchor_after(old_range.start);
710 anchor..anchor
711 } else {
712 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
713 };
714 (range, new_text)
715 })
716 .collect();
717
718 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
719 let cursor_in_old = (cursor_offset as isize - delta) as usize;
720 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
721 cursor_position = Some(PredictedCursorPosition::at_anchor(
722 snapshot.anchor_after(buffer_offset),
723 ));
724 }
725
726 (edits, cursor_position)
727}
728
729fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
730 a.zip(b)
731 .take_while(|(a, b)| a == b)
732 .map(|(a, _)| a.len_utf8())
733 .sum()
734}