1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc};
26use zeta_prompt::{
27 ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 parsed_output_from_editable_region, prompt_input_contains_special_tokens,
29 stop_tokens_for_format,
30 zeta1::{self, EDITABLE_REGION_END_MARKER},
31};
32
33use crate::open_ai_compatible::{
34 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
35};
36
37pub fn request_prediction_with_zeta(
38 store: &mut EditPredictionStore,
39 EditPredictionModelInput {
40 buffer,
41 snapshot,
42 position,
43 related_files,
44 events,
45 debug_tx,
46 mode,
47 trigger,
48 project,
49 diagnostic_search_range,
50 can_collect_data,
51 is_open_source,
52 ..
53 }: EditPredictionModelInput,
54 capture_data: Option<Vec<StoredEvent>>,
55 cx: &mut Context<EditPredictionStore>,
56) -> Task<Result<Option<EditPredictionResult>>> {
57 let settings = &all_language_settings(None, cx).edit_predictions;
58 let provider = settings.provider;
59 let custom_server_settings = match provider {
60 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
61 settings::EditPredictionProvider::OpenAiCompatibleApi => {
62 settings.open_ai_compatible_api.clone()
63 }
64 _ => None,
65 };
66
67 let http_client = cx.http_client();
68 let request_start = cx.background_executor().now();
69 let raw_config = store.zeta2_raw_config().cloned();
70 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
71 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
72
73 let excerpt_path: Arc<Path> = snapshot
74 .file()
75 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
76 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
77
78 let repo_url = if can_collect_data {
79 let buffer_id = buffer.read(cx).remote_id();
80 project
81 .read(cx)
82 .git_store()
83 .read(cx)
84 .repository_and_path_for_buffer_id(buffer_id, cx)
85 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
86 } else {
87 None
88 };
89 let client = store.client.clone();
90 let llm_token = store.llm_token.clone();
91 let organization_id = store
92 .user_store
93 .read(cx)
94 .current_organization()
95 .map(|organization| organization.id.clone());
96 let app_version = AppVersion::global(cx);
97
98 struct Prediction {
99 prompt_input: ZetaPromptInput,
100 buffer: Entity<Buffer>,
101 snapshot: BufferSnapshot,
102 edits: Vec<(Range<Anchor>, Arc<str>)>,
103 cursor_position: Option<PredictedCursorPosition>,
104 editable_range_in_buffer: Range<usize>,
105 }
106
107 let request_task = cx.background_spawn({
108 async move {
109 let zeta_version = raw_config
110 .as_ref()
111 .map(|config| config.format)
112 .unwrap_or(ZetaFormat::default());
113
114 let cursor_offset = position.to_offset(&snapshot);
115 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
116 &snapshot,
117 related_files,
118 events,
119 diagnostic_search_range,
120 excerpt_path,
121 cursor_offset,
122 is_open_source,
123 can_collect_data,
124 repo_url,
125 );
126
127 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
128 return Err(anyhow::anyhow!("prompt contains special tokens"));
129 }
130
131 let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
132
133 if let Some(debug_tx) = &debug_tx {
134 debug_tx
135 .unbounded_send(DebugEvent::EditPredictionStarted(
136 EditPredictionStartedDebugEvent {
137 buffer: buffer.downgrade(),
138 prompt: formatted_prompt.clone(),
139 position,
140 },
141 ))
142 .ok();
143 }
144
145 log::trace!("Sending edit prediction request");
146
147 let Some((request_id, output, model_version, usage)) =
148 (if let Some(custom_settings) = &custom_server_settings {
149 let max_tokens = custom_settings.max_output_tokens * 4;
150
151 Some(match custom_settings.prompt_format {
152 EditPredictionPromptFormat::Zeta => {
153 let ranges = &prompt_input.excerpt_ranges;
154 let editable_range_in_excerpt = ranges.editable_350.clone();
155 let prompt = zeta1::format_zeta1_from_input(
156 &prompt_input,
157 editable_range_in_excerpt.clone(),
158 ranges.editable_350_context_150.clone(),
159 );
160 let stop_tokens = vec![
161 EDITABLE_REGION_END_MARKER.to_string(),
162 format!("{EDITABLE_REGION_END_MARKER}\n"),
163 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
164 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
165 ];
166
167 let (response_text, request_id) = send_custom_server_request(
168 provider,
169 custom_settings,
170 prompt,
171 max_tokens,
172 stop_tokens,
173 open_ai_compatible_api_key.clone(),
174 &http_client,
175 )
176 .await?;
177
178 let request_id = EditPredictionId(request_id.into());
179 let output_text = zeta1::clean_zeta1_model_output(&response_text);
180 let parsed_output = output_text.map(|text| ParsedOutput {
181 new_editable_region: text,
182 range_in_excerpt: editable_range_in_excerpt,
183 cursor_offset_in_new_editable_region: None,
184 });
185
186 (request_id, parsed_output, None, None)
187 }
188 EditPredictionPromptFormat::Zeta2 => {
189 let Some(prompt) = formatted_prompt.clone() else {
190 return Ok((None, None));
191 };
192 let prefill = get_prefill(&prompt_input, zeta_version);
193 let prompt = format!("{prompt}{prefill}");
194
195 let (response_text, request_id) = send_custom_server_request(
196 provider,
197 custom_settings,
198 prompt,
199 max_tokens,
200 stop_tokens_for_format(zeta_version)
201 .iter()
202 .map(|token| token.to_string())
203 .collect(),
204 open_ai_compatible_api_key.clone(),
205 &http_client,
206 )
207 .await?;
208
209 let request_id = EditPredictionId(request_id.into());
210 let output_text = if response_text.is_empty() {
211 None
212 } else {
213 let output = format!("{prefill}{response_text}");
214 Some(parse_zeta2_model_output(
215 &output,
216 zeta_version,
217 &prompt_input,
218 )?)
219 };
220
221 (request_id, output_text, None, None)
222 }
223 _ => anyhow::bail!("unsupported prompt format"),
224 })
225 } else if let Some(config) = &raw_config {
226 let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
227 return Ok((None, None));
228 };
229 let prefill = get_prefill(&prompt_input, config.format);
230 let prompt = format!("{prompt}{prefill}");
231 let environment = config
232 .environment
233 .clone()
234 .or_else(|| Some(config.format.to_string().to_lowercase()));
235 let request = RawCompletionRequest {
236 model: config.model_id.clone().unwrap_or_default(),
237 prompt,
238 temperature: None,
239 stop: stop_tokens_for_format(config.format)
240 .iter()
241 .map(|token| std::borrow::Cow::Borrowed(*token))
242 .collect(),
243 max_tokens: Some(2048),
244 environment,
245 };
246
247 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
248 request,
249 client,
250 None,
251 llm_token,
252 organization_id,
253 app_version,
254 )
255 .await?;
256
257 let request_id = EditPredictionId(response.id.clone().into());
258 let output = if let Some(choice) = response.choices.pop() {
259 let response = &choice.text;
260 let output = format!("{prefill}{response}");
261 Some(parse_zeta2_model_output(
262 &output,
263 config.format,
264 &prompt_input,
265 )?)
266 } else {
267 None
268 };
269
270 Some((request_id, output, None, usage))
271 } else {
272 // Use V3 endpoint - server handles model/version selection and suffix stripping
273 let (response, usage) = EditPredictionStore::send_v3_request(
274 prompt_input.clone(),
275 preferred_experiment.clone(),
276 client,
277 llm_token,
278 organization_id,
279 app_version,
280 trigger,
281 mode,
282 )
283 .await?;
284
285 let request_id = EditPredictionId(response.request_id.into());
286 let output_text = Some(response.output).filter(|s| !s.is_empty());
287 let model_version = response.model_version;
288 let parsed_output = parsed_output_from_editable_region(
289 response.editable_range,
290 output_text.unwrap_or_default(),
291 );
292
293 Some((request_id, Some(parsed_output), model_version, usage))
294 })
295 else {
296 return Ok((None, None));
297 };
298
299 log::trace!("Got edit prediction response");
300
301 let Some(ParsedOutput {
302 new_editable_region: mut output_text,
303 range_in_excerpt: editable_range_in_excerpt,
304 cursor_offset_in_new_editable_region: cursor_offset_in_output,
305 }) = output
306 else {
307 return Ok((Some((request_id, None, model_version)), None));
308 };
309
310 let editable_range_in_buffer = editable_range_in_excerpt.start
311 + full_context_offset_range.start
312 ..editable_range_in_excerpt.end + full_context_offset_range.start;
313
314 let mut old_text = snapshot
315 .text_for_range(editable_range_in_buffer.clone())
316 .collect::<String>();
317
318 if let Some(debug_tx) = &debug_tx {
319 debug_tx
320 .unbounded_send(DebugEvent::EditPredictionFinished(
321 EditPredictionFinishedDebugEvent {
322 buffer: buffer.downgrade(),
323 position,
324 model_output: Some(output_text.clone()),
325 },
326 ))
327 .ok();
328 }
329
330 if !output_text.is_empty() && !output_text.ends_with('\n') {
331 output_text.push('\n');
332 }
333 if !old_text.is_empty() && !old_text.ends_with('\n') {
334 old_text.push('\n');
335 }
336
337 let (edits, cursor_position) = compute_edits_and_cursor_position(
338 old_text,
339 &output_text,
340 editable_range_in_buffer.start,
341 cursor_offset_in_output,
342 &snapshot,
343 );
344
345 let prediction = Some(Prediction {
346 prompt_input,
347 buffer,
348 snapshot: snapshot.clone(),
349 edits,
350 cursor_position,
351 editable_range_in_buffer,
352 });
353
354 anyhow::Ok((Some((request_id, prediction, model_version)), usage))
355 }
356 });
357
358 cx.spawn(async move |this, cx| {
359 let Some((id, prediction, model_version)) =
360 handle_api_response(&this, request_task.await, cx)?
361 else {
362 return Ok(None);
363 };
364 let request_duration = cx.background_executor().now() - request_start;
365
366 let Some(Prediction {
367 prompt_input: inputs,
368 buffer: edited_buffer,
369 snapshot: edited_buffer_snapshot,
370 edits,
371 cursor_position,
372 editable_range_in_buffer,
373 ..
374 }) = prediction
375 else {
376 return Ok(Some(EditPredictionResult {
377 id,
378 prediction: Err(EditPredictionRejectReason::Empty),
379 model_version,
380 e2e_latency: request_duration,
381 }));
382 };
383
384 let result = EditPredictionResult::new(
385 id,
386 &edited_buffer,
387 &edited_buffer_snapshot,
388 edits.into(),
389 cursor_position,
390 inputs,
391 model_version,
392 request_duration,
393 cx,
394 )
395 .await;
396
397 if can_collect_data && let Ok(prediction) = &result.prediction {
398 let weak_this = this.clone();
399 let request_id = prediction.id.clone();
400 let edited_buffer = edited_buffer.clone();
401 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
402 let editable_range_in_buffer = editable_range_in_buffer.clone();
403 let edit_preview = prediction.edit_preview.clone();
404 let example_task = capture_data.and_then(|stored_events| {
405 cx.update(|cx| {
406 crate::capture_example(
407 project.clone(),
408 edited_buffer.clone(),
409 position,
410 stored_events,
411 false,
412 cx,
413 )
414 })
415 });
416 cx.spawn(async move |cx| {
417 let example_spec = if let Some(task) = example_task {
418 task.await.ok()
419 } else {
420 None
421 };
422
423 weak_this
424 .update(cx, |this, cx| {
425 this.enqueue_settled_prediction(
426 request_id.clone(),
427 &project,
428 &edited_buffer,
429 &edited_buffer_snapshot,
430 editable_range_in_buffer,
431 &edit_preview,
432 example_spec,
433 request_duration,
434 cx,
435 );
436 })
437 .ok();
438 })
439 .detach();
440 }
441
442 Ok(Some(result))
443 })
444}
445
446fn handle_api_response<T>(
447 this: &WeakEntity<EditPredictionStore>,
448 response: Result<(T, Option<client::EditPredictionUsage>)>,
449 cx: &mut gpui::AsyncApp,
450) -> Result<T> {
451 match response {
452 Ok((data, usage)) => {
453 if let Some(usage) = usage {
454 this.update(cx, |this, cx| {
455 this.user_store.update(cx, |user_store, cx| {
456 user_store.update_edit_prediction_usage(usage, cx);
457 });
458 })
459 .ok();
460 }
461 Ok(data)
462 }
463 Err(err) => {
464 if err.is::<ZedUpdateRequiredError>() {
465 cx.update(|cx| {
466 this.update(cx, |this, _cx| {
467 this.update_required = true;
468 })
469 .ok();
470
471 let error_message: SharedString = err.to_string().into();
472 show_app_notification(
473 NotificationId::unique::<ZedUpdateRequiredError>(),
474 cx,
475 move |cx| {
476 cx.new(|cx| {
477 ErrorMessagePrompt::new(error_message.clone(), cx)
478 .with_link_button("Update Zed", "https://zed.dev/releases")
479 })
480 },
481 );
482 });
483 }
484 Err(err)
485 }
486 }
487}
488
489pub(crate) fn active_buffer_diagnostics(
490 snapshot: &language::BufferSnapshot,
491 diagnostic_search_range: Range<Point>,
492 additional_context_token_count: usize,
493) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
494 snapshot
495 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
496 .map(|entry| {
497 let severity = match entry.diagnostic.severity {
498 DiagnosticSeverity::ERROR => Some(1),
499 DiagnosticSeverity::WARNING => Some(2),
500 DiagnosticSeverity::INFORMATION => Some(3),
501 DiagnosticSeverity::HINT => Some(4),
502 _ => None,
503 };
504 let diagnostic_point_range = entry.range.clone();
505 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
506 snapshot,
507 diagnostic_point_range.clone(),
508 additional_context_token_count,
509 );
510 let snippet = snapshot
511 .text_for_range(snippet_point_range.clone())
512 .collect::<String>();
513 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
514 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
515 zeta_prompt::ActiveBufferDiagnostic {
516 severity,
517 message: entry.diagnostic.message.clone(),
518 snippet,
519 snippet_buffer_row_range: diagnostic_point_range.start.row
520 ..diagnostic_point_range.end.row,
521 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
522 ..diagnostic_offset_range.end - snippet_start_offset,
523 }
524 })
525 .collect()
526}
527
528pub fn zeta2_prompt_input(
529 snapshot: &language::BufferSnapshot,
530 related_files: Vec<zeta_prompt::RelatedFile>,
531 events: Vec<Arc<zeta_prompt::Event>>,
532 diagnostic_search_range: Range<Point>,
533 excerpt_path: Arc<Path>,
534 cursor_offset: usize,
535 is_open_source: bool,
536 can_collect_data: bool,
537 repo_url: Option<String>,
538) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
539 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
540 compute_cursor_excerpt(snapshot, cursor_offset);
541
542 let cursor_excerpt: Arc<str> = snapshot
543 .text_for_range(excerpt_point_range.clone())
544 .collect::<String>()
545 .into();
546 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
547 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
548 &cursor_excerpt,
549 cursor_offset_in_excerpt,
550 &syntax_ranges,
551 );
552
553 let active_buffer_diagnostics =
554 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
555
556 let prompt_input = zeta_prompt::ZetaPromptInput {
557 cursor_path: excerpt_path,
558 cursor_excerpt,
559 cursor_offset_in_excerpt,
560 excerpt_start_row: Some(excerpt_point_range.start.row),
561 events,
562 related_files: Some(related_files),
563 active_buffer_diagnostics,
564 excerpt_ranges,
565 syntax_ranges: Some(syntax_ranges),
566 in_open_source_repo: is_open_source,
567 can_collect_data,
568 repo_url,
569 };
570 (excerpt_offset_range, prompt_input)
571}
572
573pub(crate) fn edit_prediction_accepted(
574 store: &EditPredictionStore,
575 current_prediction: CurrentEditPrediction,
576 cx: &App,
577) {
578 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
579 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
580 return;
581 }
582
583 let request_id = current_prediction.prediction.id.to_string();
584 let model_version = current_prediction.prediction.model_version;
585 let e2e_latency = current_prediction.e2e_latency;
586 let require_auth = custom_accept_url.is_none();
587 let client = store.client.clone();
588 let llm_token = store.llm_token.clone();
589 let organization_id = store
590 .user_store
591 .read(cx)
592 .current_organization()
593 .map(|organization| organization.id.clone());
594 let app_version = AppVersion::global(cx);
595
596 cx.background_spawn(async move {
597 let url = if let Some(accept_edits_url) = custom_accept_url {
598 gpui::http_client::Url::parse(&accept_edits_url)?
599 } else {
600 client
601 .http_client()
602 .build_zed_llm_url("/predict_edits/accept", &[])?
603 };
604
605 let response = EditPredictionStore::send_api_request::<()>(
606 move |builder| {
607 let req = builder.uri(url.as_ref()).body(
608 serde_json::to_string(&AcceptEditPredictionBody {
609 request_id: request_id.clone(),
610 model_version: model_version.clone(),
611 e2e_latency_ms: Some(e2e_latency.as_millis()),
612 })?
613 .into(),
614 );
615 Ok(req?)
616 },
617 client,
618 llm_token,
619 organization_id,
620 app_version,
621 require_auth,
622 )
623 .await;
624
625 response?;
626 anyhow::Ok(())
627 })
628 .detach_and_log_err(cx);
629}
630
631pub fn compute_edits(
632 old_text: String,
633 new_text: &str,
634 offset: usize,
635 snapshot: &BufferSnapshot,
636) -> Vec<(Range<Anchor>, Arc<str>)> {
637 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
638}
639
640pub fn compute_edits_and_cursor_position(
641 old_text: String,
642 new_text: &str,
643 offset: usize,
644 cursor_offset_in_new_text: Option<usize>,
645 snapshot: &BufferSnapshot,
646) -> (
647 Vec<(Range<Anchor>, Arc<str>)>,
648 Option<PredictedCursorPosition>,
649) {
650 let diffs = text_diff(&old_text, new_text);
651
652 // Delta represents the cumulative change in byte count from all preceding edits.
653 // new_offset = old_offset + delta, so old_offset = new_offset - delta
654 let mut delta: isize = 0;
655 let mut cursor_position: Option<PredictedCursorPosition> = None;
656 let buffer_len = snapshot.len();
657
658 let edits = diffs
659 .iter()
660 .map(|(raw_old_range, new_text)| {
661 // Compute cursor position if it falls within or before this edit.
662 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
663 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
664 let edit_end_in_new = edit_start_in_new + new_text.len();
665
666 if cursor_offset < edit_start_in_new {
667 let cursor_in_old = (cursor_offset as isize - delta) as usize;
668 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
669 cursor_position = Some(PredictedCursorPosition::at_anchor(
670 snapshot.anchor_after(buffer_offset),
671 ));
672 } else if cursor_offset < edit_end_in_new {
673 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
674 let offset_within_insertion = cursor_offset - edit_start_in_new;
675 cursor_position = Some(PredictedCursorPosition::new(
676 snapshot.anchor_before(buffer_offset),
677 offset_within_insertion,
678 ));
679 }
680
681 delta += new_text.len() as isize - raw_old_range.len() as isize;
682 }
683
684 // Compute the edit with prefix/suffix trimming.
685 let mut old_range = raw_old_range.clone();
686 let old_slice = &old_text[old_range.clone()];
687
688 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
689 let suffix_len = common_prefix(
690 old_slice[prefix_len..].chars().rev(),
691 new_text[prefix_len..].chars().rev(),
692 );
693
694 old_range.start += offset;
695 old_range.end += offset;
696 old_range.start += prefix_len;
697 old_range.end -= suffix_len;
698
699 old_range.start = old_range.start.min(buffer_len);
700 old_range.end = old_range.end.min(buffer_len);
701
702 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
703 let range = if old_range.is_empty() {
704 let anchor = snapshot.anchor_after(old_range.start);
705 anchor..anchor
706 } else {
707 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
708 };
709 (range, new_text)
710 })
711 .collect();
712
713 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
714 let cursor_in_old = (cursor_offset as isize - delta) as usize;
715 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
716 cursor_position = Some(PredictedCursorPosition::at_anchor(
717 snapshot.anchor_after(buffer_offset),
718 ));
719 }
720
721 (edits, cursor_position)
722}
723
724fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
725 a.zip(b)
726 .take_while(|(a, b)| a == b)
727 .map(|(a, _)| a.len_utf8())
728 .sum()
729}