1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc};
26use zeta_prompt::{
27 ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 parsed_output_from_editable_region, prompt_input_contains_special_tokens,
29 stop_tokens_for_format,
30 zeta1::{self, EDITABLE_REGION_END_MARKER},
31};
32
33use crate::open_ai_compatible::{
34 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
35};
36
37pub fn request_prediction_with_zeta(
38 store: &mut EditPredictionStore,
39 EditPredictionModelInput {
40 buffer,
41 snapshot,
42 position,
43 related_files,
44 events,
45 debug_tx,
46 trigger,
47 project,
48 diagnostic_search_range,
49 can_collect_data,
50 is_open_source,
51 ..
52 }: EditPredictionModelInput,
53 capture_data: Option<Vec<StoredEvent>>,
54 cx: &mut Context<EditPredictionStore>,
55) -> Task<Result<Option<EditPredictionResult>>> {
56 let settings = &all_language_settings(None, cx).edit_predictions;
57 let provider = settings.provider;
58 let custom_server_settings = match provider {
59 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
60 settings::EditPredictionProvider::OpenAiCompatibleApi => {
61 settings.open_ai_compatible_api.clone()
62 }
63 _ => None,
64 };
65
66 let http_client = cx.http_client();
67 let request_start = cx.background_executor().now();
68 let raw_config = store.zeta2_raw_config().cloned();
69 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
70 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
71
72 let excerpt_path: Arc<Path> = snapshot
73 .file()
74 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
75 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
76
77 let repo_url = if can_collect_data {
78 let buffer_id = buffer.read(cx).remote_id();
79 project
80 .read(cx)
81 .git_store()
82 .read(cx)
83 .repository_and_path_for_buffer_id(buffer_id, cx)
84 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
85 } else {
86 None
87 };
88 let client = store.client.clone();
89 let llm_token = store.llm_token.clone();
90 let organization_id = store
91 .user_store
92 .read(cx)
93 .current_organization()
94 .map(|organization| organization.id.clone());
95 let app_version = AppVersion::global(cx);
96
97 struct Prediction {
98 prompt_input: ZetaPromptInput,
99 buffer: Entity<Buffer>,
100 snapshot: BufferSnapshot,
101 edits: Vec<(Range<Anchor>, Arc<str>)>,
102 cursor_position: Option<PredictedCursorPosition>,
103 editable_range_in_buffer: Range<usize>,
104 model_version: Option<String>,
105 }
106
107 let request_task = cx.background_spawn({
108 async move {
109 let zeta_version = raw_config
110 .as_ref()
111 .map(|config| config.format)
112 .unwrap_or(ZetaFormat::default());
113
114 let cursor_offset = position.to_offset(&snapshot);
115 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
116 &snapshot,
117 related_files,
118 events,
119 diagnostic_search_range,
120 excerpt_path,
121 cursor_offset,
122 preferred_experiment,
123 is_open_source,
124 can_collect_data,
125 repo_url,
126 );
127
128 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
129 return Err(anyhow::anyhow!("prompt contains special tokens"));
130 }
131
132 let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
133
134 if let Some(debug_tx) = &debug_tx {
135 debug_tx
136 .unbounded_send(DebugEvent::EditPredictionStarted(
137 EditPredictionStartedDebugEvent {
138 buffer: buffer.downgrade(),
139 prompt: formatted_prompt.clone(),
140 position,
141 },
142 ))
143 .ok();
144 }
145
146 log::trace!("Sending edit prediction request");
147
148 let Some((request_id, output, model_version, usage)) =
149 (if let Some(custom_settings) = &custom_server_settings {
150 let max_tokens = custom_settings.max_output_tokens * 4;
151
152 Some(match custom_settings.prompt_format {
153 EditPredictionPromptFormat::Zeta => {
154 let ranges = &prompt_input.excerpt_ranges;
155 let editable_range_in_excerpt = ranges.editable_350.clone();
156 let prompt = zeta1::format_zeta1_from_input(
157 &prompt_input,
158 editable_range_in_excerpt.clone(),
159 ranges.editable_350_context_150.clone(),
160 );
161 let stop_tokens = vec![
162 EDITABLE_REGION_END_MARKER.to_string(),
163 format!("{EDITABLE_REGION_END_MARKER}\n"),
164 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
165 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
166 ];
167
168 let (response_text, request_id) = send_custom_server_request(
169 provider,
170 custom_settings,
171 prompt,
172 max_tokens,
173 stop_tokens,
174 open_ai_compatible_api_key.clone(),
175 &http_client,
176 )
177 .await?;
178
179 let request_id = EditPredictionId(request_id.into());
180 let output_text = zeta1::clean_zeta1_model_output(&response_text);
181 let parsed_output = output_text.map(|text| ParsedOutput {
182 new_editable_region: text,
183 range_in_excerpt: editable_range_in_excerpt,
184 cursor_offset_in_new_editable_region: None,
185 });
186
187 (request_id, parsed_output, None, None)
188 }
189 EditPredictionPromptFormat::Zeta2 => {
190 let Some(prompt) = formatted_prompt.clone() else {
191 return Ok((None, None));
192 };
193 let prefill = get_prefill(&prompt_input, zeta_version);
194 let prompt = format!("{prompt}{prefill}");
195
196 let (response_text, request_id) = send_custom_server_request(
197 provider,
198 custom_settings,
199 prompt,
200 max_tokens,
201 stop_tokens_for_format(zeta_version)
202 .iter()
203 .map(|token| token.to_string())
204 .collect(),
205 open_ai_compatible_api_key.clone(),
206 &http_client,
207 )
208 .await?;
209
210 let request_id = EditPredictionId(request_id.into());
211 let output_text = if response_text.is_empty() {
212 None
213 } else {
214 let output = format!("{prefill}{response_text}");
215 Some(parse_zeta2_model_output(
216 &output,
217 zeta_version,
218 &prompt_input,
219 )?)
220 };
221
222 (request_id, output_text, None, None)
223 }
224 _ => anyhow::bail!("unsupported prompt format"),
225 })
226 } else if let Some(config) = &raw_config {
227 let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
228 return Ok((None, None));
229 };
230 let prefill = get_prefill(&prompt_input, config.format);
231 let prompt = format!("{prompt}{prefill}");
232 let environment = config
233 .environment
234 .clone()
235 .or_else(|| Some(config.format.to_string().to_lowercase()));
236 let request = RawCompletionRequest {
237 model: config.model_id.clone().unwrap_or_default(),
238 prompt,
239 temperature: None,
240 stop: stop_tokens_for_format(config.format)
241 .iter()
242 .map(|token| std::borrow::Cow::Borrowed(*token))
243 .collect(),
244 max_tokens: Some(2048),
245 environment,
246 };
247
248 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
249 request,
250 client,
251 None,
252 llm_token,
253 organization_id,
254 app_version,
255 )
256 .await?;
257
258 let request_id = EditPredictionId(response.id.clone().into());
259 let output = if let Some(choice) = response.choices.pop() {
260 let response = &choice.text;
261 let output = format!("{prefill}{response}");
262 Some(parse_zeta2_model_output(
263 &output,
264 config.format,
265 &prompt_input,
266 )?)
267 } else {
268 None
269 };
270
271 Some((request_id, output, None, usage))
272 } else {
273 // Use V3 endpoint - server handles model/version selection and suffix stripping
274 let (response, usage) = EditPredictionStore::send_v3_request(
275 prompt_input.clone(),
276 client,
277 llm_token,
278 organization_id,
279 app_version,
280 trigger,
281 )
282 .await?;
283
284 let request_id = EditPredictionId(response.request_id.into());
285 let output_text = Some(response.output).filter(|s| !s.is_empty());
286 let model_version = response.model_version;
287 let parsed_output = parsed_output_from_editable_region(
288 response.editable_range,
289 output_text.unwrap_or_default(),
290 );
291
292 Some((request_id, Some(parsed_output), model_version, usage))
293 })
294 else {
295 return Ok((None, None));
296 };
297
298 log::trace!("Got edit prediction response");
299
300 let Some(ParsedOutput {
301 new_editable_region: mut output_text,
302 range_in_excerpt: editable_range_in_excerpt,
303 cursor_offset_in_new_editable_region: cursor_offset_in_output,
304 }) = output
305 else {
306 return Ok((Some((request_id, None)), None));
307 };
308
309 let editable_range_in_buffer = editable_range_in_excerpt.start
310 + full_context_offset_range.start
311 ..editable_range_in_excerpt.end + full_context_offset_range.start;
312
313 let mut old_text = snapshot
314 .text_for_range(editable_range_in_buffer.clone())
315 .collect::<String>();
316
317 if let Some(debug_tx) = &debug_tx {
318 debug_tx
319 .unbounded_send(DebugEvent::EditPredictionFinished(
320 EditPredictionFinishedDebugEvent {
321 buffer: buffer.downgrade(),
322 position,
323 model_output: Some(output_text.clone()),
324 },
325 ))
326 .ok();
327 }
328
329 if !output_text.is_empty() && !output_text.ends_with('\n') {
330 output_text.push('\n');
331 }
332 if !old_text.is_empty() && !old_text.ends_with('\n') {
333 old_text.push('\n');
334 }
335
336 let (edits, cursor_position) = compute_edits_and_cursor_position(
337 old_text,
338 &output_text,
339 editable_range_in_buffer.start,
340 cursor_offset_in_output,
341 &snapshot,
342 );
343
344 anyhow::Ok((
345 Some((
346 request_id,
347 Some(Prediction {
348 prompt_input,
349 buffer,
350 snapshot: snapshot.clone(),
351 edits,
352 cursor_position,
353 editable_range_in_buffer,
354 model_version,
355 }),
356 )),
357 usage,
358 ))
359 }
360 });
361
362 cx.spawn(async move |this, cx| {
363 let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
364 return Ok(None);
365 };
366 let request_duration = cx.background_executor().now() - request_start;
367
368 let Some(Prediction {
369 prompt_input: inputs,
370 buffer: edited_buffer,
371 snapshot: edited_buffer_snapshot,
372 edits,
373 cursor_position,
374 editable_range_in_buffer,
375 model_version,
376 }) = prediction
377 else {
378 return Ok(Some(EditPredictionResult {
379 id,
380 e2e_latency: request_duration,
381 prediction: Err(EditPredictionRejectReason::Empty),
382 }));
383 };
384
385 let result = EditPredictionResult::new(
386 id,
387 &edited_buffer,
388 &edited_buffer_snapshot,
389 edits.into(),
390 cursor_position,
391 inputs,
392 model_version,
393 request_duration,
394 cx,
395 )
396 .await;
397
398 if can_collect_data && let Ok(prediction) = &result.prediction {
399 let weak_this = this.clone();
400 let request_id = prediction.id.clone();
401 let edited_buffer = edited_buffer.clone();
402 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
403 let editable_range_in_buffer = editable_range_in_buffer.clone();
404 let edit_preview = prediction.edit_preview.clone();
405 let example_task = capture_data.and_then(|stored_events| {
406 cx.update(|cx| {
407 crate::capture_example(
408 project.clone(),
409 edited_buffer.clone(),
410 position,
411 stored_events,
412 false,
413 cx,
414 )
415 })
416 });
417 cx.spawn(async move |cx| {
418 let example_spec = if let Some(task) = example_task {
419 task.await.ok()
420 } else {
421 None
422 };
423
424 weak_this
425 .update(cx, |this, cx| {
426 this.enqueue_settled_prediction(
427 request_id.clone(),
428 &project,
429 &edited_buffer,
430 &edited_buffer_snapshot,
431 editable_range_in_buffer,
432 &edit_preview,
433 example_spec,
434 request_duration,
435 cx,
436 );
437 })
438 .ok();
439 })
440 .detach();
441 }
442
443 Ok(Some(result))
444 })
445}
446
447fn handle_api_response<T>(
448 this: &WeakEntity<EditPredictionStore>,
449 response: Result<(T, Option<client::EditPredictionUsage>)>,
450 cx: &mut gpui::AsyncApp,
451) -> Result<T> {
452 match response {
453 Ok((data, usage)) => {
454 if let Some(usage) = usage {
455 this.update(cx, |this, cx| {
456 this.user_store.update(cx, |user_store, cx| {
457 user_store.update_edit_prediction_usage(usage, cx);
458 });
459 })
460 .ok();
461 }
462 Ok(data)
463 }
464 Err(err) => {
465 if err.is::<ZedUpdateRequiredError>() {
466 cx.update(|cx| {
467 this.update(cx, |this, _cx| {
468 this.update_required = true;
469 })
470 .ok();
471
472 let error_message: SharedString = err.to_string().into();
473 show_app_notification(
474 NotificationId::unique::<ZedUpdateRequiredError>(),
475 cx,
476 move |cx| {
477 cx.new(|cx| {
478 ErrorMessagePrompt::new(error_message.clone(), cx)
479 .with_link_button("Update Zed", "https://zed.dev/releases")
480 })
481 },
482 );
483 });
484 }
485 Err(err)
486 }
487 }
488}
489
490pub(crate) fn active_buffer_diagnostics(
491 snapshot: &language::BufferSnapshot,
492 diagnostic_search_range: Range<Point>,
493 additional_context_token_count: usize,
494) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
495 snapshot
496 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
497 .map(|entry| {
498 let severity = match entry.diagnostic.severity {
499 DiagnosticSeverity::ERROR => Some(1),
500 DiagnosticSeverity::WARNING => Some(2),
501 DiagnosticSeverity::INFORMATION => Some(3),
502 DiagnosticSeverity::HINT => Some(4),
503 _ => None,
504 };
505 let diagnostic_point_range = entry.range.clone();
506 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
507 snapshot,
508 diagnostic_point_range.clone(),
509 additional_context_token_count,
510 );
511 let snippet = snapshot
512 .text_for_range(snippet_point_range.clone())
513 .collect::<String>();
514 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
515 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
516 zeta_prompt::ActiveBufferDiagnostic {
517 severity,
518 message: entry.diagnostic.message.clone(),
519 snippet,
520 snippet_buffer_row_range: diagnostic_point_range.start.row
521 ..diagnostic_point_range.end.row,
522 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
523 ..diagnostic_offset_range.end - snippet_start_offset,
524 }
525 })
526 .collect()
527}
528
529pub fn zeta2_prompt_input(
530 snapshot: &language::BufferSnapshot,
531 related_files: Vec<zeta_prompt::RelatedFile>,
532 events: Vec<Arc<zeta_prompt::Event>>,
533 diagnostic_search_range: Range<Point>,
534 excerpt_path: Arc<Path>,
535 cursor_offset: usize,
536 preferred_experiment: Option<String>,
537 is_open_source: bool,
538 can_collect_data: bool,
539 repo_url: Option<String>,
540) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
541 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
542 compute_cursor_excerpt(snapshot, cursor_offset);
543
544 let cursor_excerpt: Arc<str> = snapshot
545 .text_for_range(excerpt_point_range.clone())
546 .collect::<String>()
547 .into();
548 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
549 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
550 &cursor_excerpt,
551 cursor_offset_in_excerpt,
552 &syntax_ranges,
553 );
554
555 let active_buffer_diagnostics =
556 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
557
558 let prompt_input = zeta_prompt::ZetaPromptInput {
559 cursor_path: excerpt_path,
560 cursor_excerpt,
561 cursor_offset_in_excerpt,
562 excerpt_start_row: Some(excerpt_point_range.start.row),
563 events,
564 related_files: Some(related_files),
565 active_buffer_diagnostics,
566 excerpt_ranges,
567 syntax_ranges: Some(syntax_ranges),
568 experiment: preferred_experiment,
569 in_open_source_repo: is_open_source,
570 can_collect_data,
571 repo_url,
572 };
573 (excerpt_offset_range, prompt_input)
574}
575
576pub(crate) fn edit_prediction_accepted(
577 store: &EditPredictionStore,
578 current_prediction: CurrentEditPrediction,
579 cx: &App,
580) {
581 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
582 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
583 return;
584 }
585
586 let request_id = current_prediction.prediction.id.to_string();
587 let model_version = current_prediction.prediction.model_version;
588 let e2e_latency = current_prediction.e2e_latency;
589 let require_auth = custom_accept_url.is_none();
590 let client = store.client.clone();
591 let llm_token = store.llm_token.clone();
592 let organization_id = store
593 .user_store
594 .read(cx)
595 .current_organization()
596 .map(|organization| organization.id.clone());
597 let app_version = AppVersion::global(cx);
598
599 cx.background_spawn(async move {
600 let url = if let Some(accept_edits_url) = custom_accept_url {
601 gpui::http_client::Url::parse(&accept_edits_url)?
602 } else {
603 client
604 .http_client()
605 .build_zed_llm_url("/predict_edits/accept", &[])?
606 };
607
608 let response = EditPredictionStore::send_api_request::<()>(
609 move |builder| {
610 let req = builder.uri(url.as_ref()).body(
611 serde_json::to_string(&AcceptEditPredictionBody {
612 request_id: request_id.clone(),
613 model_version: model_version.clone(),
614 e2e_latency_ms: Some(e2e_latency.as_millis()),
615 })?
616 .into(),
617 );
618 Ok(req?)
619 },
620 client,
621 llm_token,
622 organization_id,
623 app_version,
624 require_auth,
625 )
626 .await;
627
628 response?;
629 anyhow::Ok(())
630 })
631 .detach_and_log_err(cx);
632}
633
634pub fn compute_edits(
635 old_text: String,
636 new_text: &str,
637 offset: usize,
638 snapshot: &BufferSnapshot,
639) -> Vec<(Range<Anchor>, Arc<str>)> {
640 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
641}
642
643pub fn compute_edits_and_cursor_position(
644 old_text: String,
645 new_text: &str,
646 offset: usize,
647 cursor_offset_in_new_text: Option<usize>,
648 snapshot: &BufferSnapshot,
649) -> (
650 Vec<(Range<Anchor>, Arc<str>)>,
651 Option<PredictedCursorPosition>,
652) {
653 let diffs = text_diff(&old_text, new_text);
654
655 // Delta represents the cumulative change in byte count from all preceding edits.
656 // new_offset = old_offset + delta, so old_offset = new_offset - delta
657 let mut delta: isize = 0;
658 let mut cursor_position: Option<PredictedCursorPosition> = None;
659 let buffer_len = snapshot.len();
660
661 let edits = diffs
662 .iter()
663 .map(|(raw_old_range, new_text)| {
664 // Compute cursor position if it falls within or before this edit.
665 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
666 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
667 let edit_end_in_new = edit_start_in_new + new_text.len();
668
669 if cursor_offset < edit_start_in_new {
670 let cursor_in_old = (cursor_offset as isize - delta) as usize;
671 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
672 cursor_position = Some(PredictedCursorPosition::at_anchor(
673 snapshot.anchor_after(buffer_offset),
674 ));
675 } else if cursor_offset < edit_end_in_new {
676 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
677 let offset_within_insertion = cursor_offset - edit_start_in_new;
678 cursor_position = Some(PredictedCursorPosition::new(
679 snapshot.anchor_before(buffer_offset),
680 offset_within_insertion,
681 ));
682 }
683
684 delta += new_text.len() as isize - raw_old_range.len() as isize;
685 }
686
687 // Compute the edit with prefix/suffix trimming.
688 let mut old_range = raw_old_range.clone();
689 let old_slice = &old_text[old_range.clone()];
690
691 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
692 let suffix_len = common_prefix(
693 old_slice[prefix_len..].chars().rev(),
694 new_text[prefix_len..].chars().rev(),
695 );
696
697 old_range.start += offset;
698 old_range.end += offset;
699 old_range.start += prefix_len;
700 old_range.end -= suffix_len;
701
702 old_range.start = old_range.start.min(buffer_len);
703 old_range.end = old_range.end.min(buffer_len);
704
705 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
706 let range = if old_range.is_empty() {
707 let anchor = snapshot.anchor_after(old_range.start);
708 anchor..anchor
709 } else {
710 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
711 };
712 (range, new_text)
713 })
714 .collect();
715
716 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
717 let cursor_in_old = (cursor_offset as isize - delta) as usize;
718 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
719 cursor_position = Some(PredictedCursorPosition::at_anchor(
720 snapshot.anchor_after(buffer_offset),
721 ));
722 }
723
724 (edits, cursor_position)
725}
726
727fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
728 a.zip(b)
729 .take_while(|(a, b)| a == b)
730 .map(|(a, _)| a.len_utf8())
731 .sum()
732}