1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
5 prediction::EditPredictionResult,
6};
7use anyhow::Result;
8use cloud_llm_client::{
9 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
10};
11use edit_prediction_types::PredictedCursorPosition;
12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
13use language::{
14 Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
15 text_diff,
16};
17use release_channel::AppVersion;
18use settings::EditPredictionPromptFormat;
19use text::{Anchor, Bias};
20use ui::SharedString;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::{ParsedOutput, ZetaPromptInput};
23
24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
25use zeta_prompt::{
26 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
27 prompt_input_contains_special_tokens, stop_tokens_for_format,
28 zeta1::{self, EDITABLE_REGION_END_MARKER},
29};
30
31use crate::open_ai_compatible::{
32 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
33};
34
35pub fn request_prediction_with_zeta(
36 store: &mut EditPredictionStore,
37 EditPredictionModelInput {
38 buffer,
39 snapshot,
40 position,
41 related_files,
42 events,
43 debug_tx,
44 trigger,
45 project,
46 can_collect_data,
47 is_open_source,
48 ..
49 }: EditPredictionModelInput,
50 capture_data: Option<Vec<StoredEvent>>,
51 cx: &mut Context<EditPredictionStore>,
52) -> Task<Result<Option<EditPredictionResult>>> {
53 let settings = &all_language_settings(None, cx).edit_predictions;
54 let provider = settings.provider;
55 let custom_server_settings = match provider {
56 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
57 settings::EditPredictionProvider::OpenAiCompatibleApi => {
58 settings.open_ai_compatible_api.clone()
59 }
60 _ => None,
61 };
62
63 let http_client = cx.http_client();
64 let buffer_snapshotted_at = Instant::now();
65 let raw_config = store.zeta2_raw_config().cloned();
66 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
67 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
68
69 let excerpt_path: Arc<Path> = snapshot
70 .file()
71 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
72 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
73
74 let repo_url = if can_collect_data {
75 let buffer_id = buffer.read(cx).remote_id();
76 project
77 .read(cx)
78 .git_store()
79 .read(cx)
80 .repository_and_path_for_buffer_id(buffer_id, cx)
81 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
82 } else {
83 None
84 };
85
86 let client = store.client.clone();
87 let llm_token = store.llm_token.clone();
88 let organization_id = store
89 .user_store
90 .read(cx)
91 .current_organization()
92 .map(|organization| organization.id.clone());
93 let app_version = AppVersion::global(cx);
94
95 struct Prediction {
96 prompt_input: ZetaPromptInput,
97 buffer: Entity<Buffer>,
98 snapshot: BufferSnapshot,
99 edits: Vec<(Range<Anchor>, Arc<str>)>,
100 cursor_position: Option<PredictedCursorPosition>,
101 received_response_at: Instant,
102 editable_range_in_buffer: Range<usize>,
103 model_version: Option<String>,
104 }
105
106 let request_task = cx.background_spawn({
107 async move {
108 let zeta_version = raw_config
109 .as_ref()
110 .map(|config| config.format)
111 .unwrap_or(ZetaFormat::default());
112
113 let cursor_offset = position.to_offset(&snapshot);
114 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115 &snapshot,
116 related_files,
117 events,
118 excerpt_path,
119 cursor_offset,
120 preferred_experiment,
121 is_open_source,
122 can_collect_data,
123 repo_url,
124 );
125
126 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127 return Err(anyhow::anyhow!("prompt contains special tokens"));
128 }
129
130 if let Some(debug_tx) = &debug_tx {
131 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132 debug_tx
133 .unbounded_send(DebugEvent::EditPredictionStarted(
134 EditPredictionStartedDebugEvent {
135 buffer: buffer.downgrade(),
136 prompt: Some(prompt),
137 position,
138 },
139 ))
140 .ok();
141 }
142
143 log::trace!("Sending edit prediction request");
144
145 let (request_id, output, model_version, usage) =
146 if let Some(custom_settings) = &custom_server_settings {
147 let max_tokens = custom_settings.max_output_tokens * 4;
148
149 match custom_settings.prompt_format {
150 EditPredictionPromptFormat::Zeta => {
151 let ranges = &prompt_input.excerpt_ranges;
152 let editable_range_in_excerpt = ranges.editable_350.clone();
153 let prompt = zeta1::format_zeta1_from_input(
154 &prompt_input,
155 editable_range_in_excerpt.clone(),
156 ranges.editable_350_context_150.clone(),
157 );
158 let stop_tokens = vec![
159 EDITABLE_REGION_END_MARKER.to_string(),
160 format!("{EDITABLE_REGION_END_MARKER}\n"),
161 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163 ];
164
165 let (response_text, request_id) = send_custom_server_request(
166 provider,
167 custom_settings,
168 prompt,
169 max_tokens,
170 stop_tokens,
171 open_ai_compatible_api_key.clone(),
172 &http_client,
173 )
174 .await?;
175
176 let request_id = EditPredictionId(request_id.into());
177 let output_text = zeta1::clean_zeta1_model_output(&response_text);
178 let parsed_output = output_text.map(|text| ParsedOutput {
179 new_editable_region: text,
180 range_in_excerpt: editable_range_in_excerpt,
181 });
182
183 (request_id, parsed_output, None, None)
184 }
185 EditPredictionPromptFormat::Zeta2 => {
186 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
187 let prefill = get_prefill(&prompt_input, zeta_version);
188 let prompt = format!("{prompt}{prefill}");
189
190 let (response_text, request_id) = send_custom_server_request(
191 provider,
192 custom_settings,
193 prompt,
194 max_tokens,
195 stop_tokens_for_format(zeta_version)
196 .iter()
197 .map(|token| token.to_string())
198 .collect(),
199 open_ai_compatible_api_key.clone(),
200 &http_client,
201 )
202 .await?;
203
204 let request_id = EditPredictionId(request_id.into());
205 let output_text = if response_text.is_empty() {
206 None
207 } else {
208 let output = format!("{prefill}{response_text}");
209 Some(parse_zeta2_model_output(
210 &output,
211 zeta_version,
212 &prompt_input,
213 )?)
214 };
215
216 (request_id, output_text, None, None)
217 }
218 _ => anyhow::bail!("unsupported prompt format"),
219 }
220 } else if let Some(config) = &raw_config {
221 let prompt = format_zeta_prompt(&prompt_input, config.format);
222 let prefill = get_prefill(&prompt_input, config.format);
223 let prompt = format!("{prompt}{prefill}");
224 let environment = config
225 .environment
226 .clone()
227 .or_else(|| Some(config.format.to_string().to_lowercase()));
228 let request = RawCompletionRequest {
229 model: config.model_id.clone().unwrap_or_default(),
230 prompt,
231 temperature: None,
232 stop: stop_tokens_for_format(config.format)
233 .iter()
234 .map(|token| std::borrow::Cow::Borrowed(*token))
235 .collect(),
236 max_tokens: Some(2048),
237 environment,
238 };
239
240 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
241 request,
242 client,
243 None,
244 llm_token,
245 organization_id,
246 app_version,
247 )
248 .await?;
249
250 let request_id = EditPredictionId(response.id.clone().into());
251 let output = if let Some(choice) = response.choices.pop() {
252 let response = &choice.text;
253 let output = format!("{prefill}{response}");
254 Some(parse_zeta2_model_output(
255 &output,
256 config.format,
257 &prompt_input,
258 )?)
259 } else {
260 None
261 };
262
263 (request_id, output, None, usage)
264 } else {
265 // Use V3 endpoint - server handles model/version selection and suffix stripping
266 let (response, usage) = EditPredictionStore::send_v3_request(
267 prompt_input.clone(),
268 client,
269 llm_token,
270 organization_id,
271 app_version,
272 trigger,
273 )
274 .await?;
275
276 let request_id = EditPredictionId(response.request_id.into());
277 let output_text = Some(response.output).filter(|s| !s.is_empty());
278 let model_version = response.model_version;
279 let parsed_output = ParsedOutput {
280 new_editable_region: output_text.unwrap_or_default(),
281 range_in_excerpt: response.editable_range,
282 };
283
284 (request_id, Some(parsed_output), model_version, usage)
285 };
286
287 let received_response_at = Instant::now();
288
289 log::trace!("Got edit prediction response");
290
291 let Some(ParsedOutput {
292 new_editable_region: mut output_text,
293 range_in_excerpt: editable_range_in_excerpt,
294 }) = output
295 else {
296 return Ok(((request_id, None), None));
297 };
298
299 let editable_range_in_buffer = editable_range_in_excerpt.start
300 + full_context_offset_range.start
301 ..editable_range_in_excerpt.end + full_context_offset_range.start;
302
303 let mut old_text = snapshot
304 .text_for_range(editable_range_in_buffer.clone())
305 .collect::<String>();
306
307 // Client-side cursor marker processing (applies to both raw and v3 responses)
308 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
309 if let Some(offset) = cursor_offset_in_output {
310 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
311 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
312 }
313
314 if let Some(debug_tx) = &debug_tx {
315 debug_tx
316 .unbounded_send(DebugEvent::EditPredictionFinished(
317 EditPredictionFinishedDebugEvent {
318 buffer: buffer.downgrade(),
319 position,
320 model_output: Some(output_text.clone()),
321 },
322 ))
323 .ok();
324 }
325
326 if !output_text.is_empty() && !output_text.ends_with('\n') {
327 output_text.push('\n');
328 }
329 if !old_text.is_empty() && !old_text.ends_with('\n') {
330 old_text.push('\n');
331 }
332
333 let (edits, cursor_position) = compute_edits_and_cursor_position(
334 old_text,
335 &output_text,
336 editable_range_in_buffer.start,
337 cursor_offset_in_output,
338 &snapshot,
339 );
340
341 anyhow::Ok((
342 (
343 request_id,
344 Some(Prediction {
345 prompt_input,
346 buffer,
347 snapshot: snapshot.clone(),
348 edits,
349 cursor_position,
350 received_response_at,
351 editable_range_in_buffer,
352 model_version,
353 }),
354 ),
355 usage,
356 ))
357 }
358 });
359
360 cx.spawn(async move |this, cx| {
361 let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
362
363 let Some(Prediction {
364 prompt_input: inputs,
365 buffer: edited_buffer,
366 snapshot: edited_buffer_snapshot,
367 edits,
368 cursor_position,
369 received_response_at,
370 editable_range_in_buffer,
371 model_version,
372 }) = prediction
373 else {
374 return Ok(Some(EditPredictionResult {
375 id,
376 prediction: Err(EditPredictionRejectReason::Empty),
377 }));
378 };
379
380 if can_collect_data {
381 let weak_this = this.clone();
382 let id = id.clone();
383 let edited_buffer = edited_buffer.clone();
384 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
385 let example_task = capture_data.and_then(|stored_events| {
386 cx.update(|cx| {
387 crate::capture_example(
388 project.clone(),
389 edited_buffer.clone(),
390 position,
391 stored_events,
392 false,
393 cx,
394 )
395 })
396 });
397 cx.spawn(async move |cx| {
398 let example_spec = if let Some(task) = example_task {
399 task.await.ok()
400 } else {
401 None
402 };
403
404 weak_this
405 .update(cx, |this, cx| {
406 this.enqueue_settled_prediction(
407 id.clone(),
408 &project,
409 &edited_buffer,
410 &edited_buffer_snapshot,
411 editable_range_in_buffer,
412 example_spec,
413 cx,
414 );
415 })
416 .ok();
417 })
418 .detach();
419 }
420
421 Ok(Some(
422 EditPredictionResult::new(
423 id,
424 &edited_buffer,
425 &edited_buffer_snapshot,
426 edits.into(),
427 cursor_position,
428 buffer_snapshotted_at,
429 received_response_at,
430 inputs,
431 model_version,
432 cx,
433 )
434 .await,
435 ))
436 })
437}
438
439fn handle_api_response<T>(
440 this: &WeakEntity<EditPredictionStore>,
441 response: Result<(T, Option<client::EditPredictionUsage>)>,
442 cx: &mut gpui::AsyncApp,
443) -> Result<T> {
444 match response {
445 Ok((data, usage)) => {
446 if let Some(usage) = usage {
447 this.update(cx, |this, cx| {
448 this.user_store.update(cx, |user_store, cx| {
449 user_store.update_edit_prediction_usage(usage, cx);
450 });
451 })
452 .ok();
453 }
454 Ok(data)
455 }
456 Err(err) => {
457 if err.is::<ZedUpdateRequiredError>() {
458 cx.update(|cx| {
459 this.update(cx, |this, _cx| {
460 this.update_required = true;
461 })
462 .ok();
463
464 let error_message: SharedString = err.to_string().into();
465 show_app_notification(
466 NotificationId::unique::<ZedUpdateRequiredError>(),
467 cx,
468 move |cx| {
469 cx.new(|cx| {
470 ErrorMessagePrompt::new(error_message.clone(), cx)
471 .with_link_button("Update Zed", "https://zed.dev/releases")
472 })
473 },
474 );
475 });
476 }
477 Err(err)
478 }
479 }
480}
481
482pub fn zeta2_prompt_input(
483 snapshot: &language::BufferSnapshot,
484 related_files: Vec<zeta_prompt::RelatedFile>,
485 events: Vec<Arc<zeta_prompt::Event>>,
486 excerpt_path: Arc<Path>,
487 cursor_offset: usize,
488 preferred_experiment: Option<String>,
489 is_open_source: bool,
490 can_collect_data: bool,
491 repo_url: Option<String>,
492) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
493 let cursor_point = cursor_offset.to_point(snapshot);
494
495 let (full_context, full_context_offset_range, excerpt_ranges) =
496 compute_excerpt_ranges(cursor_point, snapshot);
497
498 let full_context_start_offset = full_context_offset_range.start;
499 let full_context_start_row = full_context.start.row;
500
501 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
502
503 let prompt_input = zeta_prompt::ZetaPromptInput {
504 cursor_path: excerpt_path,
505 cursor_excerpt: snapshot
506 .text_for_range(full_context)
507 .collect::<String>()
508 .into(),
509 cursor_offset_in_excerpt,
510 excerpt_start_row: Some(full_context_start_row),
511 events,
512 related_files: Some(related_files),
513 excerpt_ranges,
514 experiment: preferred_experiment,
515 in_open_source_repo: is_open_source,
516 can_collect_data,
517 repo_url,
518 };
519 (full_context_offset_range, prompt_input)
520}
521
522pub(crate) fn edit_prediction_accepted(
523 store: &EditPredictionStore,
524 current_prediction: CurrentEditPrediction,
525 cx: &App,
526) {
527 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
528 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
529 return;
530 }
531
532 let request_id = current_prediction.prediction.id.to_string();
533 let model_version = current_prediction.prediction.model_version;
534 let require_auth = custom_accept_url.is_none();
535 let client = store.client.clone();
536 let llm_token = store.llm_token.clone();
537 let organization_id = store
538 .user_store
539 .read(cx)
540 .current_organization()
541 .map(|organization| organization.id.clone());
542 let app_version = AppVersion::global(cx);
543
544 cx.background_spawn(async move {
545 let url = if let Some(accept_edits_url) = custom_accept_url {
546 gpui::http_client::Url::parse(&accept_edits_url)?
547 } else {
548 client
549 .http_client()
550 .build_zed_llm_url("/predict_edits/accept", &[])?
551 };
552
553 let response = EditPredictionStore::send_api_request::<()>(
554 move |builder| {
555 let req = builder.uri(url.as_ref()).body(
556 serde_json::to_string(&AcceptEditPredictionBody {
557 request_id: request_id.clone(),
558 model_version: model_version.clone(),
559 })?
560 .into(),
561 );
562 Ok(req?)
563 },
564 client,
565 llm_token,
566 organization_id,
567 app_version,
568 require_auth,
569 )
570 .await;
571
572 response?;
573 anyhow::Ok(())
574 })
575 .detach_and_log_err(cx);
576}
577
578pub fn compute_edits(
579 old_text: String,
580 new_text: &str,
581 offset: usize,
582 snapshot: &BufferSnapshot,
583) -> Vec<(Range<Anchor>, Arc<str>)> {
584 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
585}
586
587pub fn compute_edits_and_cursor_position(
588 old_text: String,
589 new_text: &str,
590 offset: usize,
591 cursor_offset_in_new_text: Option<usize>,
592 snapshot: &BufferSnapshot,
593) -> (
594 Vec<(Range<Anchor>, Arc<str>)>,
595 Option<PredictedCursorPosition>,
596) {
597 let diffs = text_diff(&old_text, new_text);
598
599 // Delta represents the cumulative change in byte count from all preceding edits.
600 // new_offset = old_offset + delta, so old_offset = new_offset - delta
601 let mut delta: isize = 0;
602 let mut cursor_position: Option<PredictedCursorPosition> = None;
603 let buffer_len = snapshot.len();
604
605 let edits = diffs
606 .iter()
607 .map(|(raw_old_range, new_text)| {
608 // Compute cursor position if it falls within or before this edit.
609 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
610 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
611 let edit_end_in_new = edit_start_in_new + new_text.len();
612
613 if cursor_offset < edit_start_in_new {
614 let cursor_in_old = (cursor_offset as isize - delta) as usize;
615 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
616 cursor_position = Some(PredictedCursorPosition::at_anchor(
617 snapshot.anchor_after(buffer_offset),
618 ));
619 } else if cursor_offset < edit_end_in_new {
620 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
621 let offset_within_insertion = cursor_offset - edit_start_in_new;
622 cursor_position = Some(PredictedCursorPosition::new(
623 snapshot.anchor_before(buffer_offset),
624 offset_within_insertion,
625 ));
626 }
627
628 delta += new_text.len() as isize - raw_old_range.len() as isize;
629 }
630
631 // Compute the edit with prefix/suffix trimming.
632 let mut old_range = raw_old_range.clone();
633 let old_slice = &old_text[old_range.clone()];
634
635 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
636 let suffix_len = common_prefix(
637 old_slice[prefix_len..].chars().rev(),
638 new_text[prefix_len..].chars().rev(),
639 );
640
641 old_range.start += offset;
642 old_range.end += offset;
643 old_range.start += prefix_len;
644 old_range.end -= suffix_len;
645
646 old_range.start = old_range.start.min(buffer_len);
647 old_range.end = old_range.end.min(buffer_len);
648
649 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
650 let range = if old_range.is_empty() {
651 let anchor = snapshot.anchor_after(old_range.start);
652 anchor..anchor
653 } else {
654 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
655 };
656 (range, new_text)
657 })
658 .collect();
659
660 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
661 let cursor_in_old = (cursor_offset as isize - delta) as usize;
662 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
663 cursor_position = Some(PredictedCursorPosition::at_anchor(
664 snapshot.anchor_after(buffer_offset),
665 ));
666 }
667
668 (edits, cursor_position)
669}
670
671fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
672 a.zip(b)
673 .take_while(|(a, b)| a == b)
674 .map(|(a, _)| a.len_utf8())
675 .sum()
676}