1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::Result;
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use edit_prediction_types::PredictedCursorPosition;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use text::Bias;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::{
23 Event, ZetaPromptInput,
24 zeta1::{
25 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
26 START_OF_FILE_MARKER,
27 },
28};
29
30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
33
34pub(crate) fn request_prediction_with_zeta1(
35 store: &mut EditPredictionStore,
36 EditPredictionModelInput {
37 project,
38 buffer,
39 snapshot,
40 position,
41 events,
42 trigger,
43 debug_tx,
44 ..
45 }: EditPredictionModelInput,
46 cx: &mut Context<EditPredictionStore>,
47) -> Task<Result<Option<EditPredictionResult>>> {
48 let buffer_snapshotted_at = Instant::now();
49 let client = store.client.clone();
50 let llm_token = store.llm_token.clone();
51 let app_version = AppVersion::global(cx);
52
53 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
54 let can_collect_file = store.can_collect_file(&project, file, cx);
55 let git_info = if can_collect_file {
56 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
57 } else {
58 None
59 };
60 (git_info, can_collect_file)
61 } else {
62 (None, false)
63 };
64
65 let full_path: Arc<Path> = snapshot
66 .file()
67 .map(|f| Arc::from(f.full_path(cx).as_path()))
68 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
69 let full_path_str = full_path.to_string_lossy().into_owned();
70 let cursor_point = position.to_point(&snapshot);
71 let prompt_for_events = {
72 let events = events.clone();
73 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
74 };
75 let gather_task = gather_context(
76 full_path_str,
77 &snapshot,
78 cursor_point,
79 prompt_for_events,
80 trigger,
81 cx,
82 );
83
84 let (uri, require_auth) = match &store.custom_predict_edits_url {
85 Some(custom_url) => (custom_url.clone(), false),
86 None => {
87 match client
88 .http_client()
89 .build_zed_llm_url("/predict_edits/v2", &[])
90 {
91 Ok(url) => (url.into(), true),
92 Err(err) => return Task::ready(Err(err)),
93 }
94 }
95 };
96
97 cx.spawn(async move |this, cx| {
98 let GatherContextOutput {
99 mut body,
100 context_range,
101 editable_range,
102 included_events_count,
103 } = gather_task.await?;
104 let done_gathering_context_at = Instant::now();
105
106 let included_events = &events[events.len() - included_events_count..events.len()];
107 body.can_collect_data = can_collect_file
108 && this
109 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
110 .unwrap_or(false);
111 if body.can_collect_data {
112 body.git_info = git_info;
113 }
114
115 log::debug!(
116 "Events:\n{}\nExcerpt:\n{:?}",
117 body.input_events,
118 body.input_excerpt
119 );
120
121 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
122 |request| {
123 Ok(request
124 .uri(uri.as_str())
125 .body(serde_json::to_string(&body)?.into())?)
126 },
127 client,
128 llm_token,
129 app_version,
130 require_auth,
131 )
132 .await;
133
134 let context_start_offset = context_range.start.to_offset(&snapshot);
135 let context_start_row = context_range.start.row;
136 let editable_offset_range = editable_range.to_offset(&snapshot);
137
138 let inputs = ZetaPromptInput {
139 events: included_events.into(),
140 related_files: vec![],
141 cursor_path: full_path,
142 cursor_excerpt: snapshot
143 .text_for_range(context_range)
144 .collect::<String>()
145 .into(),
146 editable_range_in_excerpt: (editable_range.start - context_start_offset)
147 ..(editable_offset_range.end - context_start_offset),
148 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
149 excerpt_start_row: Some(context_start_row),
150 };
151
152 if let Some(debug_tx) = &debug_tx {
153 debug_tx
154 .unbounded_send(DebugEvent::EditPredictionStarted(
155 EditPredictionStartedDebugEvent {
156 buffer: buffer.downgrade(),
157 prompt: Some(serde_json::to_string(&inputs).unwrap()),
158 position,
159 },
160 ))
161 .ok();
162 }
163
164 let (response, usage) = match response {
165 Ok(response) => response,
166 Err(err) => {
167 if err.is::<ZedUpdateRequiredError>() {
168 cx.update(|cx| {
169 this.update(cx, |ep_store, _cx| {
170 ep_store.update_required = true;
171 })
172 .ok();
173
174 let error_message: SharedString = err.to_string().into();
175 show_app_notification(
176 NotificationId::unique::<ZedUpdateRequiredError>(),
177 cx,
178 move |cx| {
179 cx.new(|cx| {
180 ErrorMessagePrompt::new(error_message.clone(), cx)
181 .with_link_button("Update Zed", "https://zed.dev/releases")
182 })
183 },
184 );
185 });
186 }
187
188 return Err(err);
189 }
190 };
191
192 let received_response_at = Instant::now();
193 log::debug!("completion response: {}", &response.output_excerpt);
194
195 if let Some(usage) = usage {
196 this.update(cx, |this, cx| {
197 this.user_store.update(cx, |user_store, cx| {
198 user_store.update_edit_prediction_usage(usage, cx);
199 });
200 })
201 .ok();
202 }
203
204 if let Some(debug_tx) = &debug_tx {
205 debug_tx
206 .unbounded_send(DebugEvent::EditPredictionFinished(
207 EditPredictionFinishedDebugEvent {
208 buffer: buffer.downgrade(),
209 model_output: Some(response.output_excerpt.clone()),
210 position,
211 },
212 ))
213 .ok();
214 }
215
216 let edit_prediction = process_completion_response(
217 response,
218 buffer,
219 &snapshot,
220 editable_range,
221 inputs,
222 buffer_snapshotted_at,
223 received_response_at,
224 cx,
225 )
226 .await;
227
228 let finished_at = Instant::now();
229
230 // record latency for ~1% of requests
231 if rand::random::<u8>() <= 2 {
232 telemetry::event!(
233 "Edit Prediction Request",
234 context_latency = done_gathering_context_at
235 .duration_since(buffer_snapshotted_at)
236 .as_millis(),
237 request_latency = received_response_at
238 .duration_since(done_gathering_context_at)
239 .as_millis(),
240 process_latency = finished_at.duration_since(received_response_at).as_millis()
241 );
242 }
243
244 edit_prediction.map(Some)
245 })
246}
247
248fn process_completion_response(
249 prediction_response: PredictEditsResponse,
250 buffer: Entity<Buffer>,
251 snapshot: &BufferSnapshot,
252 editable_range: Range<usize>,
253 inputs: ZetaPromptInput,
254 buffer_snapshotted_at: Instant,
255 received_response_at: Instant,
256 cx: &AsyncApp,
257) -> Task<Result<EditPredictionResult>> {
258 let snapshot = snapshot.clone();
259 let request_id = prediction_response.request_id;
260 let output_excerpt = prediction_response.output_excerpt;
261 cx.spawn(async move |cx| {
262 let output_excerpt: Arc<str> = output_excerpt.into();
263
264 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
265 .background_spawn({
266 let output_excerpt = output_excerpt.clone();
267 let editable_range = editable_range.clone();
268 let snapshot = snapshot.clone();
269 async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
270 })
271 .await?
272 .into();
273
274 let id = EditPredictionId(request_id.into());
275 Ok(EditPredictionResult::new(
276 id,
277 &buffer,
278 &snapshot,
279 edits,
280 None,
281 buffer_snapshotted_at,
282 received_response_at,
283 inputs,
284 cx,
285 )
286 .await)
287 })
288}
289
290pub(crate) fn parse_edits(
291 output_excerpt: &str,
292 editable_range: Range<usize>,
293 snapshot: &BufferSnapshot,
294) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
295 let content = output_excerpt.replace(CURSOR_MARKER, "");
296
297 let start_markers = content
298 .match_indices(EDITABLE_REGION_START_MARKER)
299 .collect::<Vec<_>>();
300 anyhow::ensure!(
301 start_markers.len() <= 1,
302 "expected at most one start marker, found {}",
303 start_markers.len()
304 );
305
306 let end_markers = content
307 .match_indices(EDITABLE_REGION_END_MARKER)
308 .collect::<Vec<_>>();
309 anyhow::ensure!(
310 end_markers.len() <= 1,
311 "expected at most one end marker, found {}",
312 end_markers.len()
313 );
314
315 let sof_markers = content
316 .match_indices(START_OF_FILE_MARKER)
317 .collect::<Vec<_>>();
318 anyhow::ensure!(
319 sof_markers.len() <= 1,
320 "expected at most one start-of-file marker, found {}",
321 sof_markers.len()
322 );
323
324 let content_start = start_markers
325 .first()
326 .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
327 .unwrap_or(0);
328 let content_end = end_markers
329 .first()
330 .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
331 .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
332
333 // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
334 let new_text = &content[content_start.min(content_end)..content_end];
335
336 let old_text = snapshot
337 .text_for_range(editable_range.clone())
338 .collect::<String>();
339
340 Ok(compute_edits(
341 old_text,
342 new_text,
343 editable_range.start,
344 snapshot,
345 ))
346}
347
348pub fn compute_edits(
349 old_text: String,
350 new_text: &str,
351 offset: usize,
352 snapshot: &BufferSnapshot,
353) -> Vec<(Range<Anchor>, Arc<str>)> {
354 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
355}
356
357pub fn compute_edits_and_cursor_position(
358 old_text: String,
359 new_text: &str,
360 offset: usize,
361 cursor_offset_in_new_text: Option<usize>,
362 snapshot: &BufferSnapshot,
363) -> (
364 Vec<(Range<Anchor>, Arc<str>)>,
365 Option<PredictedCursorPosition>,
366) {
367 let diffs = text_diff(&old_text, new_text);
368
369 // Delta represents the cumulative change in byte count from all preceding edits.
370 // new_offset = old_offset + delta, so old_offset = new_offset - delta
371 let mut delta: isize = 0;
372 let mut cursor_position: Option<PredictedCursorPosition> = None;
373
374 let edits = diffs
375 .iter()
376 .map(|(raw_old_range, new_text)| {
377 // Compute cursor position if it falls within or before this edit.
378 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
379 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
380 let edit_end_in_new = edit_start_in_new + new_text.len();
381
382 if cursor_offset < edit_start_in_new {
383 let cursor_in_old = (cursor_offset as isize - delta) as usize;
384 cursor_position = Some(PredictedCursorPosition::at_anchor(
385 snapshot.anchor_after(offset + cursor_in_old),
386 ));
387 } else if cursor_offset < edit_end_in_new {
388 let offset_within_insertion = cursor_offset - edit_start_in_new;
389 cursor_position = Some(PredictedCursorPosition::new(
390 snapshot.anchor_before(offset + raw_old_range.start),
391 offset_within_insertion,
392 ));
393 }
394
395 delta += new_text.len() as isize - raw_old_range.len() as isize;
396 }
397
398 // Compute the edit with prefix/suffix trimming.
399 let mut old_range = raw_old_range.clone();
400 let old_slice = &old_text[old_range.clone()];
401
402 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
403 let suffix_len = common_prefix(
404 old_slice[prefix_len..].chars().rev(),
405 new_text[prefix_len..].chars().rev(),
406 );
407
408 old_range.start += offset;
409 old_range.end += offset;
410 old_range.start += prefix_len;
411 old_range.end -= suffix_len;
412
413 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
414 let range = if old_range.is_empty() {
415 let anchor = snapshot.anchor_after(old_range.start);
416 anchor..anchor
417 } else {
418 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
419 };
420 (range, new_text)
421 })
422 .collect();
423
424 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
425 let cursor_in_old = (cursor_offset as isize - delta) as usize;
426 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
427 cursor_position = Some(PredictedCursorPosition::at_anchor(
428 snapshot.anchor_after(buffer_offset),
429 ));
430 }
431
432 (edits, cursor_position)
433}
434
435fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
436 a.zip(b)
437 .take_while(|(a, b)| a == b)
438 .map(|(a, _)| a.len_utf8())
439 .sum()
440}
441
442fn git_info_for_file(
443 project: &Entity<Project>,
444 project_path: &ProjectPath,
445 cx: &App,
446) -> Option<PredictEditsGitInfo> {
447 let git_store = project.read(cx).git_store().read(cx);
448 if let Some((repository, _repo_path)) =
449 git_store.repository_and_path_for_project_path(project_path, cx)
450 {
451 let repository = repository.read(cx);
452 let head_sha = repository
453 .head_commit
454 .as_ref()
455 .map(|head_commit| head_commit.sha.to_string());
456 let remote_origin_url = repository.remote_origin_url.clone();
457 let remote_upstream_url = repository.remote_upstream_url.clone();
458 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
459 return None;
460 }
461 Some(PredictEditsGitInfo {
462 head_sha,
463 remote_origin_url,
464 remote_upstream_url,
465 })
466 } else {
467 None
468 }
469}
470
471pub struct GatherContextOutput {
472 pub body: PredictEditsBody,
473 pub context_range: Range<Point>,
474 pub editable_range: Range<usize>,
475 pub included_events_count: usize,
476}
477
478pub fn gather_context(
479 full_path_str: String,
480 snapshot: &BufferSnapshot,
481 cursor_point: language::Point,
482 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
483 trigger: PredictEditsRequestTrigger,
484 cx: &App,
485) -> Task<Result<GatherContextOutput>> {
486 cx.background_spawn({
487 let snapshot = snapshot.clone();
488 async move {
489 let input_excerpt = excerpt_for_cursor_position(
490 cursor_point,
491 &full_path_str,
492 &snapshot,
493 MAX_REWRITE_TOKENS,
494 MAX_CONTEXT_TOKENS,
495 );
496 let (input_events, included_events_count) = prompt_for_events();
497 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
498
499 let body = PredictEditsBody {
500 input_events,
501 input_excerpt: input_excerpt.prompt,
502 can_collect_data: false,
503 diagnostic_groups: None,
504 git_info: None,
505 outline: None,
506 speculated_output: None,
507 trigger,
508 };
509
510 Ok(GatherContextOutput {
511 body,
512 context_range: input_excerpt.context_range,
513 editable_range,
514 included_events_count,
515 })
516 }
517 })
518}
519
520pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
521 prompt_for_events_impl(events, max_tokens).0
522}
523
524fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
525 let mut result = String::new();
526 for (ix, event) in events.iter().rev().enumerate() {
527 let event_string = format_event(event.as_ref());
528 let event_tokens = guess_token_count(event_string.len());
529 if event_tokens > remaining_tokens {
530 return (result, ix);
531 }
532
533 if !result.is_empty() {
534 result.insert_str(0, "\n\n");
535 }
536 result.insert_str(0, &event_string);
537 remaining_tokens -= event_tokens;
538 }
539 return (result, events.len());
540}
541
542pub fn format_event(event: &Event) -> String {
543 match event {
544 Event::BufferChange {
545 path,
546 old_path,
547 diff,
548 ..
549 } => {
550 let mut prompt = String::new();
551
552 if old_path != path {
553 writeln!(
554 prompt,
555 "User renamed {} to {}\n",
556 old_path.display(),
557 path.display()
558 )
559 .unwrap();
560 }
561
562 if !diff.is_empty() {
563 write!(
564 prompt,
565 "User edited {}:\n```diff\n{}\n```",
566 path.display(),
567 diff
568 )
569 .unwrap();
570 }
571
572 prompt
573 }
574 }
575}
576
577#[derive(Debug)]
578pub struct InputExcerpt {
579 pub context_range: Range<Point>,
580 pub editable_range: Range<Point>,
581 pub prompt: String,
582}
583
584pub fn excerpt_for_cursor_position(
585 position: Point,
586 path: &str,
587 snapshot: &BufferSnapshot,
588 editable_region_token_limit: usize,
589 context_token_limit: usize,
590) -> InputExcerpt {
591 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
592 position,
593 snapshot,
594 editable_region_token_limit,
595 context_token_limit,
596 );
597
598 let mut prompt = String::new();
599
600 writeln!(&mut prompt, "```{path}").unwrap();
601 if context_range.start == Point::zero() {
602 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
603 }
604
605 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
606 prompt.push_str(chunk.text);
607 }
608
609 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
610
611 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
612 prompt.push_str(chunk.text);
613 }
614 write!(prompt, "\n```").unwrap();
615
616 InputExcerpt {
617 context_range,
618 editable_range,
619 prompt,
620 }
621}
622
623fn push_editable_range(
624 cursor_position: Point,
625 snapshot: &BufferSnapshot,
626 editable_range: Range<Point>,
627 prompt: &mut String,
628) {
629 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
630 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
631 prompt.push_str(chunk.text);
632 }
633 prompt.push_str(CURSOR_MARKER);
634 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
635 prompt.push_str(chunk.text);
636 }
637 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use gpui::{App, AppContext};
644 use indoc::indoc;
645 use language::Buffer;
646
647 #[gpui::test]
648 fn test_excerpt_for_cursor_position(cx: &mut App) {
649 let text = indoc! {r#"
650 fn foo() {
651 let x = 42;
652 println!("Hello, world!");
653 }
654
655 fn bar() {
656 let x = 42;
657 let mut sum = 0;
658 for i in 0..x {
659 sum += i;
660 }
661 println!("Sum: {}", sum);
662 return sum;
663 }
664
665 fn generate_random_numbers() -> Vec<i32> {
666 let mut rng = rand::thread_rng();
667 let mut numbers = Vec::new();
668 for _ in 0..5 {
669 numbers.push(rng.random_range(1..101));
670 }
671 numbers
672 }
673 "#};
674 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
675 let snapshot = buffer.read(cx).snapshot();
676
677 // The excerpt expands to syntax boundaries.
678 // With 50 token editable limit, we get a region that expands to syntax nodes.
679 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
680 assert_eq!(
681 excerpt.prompt,
682 indoc! {r#"
683 ```main.rs
684
685 fn bar() {
686 let x = 42;
687 <|editable_region_start|>
688 let mut sum = 0;
689 for i in 0..x {
690 sum += i;
691 }
692 println!("Sum: {}", sum);
693 r<|user_cursor_is_here|>eturn sum;
694 }
695
696 fn generate_random_numbers() -> Vec<i32> {
697 <|editable_region_end|>
698 let mut rng = rand::thread_rng();
699 let mut numbers = Vec::new();
700 ```"#}
701 );
702
703 // With smaller budget, the region expands to syntax boundaries but is tighter.
704 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
705 assert_eq!(
706 excerpt.prompt,
707 indoc! {r#"
708 ```main.rs
709 fn bar() {
710 let x = 42;
711 let mut sum = 0;
712 for i in 0..x {
713 <|editable_region_start|>
714 sum += i;
715 }
716 println!("Sum: {}", sum);
717 r<|user_cursor_is_here|>eturn sum;
718 }
719
720 fn generate_random_numbers() -> Vec<i32> {
721 <|editable_region_end|>
722 let mut rng = rand::thread_rng();
723 ```"#}
724 );
725 }
726
727 #[gpui::test]
728 fn test_parse_edits_empty_editable_region(cx: &mut App) {
729 let text = "fn foo() {\n let x = 42;\n}\n";
730 let buffer = cx.new(|cx| Buffer::local(text, cx));
731 let snapshot = buffer.read(cx).snapshot();
732
733 let output = "<|editable_region_start|>\n<|editable_region_end|>";
734 let editable_range = 0..text.len();
735 let edits = parse_edits(output, editable_range, &snapshot).unwrap();
736 assert_eq!(edits.len(), 1);
737 let (range, new_text) = &edits[0];
738 assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
739 assert_eq!(new_text.as_ref(), "");
740 }
741}