1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError,
5 cursor_excerpt::{self, compute_cursor_excerpt, compute_syntax_ranges},
6 prediction::EditPredictionResult,
7};
8use anyhow::Result;
9use cloud_llm_client::{
10 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
11};
12use edit_prediction_types::PredictedCursorPosition;
13use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
14use language::{
15 Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
16 language_settings::all_language_settings, text_diff,
17};
18use release_channel::AppVersion;
19use settings::EditPredictionPromptFormat;
20use text::{Anchor, Bias, Point};
21use ui::SharedString;
22use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
23use zeta_prompt::{ParsedOutput, ZetaPromptInput};
24
25use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
26use zeta_prompt::{
27 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
28 prompt_input_contains_special_tokens, stop_tokens_for_format,
29 zeta1::{self, EDITABLE_REGION_END_MARKER},
30};
31
32use crate::open_ai_compatible::{
33 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
34};
35
36pub fn request_prediction_with_zeta(
37 store: &mut EditPredictionStore,
38 EditPredictionModelInput {
39 buffer,
40 snapshot,
41 position,
42 related_files,
43 events,
44 debug_tx,
45 trigger,
46 project,
47 diagnostic_search_range,
48 can_collect_data,
49 is_open_source,
50 ..
51 }: EditPredictionModelInput,
52 capture_data: Option<Vec<StoredEvent>>,
53 cx: &mut Context<EditPredictionStore>,
54) -> Task<Result<Option<EditPredictionResult>>> {
55 let settings = &all_language_settings(None, cx).edit_predictions;
56 let provider = settings.provider;
57 let custom_server_settings = match provider {
58 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
59 settings::EditPredictionProvider::OpenAiCompatibleApi => {
60 settings.open_ai_compatible_api.clone()
61 }
62 _ => None,
63 };
64
65 let http_client = cx.http_client();
66 let buffer_snapshotted_at = Instant::now();
67 let raw_config = store.zeta2_raw_config().cloned();
68 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
69 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
70
71 let excerpt_path: Arc<Path> = snapshot
72 .file()
73 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
74 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
75
76 let repo_url = if can_collect_data {
77 let buffer_id = buffer.read(cx).remote_id();
78 project
79 .read(cx)
80 .git_store()
81 .read(cx)
82 .repository_and_path_for_buffer_id(buffer_id, cx)
83 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
84 } else {
85 None
86 };
87
88 let client = store.client.clone();
89 let llm_token = store.llm_token.clone();
90 let organization_id = store
91 .user_store
92 .read(cx)
93 .current_organization()
94 .map(|organization| organization.id.clone());
95 let app_version = AppVersion::global(cx);
96
97 struct Prediction {
98 prompt_input: ZetaPromptInput,
99 buffer: Entity<Buffer>,
100 snapshot: BufferSnapshot,
101 edits: Vec<(Range<Anchor>, Arc<str>)>,
102 cursor_position: Option<PredictedCursorPosition>,
103 received_response_at: Instant,
104 editable_range_in_buffer: Range<usize>,
105 model_version: Option<String>,
106 }
107
108 let request_task = cx.background_spawn({
109 async move {
110 let zeta_version = raw_config
111 .as_ref()
112 .map(|config| config.format)
113 .unwrap_or(ZetaFormat::default());
114
115 let cursor_offset = position.to_offset(&snapshot);
116 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
117 &snapshot,
118 related_files,
119 events,
120 diagnostic_search_range,
121 excerpt_path,
122 cursor_offset,
123 preferred_experiment,
124 is_open_source,
125 can_collect_data,
126 repo_url,
127 );
128
129 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
130 return Err(anyhow::anyhow!("prompt contains special tokens"));
131 }
132
133 if let Some(debug_tx) = &debug_tx {
134 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
135 debug_tx
136 .unbounded_send(DebugEvent::EditPredictionStarted(
137 EditPredictionStartedDebugEvent {
138 buffer: buffer.downgrade(),
139 prompt: Some(prompt),
140 position,
141 },
142 ))
143 .ok();
144 }
145
146 log::trace!("Sending edit prediction request");
147
148 let (request_id, output, model_version, usage) =
149 if let Some(custom_settings) = &custom_server_settings {
150 let max_tokens = custom_settings.max_output_tokens * 4;
151
152 match custom_settings.prompt_format {
153 EditPredictionPromptFormat::Zeta => {
154 let ranges = &prompt_input.excerpt_ranges;
155 let editable_range_in_excerpt = ranges.editable_350.clone();
156 let prompt = zeta1::format_zeta1_from_input(
157 &prompt_input,
158 editable_range_in_excerpt.clone(),
159 ranges.editable_350_context_150.clone(),
160 );
161 let stop_tokens = vec![
162 EDITABLE_REGION_END_MARKER.to_string(),
163 format!("{EDITABLE_REGION_END_MARKER}\n"),
164 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
165 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
166 ];
167
168 let (response_text, request_id) = send_custom_server_request(
169 provider,
170 custom_settings,
171 prompt,
172 max_tokens,
173 stop_tokens,
174 open_ai_compatible_api_key.clone(),
175 &http_client,
176 )
177 .await?;
178
179 let request_id = EditPredictionId(request_id.into());
180 let output_text = zeta1::clean_zeta1_model_output(&response_text);
181 let parsed_output = output_text.map(|text| ParsedOutput {
182 new_editable_region: text,
183 range_in_excerpt: editable_range_in_excerpt,
184 });
185
186 (request_id, parsed_output, None, None)
187 }
188 EditPredictionPromptFormat::Zeta2 => {
189 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
190 let prefill = get_prefill(&prompt_input, zeta_version);
191 let prompt = format!("{prompt}{prefill}");
192
193 let (response_text, request_id) = send_custom_server_request(
194 provider,
195 custom_settings,
196 prompt,
197 max_tokens,
198 stop_tokens_for_format(zeta_version)
199 .iter()
200 .map(|token| token.to_string())
201 .collect(),
202 open_ai_compatible_api_key.clone(),
203 &http_client,
204 )
205 .await?;
206
207 let request_id = EditPredictionId(request_id.into());
208 let output_text = if response_text.is_empty() {
209 None
210 } else {
211 let output = format!("{prefill}{response_text}");
212 Some(parse_zeta2_model_output(
213 &output,
214 zeta_version,
215 &prompt_input,
216 )?)
217 };
218
219 (request_id, output_text, None, None)
220 }
221 _ => anyhow::bail!("unsupported prompt format"),
222 }
223 } else if let Some(config) = &raw_config {
224 let prompt = format_zeta_prompt(&prompt_input, config.format);
225 let prefill = get_prefill(&prompt_input, config.format);
226 let prompt = format!("{prompt}{prefill}");
227 let environment = config
228 .environment
229 .clone()
230 .or_else(|| Some(config.format.to_string().to_lowercase()));
231 let request = RawCompletionRequest {
232 model: config.model_id.clone().unwrap_or_default(),
233 prompt,
234 temperature: None,
235 stop: stop_tokens_for_format(config.format)
236 .iter()
237 .map(|token| std::borrow::Cow::Borrowed(*token))
238 .collect(),
239 max_tokens: Some(2048),
240 environment,
241 };
242
243 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
244 request,
245 client,
246 None,
247 llm_token,
248 organization_id,
249 app_version,
250 )
251 .await?;
252
253 let request_id = EditPredictionId(response.id.clone().into());
254 let output = if let Some(choice) = response.choices.pop() {
255 let response = &choice.text;
256 let output = format!("{prefill}{response}");
257 Some(parse_zeta2_model_output(
258 &output,
259 config.format,
260 &prompt_input,
261 )?)
262 } else {
263 None
264 };
265
266 (request_id, output, None, usage)
267 } else {
268 // Use V3 endpoint - server handles model/version selection and suffix stripping
269 let (response, usage) = EditPredictionStore::send_v3_request(
270 prompt_input.clone(),
271 client,
272 llm_token,
273 organization_id,
274 app_version,
275 trigger,
276 )
277 .await?;
278
279 let request_id = EditPredictionId(response.request_id.into());
280 let output_text = Some(response.output).filter(|s| !s.is_empty());
281 let model_version = response.model_version;
282 let parsed_output = ParsedOutput {
283 new_editable_region: output_text.unwrap_or_default(),
284 range_in_excerpt: response.editable_range,
285 };
286
287 (request_id, Some(parsed_output), model_version, usage)
288 };
289
290 let received_response_at = Instant::now();
291
292 log::trace!("Got edit prediction response");
293
294 let Some(ParsedOutput {
295 new_editable_region: mut output_text,
296 range_in_excerpt: editable_range_in_excerpt,
297 }) = output
298 else {
299 return Ok(((request_id, None), None));
300 };
301
302 let editable_range_in_buffer = editable_range_in_excerpt.start
303 + full_context_offset_range.start
304 ..editable_range_in_excerpt.end + full_context_offset_range.start;
305
306 let mut old_text = snapshot
307 .text_for_range(editable_range_in_buffer.clone())
308 .collect::<String>();
309
310 // Client-side cursor marker processing (applies to both raw and v3 responses)
311 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
312 if let Some(offset) = cursor_offset_in_output {
313 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
314 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
315 }
316
317 if let Some(debug_tx) = &debug_tx {
318 debug_tx
319 .unbounded_send(DebugEvent::EditPredictionFinished(
320 EditPredictionFinishedDebugEvent {
321 buffer: buffer.downgrade(),
322 position,
323 model_output: Some(output_text.clone()),
324 },
325 ))
326 .ok();
327 }
328
329 if !output_text.is_empty() && !output_text.ends_with('\n') {
330 output_text.push('\n');
331 }
332 if !old_text.is_empty() && !old_text.ends_with('\n') {
333 old_text.push('\n');
334 }
335
336 let (edits, cursor_position) = compute_edits_and_cursor_position(
337 old_text,
338 &output_text,
339 editable_range_in_buffer.start,
340 cursor_offset_in_output,
341 &snapshot,
342 );
343
344 anyhow::Ok((
345 (
346 request_id,
347 Some(Prediction {
348 prompt_input,
349 buffer,
350 snapshot: snapshot.clone(),
351 edits,
352 cursor_position,
353 received_response_at,
354 editable_range_in_buffer,
355 model_version,
356 }),
357 ),
358 usage,
359 ))
360 }
361 });
362
363 cx.spawn(async move |this, cx| {
364 let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
365
366 let Some(Prediction {
367 prompt_input: inputs,
368 buffer: edited_buffer,
369 snapshot: edited_buffer_snapshot,
370 edits,
371 cursor_position,
372 received_response_at,
373 editable_range_in_buffer,
374 model_version,
375 }) = prediction
376 else {
377 return Ok(Some(EditPredictionResult {
378 id,
379 prediction: Err(EditPredictionRejectReason::Empty),
380 }));
381 };
382
383 if can_collect_data {
384 let weak_this = this.clone();
385 let id = id.clone();
386 let edited_buffer = edited_buffer.clone();
387 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
388 let example_task = capture_data.and_then(|stored_events| {
389 cx.update(|cx| {
390 crate::capture_example(
391 project.clone(),
392 edited_buffer.clone(),
393 position,
394 stored_events,
395 false,
396 cx,
397 )
398 })
399 });
400 cx.spawn(async move |cx| {
401 let example_spec = if let Some(task) = example_task {
402 task.await.ok()
403 } else {
404 None
405 };
406
407 weak_this
408 .update(cx, |this, cx| {
409 this.enqueue_settled_prediction(
410 id.clone(),
411 &project,
412 &edited_buffer,
413 &edited_buffer_snapshot,
414 editable_range_in_buffer,
415 example_spec,
416 cx,
417 );
418 })
419 .ok();
420 })
421 .detach();
422 }
423
424 Ok(Some(
425 EditPredictionResult::new(
426 id,
427 &edited_buffer,
428 &edited_buffer_snapshot,
429 edits.into(),
430 cursor_position,
431 buffer_snapshotted_at,
432 received_response_at,
433 inputs,
434 model_version,
435 cx,
436 )
437 .await,
438 ))
439 })
440}
441
442fn handle_api_response<T>(
443 this: &WeakEntity<EditPredictionStore>,
444 response: Result<(T, Option<client::EditPredictionUsage>)>,
445 cx: &mut gpui::AsyncApp,
446) -> Result<T> {
447 match response {
448 Ok((data, usage)) => {
449 if let Some(usage) = usage {
450 this.update(cx, |this, cx| {
451 this.user_store.update(cx, |user_store, cx| {
452 user_store.update_edit_prediction_usage(usage, cx);
453 });
454 })
455 .ok();
456 }
457 Ok(data)
458 }
459 Err(err) => {
460 if err.is::<ZedUpdateRequiredError>() {
461 cx.update(|cx| {
462 this.update(cx, |this, _cx| {
463 this.update_required = true;
464 })
465 .ok();
466
467 let error_message: SharedString = err.to_string().into();
468 show_app_notification(
469 NotificationId::unique::<ZedUpdateRequiredError>(),
470 cx,
471 move |cx| {
472 cx.new(|cx| {
473 ErrorMessagePrompt::new(error_message.clone(), cx)
474 .with_link_button("Update Zed", "https://zed.dev/releases")
475 })
476 },
477 );
478 });
479 }
480 Err(err)
481 }
482 }
483}
484
485pub(crate) fn active_buffer_diagnostics(
486 snapshot: &language::BufferSnapshot,
487 diagnostic_search_range: Range<Point>,
488 additional_context_token_count: usize,
489) -> Vec<zeta_prompt::ActiveBufferDiagnostic> {
490 snapshot
491 .diagnostics_in_range::<Point, Point>(diagnostic_search_range, false)
492 .map(|entry| {
493 let severity = match entry.diagnostic.severity {
494 DiagnosticSeverity::ERROR => Some(1),
495 DiagnosticSeverity::WARNING => Some(2),
496 DiagnosticSeverity::INFORMATION => Some(3),
497 DiagnosticSeverity::HINT => Some(4),
498 _ => None,
499 };
500 let diagnostic_point_range = entry.range.clone();
501 let snippet_point_range = cursor_excerpt::expand_context_syntactically_then_linewise(
502 snapshot,
503 diagnostic_point_range.clone(),
504 additional_context_token_count,
505 );
506 let snippet = snapshot
507 .text_for_range(snippet_point_range.clone())
508 .collect::<String>();
509 let snippet_start_offset = snippet_point_range.start.to_offset(snapshot);
510 let diagnostic_offset_range = diagnostic_point_range.to_offset(snapshot);
511 zeta_prompt::ActiveBufferDiagnostic {
512 severity,
513 message: entry.diagnostic.message.clone(),
514 snippet,
515 snippet_buffer_row_range: diagnostic_point_range.start.row
516 ..diagnostic_point_range.end.row,
517 diagnostic_range_in_snippet: diagnostic_offset_range.start - snippet_start_offset
518 ..diagnostic_offset_range.end - snippet_start_offset,
519 }
520 })
521 .collect()
522}
523
524pub fn zeta2_prompt_input(
525 snapshot: &language::BufferSnapshot,
526 related_files: Vec<zeta_prompt::RelatedFile>,
527 events: Vec<Arc<zeta_prompt::Event>>,
528 diagnostic_search_range: Range<Point>,
529 excerpt_path: Arc<Path>,
530 cursor_offset: usize,
531 preferred_experiment: Option<String>,
532 is_open_source: bool,
533 can_collect_data: bool,
534 repo_url: Option<String>,
535) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
536 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
537 compute_cursor_excerpt(snapshot, cursor_offset);
538
539 let cursor_excerpt: Arc<str> = snapshot
540 .text_for_range(excerpt_point_range.clone())
541 .collect::<String>()
542 .into();
543 let syntax_ranges = compute_syntax_ranges(snapshot, cursor_offset, &excerpt_offset_range);
544 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
545 &cursor_excerpt,
546 cursor_offset_in_excerpt,
547 &syntax_ranges,
548 );
549
550 let active_buffer_diagnostics =
551 active_buffer_diagnostics(snapshot, diagnostic_search_range, 100);
552
553 let prompt_input = zeta_prompt::ZetaPromptInput {
554 cursor_path: excerpt_path,
555 cursor_excerpt,
556 cursor_offset_in_excerpt,
557 excerpt_start_row: Some(excerpt_point_range.start.row),
558 events,
559 related_files: Some(related_files),
560 active_buffer_diagnostics,
561 excerpt_ranges,
562 syntax_ranges: Some(syntax_ranges),
563 experiment: preferred_experiment,
564 in_open_source_repo: is_open_source,
565 can_collect_data,
566 repo_url,
567 };
568 (excerpt_offset_range, prompt_input)
569}
570
571pub(crate) fn edit_prediction_accepted(
572 store: &EditPredictionStore,
573 current_prediction: CurrentEditPrediction,
574 cx: &App,
575) {
576 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
577 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
578 return;
579 }
580
581 let request_id = current_prediction.prediction.id.to_string();
582 let model_version = current_prediction.prediction.model_version;
583 let require_auth = custom_accept_url.is_none();
584 let client = store.client.clone();
585 let llm_token = store.llm_token.clone();
586 let organization_id = store
587 .user_store
588 .read(cx)
589 .current_organization()
590 .map(|organization| organization.id.clone());
591 let app_version = AppVersion::global(cx);
592
593 cx.background_spawn(async move {
594 let url = if let Some(accept_edits_url) = custom_accept_url {
595 gpui::http_client::Url::parse(&accept_edits_url)?
596 } else {
597 client
598 .http_client()
599 .build_zed_llm_url("/predict_edits/accept", &[])?
600 };
601
602 let response = EditPredictionStore::send_api_request::<()>(
603 move |builder| {
604 let req = builder.uri(url.as_ref()).body(
605 serde_json::to_string(&AcceptEditPredictionBody {
606 request_id: request_id.clone(),
607 model_version: model_version.clone(),
608 })?
609 .into(),
610 );
611 Ok(req?)
612 },
613 client,
614 llm_token,
615 organization_id,
616 app_version,
617 require_auth,
618 )
619 .await;
620
621 response?;
622 anyhow::Ok(())
623 })
624 .detach_and_log_err(cx);
625}
626
627pub fn compute_edits(
628 old_text: String,
629 new_text: &str,
630 offset: usize,
631 snapshot: &BufferSnapshot,
632) -> Vec<(Range<Anchor>, Arc<str>)> {
633 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
634}
635
636pub fn compute_edits_and_cursor_position(
637 old_text: String,
638 new_text: &str,
639 offset: usize,
640 cursor_offset_in_new_text: Option<usize>,
641 snapshot: &BufferSnapshot,
642) -> (
643 Vec<(Range<Anchor>, Arc<str>)>,
644 Option<PredictedCursorPosition>,
645) {
646 let diffs = text_diff(&old_text, new_text);
647
648 // Delta represents the cumulative change in byte count from all preceding edits.
649 // new_offset = old_offset + delta, so old_offset = new_offset - delta
650 let mut delta: isize = 0;
651 let mut cursor_position: Option<PredictedCursorPosition> = None;
652 let buffer_len = snapshot.len();
653
654 let edits = diffs
655 .iter()
656 .map(|(raw_old_range, new_text)| {
657 // Compute cursor position if it falls within or before this edit.
658 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
659 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
660 let edit_end_in_new = edit_start_in_new + new_text.len();
661
662 if cursor_offset < edit_start_in_new {
663 let cursor_in_old = (cursor_offset as isize - delta) as usize;
664 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
665 cursor_position = Some(PredictedCursorPosition::at_anchor(
666 snapshot.anchor_after(buffer_offset),
667 ));
668 } else if cursor_offset < edit_end_in_new {
669 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
670 let offset_within_insertion = cursor_offset - edit_start_in_new;
671 cursor_position = Some(PredictedCursorPosition::new(
672 snapshot.anchor_before(buffer_offset),
673 offset_within_insertion,
674 ));
675 }
676
677 delta += new_text.len() as isize - raw_old_range.len() as isize;
678 }
679
680 // Compute the edit with prefix/suffix trimming.
681 let mut old_range = raw_old_range.clone();
682 let old_slice = &old_text[old_range.clone()];
683
684 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
685 let suffix_len = common_prefix(
686 old_slice[prefix_len..].chars().rev(),
687 new_text[prefix_len..].chars().rev(),
688 );
689
690 old_range.start += offset;
691 old_range.end += offset;
692 old_range.start += prefix_len;
693 old_range.end -= suffix_len;
694
695 old_range.start = old_range.start.min(buffer_len);
696 old_range.end = old_range.end.min(buffer_len);
697
698 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
699 let range = if old_range.is_empty() {
700 let anchor = snapshot.anchor_after(old_range.start);
701 anchor..anchor
702 } else {
703 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
704 };
705 (range, new_text)
706 })
707 .collect();
708
709 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
710 let cursor_in_old = (cursor_offset as isize - delta) as usize;
711 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
712 cursor_position = Some(PredictedCursorPosition::at_anchor(
713 snapshot.anchor_after(buffer_offset),
714 ));
715 }
716
717 (edits, cursor_position)
718}
719
720fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
721 a.zip(b)
722 .take_while(|(a, b)| a == b)
723 .map(|(a, _)| a.len_utf8())
724 .sum()
725}