1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
5 prediction::EditPredictionResult,
6};
7use anyhow::Result;
8use cloud_llm_client::{
9 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
10};
11use edit_prediction_types::PredictedCursorPosition;
12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
13use language::{
14 Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
15 text_diff,
16};
17use release_channel::AppVersion;
18use settings::EditPredictionPromptFormat;
19use text::{Anchor, Bias};
20use ui::SharedString;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::ZetaPromptInput;
23
24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
25use zeta_prompt::{
26 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
27 prompt_input_contains_special_tokens,
28 zeta1::{self, EDITABLE_REGION_END_MARKER},
29};
30
31use crate::open_ai_compatible::{
32 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
33};
34
35pub fn request_prediction_with_zeta(
36 store: &mut EditPredictionStore,
37 EditPredictionModelInput {
38 buffer,
39 snapshot,
40 position,
41 related_files,
42 events,
43 debug_tx,
44 trigger,
45 project,
46 can_collect_data,
47 is_open_source,
48 ..
49 }: EditPredictionModelInput,
50 capture_data: Option<Vec<StoredEvent>>,
51 cx: &mut Context<EditPredictionStore>,
52) -> Task<Result<Option<EditPredictionResult>>> {
53 let settings = &all_language_settings(None, cx).edit_predictions;
54 let provider = settings.provider;
55 let custom_server_settings = match provider {
56 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
57 settings::EditPredictionProvider::OpenAiCompatibleApi => {
58 settings.open_ai_compatible_api.clone()
59 }
60 _ => None,
61 };
62
63 let http_client = cx.http_client();
64 let buffer_snapshotted_at = Instant::now();
65 let raw_config = store.zeta2_raw_config().cloned();
66 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
67 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
68
69 let excerpt_path: Arc<Path> = snapshot
70 .file()
71 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
72 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
73
74 let repo_url = if can_collect_data {
75 let buffer_id = buffer.read(cx).remote_id();
76 project
77 .read(cx)
78 .git_store()
79 .read(cx)
80 .repository_and_path_for_buffer_id(buffer_id, cx)
81 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
82 } else {
83 None
84 };
85
86 let client = store.client.clone();
87 let llm_token = store.llm_token.clone();
88 let organization_id = store
89 .user_store
90 .read(cx)
91 .current_organization()
92 .map(|organization| organization.id.clone());
93 let app_version = AppVersion::global(cx);
94
95 struct Prediction {
96 prompt_input: ZetaPromptInput,
97 buffer: Entity<Buffer>,
98 snapshot: BufferSnapshot,
99 edits: Vec<(Range<Anchor>, Arc<str>)>,
100 cursor_position: Option<PredictedCursorPosition>,
101 received_response_at: Instant,
102 editable_range_in_buffer: Range<usize>,
103 model_version: Option<String>,
104 }
105
106 let request_task = cx.background_spawn({
107 async move {
108 let zeta_version = raw_config
109 .as_ref()
110 .map(|config| config.format)
111 .unwrap_or(ZetaFormat::default());
112
113 let cursor_offset = position.to_offset(&snapshot);
114 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115 &snapshot,
116 related_files,
117 events,
118 excerpt_path,
119 cursor_offset,
120 preferred_experiment,
121 is_open_source,
122 can_collect_data,
123 repo_url,
124 );
125
126 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127 return Err(anyhow::anyhow!("prompt contains special tokens"));
128 }
129
130 if let Some(debug_tx) = &debug_tx {
131 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132 debug_tx
133 .unbounded_send(DebugEvent::EditPredictionStarted(
134 EditPredictionStartedDebugEvent {
135 buffer: buffer.downgrade(),
136 prompt: Some(prompt),
137 position,
138 },
139 ))
140 .ok();
141 }
142
143 log::trace!("Sending edit prediction request");
144
145 let (request_id, output, model_version, usage) =
146 if let Some(custom_settings) = &custom_server_settings {
147 let max_tokens = custom_settings.max_output_tokens * 4;
148
149 match custom_settings.prompt_format {
150 EditPredictionPromptFormat::Zeta => {
151 let ranges = &prompt_input.excerpt_ranges;
152 let editable_range_in_excerpt = ranges.editable_350.clone();
153 let prompt = zeta1::format_zeta1_from_input(
154 &prompt_input,
155 editable_range_in_excerpt.clone(),
156 ranges.editable_350_context_150.clone(),
157 );
158 let stop_tokens = vec![
159 EDITABLE_REGION_END_MARKER.to_string(),
160 format!("{EDITABLE_REGION_END_MARKER}\n"),
161 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163 ];
164
165 let (response_text, request_id) = send_custom_server_request(
166 provider,
167 custom_settings,
168 prompt,
169 max_tokens,
170 stop_tokens,
171 open_ai_compatible_api_key.clone(),
172 &http_client,
173 )
174 .await?;
175
176 let request_id = EditPredictionId(request_id.into());
177 let output_text = zeta1::clean_zeta1_model_output(&response_text);
178
179 (
180 request_id,
181 Some(editable_range_in_excerpt).zip(output_text),
182 None,
183 None,
184 )
185 }
186 EditPredictionPromptFormat::Zeta2 => {
187 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
188 let prefill = get_prefill(&prompt_input, zeta_version);
189 let prompt = format!("{prompt}{prefill}");
190
191 let (response_text, request_id) = send_custom_server_request(
192 provider,
193 custom_settings,
194 prompt,
195 max_tokens,
196 vec![],
197 open_ai_compatible_api_key.clone(),
198 &http_client,
199 )
200 .await?;
201
202 let request_id = EditPredictionId(request_id.into());
203 let output_text = if response_text.is_empty() {
204 None
205 } else {
206 let output = format!("{prefill}{response_text}");
207 Some(parse_zeta2_model_output(
208 &output,
209 zeta_version,
210 &prompt_input,
211 )?)
212 };
213
214 (request_id, output_text, None, None)
215 }
216 _ => anyhow::bail!("unsupported prompt format"),
217 }
218 } else if let Some(config) = &raw_config {
219 let prompt = format_zeta_prompt(&prompt_input, config.format);
220 let prefill = get_prefill(&prompt_input, config.format);
221 let prompt = format!("{prompt}{prefill}");
222 let environment = config
223 .environment
224 .clone()
225 .or_else(|| Some(config.format.to_string().to_lowercase()));
226 let request = RawCompletionRequest {
227 model: config.model_id.clone().unwrap_or_default(),
228 prompt,
229 temperature: None,
230 stop: vec![],
231 max_tokens: Some(2048),
232 environment,
233 };
234
235 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
236 request,
237 client,
238 None,
239 llm_token,
240 organization_id,
241 app_version,
242 )
243 .await?;
244
245 let request_id = EditPredictionId(response.id.clone().into());
246 let output = if let Some(choice) = response.choices.pop() {
247 let response = &choice.text;
248 let output = format!("{prefill}{response}");
249 Some(parse_zeta2_model_output(
250 &output,
251 config.format,
252 &prompt_input,
253 )?)
254 } else {
255 None
256 };
257
258 (request_id, output, None, usage)
259 } else {
260 // Use V3 endpoint - server handles model/version selection and suffix stripping
261 let (response, usage) = EditPredictionStore::send_v3_request(
262 prompt_input.clone(),
263 client,
264 llm_token,
265 organization_id,
266 app_version,
267 trigger,
268 )
269 .await?;
270
271 let request_id = EditPredictionId(response.request_id.into());
272 let output_text = Some(response.output).filter(|s| !s.is_empty());
273 let model_version = response.model_version;
274
275 (
276 request_id,
277 Some(response.editable_range).zip(output_text),
278 model_version,
279 usage,
280 )
281 };
282
283 let received_response_at = Instant::now();
284
285 log::trace!("Got edit prediction response");
286
287 let Some((editable_range_in_excerpt, mut output_text)) = output else {
288 return Ok(((request_id, None), None));
289 };
290
291 let editable_range_in_buffer = editable_range_in_excerpt.start
292 + full_context_offset_range.start
293 ..editable_range_in_excerpt.end + full_context_offset_range.start;
294
295 let mut old_text = snapshot
296 .text_for_range(editable_range_in_buffer.clone())
297 .collect::<String>();
298
299 // Client-side cursor marker processing (applies to both raw and v3 responses)
300 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
301 if let Some(offset) = cursor_offset_in_output {
302 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
303 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
304 }
305
306 if let Some(debug_tx) = &debug_tx {
307 debug_tx
308 .unbounded_send(DebugEvent::EditPredictionFinished(
309 EditPredictionFinishedDebugEvent {
310 buffer: buffer.downgrade(),
311 position,
312 model_output: Some(output_text.clone()),
313 },
314 ))
315 .ok();
316 }
317
318 if !output_text.is_empty() && !output_text.ends_with('\n') {
319 output_text.push('\n');
320 }
321 if !old_text.is_empty() && !old_text.ends_with('\n') {
322 old_text.push('\n');
323 }
324
325 let (edits, cursor_position) = compute_edits_and_cursor_position(
326 old_text,
327 &output_text,
328 editable_range_in_buffer.start,
329 cursor_offset_in_output,
330 &snapshot,
331 );
332
333 anyhow::Ok((
334 (
335 request_id,
336 Some(Prediction {
337 prompt_input,
338 buffer,
339 snapshot: snapshot.clone(),
340 edits,
341 cursor_position,
342 received_response_at,
343 editable_range_in_buffer,
344 model_version,
345 }),
346 ),
347 usage,
348 ))
349 }
350 });
351
352 cx.spawn(async move |this, cx| {
353 let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
354
355 let Some(Prediction {
356 prompt_input: inputs,
357 buffer: edited_buffer,
358 snapshot: edited_buffer_snapshot,
359 edits,
360 cursor_position,
361 received_response_at,
362 editable_range_in_buffer,
363 model_version,
364 }) = prediction
365 else {
366 return Ok(Some(EditPredictionResult {
367 id,
368 prediction: Err(EditPredictionRejectReason::Empty),
369 }));
370 };
371
372 if can_collect_data {
373 let weak_this = this.clone();
374 let id = id.clone();
375 let edited_buffer = edited_buffer.clone();
376 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
377 let example_task = capture_data.and_then(|stored_events| {
378 cx.update(|cx| {
379 crate::capture_example(
380 project.clone(),
381 edited_buffer.clone(),
382 position,
383 stored_events,
384 false,
385 cx,
386 )
387 })
388 });
389 cx.spawn(async move |cx| {
390 let example_spec = if let Some(task) = example_task {
391 task.await.ok()
392 } else {
393 None
394 };
395
396 weak_this
397 .update(cx, |this, cx| {
398 this.enqueue_settled_prediction(
399 id.clone(),
400 &project,
401 &edited_buffer,
402 &edited_buffer_snapshot,
403 editable_range_in_buffer,
404 example_spec,
405 cx,
406 );
407 })
408 .ok();
409 })
410 .detach();
411 }
412
413 Ok(Some(
414 EditPredictionResult::new(
415 id,
416 &edited_buffer,
417 &edited_buffer_snapshot,
418 edits.into(),
419 cursor_position,
420 buffer_snapshotted_at,
421 received_response_at,
422 inputs,
423 model_version,
424 cx,
425 )
426 .await,
427 ))
428 })
429}
430
431fn handle_api_response<T>(
432 this: &WeakEntity<EditPredictionStore>,
433 response: Result<(T, Option<client::EditPredictionUsage>)>,
434 cx: &mut gpui::AsyncApp,
435) -> Result<T> {
436 match response {
437 Ok((data, usage)) => {
438 if let Some(usage) = usage {
439 this.update(cx, |this, cx| {
440 this.user_store.update(cx, |user_store, cx| {
441 user_store.update_edit_prediction_usage(usage, cx);
442 });
443 })
444 .ok();
445 }
446 Ok(data)
447 }
448 Err(err) => {
449 if err.is::<ZedUpdateRequiredError>() {
450 cx.update(|cx| {
451 this.update(cx, |this, _cx| {
452 this.update_required = true;
453 })
454 .ok();
455
456 let error_message: SharedString = err.to_string().into();
457 show_app_notification(
458 NotificationId::unique::<ZedUpdateRequiredError>(),
459 cx,
460 move |cx| {
461 cx.new(|cx| {
462 ErrorMessagePrompt::new(error_message.clone(), cx)
463 .with_link_button("Update Zed", "https://zed.dev/releases")
464 })
465 },
466 );
467 });
468 }
469 Err(err)
470 }
471 }
472}
473
474pub fn zeta2_prompt_input(
475 snapshot: &language::BufferSnapshot,
476 related_files: Vec<zeta_prompt::RelatedFile>,
477 events: Vec<Arc<zeta_prompt::Event>>,
478 excerpt_path: Arc<Path>,
479 cursor_offset: usize,
480 preferred_experiment: Option<String>,
481 is_open_source: bool,
482 can_collect_data: bool,
483 repo_url: Option<String>,
484) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
485 let cursor_point = cursor_offset.to_point(snapshot);
486
487 let (full_context, full_context_offset_range, excerpt_ranges) =
488 compute_excerpt_ranges(cursor_point, snapshot);
489
490 let full_context_start_offset = full_context_offset_range.start;
491 let full_context_start_row = full_context.start.row;
492
493 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
494
495 let prompt_input = zeta_prompt::ZetaPromptInput {
496 cursor_path: excerpt_path,
497 cursor_excerpt: snapshot
498 .text_for_range(full_context)
499 .collect::<String>()
500 .into(),
501 cursor_offset_in_excerpt,
502 excerpt_start_row: Some(full_context_start_row),
503 events,
504 related_files,
505 excerpt_ranges,
506 experiment: preferred_experiment,
507 in_open_source_repo: is_open_source,
508 can_collect_data,
509 repo_url,
510 };
511 (full_context_offset_range, prompt_input)
512}
513
514pub(crate) fn edit_prediction_accepted(
515 store: &EditPredictionStore,
516 current_prediction: CurrentEditPrediction,
517 cx: &App,
518) {
519 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
520 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
521 return;
522 }
523
524 let request_id = current_prediction.prediction.id.to_string();
525 let model_version = current_prediction.prediction.model_version;
526 let require_auth = custom_accept_url.is_none();
527 let client = store.client.clone();
528 let llm_token = store.llm_token.clone();
529 let organization_id = store
530 .user_store
531 .read(cx)
532 .current_organization()
533 .map(|organization| organization.id.clone());
534 let app_version = AppVersion::global(cx);
535
536 cx.background_spawn(async move {
537 let url = if let Some(accept_edits_url) = custom_accept_url {
538 gpui::http_client::Url::parse(&accept_edits_url)?
539 } else {
540 client
541 .http_client()
542 .build_zed_llm_url("/predict_edits/accept", &[])?
543 };
544
545 let response = EditPredictionStore::send_api_request::<()>(
546 move |builder| {
547 let req = builder.uri(url.as_ref()).body(
548 serde_json::to_string(&AcceptEditPredictionBody {
549 request_id: request_id.clone(),
550 model_version: model_version.clone(),
551 })?
552 .into(),
553 );
554 Ok(req?)
555 },
556 client,
557 llm_token,
558 organization_id,
559 app_version,
560 require_auth,
561 )
562 .await;
563
564 response?;
565 anyhow::Ok(())
566 })
567 .detach_and_log_err(cx);
568}
569
570pub fn compute_edits(
571 old_text: String,
572 new_text: &str,
573 offset: usize,
574 snapshot: &BufferSnapshot,
575) -> Vec<(Range<Anchor>, Arc<str>)> {
576 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
577}
578
579pub fn compute_edits_and_cursor_position(
580 old_text: String,
581 new_text: &str,
582 offset: usize,
583 cursor_offset_in_new_text: Option<usize>,
584 snapshot: &BufferSnapshot,
585) -> (
586 Vec<(Range<Anchor>, Arc<str>)>,
587 Option<PredictedCursorPosition>,
588) {
589 let diffs = text_diff(&old_text, new_text);
590
591 // Delta represents the cumulative change in byte count from all preceding edits.
592 // new_offset = old_offset + delta, so old_offset = new_offset - delta
593 let mut delta: isize = 0;
594 let mut cursor_position: Option<PredictedCursorPosition> = None;
595 let buffer_len = snapshot.len();
596
597 let edits = diffs
598 .iter()
599 .map(|(raw_old_range, new_text)| {
600 // Compute cursor position if it falls within or before this edit.
601 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
602 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
603 let edit_end_in_new = edit_start_in_new + new_text.len();
604
605 if cursor_offset < edit_start_in_new {
606 let cursor_in_old = (cursor_offset as isize - delta) as usize;
607 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
608 cursor_position = Some(PredictedCursorPosition::at_anchor(
609 snapshot.anchor_after(buffer_offset),
610 ));
611 } else if cursor_offset < edit_end_in_new {
612 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
613 let offset_within_insertion = cursor_offset - edit_start_in_new;
614 cursor_position = Some(PredictedCursorPosition::new(
615 snapshot.anchor_before(buffer_offset),
616 offset_within_insertion,
617 ));
618 }
619
620 delta += new_text.len() as isize - raw_old_range.len() as isize;
621 }
622
623 // Compute the edit with prefix/suffix trimming.
624 let mut old_range = raw_old_range.clone();
625 let old_slice = &old_text[old_range.clone()];
626
627 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
628 let suffix_len = common_prefix(
629 old_slice[prefix_len..].chars().rev(),
630 new_text[prefix_len..].chars().rev(),
631 );
632
633 old_range.start += offset;
634 old_range.end += offset;
635 old_range.start += prefix_len;
636 old_range.end -= suffix_len;
637
638 old_range.start = old_range.start.min(buffer_len);
639 old_range.end = old_range.end.min(buffer_len);
640
641 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
642 let range = if old_range.is_empty() {
643 let anchor = snapshot.anchor_after(old_range.start);
644 anchor..anchor
645 } else {
646 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
647 };
648 (range, new_text)
649 })
650 .collect();
651
652 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
653 let cursor_in_old = (cursor_offset as isize - delta) as usize;
654 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
655 cursor_position = Some(PredictedCursorPosition::at_anchor(
656 snapshot.anchor_after(buffer_offset),
657 ));
658 }
659
660 (edits, cursor_position)
661}
662
663fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
664 a.zip(b)
665 .take_while(|(a, b)| a == b)
666 .map(|(a, _)| a.len_utf8())
667 .sum()
668}