1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::Result;
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{
21 Event, ZetaPromptInput,
22 zeta1::{
23 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
24 START_OF_FILE_MARKER,
25 },
26};
27
28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
31
32pub(crate) fn request_prediction_with_zeta1(
33 store: &mut EditPredictionStore,
34 EditPredictionModelInput {
35 project,
36 buffer,
37 snapshot,
38 position,
39 events,
40 trigger,
41 debug_tx,
42 ..
43 }: EditPredictionModelInput,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let buffer_snapshotted_at = Instant::now();
47 let client = store.client.clone();
48 let llm_token = store.llm_token.clone();
49 let app_version = AppVersion::global(cx);
50
51 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
52 let can_collect_file = store.can_collect_file(&project, file, cx);
53 let git_info = if can_collect_file {
54 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
55 } else {
56 None
57 };
58 (git_info, can_collect_file)
59 } else {
60 (None, false)
61 };
62
63 let full_path: Arc<Path> = snapshot
64 .file()
65 .map(|f| Arc::from(f.full_path(cx).as_path()))
66 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
67 let full_path_str = full_path.to_string_lossy().into_owned();
68 let cursor_point = position.to_point(&snapshot);
69 let prompt_for_events = {
70 let events = events.clone();
71 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
72 };
73 let gather_task = gather_context(
74 full_path_str,
75 &snapshot,
76 cursor_point,
77 prompt_for_events,
78 trigger,
79 cx,
80 );
81
82 let (uri, require_auth) = match &store.custom_predict_edits_url {
83 Some(custom_url) => (custom_url.clone(), false),
84 None => {
85 match client
86 .http_client()
87 .build_zed_llm_url("/predict_edits/v2", &[])
88 {
89 Ok(url) => (url.into(), true),
90 Err(err) => return Task::ready(Err(err)),
91 }
92 }
93 };
94
95 cx.spawn(async move |this, cx| {
96 let GatherContextOutput {
97 mut body,
98 context_range,
99 editable_range,
100 included_events_count,
101 } = gather_task.await?;
102 let done_gathering_context_at = Instant::now();
103
104 let included_events = &events[events.len() - included_events_count..events.len()];
105 body.can_collect_data = can_collect_file
106 && this
107 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
108 .unwrap_or(false);
109 if body.can_collect_data {
110 body.git_info = git_info;
111 }
112
113 log::debug!(
114 "Events:\n{}\nExcerpt:\n{:?}",
115 body.input_events,
116 body.input_excerpt
117 );
118
119 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
120 |request| {
121 Ok(request
122 .uri(uri.as_str())
123 .body(serde_json::to_string(&body)?.into())?)
124 },
125 client,
126 llm_token,
127 app_version,
128 require_auth,
129 )
130 .await;
131
132 let context_start_offset = context_range.start.to_offset(&snapshot);
133 let editable_offset_range = editable_range.to_offset(&snapshot);
134
135 let inputs = ZetaPromptInput {
136 events: included_events.into(),
137 related_files: vec![],
138 cursor_path: full_path,
139 cursor_excerpt: snapshot
140 .text_for_range(context_range)
141 .collect::<String>()
142 .into(),
143 editable_range_in_excerpt: (editable_range.start - context_start_offset)
144 ..(editable_offset_range.end - context_start_offset),
145 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
146 };
147
148 if let Some(debug_tx) = &debug_tx {
149 debug_tx
150 .unbounded_send(DebugEvent::EditPredictionStarted(
151 EditPredictionStartedDebugEvent {
152 buffer: buffer.downgrade(),
153 prompt: Some(serde_json::to_string(&inputs).unwrap()),
154 position,
155 },
156 ))
157 .ok();
158 }
159
160 let (response, usage) = match response {
161 Ok(response) => response,
162 Err(err) => {
163 if err.is::<ZedUpdateRequiredError>() {
164 cx.update(|cx| {
165 this.update(cx, |ep_store, _cx| {
166 ep_store.update_required = true;
167 })
168 .ok();
169
170 let error_message: SharedString = err.to_string().into();
171 show_app_notification(
172 NotificationId::unique::<ZedUpdateRequiredError>(),
173 cx,
174 move |cx| {
175 cx.new(|cx| {
176 ErrorMessagePrompt::new(error_message.clone(), cx)
177 .with_link_button("Update Zed", "https://zed.dev/releases")
178 })
179 },
180 );
181 });
182 }
183
184 return Err(err);
185 }
186 };
187
188 let received_response_at = Instant::now();
189 log::debug!("completion response: {}", &response.output_excerpt);
190
191 if let Some(usage) = usage {
192 this.update(cx, |this, cx| {
193 this.user_store.update(cx, |user_store, cx| {
194 user_store.update_edit_prediction_usage(usage, cx);
195 });
196 })
197 .ok();
198 }
199
200 if let Some(debug_tx) = &debug_tx {
201 debug_tx
202 .unbounded_send(DebugEvent::EditPredictionFinished(
203 EditPredictionFinishedDebugEvent {
204 buffer: buffer.downgrade(),
205 model_output: Some(response.output_excerpt.clone()),
206 position,
207 },
208 ))
209 .ok();
210 }
211
212 let edit_prediction = process_completion_response(
213 response,
214 buffer,
215 &snapshot,
216 editable_range,
217 inputs,
218 buffer_snapshotted_at,
219 received_response_at,
220 cx,
221 )
222 .await;
223
224 let finished_at = Instant::now();
225
226 // record latency for ~1% of requests
227 if rand::random::<u8>() <= 2 {
228 telemetry::event!(
229 "Edit Prediction Request",
230 context_latency = done_gathering_context_at
231 .duration_since(buffer_snapshotted_at)
232 .as_millis(),
233 request_latency = received_response_at
234 .duration_since(done_gathering_context_at)
235 .as_millis(),
236 process_latency = finished_at.duration_since(received_response_at).as_millis()
237 );
238 }
239
240 edit_prediction.map(Some)
241 })
242}
243
244fn process_completion_response(
245 prediction_response: PredictEditsResponse,
246 buffer: Entity<Buffer>,
247 snapshot: &BufferSnapshot,
248 editable_range: Range<usize>,
249 inputs: ZetaPromptInput,
250 buffer_snapshotted_at: Instant,
251 received_response_at: Instant,
252 cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254 let snapshot = snapshot.clone();
255 let request_id = prediction_response.request_id;
256 let output_excerpt = prediction_response.output_excerpt;
257 cx.spawn(async move |cx| {
258 let output_excerpt: Arc<str> = output_excerpt.into();
259
260 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261 .background_spawn({
262 let output_excerpt = output_excerpt.clone();
263 let editable_range = editable_range.clone();
264 let snapshot = snapshot.clone();
265 async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
266 })
267 .await?
268 .into();
269
270 let id = EditPredictionId(request_id.into());
271 Ok(EditPredictionResult::new(
272 id,
273 &buffer,
274 &snapshot,
275 edits,
276 buffer_snapshotted_at,
277 received_response_at,
278 inputs,
279 cx,
280 )
281 .await)
282 })
283}
284
285pub(crate) fn parse_edits(
286 output_excerpt: &str,
287 editable_range: Range<usize>,
288 snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290 let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292 let start_markers = content
293 .match_indices(EDITABLE_REGION_START_MARKER)
294 .collect::<Vec<_>>();
295 anyhow::ensure!(
296 start_markers.len() <= 1,
297 "expected at most one start marker, found {}",
298 start_markers.len()
299 );
300
301 let end_markers = content
302 .match_indices(EDITABLE_REGION_END_MARKER)
303 .collect::<Vec<_>>();
304 anyhow::ensure!(
305 end_markers.len() <= 1,
306 "expected at most one end marker, found {}",
307 end_markers.len()
308 );
309
310 let sof_markers = content
311 .match_indices(START_OF_FILE_MARKER)
312 .collect::<Vec<_>>();
313 anyhow::ensure!(
314 sof_markers.len() <= 1,
315 "expected at most one start-of-file marker, found {}",
316 sof_markers.len()
317 );
318
319 let content_start = start_markers
320 .first()
321 .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322 .unwrap_or(0);
323 let content_end = end_markers
324 .first()
325 .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326 .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328 // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
329 let new_text = &content[content_start.min(content_end)..content_end];
330
331 let old_text = snapshot
332 .text_for_range(editable_range.clone())
333 .collect::<String>();
334
335 Ok(compute_edits(
336 old_text,
337 new_text,
338 editable_range.start,
339 snapshot,
340 ))
341}
342
343pub fn compute_edits(
344 old_text: String,
345 new_text: &str,
346 offset: usize,
347 snapshot: &BufferSnapshot,
348) -> Vec<(Range<Anchor>, Arc<str>)> {
349 text_diff(&old_text, new_text)
350 .into_iter()
351 .map(|(mut old_range, new_text)| {
352 let old_slice = &old_text[old_range.clone()];
353
354 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
355 let suffix_len = common_prefix(
356 old_slice[prefix_len..].chars().rev(),
357 new_text[prefix_len..].chars().rev(),
358 );
359
360 old_range.start += offset;
361 old_range.end += offset;
362 old_range.start += prefix_len;
363 old_range.end -= suffix_len;
364
365 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
366 let range = if old_range.is_empty() {
367 let anchor = snapshot.anchor_after(old_range.start);
368 anchor..anchor
369 } else {
370 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
371 };
372 (range, new_text)
373 })
374 .collect()
375}
376
377fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
378 a.zip(b)
379 .take_while(|(a, b)| a == b)
380 .map(|(a, _)| a.len_utf8())
381 .sum()
382}
383
384fn git_info_for_file(
385 project: &Entity<Project>,
386 project_path: &ProjectPath,
387 cx: &App,
388) -> Option<PredictEditsGitInfo> {
389 let git_store = project.read(cx).git_store().read(cx);
390 if let Some((repository, _repo_path)) =
391 git_store.repository_and_path_for_project_path(project_path, cx)
392 {
393 let repository = repository.read(cx);
394 let head_sha = repository
395 .head_commit
396 .as_ref()
397 .map(|head_commit| head_commit.sha.to_string());
398 let remote_origin_url = repository.remote_origin_url.clone();
399 let remote_upstream_url = repository.remote_upstream_url.clone();
400 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
401 return None;
402 }
403 Some(PredictEditsGitInfo {
404 head_sha,
405 remote_origin_url,
406 remote_upstream_url,
407 })
408 } else {
409 None
410 }
411}
412
413pub struct GatherContextOutput {
414 pub body: PredictEditsBody,
415 pub context_range: Range<Point>,
416 pub editable_range: Range<usize>,
417 pub included_events_count: usize,
418}
419
420pub fn gather_context(
421 full_path_str: String,
422 snapshot: &BufferSnapshot,
423 cursor_point: language::Point,
424 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
425 trigger: PredictEditsRequestTrigger,
426 cx: &App,
427) -> Task<Result<GatherContextOutput>> {
428 cx.background_spawn({
429 let snapshot = snapshot.clone();
430 async move {
431 let input_excerpt = excerpt_for_cursor_position(
432 cursor_point,
433 &full_path_str,
434 &snapshot,
435 MAX_REWRITE_TOKENS,
436 MAX_CONTEXT_TOKENS,
437 );
438 let (input_events, included_events_count) = prompt_for_events();
439 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
440
441 let body = PredictEditsBody {
442 input_events,
443 input_excerpt: input_excerpt.prompt,
444 can_collect_data: false,
445 diagnostic_groups: None,
446 git_info: None,
447 outline: None,
448 speculated_output: None,
449 trigger,
450 };
451
452 Ok(GatherContextOutput {
453 body,
454 context_range: input_excerpt.context_range,
455 editable_range,
456 included_events_count,
457 })
458 }
459 })
460}
461
462pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
463 prompt_for_events_impl(events, max_tokens).0
464}
465
466fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
467 let mut result = String::new();
468 for (ix, event) in events.iter().rev().enumerate() {
469 let event_string = format_event(event.as_ref());
470 let event_tokens = guess_token_count(event_string.len());
471 if event_tokens > remaining_tokens {
472 return (result, ix);
473 }
474
475 if !result.is_empty() {
476 result.insert_str(0, "\n\n");
477 }
478 result.insert_str(0, &event_string);
479 remaining_tokens -= event_tokens;
480 }
481 return (result, events.len());
482}
483
484pub fn format_event(event: &Event) -> String {
485 match event {
486 Event::BufferChange {
487 path,
488 old_path,
489 diff,
490 ..
491 } => {
492 let mut prompt = String::new();
493
494 if old_path != path {
495 writeln!(
496 prompt,
497 "User renamed {} to {}\n",
498 old_path.display(),
499 path.display()
500 )
501 .unwrap();
502 }
503
504 if !diff.is_empty() {
505 write!(
506 prompt,
507 "User edited {}:\n```diff\n{}\n```",
508 path.display(),
509 diff
510 )
511 .unwrap();
512 }
513
514 prompt
515 }
516 }
517}
518
519#[derive(Debug)]
520pub struct InputExcerpt {
521 pub context_range: Range<Point>,
522 pub editable_range: Range<Point>,
523 pub prompt: String,
524}
525
526pub fn excerpt_for_cursor_position(
527 position: Point,
528 path: &str,
529 snapshot: &BufferSnapshot,
530 editable_region_token_limit: usize,
531 context_token_limit: usize,
532) -> InputExcerpt {
533 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
534 position,
535 snapshot,
536 editable_region_token_limit,
537 context_token_limit,
538 );
539
540 let mut prompt = String::new();
541
542 writeln!(&mut prompt, "```{path}").unwrap();
543 if context_range.start == Point::zero() {
544 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
545 }
546
547 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
548 prompt.push_str(chunk.text);
549 }
550
551 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
552
553 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
554 prompt.push_str(chunk.text);
555 }
556 write!(prompt, "\n```").unwrap();
557
558 InputExcerpt {
559 context_range,
560 editable_range,
561 prompt,
562 }
563}
564
565fn push_editable_range(
566 cursor_position: Point,
567 snapshot: &BufferSnapshot,
568 editable_range: Range<Point>,
569 prompt: &mut String,
570) {
571 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
572 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
573 prompt.push_str(chunk.text);
574 }
575 prompt.push_str(CURSOR_MARKER);
576 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
577 prompt.push_str(chunk.text);
578 }
579 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use gpui::{App, AppContext};
586 use indoc::indoc;
587 use language::Buffer;
588
589 #[gpui::test]
590 fn test_excerpt_for_cursor_position(cx: &mut App) {
591 let text = indoc! {r#"
592 fn foo() {
593 let x = 42;
594 println!("Hello, world!");
595 }
596
597 fn bar() {
598 let x = 42;
599 let mut sum = 0;
600 for i in 0..x {
601 sum += i;
602 }
603 println!("Sum: {}", sum);
604 return sum;
605 }
606
607 fn generate_random_numbers() -> Vec<i32> {
608 let mut rng = rand::thread_rng();
609 let mut numbers = Vec::new();
610 for _ in 0..5 {
611 numbers.push(rng.random_range(1..101));
612 }
613 numbers
614 }
615 "#};
616 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
617 let snapshot = buffer.read(cx).snapshot();
618
619 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
620 // when a larger scope doesn't fit the editable region.
621 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
622 assert_eq!(
623 excerpt.prompt,
624 indoc! {r#"
625 ```main.rs
626 let x = 42;
627 println!("Hello, world!");
628 <|editable_region_start|>
629 }
630
631 fn bar() {
632 let x = 42;
633 let mut sum = 0;
634 for i in 0..x {
635 sum += i;
636 }
637 println!("Sum: {}", sum);
638 r<|user_cursor_is_here|>eturn sum;
639 }
640
641 fn generate_random_numbers() -> Vec<i32> {
642 <|editable_region_end|>
643 let mut rng = rand::thread_rng();
644 let mut numbers = Vec::new();
645 ```"#}
646 );
647
648 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
649 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
650 assert_eq!(
651 excerpt.prompt,
652 indoc! {r#"
653 ```main.rs
654 fn bar() {
655 let x = 42;
656 let mut sum = 0;
657 <|editable_region_start|>
658 for i in 0..x {
659 sum += i;
660 }
661 println!("Sum: {}", sum);
662 r<|user_cursor_is_here|>eturn sum;
663 }
664
665 fn generate_random_numbers() -> Vec<i32> {
666 let mut rng = rand::thread_rng();
667 <|editable_region_end|>
668 let mut numbers = Vec::new();
669 for _ in 0..5 {
670 numbers.push(rng.random_range(1..101));
671 ```"#}
672 );
673 }
674
675 #[gpui::test]
676 fn test_parse_edits_empty_editable_region(cx: &mut App) {
677 let text = "fn foo() {\n let x = 42;\n}\n";
678 let buffer = cx.new(|cx| Buffer::local(text, cx));
679 let snapshot = buffer.read(cx).snapshot();
680
681 let output = "<|editable_region_start|>\n<|editable_region_end|>";
682 let editable_range = 0..text.len();
683 let edits = parse_edits(output, editable_range, &snapshot).unwrap();
684 assert_eq!(edits.len(), 1);
685 let (range, new_text) = &edits[0];
686 assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
687 assert_eq!(new_text.as_ref(), "");
688 }
689}