1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::Result;
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use edit_prediction_types::PredictedCursorPosition;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use text::Bias;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22use zeta_prompt::{
23 Event, ZetaPromptInput,
24 zeta1::{
25 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
26 START_OF_FILE_MARKER,
27 },
28};
29
30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
33
34pub(crate) fn request_prediction_with_zeta1(
35 store: &mut EditPredictionStore,
36 EditPredictionModelInput {
37 project,
38 buffer,
39 snapshot,
40 position,
41 events,
42 trigger,
43 debug_tx,
44 ..
45 }: EditPredictionModelInput,
46 cx: &mut Context<EditPredictionStore>,
47) -> Task<Result<Option<EditPredictionResult>>> {
48 let buffer_snapshotted_at = Instant::now();
49 let client = store.client.clone();
50 let llm_token = store.llm_token.clone();
51 let app_version = AppVersion::global(cx);
52
53 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
54 let can_collect_file = store.can_collect_file(&project, file, cx);
55 let git_info = if can_collect_file {
56 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
57 } else {
58 None
59 };
60 (git_info, can_collect_file)
61 } else {
62 (None, false)
63 };
64
65 let full_path: Arc<Path> = snapshot
66 .file()
67 .map(|f| Arc::from(f.full_path(cx).as_path()))
68 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
69 let full_path_str = full_path.to_string_lossy().into_owned();
70 let cursor_point = position.to_point(&snapshot);
71 let prompt_for_events = {
72 let events = events.clone();
73 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
74 };
75 let gather_task = gather_context(
76 full_path_str,
77 &snapshot,
78 cursor_point,
79 prompt_for_events,
80 trigger,
81 cx,
82 );
83
84 let (uri, require_auth) = match &store.custom_predict_edits_url {
85 Some(custom_url) => (custom_url.clone(), false),
86 None => {
87 match client
88 .http_client()
89 .build_zed_llm_url("/predict_edits/v2", &[])
90 {
91 Ok(url) => (url.into(), true),
92 Err(err) => return Task::ready(Err(err)),
93 }
94 }
95 };
96
97 cx.spawn(async move |this, cx| {
98 let GatherContextOutput {
99 mut body,
100 context_range,
101 editable_range,
102 included_events_count,
103 } = gather_task.await?;
104 let done_gathering_context_at = Instant::now();
105
106 let included_events = &events[events.len() - included_events_count..events.len()];
107 body.can_collect_data = can_collect_file
108 && this
109 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
110 .unwrap_or(false);
111 if body.can_collect_data {
112 body.git_info = git_info;
113 }
114
115 log::debug!(
116 "Events:\n{}\nExcerpt:\n{:?}",
117 body.input_events,
118 body.input_excerpt
119 );
120
121 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
122 |request| {
123 Ok(request
124 .uri(uri.as_str())
125 .body(serde_json::to_string(&body)?.into())?)
126 },
127 client,
128 llm_token,
129 app_version,
130 require_auth,
131 )
132 .await;
133
134 let context_start_offset = context_range.start.to_offset(&snapshot);
135 let context_start_row = context_range.start.row;
136 let editable_offset_range = editable_range.to_offset(&snapshot);
137
138 let inputs = ZetaPromptInput {
139 events: included_events.into(),
140 related_files: vec![],
141 cursor_path: full_path,
142 cursor_excerpt: snapshot
143 .text_for_range(context_range)
144 .collect::<String>()
145 .into(),
146 editable_range_in_excerpt: (editable_range.start - context_start_offset)
147 ..(editable_offset_range.end - context_start_offset),
148 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
149 excerpt_start_row: Some(context_start_row),
150 };
151
152 if let Some(debug_tx) = &debug_tx {
153 debug_tx
154 .unbounded_send(DebugEvent::EditPredictionStarted(
155 EditPredictionStartedDebugEvent {
156 buffer: buffer.downgrade(),
157 prompt: Some(serde_json::to_string(&inputs).unwrap()),
158 position,
159 },
160 ))
161 .ok();
162 }
163
164 let (response, usage) = match response {
165 Ok(response) => response,
166 Err(err) => {
167 if err.is::<ZedUpdateRequiredError>() {
168 cx.update(|cx| {
169 this.update(cx, |ep_store, _cx| {
170 ep_store.update_required = true;
171 })
172 .ok();
173
174 let error_message: SharedString = err.to_string().into();
175 show_app_notification(
176 NotificationId::unique::<ZedUpdateRequiredError>(),
177 cx,
178 move |cx| {
179 cx.new(|cx| {
180 ErrorMessagePrompt::new(error_message.clone(), cx)
181 .with_link_button("Update Zed", "https://zed.dev/releases")
182 })
183 },
184 );
185 });
186 }
187
188 return Err(err);
189 }
190 };
191
192 let received_response_at = Instant::now();
193 log::debug!("completion response: {}", &response.output_excerpt);
194
195 if let Some(usage) = usage {
196 this.update(cx, |this, cx| {
197 this.user_store.update(cx, |user_store, cx| {
198 user_store.update_edit_prediction_usage(usage, cx);
199 });
200 })
201 .ok();
202 }
203
204 if let Some(debug_tx) = &debug_tx {
205 debug_tx
206 .unbounded_send(DebugEvent::EditPredictionFinished(
207 EditPredictionFinishedDebugEvent {
208 buffer: buffer.downgrade(),
209 model_output: Some(response.output_excerpt.clone()),
210 position,
211 },
212 ))
213 .ok();
214 }
215
216 let edit_prediction = process_completion_response(
217 response,
218 buffer,
219 &snapshot,
220 editable_range,
221 inputs,
222 buffer_snapshotted_at,
223 received_response_at,
224 cx,
225 )
226 .await;
227
228 let finished_at = Instant::now();
229
230 // record latency for ~1% of requests
231 if rand::random::<u8>() <= 2 {
232 telemetry::event!(
233 "Edit Prediction Request",
234 context_latency = done_gathering_context_at
235 .duration_since(buffer_snapshotted_at)
236 .as_millis(),
237 request_latency = received_response_at
238 .duration_since(done_gathering_context_at)
239 .as_millis(),
240 process_latency = finished_at.duration_since(received_response_at).as_millis()
241 );
242 }
243
244 edit_prediction.map(Some)
245 })
246}
247
248fn process_completion_response(
249 prediction_response: PredictEditsResponse,
250 buffer: Entity<Buffer>,
251 snapshot: &BufferSnapshot,
252 editable_range: Range<usize>,
253 inputs: ZetaPromptInput,
254 buffer_snapshotted_at: Instant,
255 received_response_at: Instant,
256 cx: &AsyncApp,
257) -> Task<Result<EditPredictionResult>> {
258 let snapshot = snapshot.clone();
259 let request_id = prediction_response.request_id;
260 let output_excerpt = prediction_response.output_excerpt;
261 cx.spawn(async move |cx| {
262 let output_excerpt: Arc<str> = output_excerpt.into();
263
264 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
265 .background_spawn({
266 let output_excerpt = output_excerpt.clone();
267 let editable_range = editable_range.clone();
268 let snapshot = snapshot.clone();
269 async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
270 })
271 .await?
272 .into();
273
274 let id = EditPredictionId(request_id.into());
275 Ok(EditPredictionResult::new(
276 id,
277 &buffer,
278 &snapshot,
279 edits,
280 None,
281 buffer_snapshotted_at,
282 received_response_at,
283 inputs,
284 cx,
285 )
286 .await)
287 })
288}
289
290pub(crate) fn parse_edits(
291 output_excerpt: &str,
292 editable_range: Range<usize>,
293 snapshot: &BufferSnapshot,
294) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
295 let content = output_excerpt.replace(CURSOR_MARKER, "");
296
297 let start_markers = content
298 .match_indices(EDITABLE_REGION_START_MARKER)
299 .collect::<Vec<_>>();
300 anyhow::ensure!(
301 start_markers.len() <= 1,
302 "expected at most one start marker, found {}",
303 start_markers.len()
304 );
305
306 let end_markers = content
307 .match_indices(EDITABLE_REGION_END_MARKER)
308 .collect::<Vec<_>>();
309 anyhow::ensure!(
310 end_markers.len() <= 1,
311 "expected at most one end marker, found {}",
312 end_markers.len()
313 );
314
315 let sof_markers = content
316 .match_indices(START_OF_FILE_MARKER)
317 .collect::<Vec<_>>();
318 anyhow::ensure!(
319 sof_markers.len() <= 1,
320 "expected at most one start-of-file marker, found {}",
321 sof_markers.len()
322 );
323
324 let content_start = start_markers
325 .first()
326 .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
327 .unwrap_or(0);
328 let content_end = end_markers
329 .first()
330 .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
331 .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
332
333 let new_text = &content[content_start..content_end];
334
335 let old_text = snapshot
336 .text_for_range(editable_range.clone())
337 .collect::<String>();
338
339 Ok(compute_edits(
340 old_text,
341 new_text,
342 editable_range.start,
343 snapshot,
344 ))
345}
346
347pub fn compute_edits(
348 old_text: String,
349 new_text: &str,
350 offset: usize,
351 snapshot: &BufferSnapshot,
352) -> Vec<(Range<Anchor>, Arc<str>)> {
353 compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
354}
355
356pub fn compute_edits_and_cursor_position(
357 old_text: String,
358 new_text: &str,
359 offset: usize,
360 cursor_offset_in_new_text: Option<usize>,
361 snapshot: &BufferSnapshot,
362) -> (
363 Vec<(Range<Anchor>, Arc<str>)>,
364 Option<PredictedCursorPosition>,
365) {
366 let diffs = text_diff(&old_text, new_text);
367
368 // Delta represents the cumulative change in byte count from all preceding edits.
369 // new_offset = old_offset + delta, so old_offset = new_offset - delta
370 let mut delta: isize = 0;
371 let mut cursor_position: Option<PredictedCursorPosition> = None;
372
373 let edits = diffs
374 .iter()
375 .map(|(raw_old_range, new_text)| {
376 // Compute cursor position if it falls within or before this edit.
377 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
378 let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
379 let edit_end_in_new = edit_start_in_new + new_text.len();
380
381 if cursor_offset < edit_start_in_new {
382 let cursor_in_old = (cursor_offset as isize - delta) as usize;
383 cursor_position = Some(PredictedCursorPosition::at_anchor(
384 snapshot.anchor_after(offset + cursor_in_old),
385 ));
386 } else if cursor_offset < edit_end_in_new {
387 let offset_within_insertion = cursor_offset - edit_start_in_new;
388 cursor_position = Some(PredictedCursorPosition::new(
389 snapshot.anchor_before(offset + raw_old_range.start),
390 offset_within_insertion,
391 ));
392 }
393
394 delta += new_text.len() as isize - raw_old_range.len() as isize;
395 }
396
397 // Compute the edit with prefix/suffix trimming.
398 let mut old_range = raw_old_range.clone();
399 let old_slice = &old_text[old_range.clone()];
400
401 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
402 let suffix_len = common_prefix(
403 old_slice[prefix_len..].chars().rev(),
404 new_text[prefix_len..].chars().rev(),
405 );
406
407 old_range.start += offset;
408 old_range.end += offset;
409 old_range.start += prefix_len;
410 old_range.end -= suffix_len;
411
412 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
413 let range = if old_range.is_empty() {
414 let anchor = snapshot.anchor_after(old_range.start);
415 anchor..anchor
416 } else {
417 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
418 };
419 (range, new_text)
420 })
421 .collect();
422
423 if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
424 let cursor_in_old = (cursor_offset as isize - delta) as usize;
425 let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
426 cursor_position = Some(PredictedCursorPosition::at_anchor(
427 snapshot.anchor_after(buffer_offset),
428 ));
429 }
430
431 (edits, cursor_position)
432}
433
434fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
435 a.zip(b)
436 .take_while(|(a, b)| a == b)
437 .map(|(a, _)| a.len_utf8())
438 .sum()
439}
440
441fn git_info_for_file(
442 project: &Entity<Project>,
443 project_path: &ProjectPath,
444 cx: &App,
445) -> Option<PredictEditsGitInfo> {
446 let git_store = project.read(cx).git_store().read(cx);
447 if let Some((repository, _repo_path)) =
448 git_store.repository_and_path_for_project_path(project_path, cx)
449 {
450 let repository = repository.read(cx);
451 let head_sha = repository
452 .head_commit
453 .as_ref()
454 .map(|head_commit| head_commit.sha.to_string());
455 let remote_origin_url = repository.remote_origin_url.clone();
456 let remote_upstream_url = repository.remote_upstream_url.clone();
457 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
458 return None;
459 }
460 Some(PredictEditsGitInfo {
461 head_sha,
462 remote_origin_url,
463 remote_upstream_url,
464 })
465 } else {
466 None
467 }
468}
469
470pub struct GatherContextOutput {
471 pub body: PredictEditsBody,
472 pub context_range: Range<Point>,
473 pub editable_range: Range<usize>,
474 pub included_events_count: usize,
475}
476
477pub fn gather_context(
478 full_path_str: String,
479 snapshot: &BufferSnapshot,
480 cursor_point: language::Point,
481 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
482 trigger: PredictEditsRequestTrigger,
483 cx: &App,
484) -> Task<Result<GatherContextOutput>> {
485 cx.background_spawn({
486 let snapshot = snapshot.clone();
487 async move {
488 let input_excerpt = excerpt_for_cursor_position(
489 cursor_point,
490 &full_path_str,
491 &snapshot,
492 MAX_REWRITE_TOKENS,
493 MAX_CONTEXT_TOKENS,
494 );
495 let (input_events, included_events_count) = prompt_for_events();
496 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
497
498 let body = PredictEditsBody {
499 input_events,
500 input_excerpt: input_excerpt.prompt,
501 can_collect_data: false,
502 diagnostic_groups: None,
503 git_info: None,
504 outline: None,
505 speculated_output: None,
506 trigger,
507 };
508
509 Ok(GatherContextOutput {
510 body,
511 context_range: input_excerpt.context_range,
512 editable_range,
513 included_events_count,
514 })
515 }
516 })
517}
518
519pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
520 prompt_for_events_impl(events, max_tokens).0
521}
522
523fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
524 let mut result = String::new();
525 for (ix, event) in events.iter().rev().enumerate() {
526 let event_string = format_event(event.as_ref());
527 let event_tokens = guess_token_count(event_string.len());
528 if event_tokens > remaining_tokens {
529 return (result, ix);
530 }
531
532 if !result.is_empty() {
533 result.insert_str(0, "\n\n");
534 }
535 result.insert_str(0, &event_string);
536 remaining_tokens -= event_tokens;
537 }
538 return (result, events.len());
539}
540
541pub fn format_event(event: &Event) -> String {
542 match event {
543 Event::BufferChange {
544 path,
545 old_path,
546 diff,
547 ..
548 } => {
549 let mut prompt = String::new();
550
551 if old_path != path {
552 writeln!(
553 prompt,
554 "User renamed {} to {}\n",
555 old_path.display(),
556 path.display()
557 )
558 .unwrap();
559 }
560
561 if !diff.is_empty() {
562 write!(
563 prompt,
564 "User edited {}:\n```diff\n{}\n```",
565 path.display(),
566 diff
567 )
568 .unwrap();
569 }
570
571 prompt
572 }
573 }
574}
575
576#[derive(Debug)]
577pub struct InputExcerpt {
578 pub context_range: Range<Point>,
579 pub editable_range: Range<Point>,
580 pub prompt: String,
581}
582
583pub fn excerpt_for_cursor_position(
584 position: Point,
585 path: &str,
586 snapshot: &BufferSnapshot,
587 editable_region_token_limit: usize,
588 context_token_limit: usize,
589) -> InputExcerpt {
590 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
591 position,
592 snapshot,
593 editable_region_token_limit,
594 context_token_limit,
595 );
596
597 let mut prompt = String::new();
598
599 writeln!(&mut prompt, "```{path}").unwrap();
600 if context_range.start == Point::zero() {
601 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
602 }
603
604 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
605 prompt.push_str(chunk.text);
606 }
607
608 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
609
610 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
611 prompt.push_str(chunk.text);
612 }
613 write!(prompt, "\n```").unwrap();
614
615 InputExcerpt {
616 context_range,
617 editable_range,
618 prompt,
619 }
620}
621
622fn push_editable_range(
623 cursor_position: Point,
624 snapshot: &BufferSnapshot,
625 editable_range: Range<Point>,
626 prompt: &mut String,
627) {
628 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
629 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
630 prompt.push_str(chunk.text);
631 }
632 prompt.push_str(CURSOR_MARKER);
633 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
634 prompt.push_str(chunk.text);
635 }
636 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use gpui::{App, AppContext};
643 use indoc::indoc;
644 use language::Buffer;
645
646 #[gpui::test]
647 fn test_excerpt_for_cursor_position(cx: &mut App) {
648 let text = indoc! {r#"
649 fn foo() {
650 let x = 42;
651 println!("Hello, world!");
652 }
653
654 fn bar() {
655 let x = 42;
656 let mut sum = 0;
657 for i in 0..x {
658 sum += i;
659 }
660 println!("Sum: {}", sum);
661 return sum;
662 }
663
664 fn generate_random_numbers() -> Vec<i32> {
665 let mut rng = rand::thread_rng();
666 let mut numbers = Vec::new();
667 for _ in 0..5 {
668 numbers.push(rng.random_range(1..101));
669 }
670 numbers
671 }
672 "#};
673 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
674 let snapshot = buffer.read(cx).snapshot();
675
676 // The excerpt expands to syntax boundaries.
677 // With 50 token editable limit, we get a region that expands to syntax nodes.
678 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
679 assert_eq!(
680 excerpt.prompt,
681 indoc! {r#"
682 ```main.rs
683
684 fn bar() {
685 let x = 42;
686 <|editable_region_start|>
687 let mut sum = 0;
688 for i in 0..x {
689 sum += i;
690 }
691 println!("Sum: {}", sum);
692 r<|user_cursor_is_here|>eturn sum;
693 }
694
695 fn generate_random_numbers() -> Vec<i32> {
696 <|editable_region_end|>
697 let mut rng = rand::thread_rng();
698 let mut numbers = Vec::new();
699 ```"#}
700 );
701
702 // With smaller budget, the region expands to syntax boundaries but is tighter.
703 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
704 assert_eq!(
705 excerpt.prompt,
706 indoc! {r#"
707 ```main.rs
708 fn bar() {
709 let x = 42;
710 let mut sum = 0;
711 for i in 0..x {
712 <|editable_region_start|>
713 sum += i;
714 }
715 println!("Sum: {}", sum);
716 r<|user_cursor_is_here|>eturn sum;
717 }
718
719 fn generate_random_numbers() -> Vec<i32> {
720 <|editable_region_end|>
721 let mut rng = rand::thread_rng();
722 ```"#}
723 );
724 }
725}