1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
26use zeta_prompt::{
27 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 prompt_input_contains_special_tokens, stop_tokens_for_format,
29 zeta1::{self, EDITABLE_REGION_END_MARKER},
30};
31
32use crate::open_ai_compatible::{
33 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
34};
35
36pub fn request_prediction_with_zeta(
37 store: &mut EditPredictionStore,
38 EditPredictionModelInput {
39 buffer,
40 snapshot,
41 position,
42 related_files,
43 events,
44 debug_tx,
45 trigger,
46 project,
47 diagnostic_search_range,
48 can_collect_data,
49 is_open_source,
50 ..
51 }: EditPredictionModelInput,
52 capture_data: Option<Vec<StoredEvent>>,
53 cx: &mut Context<EditPredictionStore>,
54) -> Task<Result<Option<EditPredictionResult>>> {
55 let settings = &all_language_settings(None, cx).edit_predictions;
56 let provider = settings.provider;
57 let custom_server_settings = match provider {
58 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
59 settings::EditPredictionProvider::OpenAiCompatibleApi => {
60 settings.open_ai_compatible_api.clone()
61 }
62 _ => None,
63 };
64
65 let http_client = cx.http_client();
66 let buffer_snapshotted_at = Instant::now();
67 let raw_config = store.zeta2_raw_config().cloned();
68 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
69 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
70
71 let excerpt_path: Arc<Path> = snapshot
72 .file()
73 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
74 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
75
76 let repo_url = if can_collect_data {
77 let buffer_id = buffer.read(cx).remote_id();
78 project
79 .read(cx)
80 .git_store()
81 .read(cx)
82 .repository_and_path_for_buffer_id(buffer_id, cx)
83 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
84 } else {
85 None
86 };
87
88 let client = store.client.clone();
89 let llm_token = store.llm_token.clone();
90 let organization_id = store
91 .user_store
92 .read(cx)
93 .current_organization()
94 .map(|organization| organization.id.clone());
95 let app_version = AppVersion::global(cx);
96
97 struct Prediction {
98 prompt_input: ZetaPromptInput,
99 buffer: Entity<Buffer>,
100 snapshot: BufferSnapshot,
101 edits: Vec<(Range<Anchor>, Arc<str>)>,
102 cursor_position: Option<PredictedCursorPosition>,
103 received_response_at: Instant,
104 editable_range_in_buffer: Range<usize>,
105 model_version: Option<String>,
106 }
107
108 let request_task = cx.background_spawn({
109 async move {
110 let zeta_version = raw_config
111 .as_ref()
112 .map(|config| config.format)
113 .unwrap_or(ZetaFormat::default());
114
115 let cursor_offset = position.to_offset(&snapshot);
116 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
117 &snapshot,
118 related_files,
119 events,
120 diagnostic_search_range,
121 excerpt_path,
122 cursor_offset,
123 preferred_experiment,
124 is_open_source,
125 can_collect_data,
126 repo_url,
127 );
128
129 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
130 return Err(anyhow::anyhow!("prompt contains special tokens"));
131 }
132
133 let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
134
135 if let Some(debug_tx) = &debug_tx {
136 debug_tx
137 .unbounded_send(DebugEvent::EditPredictionStarted(
138 EditPredictionStartedDebugEvent {
139 buffer: buffer.downgrade(),
140 prompt: formatted_prompt.clone(),
141 position,
142 },
143 ))
144 .ok();
145 }
146
147 log::trace!("Sending edit prediction request");
148
149 let Some((request_id, output, model_version, usage)) =
150 (if let Some(custom_settings) = &custom_server_settings {
151 let max_tokens = custom_settings.max_output_tokens * 4;
152
153 Some(match custom_settings.prompt_format {
154 EditPredictionPromptFormat::Zeta => {
155 let ranges = &prompt_input.excerpt_ranges;
156 let editable_range_in_excerpt = ranges.editable_350.clone();
157 let prompt = zeta1::format_zeta1_from_input(
158 &prompt_input,
159 editable_range_in_excerpt.clone(),
160 ranges.editable_350_context_150.clone(),
161 );
162 let stop_tokens = vec![
163 EDITABLE_REGION_END_MARKER.to_string(),
164 format!("{EDITABLE_REGION_END_MARKER}\n"),
165 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
166 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
167 ];
168
169 let (response_text, request_id) = send_custom_server_request(
170 provider,
171 custom_settings,
172 prompt,
173 max_tokens,
174 stop_tokens,
175 open_ai_compatible_api_key.clone(),
176 &http_client,
177 )
178 .await?;
179
180 let request_id = EditPredictionId(request_id.into());
181 let output_text = zeta1::clean_zeta1_model_output(&response_text);
182 let parsed_output = output_text.map(|text| ParsedOutput {
183 new_editable_region: text,
184 range_in_excerpt: editable_range_in_excerpt,
185 });
186
187 (request_id, parsed_output, None, None)
188 }
189 EditPredictionPromptFormat::Zeta2 => {
190 let Some(prompt) = formatted_prompt.clone() else {
191 return Ok((None, None));
192 };
193 let prefill = get_prefill(&prompt_input, zeta_version);
194 let prompt = format!("{prompt}{prefill}");
195
196 let (response_text, request_id) = send_custom_server_request(
197 provider,
198 custom_settings,
199 prompt,
200 max_tokens,
201 stop_tokens_for_format(zeta_version)
202 .iter()
203 .map(|token| token.to_string())
204 .collect(),
205 open_ai_compatible_api_key.clone(),
206 &http_client,
207 )
208 .await?;
209
210 let request_id = EditPredictionId(request_id.into());
211 let output_text = if response_text.is_empty() {
212 None
213 } else {
214 let output = format!("{prefill}{response_text}");
215 Some(parse_zeta2_model_output(
216 &output,
217 zeta_version,
218 &prompt_input,
219 )?)
220 };
221
222 (request_id, output_text, None, None)
223 }
224 _ => anyhow::bail!("unsupported prompt format"),
225 })
226 } else if let Some(config) = &raw_config {
227 let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
228 return Ok((None, None));
229 };
230 let prefill = get_prefill(&prompt_input, config.format);
231 let prompt = format!("{prompt}{prefill}");
232 let environment = config
233 .environment
234 .clone()
235 .or_else(|| Some(config.format.to_string().to_lowercase()));
236 let request = RawCompletionRequest {
237 model: config.model_id.clone().unwrap_or_default(),
238 prompt,
239 temperature: None,
240 stop: stop_tokens_for_format(config.format)
241 .iter()
242 .map(|token| std::borrow::Cow::Borrowed(*token))
243 .collect(),
244 max_tokens: Some(2048),
245 environment,
246 };
247
248 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
249 request,
250 client,
251 None,
252 llm_token,
253 organization_id,
254 app_version,
255 )
256 .await?;
257
258 let request_id = EditPredictionId(response.id.clone().into());
259 let output = if let Some(choice) = response.choices.pop() {
260 let response = &choice.text;
261 let output = format!("{prefill}{response}");
262 Some(parse_zeta2_model_output(
263 &output,
264 config.format,
265 &prompt_input,
266 )?)
267 } else {
268 None
269 };
270
271 Some((request_id, output, None, usage))
272 } else {
273 // Use V3 endpoint - server handles model/version selection and suffix stripping
274 let (response, usage) = EditPredictionStore::send_v3_request(
275 prompt_input.clone(),
276 client,
277 llm_token,
278 organization_id,
279 app_version,
280 trigger,
281 )
282 .await?;
283
284 let request_id = EditPredictionId(response.request_id.into());
285 let output_text = Some(response.output).filter(|s| !s.is_empty());
286 let model_version = response.model_version;
287 let parsed_output = ParsedOutput {
288 new_editable_region: output_text.unwrap_or_default(),
289 range_in_excerpt: response.editable_range,
290 };
291
292 Some((request_id, Some(parsed_output), model_version, usage))
293 })
294 else {
295 return Ok((None, None));
296 };
297
298 let received_response_at = Instant::now();
299
300 log::trace!("Got edit prediction response");
301
302 let Some(ParsedOutput {
303 new_editable_region: mut output_text,
304 range_in_excerpt: editable_range_in_excerpt,
305 }) = output
306 else {
307 return Ok((Some((request_id, None)), None));
308 };
309
310 let editable_range_in_buffer = editable_range_in_excerpt.start
311 + full_context_offset_range.start
312 ..editable_range_in_excerpt.end + full_context_offset_range.start;
313
314 let mut old_text = snapshot
315 .text_for_range(editable_range_in_buffer.clone())
316 .collect::<String>();
317
318 // Client-side cursor marker processing (applies to both raw and v3 responses)
319 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
320 if let Some(offset) = cursor_offset_in_output {
321 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
322 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
323 }
324
325 if let Some(debug_tx) = &debug_tx {
326 debug_tx
327 .unbounded_send(DebugEvent::EditPredictionFinished(
328 EditPredictionFinishedDebugEvent {
329 buffer: buffer.downgrade(),
330 position,
331 model_output: Some(output_text.clone()),
332 },
333 ))
334 .ok();
335 }
336
337 if !output_text.is_empty() && !output_text.ends_with('\n') {
338 output_text.push('\n');
339 }
340 if !old_text.is_empty() && !old_text.ends_with('\n') {
341 old_text.push('\n');
342 }
343
344 let (edits, cursor_position) = compute_edits_and_cursor_position(
345 old_text,
346 &output_text,
347 editable_range_in_buffer.start,
348 cursor_offset_in_output,
349 &snapshot,
350 );
351
352 anyhow::Ok((
353 Some((
354 request_id,
355 Some(Prediction {
356 prompt_input,
357 buffer,
358 snapshot: snapshot.clone(),
359 edits,
360 cursor_position,
361 received_response_at,
362 editable_range_in_buffer,
363 model_version,
364 }),
365 )),
366 usage,
367 ))
368 }
369 });
370
371 cx.spawn(async move |this, cx| {
372 let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
373 return Ok(None);
374 };
375
376 let Some(Prediction {
377 prompt_input: inputs,
378 buffer: edited_buffer,
379 snapshot: edited_buffer_snapshot,
380 edits,
381 cursor_position,
382 received_response_at,
383 editable_range_in_buffer,
384 model_version,
385 }) = prediction
386 else {
387 return Ok(Some(EditPredictionResult {
388 id,
389 prediction: Err(EditPredictionRejectReason::Empty),
390 }));
391 };
392
393 if can_collect_data {
394 let weak_this = this.clone();
395 let id = id.clone();
396 let edited_buffer = edited_buffer.clone();
397 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
398 let example_task = capture_data.and_then(|stored_events| {
399 cx.update(|cx| {
400 crate::capture_example(
401 project.clone(),
402 edited_buffer.clone(),
403 position,
404 stored_events,
405 false,
406 cx,
407 )
408 })
409 });
410 cx.spawn(async move |cx| {
411 let example_spec = if let Some(task) = example_task {
412 task.await.ok()
413 } else {
414 None
415 };
416
417 weak_this
418 .update(cx, |this, cx| {
419 this.enqueue_settled_prediction(
420 id.clone(),
421 &project,
422 &edited_buffer,
423 &edited_buffer_snapshot,
424 editable_range_in_buffer,
425 example_spec,
426 cx,
427 );
428 })
429 .ok();
430 })
431 .detach();
432 }
433
434 Ok(Some(
435 EditPredictionResult::new(
436 id,
437 &edited_buffer,
438 &edited_buffer_snapshot,
439 edits.into(),
440 cursor_position,
441 buffer_snapshotted_at,
442 received_response_at,
443 inputs,
444 model_version,
445 cx,
446 )
447 .await,
448 ))
449 })
450}
451
452fn handle_api_response<T>(
453 this: &WeakEntity<EditPredictionStore>,
454 response: Result<(T, Option<client::EditPredictionUsage>)>,
455 cx: &mut gpui::AsyncApp,
456) -> Result<T> {
457 match response {
458 Ok((data, usage)) => {
459 if let Some(usage) = usage {
460 this.update(cx, |this, cx| {
461 this.user_store.update(cx, |user_store, cx| {
462 user_store.update_edit_prediction_usage(usage, cx);
463 });
464 })
465 .ok();
466 }
467 Ok(data)
468 }
469 Err(err) => {
470 if err.is::<ZedUpdateRequiredError>() {
471 cx.update(|cx| {
472 this.update(cx, |this, _cx| {
473 this.update_required = true;
474 })
475 .ok();
476
477 let error_message: SharedString = err.to_string().into();
478 show_app_notification(
479 NotificationId::unique::<ZedUpdateRequiredError>(),
480 cx,
481 move |cx| {
482 cx.new(|cx| {
483 ErrorMessagePrompt::new(error_message.clone(), cx)
484 .with_link_button("Update Zed", "https://zed.dev/releases")
485 })
486 },
487 );
488 });
489 }
490 Err(err)
491 }
492 }
493}
494
495pub(crate) fn active_buffer_diagnostics(
496 snapshot: &language::BufferSnapshot,
497 diagnostic_search_range: Range<Point>,
498 additional_context_token_count: usize,
499) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
500 snapshot
501 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
502 .map(|entry| {
503 let severity = match entry.diagnostic.severity {
504 DiagnosticSeverity::ERROR => Some(1),
505 DiagnosticSeverity::WARNING => Some(2),
506 DiagnosticSeverity::INFORMATION => Some(3),
507 DiagnosticSeverity::HINT => Some(4),
508 _ => None,
509 };
510 let diagnostic_point_range = entry.range.clone();
511 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
512 snapshot,
513 diagnostic_point_range.clone(),
514 additional_context_token_count,
515 );
516 let snippet = snapshot
517 .text_for_range(snippet_point_range.clone())
518 .collect::<String>();
519 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
520 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
521 zeta_prompt::ActiveBufferDiagnostic {
522 severity,
523 message: entry.diagnostic.message.clone(),
524 snippet,
525 snippet_buffer_row_range: diagnostic_point_range.start.row
526 ..diagnostic_point_range.end.row,
527 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
528 ..diagnostic_offset_range.end - snippet_start_offset,
529 }
530 })
531 .collect()
532}
533
534pub fn zeta2_prompt_input(
535 snapshot: &language::BufferSnapshot,
536 related_files: Vec<zeta_prompt::RelatedFile>,
537 events: Vec<Arc<zeta_prompt::Event>>,
538 diagnostic_search_range: Range<Point>,
539 excerpt_path: Arc<Path>,
540 cursor_offset: usize,
541 preferred_experiment: Option<String>,
542 is_open_source: bool,
543 can_collect_data: bool,
544 repo_url: Option<String>,
545) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
546 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
547 compute_cursor_excerpt(snapshot, cursor_offset);
548
549 let cursor_excerpt: Arc<str> = snapshot
550 .text_for_range(excerpt_point_range.clone())
551 .collect::<String>()
552 .into();
553 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
554 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
555 &cursor_excerpt,
556 cursor_offset_in_excerpt,
557 &syntax_ranges,
558 );
559
560 let active_buffer_diagnostics =
561 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
562
563 let prompt_input = zeta_prompt::ZetaPromptInput {
564 cursor_path: excerpt_path,
565 cursor_excerpt,
566 cursor_offset_in_excerpt,
567 excerpt_start_row: Some(excerpt_point_range.start.row),
568 events,
569 related_files: Some(related_files),
570 active_buffer_diagnostics,
571 excerpt_ranges,
572 syntax_ranges: Some(syntax_ranges),
573 experiment: preferred_experiment,
574 in_open_source_repo: is_open_source,
575 can_collect_data,
576 repo_url,
577 };
578 (excerpt_offset_range, prompt_input)
579}
580
581pub(crate) fn edit_prediction_accepted(
582 store: &EditPredictionStore,
583 current_prediction: CurrentEditPrediction,
584 cx: &App,
585) {
586 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
587 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
588 return;
589 }
590
591 let request_id = current_prediction.prediction.id.to_string();
592 let model_version = current_prediction.prediction.model_version;
593 let require_auth = custom_accept_url.is_none();
594 let client = store.client.clone();
595 let llm_token = store.llm_token.clone();
596 let organization_id = store
597 .user_store
598 .read(cx)
599 .current_organization()
600 .map(|organization| organization.id.clone());
601 let app_version = AppVersion::global(cx);
602
603 cx.background_spawn(async move {
604 let url = if let Some(accept_edits_url) = custom_accept_url {
605 gpui::http_client::Url::parse(&accept_edits_url)?
606 } else {
607 client
608 .http_client()
609 .build_zed_llm_url("/predict_edits/accept", &[])?
610 };
611
612 let response = EditPredictionStore::send_api_request::<()>(
613 move |builder| {
614 let req = builder.uri(url.as_ref()).body(
615 serde_json::to_string(&AcceptEditPredictionBody {
616 request_id: request_id.clone(),
617 model_version: model_version.clone(),
618 })?
619 .into(),
620 );
621 Ok(req?)
622 },
623 client,
624 llm_token,
625 organization_id,
626 app_version,
627 require_auth,
628 )
629 .await;
630
631 response?;
632 anyhow::Ok(())
633 })
634 .detach_and_log_err(cx);
635}
636
637pub fn compute_edits(
638 old_text: String,
639 new_text: &str,
640 offset: usize,
641 snapshot: &BufferSnapshot,
642) -> Vec<(Range<Anchor>, Arc<str>)> {
643 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
644}
645
646pub fn compute_edits_and_cursor_position(
647 old_text: String,
648 new_text: &str,
649 offset: usize,
650 cursor_offset_in_new_text: Option<usize>,
651 snapshot: &BufferSnapshot,
652) -> (
653 Vec<(Range<Anchor>, Arc<str>)>,
654 Option<PredictedCursorPosition>,
655) {
656 let diffs = text_diff(&old_text, new_text);
657
658 // Delta represents the cumulative change in byte count from all preceding edits.
659 // new_offset = old_offset + delta, so old_offset = new_offset - delta
660 let mut delta: isize = 0;
661 let mut cursor_position: Option<PredictedCursorPosition> = None;
662 let buffer_len = snapshot.len();
663
664 let edits = diffs
665 .iter()
666 .map(|(raw_old_range, new_text)| {
667 // Compute cursor position if it falls within or before this edit.
668 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
669 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
670 let edit_end_in_new = edit_start_in_new + new_text.len();
671
672 if cursor_offset < edit_start_in_new {
673 let cursor_in_old = (cursor_offset as isize - delta) as usize;
674 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
675 cursor_position = Some(PredictedCursorPosition::at_anchor(
676 snapshot.anchor_after(buffer_offset),
677 ));
678 } else if cursor_offset < edit_end_in_new {
679 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
680 let offset_within_insertion = cursor_offset - edit_start_in_new;
681 cursor_position = Some(PredictedCursorPosition::new(
682 snapshot.anchor_before(buffer_offset),
683 offset_within_insertion,
684 ));
685 }
686
687 delta += new_text.len() as isize - raw_old_range.len() as isize;
688 }
689
690 // Compute the edit with prefix/suffix trimming.
691 let mut old_range = raw_old_range.clone();
692 let old_slice = &old_text[old_range.clone()];
693
694 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
695 let suffix_len = common_prefix(
696 old_slice[prefix_len..].chars().rev(),
697 new_text[prefix_len..].chars().rev(),
698 );
699
700 old_range.start += offset;
701 old_range.end += offset;
702 old_range.start += prefix_len;
703 old_range.end -= suffix_len;
704
705 old_range.start = old_range.start.min(buffer_len);
706 old_range.end = old_range.end.min(buffer_len);
707
708 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
709 let range = if old_range.is_empty() {
710 let anchor = snapshot.anchor_after(old_range.start);
711 anchor..anchor
712 } else {
713 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
714 };
715 (range, new_text)
716 })
717 .collect();
718
719 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
720 let cursor_in_old = (cursor_offset as isize - delta) as usize;
721 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
722 cursor_position = Some(PredictedCursorPosition::at_anchor(
723 snapshot.anchor_after(buffer_offset),
724 ));
725 }
726
727 (edits, cursor_position)
728}
729
730fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
731 a.zip(b)
732 .take_while(|(a, b)| a == b)
733 .map(|(a, _)| a.len_utf8())
734 .sum()
735}