1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc};
26use zeta_prompt::{
27 ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 parsed_output_from_editable_region, prompt_input_contains_special_tokens,
29 stop_tokens_for_format,
30 zeta1::{self, EDITABLE_REGION_END_MARKER},
31};
32
33use crate::open_ai_compatible::{
34 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
35};
36
37pub fn request_prediction_with_zeta(
38 store: &mut EditPredictionStore,
39 EditPredictionModelInput {
40 buffer,
41 snapshot,
42 position,
43 related_files,
44 events,
45 debug_tx,
46 trigger,
47 project,
48 diagnostic_search_range,
49 can_collect_data,
50 is_open_source,
51 ..
52 }: EditPredictionModelInput,
53 capture_data: Option<Vec<StoredEvent>>,
54 cx: &mut Context<EditPredictionStore>,
55) -> Task<Result<Option<EditPredictionResult>>> {
56 let settings = &all_language_settings(None, cx).edit_predictions;
57 let provider = settings.provider;
58 let custom_server_settings = match provider {
59 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
60 settings::EditPredictionProvider::OpenAiCompatibleApi => {
61 settings.open_ai_compatible_api.clone()
62 }
63 _ => None,
64 };
65
66 let http_client = cx.http_client();
67 let request_start = cx.background_executor().now();
68 let raw_config = store.zeta2_raw_config().cloned();
69 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
70 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
71
72 let excerpt_path: Arc<Path> = snapshot
73 .file()
74 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
75 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
76
77 let repo_url = if can_collect_data {
78 let buffer_id = buffer.read(cx).remote_id();
79 project
80 .read(cx)
81 .git_store()
82 .read(cx)
83 .repository_and_path_for_buffer_id(buffer_id, cx)
84 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
85 } else {
86 None
87 };
88
89 let client = store.client.clone();
90 let llm_token = store.llm_token.clone();
91 let organization_id = store
92 .user_store
93 .read(cx)
94 .current_organization()
95 .map(|organization| organization.id.clone());
96 let app_version = AppVersion::global(cx);
97
98 struct Prediction {
99 prompt_input: ZetaPromptInput,
100 buffer: Entity<Buffer>,
101 snapshot: BufferSnapshot,
102 edits: Vec<(Range<Anchor>, Arc<str>)>,
103 cursor_position: Option<PredictedCursorPosition>,
104 editable_range_in_buffer: Range<usize>,
105 model_version: Option<String>,
106 }
107
108 let request_task = cx.background_spawn({
109 async move {
110 let zeta_version = raw_config
111 .as_ref()
112 .map(|config| config.format)
113 .unwrap_or(ZetaFormat::default());
114
115 let cursor_offset = position.to_offset(&snapshot);
116 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
117 &snapshot,
118 related_files,
119 events,
120 diagnostic_search_range,
121 excerpt_path,
122 cursor_offset,
123 preferred_experiment,
124 is_open_source,
125 can_collect_data,
126 repo_url,
127 );
128
129 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
130 return Err(anyhow::anyhow!("prompt contains special tokens"));
131 }
132
133 let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
134
135 if let Some(debug_tx) = &debug_tx {
136 debug_tx
137 .unbounded_send(DebugEvent::EditPredictionStarted(
138 EditPredictionStartedDebugEvent {
139 buffer: buffer.downgrade(),
140 prompt: formatted_prompt.clone(),
141 position,
142 },
143 ))
144 .ok();
145 }
146
147 log::trace!("Sending edit prediction request");
148
149 let Some((request_id, output, model_version, usage)) =
150 (if let Some(custom_settings) = &custom_server_settings {
151 let max_tokens = custom_settings.max_output_tokens * 4;
152
153 Some(match custom_settings.prompt_format {
154 EditPredictionPromptFormat::Zeta => {
155 let ranges = &prompt_input.excerpt_ranges;
156 let editable_range_in_excerpt = ranges.editable_350.clone();
157 let prompt = zeta1::format_zeta1_from_input(
158 &prompt_input,
159 editable_range_in_excerpt.clone(),
160 ranges.editable_350_context_150.clone(),
161 );
162 let stop_tokens = vec![
163 EDITABLE_REGION_END_MARKER.to_string(),
164 format!("{EDITABLE_REGION_END_MARKER}\n"),
165 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
166 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
167 ];
168
169 let (response_text, request_id) = send_custom_server_request(
170 provider,
171 custom_settings,
172 prompt,
173 max_tokens,
174 stop_tokens,
175 open_ai_compatible_api_key.clone(),
176 &http_client,
177 )
178 .await?;
179
180 let request_id = EditPredictionId(request_id.into());
181 let output_text = zeta1::clean_zeta1_model_output(&response_text);
182 let parsed_output = output_text.map(|text| ParsedOutput {
183 new_editable_region: text,
184 range_in_excerpt: editable_range_in_excerpt,
185 cursor_offset_in_new_editable_region: None,
186 });
187
188 (request_id, parsed_output, None, None)
189 }
190 EditPredictionPromptFormat::Zeta2 => {
191 let Some(prompt) = formatted_prompt.clone() else {
192 return Ok((None, None));
193 };
194 let prefill = get_prefill(&prompt_input, zeta_version);
195 let prompt = format!("{prompt}{prefill}");
196
197 let (response_text, request_id) = send_custom_server_request(
198 provider,
199 custom_settings,
200 prompt,
201 max_tokens,
202 stop_tokens_for_format(zeta_version)
203 .iter()
204 .map(|token| token.to_string())
205 .collect(),
206 open_ai_compatible_api_key.clone(),
207 &http_client,
208 )
209 .await?;
210
211 let request_id = EditPredictionId(request_id.into());
212 let output_text = if response_text.is_empty() {
213 None
214 } else {
215 let output = format!("{prefill}{response_text}");
216 Some(parse_zeta2_model_output(
217 &output,
218 zeta_version,
219 &prompt_input,
220 )?)
221 };
222
223 (request_id, output_text, None, None)
224 }
225 _ => anyhow::bail!("unsupported prompt format"),
226 })
227 } else if let Some(config) = &raw_config {
228 let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
229 return Ok((None, None));
230 };
231 let prefill = get_prefill(&prompt_input, config.format);
232 let prompt = format!("{prompt}{prefill}");
233 let environment = config
234 .environment
235 .clone()
236 .or_else(|| Some(config.format.to_string().to_lowercase()));
237 let request = RawCompletionRequest {
238 model: config.model_id.clone().unwrap_or_default(),
239 prompt,
240 temperature: None,
241 stop: stop_tokens_for_format(config.format)
242 .iter()
243 .map(|token| std::borrow::Cow::Borrowed(*token))
244 .collect(),
245 max_tokens: Some(2048),
246 environment,
247 };
248
249 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
250 request,
251 client,
252 None,
253 llm_token,
254 organization_id,
255 app_version,
256 )
257 .await?;
258
259 let request_id = EditPredictionId(response.id.clone().into());
260 let output = if let Some(choice) = response.choices.pop() {
261 let response = &choice.text;
262 let output = format!("{prefill}{response}");
263 Some(parse_zeta2_model_output(
264 &output,
265 config.format,
266 &prompt_input,
267 )?)
268 } else {
269 None
270 };
271
272 Some((request_id, output, None, usage))
273 } else {
274 // Use V3 endpoint - server handles model/version selection and suffix stripping
275 let (response, usage) = EditPredictionStore::send_v3_request(
276 prompt_input.clone(),
277 client,
278 llm_token,
279 organization_id,
280 app_version,
281 trigger,
282 )
283 .await?;
284
285 let request_id = EditPredictionId(response.request_id.into());
286 let output_text = Some(response.output).filter(|s| !s.is_empty());
287 let model_version = response.model_version;
288 let parsed_output = parsed_output_from_editable_region(
289 response.editable_range,
290 output_text.unwrap_or_default(),
291 );
292
293 Some((request_id, Some(parsed_output), model_version, usage))
294 })
295 else {
296 return Ok((None, None));
297 };
298
299 log::trace!("Got edit prediction response");
300
301 let Some(ParsedOutput {
302 new_editable_region: mut output_text,
303 range_in_excerpt: editable_range_in_excerpt,
304 cursor_offset_in_new_editable_region: cursor_offset_in_output,
305 }) = output
306 else {
307 return Ok((Some((request_id, None)), None));
308 };
309
310 let editable_range_in_buffer = editable_range_in_excerpt.start
311 + full_context_offset_range.start
312 ..editable_range_in_excerpt.end + full_context_offset_range.start;
313
314 let mut old_text = snapshot
315 .text_for_range(editable_range_in_buffer.clone())
316 .collect::<String>();
317
318 if let Some(debug_tx) = &debug_tx {
319 debug_tx
320 .unbounded_send(DebugEvent::EditPredictionFinished(
321 EditPredictionFinishedDebugEvent {
322 buffer: buffer.downgrade(),
323 position,
324 model_output: Some(output_text.clone()),
325 },
326 ))
327 .ok();
328 }
329
330 if !output_text.is_empty() && !output_text.ends_with('\n') {
331 output_text.push('\n');
332 }
333 if !old_text.is_empty() && !old_text.ends_with('\n') {
334 old_text.push('\n');
335 }
336
337 let (edits, cursor_position) = compute_edits_and_cursor_position(
338 old_text,
339 &output_text,
340 editable_range_in_buffer.start,
341 cursor_offset_in_output,
342 &snapshot,
343 );
344
345 anyhow::Ok((
346 Some((
347 request_id,
348 Some(Prediction {
349 prompt_input,
350 buffer,
351 snapshot: snapshot.clone(),
352 edits,
353 cursor_position,
354 editable_range_in_buffer,
355 model_version,
356 }),
357 )),
358 usage,
359 ))
360 }
361 });
362
363 cx.spawn(async move |this, cx| {
364 let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
365 return Ok(None);
366 };
367 let request_duration = cx.background_executor().now() - request_start;
368
369 let Some(Prediction {
370 prompt_input: inputs,
371 buffer: edited_buffer,
372 snapshot: edited_buffer_snapshot,
373 edits,
374 cursor_position,
375 editable_range_in_buffer,
376 model_version,
377 }) = prediction
378 else {
379 return Ok(Some(EditPredictionResult {
380 id,
381 e2e_latency: request_duration,
382 prediction: Err(EditPredictionRejectReason::Empty),
383 }));
384 };
385
386 if can_collect_data {
387 let weak_this = this.clone();
388 let id = id.clone();
389 let edited_buffer = edited_buffer.clone();
390 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
391 let example_task = capture_data.and_then(|stored_events| {
392 cx.update(|cx| {
393 crate::capture_example(
394 project.clone(),
395 edited_buffer.clone(),
396 position,
397 stored_events,
398 false,
399 cx,
400 )
401 })
402 });
403 cx.spawn(async move |cx| {
404 let example_spec = if let Some(task) = example_task {
405 task.await.ok()
406 } else {
407 None
408 };
409
410 weak_this
411 .update(cx, |this, cx| {
412 this.enqueue_settled_prediction(
413 id.clone(),
414 &project,
415 &edited_buffer,
416 &edited_buffer_snapshot,
417 editable_range_in_buffer,
418 example_spec,
419 request_duration,
420 cx,
421 );
422 })
423 .ok();
424 })
425 .detach();
426 }
427
428 Ok(Some(
429 EditPredictionResult::new(
430 id,
431 &edited_buffer,
432 &edited_buffer_snapshot,
433 edits.into(),
434 cursor_position,
435 inputs,
436 model_version,
437 request_duration,
438 cx,
439 )
440 .await,
441 ))
442 })
443}
444
445fn handle_api_response<T>(
446 this: &WeakEntity<EditPredictionStore>,
447 response: Result<(T, Option<client::EditPredictionUsage>)>,
448 cx: &mut gpui::AsyncApp,
449) -> Result<T> {
450 match response {
451 Ok((data, usage)) => {
452 if let Some(usage) = usage {
453 this.update(cx, |this, cx| {
454 this.user_store.update(cx, |user_store, cx| {
455 user_store.update_edit_prediction_usage(usage, cx);
456 });
457 })
458 .ok();
459 }
460 Ok(data)
461 }
462 Err(err) => {
463 if err.is::<ZedUpdateRequiredError>() {
464 cx.update(|cx| {
465 this.update(cx, |this, _cx| {
466 this.update_required = true;
467 })
468 .ok();
469
470 let error_message: SharedString = err.to_string().into();
471 show_app_notification(
472 NotificationId::unique::<ZedUpdateRequiredError>(),
473 cx,
474 move |cx| {
475 cx.new(|cx| {
476 ErrorMessagePrompt::new(error_message.clone(), cx)
477 .with_link_button("Update Zed", "https://zed.dev/releases")
478 })
479 },
480 );
481 });
482 }
483 Err(err)
484 }
485 }
486}
487
488pub(crate) fn active_buffer_diagnostics(
489 snapshot: &language::BufferSnapshot,
490 diagnostic_search_range: Range<Point>,
491 additional_context_token_count: usize,
492) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
493 snapshot
494 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
495 .map(|entry| {
496 let severity = match entry.diagnostic.severity {
497 DiagnosticSeverity::ERROR => Some(1),
498 DiagnosticSeverity::WARNING => Some(2),
499 DiagnosticSeverity::INFORMATION => Some(3),
500 DiagnosticSeverity::HINT => Some(4),
501 _ => None,
502 };
503 let diagnostic_point_range = entry.range.clone();
504 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
505 snapshot,
506 diagnostic_point_range.clone(),
507 additional_context_token_count,
508 );
509 let snippet = snapshot
510 .text_for_range(snippet_point_range.clone())
511 .collect::<String>();
512 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
513 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
514 zeta_prompt::ActiveBufferDiagnostic {
515 severity,
516 message: entry.diagnostic.message.clone(),
517 snippet,
518 snippet_buffer_row_range: diagnostic_point_range.start.row
519 ..diagnostic_point_range.end.row,
520 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
521 ..diagnostic_offset_range.end - snippet_start_offset,
522 }
523 })
524 .collect()
525}
526
527pub fn zeta2_prompt_input(
528 snapshot: &language::BufferSnapshot,
529 related_files: Vec<zeta_prompt::RelatedFile>,
530 events: Vec<Arc<zeta_prompt::Event>>,
531 diagnostic_search_range: Range<Point>,
532 excerpt_path: Arc<Path>,
533 cursor_offset: usize,
534 preferred_experiment: Option<String>,
535 is_open_source: bool,
536 can_collect_data: bool,
537 repo_url: Option<String>,
538) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
539 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
540 compute_cursor_excerpt(snapshot, cursor_offset);
541
542 let cursor_excerpt: Arc<str> = snapshot
543 .text_for_range(excerpt_point_range.clone())
544 .collect::<String>()
545 .into();
546 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
547 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
548 &cursor_excerpt,
549 cursor_offset_in_excerpt,
550 &syntax_ranges,
551 );
552
553 let active_buffer_diagnostics =
554 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
555
556 let prompt_input = zeta_prompt::ZetaPromptInput {
557 cursor_path: excerpt_path,
558 cursor_excerpt,
559 cursor_offset_in_excerpt,
560 excerpt_start_row: Some(excerpt_point_range.start.row),
561 events,
562 related_files: Some(related_files),
563 active_buffer_diagnostics,
564 excerpt_ranges,
565 syntax_ranges: Some(syntax_ranges),
566 experiment: preferred_experiment,
567 in_open_source_repo: is_open_source,
568 can_collect_data,
569 repo_url,
570 };
571 (excerpt_offset_range, prompt_input)
572}
573
574pub(crate) fn edit_prediction_accepted(
575 store: &EditPredictionStore,
576 current_prediction: CurrentEditPrediction,
577 cx: &App,
578) {
579 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
580 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
581 return;
582 }
583
584 let request_id = current_prediction.prediction.id.to_string();
585 let model_version = current_prediction.prediction.model_version;
586 let e2e_latency = current_prediction.e2e_latency;
587 let require_auth = custom_accept_url.is_none();
588 let client = store.client.clone();
589 let llm_token = store.llm_token.clone();
590 let organization_id = store
591 .user_store
592 .read(cx)
593 .current_organization()
594 .map(|organization| organization.id.clone());
595 let app_version = AppVersion::global(cx);
596
597 cx.background_spawn(async move {
598 let url = if let Some(accept_edits_url) = custom_accept_url {
599 gpui::http_client::Url::parse(&accept_edits_url)?
600 } else {
601 client
602 .http_client()
603 .build_zed_llm_url("/predict_edits/accept", &[])?
604 };
605
606 let response = EditPredictionStore::send_api_request::<()>(
607 move |builder| {
608 let req = builder.uri(url.as_ref()).body(
609 serde_json::to_string(&AcceptEditPredictionBody {
610 request_id: request_id.clone(),
611 model_version: model_version.clone(),
612 e2e_latency_ms: Some(e2e_latency.as_millis()),
613 })?
614 .into(),
615 );
616 Ok(req?)
617 },
618 client,
619 llm_token,
620 organization_id,
621 app_version,
622 require_auth,
623 )
624 .await;
625
626 response?;
627 anyhow::Ok(())
628 })
629 .detach_and_log_err(cx);
630}
631
632pub fn compute_edits(
633 old_text: String,
634 new_text: &str,
635 offset: usize,
636 snapshot: &BufferSnapshot,
637) -> Vec<(Range<Anchor>, Arc<str>)> {
638 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
639}
640
641pub fn compute_edits_and_cursor_position(
642 old_text: String,
643 new_text: &str,
644 offset: usize,
645 cursor_offset_in_new_text: Option<usize>,
646 snapshot: &BufferSnapshot,
647) -> (
648 Vec<(Range<Anchor>, Arc<str>)>,
649 Option<PredictedCursorPosition>,
650) {
651 let diffs = text_diff(&old_text, new_text);
652
653 // Delta represents the cumulative change in byte count from all preceding edits.
654 // new_offset = old_offset + delta, so old_offset = new_offset - delta
655 let mut delta: isize = 0;
656 let mut cursor_position: Option<PredictedCursorPosition> = None;
657 let buffer_len = snapshot.len();
658
659 let edits = diffs
660 .iter()
661 .map(|(raw_old_range, new_text)| {
662 // Compute cursor position if it falls within or before this edit.
663 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
664 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
665 let edit_end_in_new = edit_start_in_new + new_text.len();
666
667 if cursor_offset < edit_start_in_new {
668 let cursor_in_old = (cursor_offset as isize - delta) as usize;
669 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
670 cursor_position = Some(PredictedCursorPosition::at_anchor(
671 snapshot.anchor_after(buffer_offset),
672 ));
673 } else if cursor_offset < edit_end_in_new {
674 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
675 let offset_within_insertion = cursor_offset - edit_start_in_new;
676 cursor_position = Some(PredictedCursorPosition::new(
677 snapshot.anchor_before(buffer_offset),
678 offset_within_insertion,
679 ));
680 }
681
682 delta += new_text.len() as isize - raw_old_range.len() as isize;
683 }
684
685 // Compute the edit with prefix/suffix trimming.
686 let mut old_range = raw_old_range.clone();
687 let old_slice = &old_text[old_range.clone()];
688
689 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
690 let suffix_len = common_prefix(
691 old_slice[prefix_len..].chars().rev(),
692 new_text[prefix_len..].chars().rev(),
693 );
694
695 old_range.start += offset;
696 old_range.end += offset;
697 old_range.start += prefix_len;
698 old_range.end -= suffix_len;
699
700 old_range.start = old_range.start.min(buffer_len);
701 old_range.end = old_range.end.min(buffer_len);
702
703 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
704 let range = if old_range.is_empty() {
705 let anchor = snapshot.anchor_after(old_range.start);
706 anchor..anchor
707 } else {
708 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
709 };
710 (range, new_text)
711 })
712 .collect();
713
714 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
715 let cursor_in_old = (cursor_offset as isize - delta) as usize;
716 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
717 cursor_position = Some(PredictedCursorPosition::at_anchor(
718 snapshot.anchor_after(buffer_offset),
719 ));
720 }
721
722 (edits, cursor_position)
723}
724
725fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
726 a.zip(b)
727 .take_while(|(a, b)| a == b)
728 .map(|(a, _)| a.len_utf8())
729 .sum()
730}