1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
3 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
4 ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
5 prediction::EditPredictionResult,
6};
7use anyhow::Result;
8use cloud_llm_client::{
9 AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
10};
11use edit_prediction_types::PredictedCursorPosition;
12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
13use language::{
14 Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
15 text_diff,
16};
17use release_channel::AppVersion;
18use settings::EditPredictionPromptFormat;
19use text::{Anchor, Bias};
20use ui::SharedString;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::{ParsedOutput, ZetaPromptInput};
23
24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
25use zeta_prompt::{
26 CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
27 prompt_input_contains_special_tokens,
28 zeta1::{self, EDITABLE_REGION_END_MARKER},
29};
30
31use crate::open_ai_compatible::{
32 load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
33};
34
35pub fn request_prediction_with_zeta(
36 store: &mut EditPredictionStore,
37 EditPredictionModelInput {
38 buffer,
39 snapshot,
40 position,
41 related_files,
42 events,
43 debug_tx,
44 trigger,
45 project,
46 can_collect_data,
47 is_open_source,
48 ..
49 }: EditPredictionModelInput,
50 capture_data: Option<Vec<StoredEvent>>,
51 cx: &mut Context<EditPredictionStore>,
52) -> Task<Result<Option<EditPredictionResult>>> {
53 let settings = &all_language_settings(None, cx).edit_predictions;
54 let provider = settings.provider;
55 let custom_server_settings = match provider {
56 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
57 settings::EditPredictionProvider::OpenAiCompatibleApi => {
58 settings.open_ai_compatible_api.clone()
59 }
60 _ => None,
61 };
62
63 let http_client = cx.http_client();
64 let buffer_snapshotted_at = Instant::now();
65 let raw_config = store.zeta2_raw_config().cloned();
66 let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
67 let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
68
69 let excerpt_path: Arc<Path> = snapshot
70 .file()
71 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
72 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
73
74 let repo_url = if can_collect_data {
75 let buffer_id = buffer.read(cx).remote_id();
76 project
77 .read(cx)
78 .git_store()
79 .read(cx)
80 .repository_and_path_for_buffer_id(buffer_id, cx)
81 .and_then(|(repo, _)| repo.read(cx).default_remote_url())
82 } else {
83 None
84 };
85
86 let client = store.client.clone();
87 let llm_token = store.llm_token.clone();
88 let organization_id = store
89 .user_store
90 .read(cx)
91 .current_organization()
92 .map(|organization| organization.id.clone());
93 let app_version = AppVersion::global(cx);
94
95 struct Prediction {
96 prompt_input: ZetaPromptInput,
97 buffer: Entity<Buffer>,
98 snapshot: BufferSnapshot,
99 edits: Vec<(Range<Anchor>, Arc<str>)>,
100 cursor_position: Option<PredictedCursorPosition>,
101 received_response_at: Instant,
102 editable_range_in_buffer: Range<usize>,
103 model_version: Option<String>,
104 }
105
106 let request_task = cx.background_spawn({
107 async move {
108 let zeta_version = raw_config
109 .as_ref()
110 .map(|config| config.format)
111 .unwrap_or(ZetaFormat::default());
112
113 let cursor_offset = position.to_offset(&snapshot);
114 let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115 &snapshot,
116 related_files,
117 events,
118 excerpt_path,
119 cursor_offset,
120 preferred_experiment,
121 is_open_source,
122 can_collect_data,
123 repo_url,
124 );
125
126 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127 return Err(anyhow::anyhow!("prompt contains special tokens"));
128 }
129
130 if let Some(debug_tx) = &debug_tx {
131 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132 debug_tx
133 .unbounded_send(DebugEvent::EditPredictionStarted(
134 EditPredictionStartedDebugEvent {
135 buffer: buffer.downgrade(),
136 prompt: Some(prompt),
137 position,
138 },
139 ))
140 .ok();
141 }
142
143 log::trace!("Sending edit prediction request");
144
145 let (request_id, output, model_version, usage) =
146 if let Some(custom_settings) = &custom_server_settings {
147 let max_tokens = custom_settings.max_output_tokens * 4;
148
149 match custom_settings.prompt_format {
150 EditPredictionPromptFormat::Zeta => {
151 let ranges = &prompt_input.excerpt_ranges;
152 let editable_range_in_excerpt = ranges.editable_350.clone();
153 let prompt = zeta1::format_zeta1_from_input(
154 &prompt_input,
155 editable_range_in_excerpt.clone(),
156 ranges.editable_350_context_150.clone(),
157 );
158 let stop_tokens = vec![
159 EDITABLE_REGION_END_MARKER.to_string(),
160 format!("{EDITABLE_REGION_END_MARKER}\n"),
161 format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162 format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163 ];
164
165 let (response_text, request_id) = send_custom_server_request(
166 provider,
167 custom_settings,
168 prompt,
169 max_tokens,
170 stop_tokens,
171 open_ai_compatible_api_key.clone(),
172 &http_client,
173 )
174 .await?;
175
176 let request_id = EditPredictionId(request_id.into());
177 let output_text = zeta1::clean_zeta1_model_output(&response_text);
178 let parsed_output = output_text.map(|text| ParsedOutput {
179 new_editable_region: text,
180 range_in_excerpt: editable_range_in_excerpt,
181 });
182
183 (request_id, parsed_output, None, None)
184 }
185 EditPredictionPromptFormat::Zeta2 => {
186 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
187 let prefill = get_prefill(&prompt_input, zeta_version);
188 let prompt = format!("{prompt}{prefill}");
189
190 let (response_text, request_id) = send_custom_server_request(
191 provider,
192 custom_settings,
193 prompt,
194 max_tokens,
195 vec![],
196 open_ai_compatible_api_key.clone(),
197 &http_client,
198 )
199 .await?;
200
201 let request_id = EditPredictionId(request_id.into());
202 let output_text = if response_text.is_empty() {
203 None
204 } else {
205 let output = format!("{prefill}{response_text}");
206 Some(parse_zeta2_model_output(
207 &output,
208 zeta_version,
209 &prompt_input,
210 )?)
211 };
212
213 (request_id, output_text, None, None)
214 }
215 _ => anyhow::bail!("unsupported prompt format"),
216 }
217 } else if let Some(config) = &raw_config {
218 let prompt = format_zeta_prompt(&prompt_input, config.format);
219 let prefill = get_prefill(&prompt_input, config.format);
220 let prompt = format!("{prompt}{prefill}");
221 let environment = config
222 .environment
223 .clone()
224 .or_else(|| Some(config.format.to_string().to_lowercase()));
225 let request = RawCompletionRequest {
226 model: config.model_id.clone().unwrap_or_default(),
227 prompt,
228 temperature: None,
229 stop: vec![],
230 max_tokens: Some(2048),
231 environment,
232 };
233
234 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
235 request,
236 client,
237 None,
238 llm_token,
239 organization_id,
240 app_version,
241 )
242 .await?;
243
244 let request_id = EditPredictionId(response.id.clone().into());
245 let output = if let Some(choice) = response.choices.pop() {
246 let response = &choice.text;
247 let output = format!("{prefill}{response}");
248 Some(parse_zeta2_model_output(
249 &output,
250 config.format,
251 &prompt_input,
252 )?)
253 } else {
254 None
255 };
256
257 (request_id, output, None, usage)
258 } else {
259 // Use V3 endpoint - server handles model/version selection and suffix stripping
260 let (response, usage) = EditPredictionStore::send_v3_request(
261 prompt_input.clone(),
262 client,
263 llm_token,
264 organization_id,
265 app_version,
266 trigger,
267 )
268 .await?;
269
270 let request_id = EditPredictionId(response.request_id.into());
271 let output_text = Some(response.output).filter(|s| !s.is_empty());
272 let model_version = response.model_version;
273 let parsed_output = ParsedOutput {
274 new_editable_region: output_text.unwrap_or_default(),
275 range_in_excerpt: response.editable_range,
276 };
277
278 (request_id, Some(parsed_output), model_version, usage)
279 };
280
281 let received_response_at = Instant::now();
282
283 log::trace!("Got edit prediction response");
284
285 let Some(ParsedOutput {
286 new_editable_region: mut output_text,
287 range_in_excerpt: editable_range_in_excerpt,
288 }) = output
289 else {
290 return Ok(((request_id, None), None));
291 };
292
293 let editable_range_in_buffer = editable_range_in_excerpt.start
294 + full_context_offset_range.start
295 ..editable_range_in_excerpt.end + full_context_offset_range.start;
296
297 let mut old_text = snapshot
298 .text_for_range(editable_range_in_buffer.clone())
299 .collect::<String>();
300
301 // Client-side cursor marker processing (applies to both raw and v3 responses)
302 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
303 if let Some(offset) = cursor_offset_in_output {
304 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
305 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
306 }
307
308 if let Some(debug_tx) = &debug_tx {
309 debug_tx
310 .unbounded_send(DebugEvent::EditPredictionFinished(
311 EditPredictionFinishedDebugEvent {
312 buffer: buffer.downgrade(),
313 position,
314 model_output: Some(output_text.clone()),
315 },
316 ))
317 .ok();
318 }
319
320 if !output_text.is_empty() && !output_text.ends_with('\n') {
321 output_text.push('\n');
322 }
323 if !old_text.is_empty() && !old_text.ends_with('\n') {
324 old_text.push('\n');
325 }
326
327 let (edits, cursor_position) = compute_edits_and_cursor_position(
328 old_text,
329 &output_text,
330 editable_range_in_buffer.start,
331 cursor_offset_in_output,
332 &snapshot,
333 );
334
335 anyhow::Ok((
336 (
337 request_id,
338 Some(Prediction {
339 prompt_input,
340 buffer,
341 snapshot: snapshot.clone(),
342 edits,
343 cursor_position,
344 received_response_at,
345 editable_range_in_buffer,
346 model_version,
347 }),
348 ),
349 usage,
350 ))
351 }
352 });
353
354 cx.spawn(async move |this, cx| {
355 let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
356
357 let Some(Prediction {
358 prompt_input: inputs,
359 buffer: edited_buffer,
360 snapshot: edited_buffer_snapshot,
361 edits,
362 cursor_position,
363 received_response_at,
364 editable_range_in_buffer,
365 model_version,
366 }) = prediction
367 else {
368 return Ok(Some(EditPredictionResult {
369 id,
370 prediction: Err(EditPredictionRejectReason::Empty),
371 }));
372 };
373
374 if can_collect_data {
375 let weak_this = this.clone();
376 let id = id.clone();
377 let edited_buffer = edited_buffer.clone();
378 let edited_buffer_snapshot = edited_buffer_snapshot.clone();
379 let example_task = capture_data.and_then(|stored_events| {
380 cx.update(|cx| {
381 crate::capture_example(
382 project.clone(),
383 edited_buffer.clone(),
384 position,
385 stored_events,
386 false,
387 cx,
388 )
389 })
390 });
391 cx.spawn(async move |cx| {
392 let example_spec = if let Some(task) = example_task {
393 task.await.ok()
394 } else {
395 None
396 };
397
398 weak_this
399 .update(cx, |this, cx| {
400 this.enqueue_settled_prediction(
401 id.clone(),
402 &project,
403 &edited_buffer,
404 &edited_buffer_snapshot,
405 editable_range_in_buffer,
406 example_spec,
407 cx,
408 );
409 })
410 .ok();
411 })
412 .detach();
413 }
414
415 Ok(Some(
416 EditPredictionResult::new(
417 id,
418 &edited_buffer,
419 &edited_buffer_snapshot,
420 edits.into(),
421 cursor_position,
422 buffer_snapshotted_at,
423 received_response_at,
424 inputs,
425 model_version,
426 cx,
427 )
428 .await,
429 ))
430 })
431}
432
433fn handle_api_response<T>(
434 this: &WeakEntity<EditPredictionStore>,
435 response: Result<(T, Option<client::EditPredictionUsage>)>,
436 cx: &mut gpui::AsyncApp,
437) -> Result<T> {
438 match response {
439 Ok((data, usage)) => {
440 if let Some(usage) = usage {
441 this.update(cx, |this, cx| {
442 this.user_store.update(cx, |user_store, cx| {
443 user_store.update_edit_prediction_usage(usage, cx);
444 });
445 })
446 .ok();
447 }
448 Ok(data)
449 }
450 Err(err) => {
451 if err.is::<ZedUpdateRequiredError>() {
452 cx.update(|cx| {
453 this.update(cx, |this, _cx| {
454 this.update_required = true;
455 })
456 .ok();
457
458 let error_message: SharedString = err.to_string().into();
459 show_app_notification(
460 NotificationId::unique::<ZedUpdateRequiredError>(),
461 cx,
462 move |cx| {
463 cx.new(|cx| {
464 ErrorMessagePrompt::new(error_message.clone(), cx)
465 .with_link_button("Update Zed", "https://zed.dev/releases")
466 })
467 },
468 );
469 });
470 }
471 Err(err)
472 }
473 }
474}
475
476pub fn zeta2_prompt_input(
477 snapshot: &language::BufferSnapshot,
478 related_files: Vec<zeta_prompt::RelatedFile>,
479 events: Vec<Arc<zeta_prompt::Event>>,
480 excerpt_path: Arc<Path>,
481 cursor_offset: usize,
482 preferred_experiment: Option<String>,
483 is_open_source: bool,
484 can_collect_data: bool,
485 repo_url: Option<String>,
486) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
487 let cursor_point = cursor_offset.to_point(snapshot);
488
489 let (full_context, full_context_offset_range, excerpt_ranges) =
490 compute_excerpt_ranges(cursor_point, snapshot);
491
492 let full_context_start_offset = full_context_offset_range.start;
493 let full_context_start_row = full_context.start.row;
494
495 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
496
497 let prompt_input = zeta_prompt::ZetaPromptInput {
498 cursor_path: excerpt_path,
499 cursor_excerpt: snapshot
500 .text_for_range(full_context)
501 .collect::<String>()
502 .into(),
503 cursor_offset_in_excerpt,
504 excerpt_start_row: Some(full_context_start_row),
505 events,
506 related_files,
507 excerpt_ranges,
508 experiment: preferred_experiment,
509 in_open_source_repo: is_open_source,
510 can_collect_data,
511 repo_url,
512 };
513 (full_context_offset_range, prompt_input)
514}
515
516pub(crate) fn edit_prediction_accepted(
517 store: &EditPredictionStore,
518 current_prediction: CurrentEditPrediction,
519 cx: &App,
520) {
521 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
522 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
523 return;
524 }
525
526 let request_id = current_prediction.prediction.id.to_string();
527 let model_version = current_prediction.prediction.model_version;
528 let require_auth = custom_accept_url.is_none();
529 let client = store.client.clone();
530 let llm_token = store.llm_token.clone();
531 let organization_id = store
532 .user_store
533 .read(cx)
534 .current_organization()
535 .map(|organization| organization.id.clone());
536 let app_version = AppVersion::global(cx);
537
538 cx.background_spawn(async move {
539 let url = if let Some(accept_edits_url) = custom_accept_url {
540 gpui::http_client::Url::parse(&accept_edits_url)?
541 } else {
542 client
543 .http_client()
544 .build_zed_llm_url("/predict_edits/accept", &[])?
545 };
546
547 let response = EditPredictionStore::send_api_request::<()>(
548 move |builder| {
549 let req = builder.uri(url.as_ref()).body(
550 serde_json::to_string(&AcceptEditPredictionBody {
551 request_id: request_id.clone(),
552 model_version: model_version.clone(),
553 })?
554 .into(),
555 );
556 Ok(req?)
557 },
558 client,
559 llm_token,
560 organization_id,
561 app_version,
562 require_auth,
563 )
564 .await;
565
566 response?;
567 anyhow::Ok(())
568 })
569 .detach_and_log_err(cx);
570}
571
572pub fn compute_edits(
573 old_text: String,
574 new_text: &str,
575 offset: usize,
576 snapshot: &BufferSnapshot,
577) -> Vec<(Range<Anchor>, Arc<str>)> {
578 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
579}
580
581pub fn compute_edits_and_cursor_position(
582 old_text: String,
583 new_text: &str,
584 offset: usize,
585 cursor_offset_in_new_text: Option<usize>,
586 snapshot: &BufferSnapshot,
587) -> (
588 Vec<(Range<Anchor>, Arc<str>)>,
589 Option<PredictedCursorPosition>,
590) {
591 let diffs = text_diff(&old_text, new_text);
592
593 // Delta represents the cumulative change in byte count from all preceding edits.
594 // new_offset = old_offset + delta, so old_offset = new_offset - delta
595 let mut delta: isize = 0;
596 let mut cursor_position: Option<PredictedCursorPosition> = None;
597 let buffer_len = snapshot.len();
598
599 let edits = diffs
600 .iter()
601 .map(|(raw_old_range, new_text)| {
602 // Compute cursor position if it falls within or before this edit.
603 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
604 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
605 let edit_end_in_new = edit_start_in_new + new_text.len();
606
607 if cursor_offset < edit_start_in_new {
608 let cursor_in_old = (cursor_offset as isize - delta) as usize;
609 let buffer_offset = (offset + cursor_in_old).min(buffer_len);
610 cursor_position = Some(PredictedCursorPosition::at_anchor(
611 snapshot.anchor_after(buffer_offset),
612 ));
613 } else if cursor_offset < edit_end_in_new {
614 let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
615 let offset_within_insertion = cursor_offset - edit_start_in_new;
616 cursor_position = Some(PredictedCursorPosition::new(
617 snapshot.anchor_before(buffer_offset),
618 offset_within_insertion,
619 ));
620 }
621
622 delta += new_text.len() as isize - raw_old_range.len() as isize;
623 }
624
625 // Compute the edit with prefix/suffix trimming.
626 let mut old_range = raw_old_range.clone();
627 let old_slice = &old_text[old_range.clone()];
628
629 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
630 let suffix_len = common_prefix(
631 old_slice[prefix_len..].chars().rev(),
632 new_text[prefix_len..].chars().rev(),
633 );
634
635 old_range.start += offset;
636 old_range.end += offset;
637 old_range.start += prefix_len;
638 old_range.end -= suffix_len;
639
640 old_range.start = old_range.start.min(buffer_len);
641 old_range.end = old_range.end.min(buffer_len);
642
643 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
644 let range = if old_range.is_empty() {
645 let anchor = snapshot.anchor_after(old_range.start);
646 anchor..anchor
647 } else {
648 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
649 };
650 (range, new_text)
651 })
652 .collect();
653
654 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
655 let cursor_in_old = (cursor_offset as isize - delta) as usize;
656 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
657 cursor_position = Some(PredictedCursorPosition::at_anchor(
658 snapshot.anchor_after(buffer_offset),
659 ));
660 }
661
662 (edits, cursor_position)
663}
664
665fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
666 a.zip(b)
667 .take_while(|(a, b)| a == b)
668 .map(|(a, _)| a.len_utf8())
669 .sum()
670}