1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::Result;
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use edit_prediction_types::PredictedCursorPosition;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use text::Bias;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::{
23 Event, ZetaPromptInput,
24 zeta1::{
25 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
26 START_OF_FILE_MARKER,
27 },
28};
29
30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
33
34pub(crate) fn request_prediction_with_zeta1(
35 store: &mut EditPredictionStore,
36 EditPredictionModelInput {
37 project,
38 buffer,
39 snapshot,
40 position,
41 events,
42 trigger,
43 debug_tx,
44 ..
45 }: EditPredictionModelInput,
46 cx: &mut Context<EditPredictionStore>,
47) -> Task<Result<Option<EditPredictionResult>>> {
48 let buffer_snapshotted_at = Instant::now();
49 let client = store.client.clone();
50 let llm_token = store.llm_token.clone();
51 let app_version = AppVersion::global(cx);
52
53 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
54 let can_collect_file = store.can_collect_file(&project, file, cx);
55 let git_info = if can_collect_file {
56 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
57 } else {
58 None
59 };
60 (git_info, can_collect_file)
61 } else {
62 (None, false)
63 };
64
65 let full_path: Arc<Path> = snapshot
66 .file()
67 .map(|f| Arc::from(f.full_path(cx).as_path()))
68 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
69 let full_path_str = full_path.to_string_lossy().into_owned();
70 let cursor_point = position.to_point(&snapshot);
71 let prompt_for_events = {
72 let events = events.clone();
73 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
74 };
75 let gather_task = gather_context(
76 full_path_str,
77 &snapshot,
78 cursor_point,
79 prompt_for_events,
80 trigger,
81 cx,
82 );
83
84 let uri = match client
85 .http_client()
86 .build_zed_llm_url("/predict_edits/v2", &[])
87 {
88 Ok(url) => Arc::from(url),
89 Err(err) => return Task::ready(Err(err)),
90 };
91
92 cx.spawn(async move |this, cx| {
93 let GatherContextOutput {
94 mut body,
95 context_range,
96 editable_range,
97 included_events_count,
98 } = gather_task.await?;
99 let done_gathering_context_at = Instant::now();
100
101 let included_events = &events[events.len() - included_events_count..events.len()];
102 body.can_collect_data = can_collect_file
103 && this
104 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
105 .unwrap_or(false);
106 if body.can_collect_data {
107 body.git_info = git_info;
108 }
109
110 log::debug!(
111 "Events:\n{}\nExcerpt:\n{:?}",
112 body.input_events,
113 body.input_excerpt
114 );
115
116 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
117 |request| {
118 Ok(request
119 .uri(uri.as_str())
120 .body(serde_json::to_string(&body)?.into())?)
121 },
122 client,
123 llm_token,
124 app_version,
125 true,
126 )
127 .await;
128
129 let context_start_offset = context_range.start.to_offset(&snapshot);
130 let context_start_row = context_range.start.row;
131 let editable_offset_range = editable_range.to_offset(&snapshot);
132
133 let inputs = ZetaPromptInput {
134 events: included_events.into(),
135 related_files: vec![],
136 cursor_path: full_path,
137 cursor_excerpt: snapshot
138 .text_for_range(context_range)
139 .collect::<String>()
140 .into(),
141 editable_range_in_excerpt: (editable_range.start - context_start_offset)
142 ..(editable_offset_range.end - context_start_offset),
143 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
144 excerpt_start_row: Some(context_start_row),
145 };
146
147 if let Some(debug_tx) = &debug_tx {
148 debug_tx
149 .unbounded_send(DebugEvent::EditPredictionStarted(
150 EditPredictionStartedDebugEvent {
151 buffer: buffer.downgrade(),
152 prompt: Some(serde_json::to_string(&inputs).unwrap()),
153 position,
154 },
155 ))
156 .ok();
157 }
158
159 let (response, usage) = match response {
160 Ok(response) => response,
161 Err(err) => {
162 if err.is::<ZedUpdateRequiredError>() {
163 cx.update(|cx| {
164 this.update(cx, |ep_store, _cx| {
165 ep_store.update_required = true;
166 })
167 .ok();
168
169 let error_message: SharedString = err.to_string().into();
170 show_app_notification(
171 NotificationId::unique::<ZedUpdateRequiredError>(),
172 cx,
173 move |cx| {
174 cx.new(|cx| {
175 ErrorMessagePrompt::new(error_message.clone(), cx)
176 .with_link_button("Update Zed", "https://zed.dev/releases")
177 })
178 },
179 );
180 });
181 }
182
183 return Err(err);
184 }
185 };
186
187 let received_response_at = Instant::now();
188 log::debug!("completion response: {}", &response.output_excerpt);
189
190 if let Some(usage) = usage {
191 this.update(cx, |this, cx| {
192 this.user_store.update(cx, |user_store, cx| {
193 user_store.update_edit_prediction_usage(usage, cx);
194 });
195 })
196 .ok();
197 }
198
199 if let Some(debug_tx) = &debug_tx {
200 debug_tx
201 .unbounded_send(DebugEvent::EditPredictionFinished(
202 EditPredictionFinishedDebugEvent {
203 buffer: buffer.downgrade(),
204 model_output: Some(response.output_excerpt.clone()),
205 position,
206 },
207 ))
208 .ok();
209 }
210
211 let edit_prediction = process_completion_response(
212 response,
213 buffer,
214 &snapshot,
215 editable_range,
216 inputs,
217 buffer_snapshotted_at,
218 received_response_at,
219 cx,
220 )
221 .await;
222
223 let finished_at = Instant::now();
224
225 // record latency for ~1% of requests
226 if rand::random::<u8>() <= 2 {
227 telemetry::event!(
228 "Edit Prediction Request",
229 context_latency = done_gathering_context_at
230 .duration_since(buffer_snapshotted_at)
231 .as_millis(),
232 request_latency = received_response_at
233 .duration_since(done_gathering_context_at)
234 .as_millis(),
235 process_latency = finished_at.duration_since(received_response_at).as_millis()
236 );
237 }
238
239 edit_prediction.map(Some)
240 })
241}
242
243fn process_completion_response(
244 prediction_response: PredictEditsResponse,
245 buffer: Entity<Buffer>,
246 snapshot: &BufferSnapshot,
247 editable_range: Range<usize>,
248 inputs: ZetaPromptInput,
249 buffer_snapshotted_at: Instant,
250 received_response_at: Instant,
251 cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253 let snapshot = snapshot.clone();
254 let request_id = prediction_response.request_id;
255 let output_excerpt = prediction_response.output_excerpt;
256 cx.spawn(async move |cx| {
257 let output_excerpt: Arc<str> = output_excerpt.into();
258
259 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260 .background_spawn({
261 let output_excerpt = output_excerpt.clone();
262 let editable_range = editable_range.clone();
263 let snapshot = snapshot.clone();
264 async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
265 })
266 .await?
267 .into();
268
269 let id = EditPredictionId(request_id.into());
270 Ok(EditPredictionResult::new(
271 id,
272 &buffer,
273 &snapshot,
274 edits,
275 None,
276 buffer_snapshotted_at,
277 received_response_at,
278 inputs,
279 cx,
280 )
281 .await)
282 })
283}
284
285pub(crate) fn parse_edits(
286 output_excerpt: &str,
287 editable_range: Range<usize>,
288 snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290 let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292 let start_markers = content
293 .match_indices(EDITABLE_REGION_START_MARKER)
294 .collect::<Vec<_>>();
295 anyhow::ensure!(
296 start_markers.len() <= 1,
297 "expected at most one start marker, found {}",
298 start_markers.len()
299 );
300
301 let end_markers = content
302 .match_indices(EDITABLE_REGION_END_MARKER)
303 .collect::<Vec<_>>();
304 anyhow::ensure!(
305 end_markers.len() <= 1,
306 "expected at most one end marker, found {}",
307 end_markers.len()
308 );
309
310 let sof_markers = content
311 .match_indices(START_OF_FILE_MARKER)
312 .collect::<Vec<_>>();
313 anyhow::ensure!(
314 sof_markers.len() <= 1,
315 "expected at most one start-of-file marker, found {}",
316 sof_markers.len()
317 );
318
319 let content_start = start_markers
320 .first()
321 .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322 .unwrap_or(0);
323 let content_end = end_markers
324 .first()
325 .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326 .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328 // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
329 let new_text = &content[content_start.min(content_end)..content_end];
330
331 let old_text = snapshot
332 .text_for_range(editable_range.clone())
333 .collect::<String>();
334
335 Ok(compute_edits(
336 old_text,
337 new_text,
338 editable_range.start,
339 snapshot,
340 ))
341}
342
343pub fn compute_edits(
344 old_text: String,
345 new_text: &str,
346 offset: usize,
347 snapshot: &BufferSnapshot,
348) -> Vec<(Range<Anchor>, Arc<str>)> {
349 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
350}
351
352pub fn compute_edits_and_cursor_position(
353 old_text: String,
354 new_text: &str,
355 offset: usize,
356 cursor_offset_in_new_text: Option<usize>,
357 snapshot: &BufferSnapshot,
358) -> (
359 Vec<(Range<Anchor>, Arc<str>)>,
360 Option<PredictedCursorPosition>,
361) {
362 let diffs = text_diff(&old_text, new_text);
363
364 // Delta represents the cumulative change in byte count from all preceding edits.
365 // new_offset = old_offset + delta, so old_offset = new_offset - delta
366 let mut delta: isize = 0;
367 let mut cursor_position: Option<PredictedCursorPosition> = None;
368
369 let edits = diffs
370 .iter()
371 .map(|(raw_old_range, new_text)| {
372 // Compute cursor position if it falls within or before this edit.
373 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
374 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
375 let edit_end_in_new = edit_start_in_new + new_text.len();
376
377 if cursor_offset < edit_start_in_new {
378 let cursor_in_old = (cursor_offset as isize - delta) as usize;
379 cursor_position = Some(PredictedCursorPosition::at_anchor(
380 snapshot.anchor_after(offset + cursor_in_old),
381 ));
382 } else if cursor_offset < edit_end_in_new {
383 let offset_within_insertion = cursor_offset - edit_start_in_new;
384 cursor_position = Some(PredictedCursorPosition::new(
385 snapshot.anchor_before(offset + raw_old_range.start),
386 offset_within_insertion,
387 ));
388 }
389
390 delta += new_text.len() as isize - raw_old_range.len() as isize;
391 }
392
393 // Compute the edit with prefix/suffix trimming.
394 let mut old_range = raw_old_range.clone();
395 let old_slice = &old_text[old_range.clone()];
396
397 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
398 let suffix_len = common_prefix(
399 old_slice[prefix_len..].chars().rev(),
400 new_text[prefix_len..].chars().rev(),
401 );
402
403 old_range.start += offset;
404 old_range.end += offset;
405 old_range.start += prefix_len;
406 old_range.end -= suffix_len;
407
408 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
409 let range = if old_range.is_empty() {
410 let anchor = snapshot.anchor_after(old_range.start);
411 anchor..anchor
412 } else {
413 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
414 };
415 (range, new_text)
416 })
417 .collect();
418
419 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
420 let cursor_in_old = (cursor_offset as isize - delta) as usize;
421 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
422 cursor_position = Some(PredictedCursorPosition::at_anchor(
423 snapshot.anchor_after(buffer_offset),
424 ));
425 }
426
427 (edits, cursor_position)
428}
429
430fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
431 a.zip(b)
432 .take_while(|(a, b)| a == b)
433 .map(|(a, _)| a.len_utf8())
434 .sum()
435}
436
437fn git_info_for_file(
438 project: &Entity<Project>,
439 project_path: &ProjectPath,
440 cx: &App,
441) -> Option<PredictEditsGitInfo> {
442 let git_store = project.read(cx).git_store().read(cx);
443 if let Some((repository, _repo_path)) =
444 git_store.repository_and_path_for_project_path(project_path, cx)
445 {
446 let repository = repository.read(cx);
447 let head_sha = repository
448 .head_commit
449 .as_ref()
450 .map(|head_commit| head_commit.sha.to_string());
451 let remote_origin_url = repository.remote_origin_url.clone();
452 let remote_upstream_url = repository.remote_upstream_url.clone();
453 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
454 return None;
455 }
456 Some(PredictEditsGitInfo {
457 head_sha,
458 remote_origin_url,
459 remote_upstream_url,
460 })
461 } else {
462 None
463 }
464}
465
466pub struct GatherContextOutput {
467 pub body: PredictEditsBody,
468 pub context_range: Range<Point>,
469 pub editable_range: Range<usize>,
470 pub included_events_count: usize,
471}
472
473pub fn gather_context(
474 full_path_str: String,
475 snapshot: &BufferSnapshot,
476 cursor_point: language::Point,
477 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
478 trigger: PredictEditsRequestTrigger,
479 cx: &App,
480) -> Task<Result<GatherContextOutput>> {
481 cx.background_spawn({
482 let snapshot = snapshot.clone();
483 async move {
484 let input_excerpt = excerpt_for_cursor_position(
485 cursor_point,
486 &full_path_str,
487 &snapshot,
488 MAX_REWRITE_TOKENS,
489 MAX_CONTEXT_TOKENS,
490 );
491 let (input_events, included_events_count) = prompt_for_events();
492 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
493
494 let body = PredictEditsBody {
495 input_events,
496 input_excerpt: input_excerpt.prompt,
497 can_collect_data: false,
498 diagnostic_groups: None,
499 git_info: None,
500 outline: None,
501 speculated_output: None,
502 trigger,
503 };
504
505 Ok(GatherContextOutput {
506 body,
507 context_range: input_excerpt.context_range,
508 editable_range,
509 included_events_count,
510 })
511 }
512 })
513}
514
515pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
516 prompt_for_events_impl(events, max_tokens).0
517}
518
519fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
520 let mut result = String::new();
521 for (ix, event) in events.iter().rev().enumerate() {
522 let event_string = format_event(event.as_ref());
523 let event_tokens = guess_token_count(event_string.len());
524 if event_tokens > remaining_tokens {
525 return (result, ix);
526 }
527
528 if !result.is_empty() {
529 result.insert_str(0, "\n\n");
530 }
531 result.insert_str(0, &event_string);
532 remaining_tokens -= event_tokens;
533 }
534 return (result, events.len());
535}
536
537pub fn format_event(event: &Event) -> String {
538 match event {
539 Event::BufferChange {
540 path,
541 old_path,
542 diff,
543 ..
544 } => {
545 let mut prompt = String::new();
546
547 if old_path != path {
548 writeln!(
549 prompt,
550 "User renamed {} to {}\n",
551 old_path.display(),
552 path.display()
553 )
554 .unwrap();
555 }
556
557 if !diff.is_empty() {
558 write!(
559 prompt,
560 "User edited {}:\n```diff\n{}\n```",
561 path.display(),
562 diff
563 )
564 .unwrap();
565 }
566
567 prompt
568 }
569 }
570}
571
572#[derive(Debug)]
573pub struct InputExcerpt {
574 pub context_range: Range<Point>,
575 pub editable_range: Range<Point>,
576 pub prompt: String,
577}
578
579pub fn excerpt_for_cursor_position(
580 position: Point,
581 path: &str,
582 snapshot: &BufferSnapshot,
583 editable_region_token_limit: usize,
584 context_token_limit: usize,
585) -> InputExcerpt {
586 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
587 position,
588 snapshot,
589 editable_region_token_limit,
590 context_token_limit,
591 );
592
593 let mut prompt = String::new();
594
595 writeln!(&mut prompt, "```{path}").unwrap();
596 if context_range.start == Point::zero() {
597 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
598 }
599
600 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
601 prompt.push_str(chunk.text);
602 }
603
604 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
605
606 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
607 prompt.push_str(chunk.text);
608 }
609 write!(prompt, "\n```").unwrap();
610
611 InputExcerpt {
612 context_range,
613 editable_range,
614 prompt,
615 }
616}
617
618fn push_editable_range(
619 cursor_position: Point,
620 snapshot: &BufferSnapshot,
621 editable_range: Range<Point>,
622 prompt: &mut String,
623) {
624 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
625 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
626 prompt.push_str(chunk.text);
627 }
628 prompt.push_str(CURSOR_MARKER);
629 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
630 prompt.push_str(chunk.text);
631 }
632 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use gpui::{App, AppContext};
639 use indoc::indoc;
640 use language::Buffer;
641
642 #[gpui::test]
643 fn test_excerpt_for_cursor_position(cx: &mut App) {
644 let text = indoc! {r#"
645 fn foo() {
646 let x = 42;
647 println!("Hello, world!");
648 }
649
650 fn bar() {
651 let x = 42;
652 let mut sum = 0;
653 for i in 0..x {
654 sum += i;
655 }
656 println!("Sum: {}", sum);
657 return sum;
658 }
659
660 fn generate_random_numbers() -> Vec<i32> {
661 let mut rng = rand::thread_rng();
662 let mut numbers = Vec::new();
663 for _ in 0..5 {
664 numbers.push(rng.random_range(1..101));
665 }
666 numbers
667 }
668 "#};
669 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
670 let snapshot = buffer.read(cx).snapshot();
671
672 // The excerpt expands to syntax boundaries.
673 // With 50 token editable limit, we get a region that expands to syntax nodes.
674 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
675 assert_eq!(
676 excerpt.prompt,
677 indoc! {r#"
678 ```main.rs
679
680 fn bar() {
681 let x = 42;
682 <|editable_region_start|>
683 let mut sum = 0;
684 for i in 0..x {
685 sum += i;
686 }
687 println!("Sum: {}", sum);
688 r<|user_cursor_is_here|>eturn sum;
689 }
690
691 fn generate_random_numbers() -> Vec<i32> {
692 <|editable_region_end|>
693 let mut rng = rand::thread_rng();
694 let mut numbers = Vec::new();
695 ```"#}
696 );
697
698 // With smaller budget, the region expands to syntax boundaries but is tighter.
699 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
700 assert_eq!(
701 excerpt.prompt,
702 indoc! {r#"
703 ```main.rs
704 fn bar() {
705 let x = 42;
706 let mut sum = 0;
707 for i in 0..x {
708 <|editable_region_start|>
709 sum += i;
710 }
711 println!("Sum: {}", sum);
712 r<|user_cursor_is_here|>eturn sum;
713 }
714
715 fn generate_random_numbers() -> Vec<i32> {
716 <|editable_region_end|>
717 let mut rng = rand::thread_rng();
718 ```"#}
719 );
720 }
721
722 #[gpui::test]
723 fn test_parse_edits_empty_editable_region(cx: &mut App) {
724 let text = "fn foo() {\n let x = 42;\n}\n";
725 let buffer = cx.new(|cx| Buffer::local(text, cx));
726 let snapshot = buffer.read(cx).snapshot();
727
728 let output = "<|editable_region_start|>\n<|editable_region_end|>";
729 let editable_range = 0..text.len();
730 let edits = parse_edits(output, editable_range, &snapshot).unwrap();
731 assert_eq!(edits.len(), 1);
732 let (range, new_text) = &edits[0];
733 assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
734 assert_eq!(new_text.as_ref(), "");
735 }
736}