1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{Event, ZetaPromptInput};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 store: &mut EditPredictionStore,
33 EditPredictionModelInput {
34 project,
35 buffer,
36 snapshot,
37 position,
38 events,
39 trigger,
40 debug_tx,
41 ..
42 }: EditPredictionModelInput,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let buffer_snapshotted_at = Instant::now();
46 let client = store.client.clone();
47 let llm_token = store.llm_token.clone();
48 let app_version = AppVersion::global(cx);
49
50 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
51 let can_collect_file = store.can_collect_file(&project, file, cx);
52 let git_info = if can_collect_file {
53 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
54 } else {
55 None
56 };
57 (git_info, can_collect_file)
58 } else {
59 (None, false)
60 };
61
62 let full_path: Arc<Path> = snapshot
63 .file()
64 .map(|f| Arc::from(f.full_path(cx).as_path()))
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66 let full_path_str = full_path.to_string_lossy().into_owned();
67 let cursor_point = position.to_point(&snapshot);
68 let prompt_for_events = {
69 let events = events.clone();
70 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
71 };
72 let gather_task = gather_context(
73 full_path_str,
74 &snapshot,
75 cursor_point,
76 prompt_for_events,
77 trigger,
78 cx,
79 );
80
81 let (uri, require_auth) = match &store.custom_predict_edits_url {
82 Some(custom_url) => (custom_url.clone(), false),
83 None => {
84 match client
85 .http_client()
86 .build_zed_llm_url("/predict_edits/v2", &[])
87 {
88 Ok(url) => (url.into(), true),
89 Err(err) => return Task::ready(Err(err)),
90 }
91 }
92 };
93
94 cx.spawn(async move |this, cx| {
95 let GatherContextOutput {
96 mut body,
97 context_range,
98 editable_range,
99 included_events_count,
100 } = gather_task.await?;
101 let done_gathering_context_at = Instant::now();
102
103 let included_events = &events[events.len() - included_events_count..events.len()];
104 body.can_collect_data = can_collect_file
105 && this
106 .read_with(cx, |this, _| this.can_collect_events(included_events))
107 .unwrap_or(false);
108 if body.can_collect_data {
109 body.git_info = git_info;
110 }
111
112 log::debug!(
113 "Events:\n{}\nExcerpt:\n{:?}",
114 body.input_events,
115 body.input_excerpt
116 );
117
118 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119 |request| {
120 Ok(request
121 .uri(uri.as_str())
122 .body(serde_json::to_string(&body)?.into())?)
123 },
124 client,
125 llm_token,
126 app_version,
127 require_auth,
128 )
129 .await;
130
131 let context_start_offset = context_range.start.to_offset(&snapshot);
132 let editable_offset_range = editable_range.to_offset(&snapshot);
133
134 let inputs = ZetaPromptInput {
135 events: included_events.into(),
136 related_files: vec![].into(),
137 cursor_path: full_path,
138 cursor_excerpt: snapshot
139 .text_for_range(context_range)
140 .collect::<String>()
141 .into(),
142 editable_range_in_excerpt: (editable_range.start - context_start_offset)
143 ..(editable_offset_range.end - context_start_offset),
144 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145 };
146
147 if let Some(debug_tx) = &debug_tx {
148 debug_tx
149 .unbounded_send(DebugEvent::EditPredictionStarted(
150 EditPredictionStartedDebugEvent {
151 buffer: buffer.downgrade(),
152 prompt: Some(serde_json::to_string(&inputs).unwrap()),
153 position,
154 },
155 ))
156 .ok();
157 }
158
159 let (response, usage) = match response {
160 Ok(response) => response,
161 Err(err) => {
162 if err.is::<ZedUpdateRequiredError>() {
163 cx.update(|cx| {
164 this.update(cx, |ep_store, _cx| {
165 ep_store.update_required = true;
166 })
167 .ok();
168
169 let error_message: SharedString = err.to_string().into();
170 show_app_notification(
171 NotificationId::unique::<ZedUpdateRequiredError>(),
172 cx,
173 move |cx| {
174 cx.new(|cx| {
175 ErrorMessagePrompt::new(error_message.clone(), cx)
176 .with_link_button("Update Zed", "https://zed.dev/releases")
177 })
178 },
179 );
180 })
181 .ok();
182 }
183
184 return Err(err);
185 }
186 };
187
188 let received_response_at = Instant::now();
189 log::debug!("completion response: {}", &response.output_excerpt);
190
191 if let Some(usage) = usage {
192 this.update(cx, |this, cx| {
193 this.user_store.update(cx, |user_store, cx| {
194 user_store.update_edit_prediction_usage(usage, cx);
195 });
196 })
197 .ok();
198 }
199
200 if let Some(debug_tx) = &debug_tx {
201 debug_tx
202 .unbounded_send(DebugEvent::EditPredictionFinished(
203 EditPredictionFinishedDebugEvent {
204 buffer: buffer.downgrade(),
205 model_output: Some(response.output_excerpt.clone()),
206 position,
207 },
208 ))
209 .ok();
210 }
211
212 let edit_prediction = process_completion_response(
213 response,
214 buffer,
215 &snapshot,
216 editable_range,
217 inputs,
218 buffer_snapshotted_at,
219 received_response_at,
220 cx,
221 )
222 .await;
223
224 let finished_at = Instant::now();
225
226 // record latency for ~1% of requests
227 if rand::random::<u8>() <= 2 {
228 telemetry::event!(
229 "Edit Prediction Request",
230 context_latency = done_gathering_context_at
231 .duration_since(buffer_snapshotted_at)
232 .as_millis(),
233 request_latency = received_response_at
234 .duration_since(done_gathering_context_at)
235 .as_millis(),
236 process_latency = finished_at.duration_since(received_response_at).as_millis()
237 );
238 }
239
240 edit_prediction.map(Some)
241 })
242}
243
244fn process_completion_response(
245 prediction_response: PredictEditsResponse,
246 buffer: Entity<Buffer>,
247 snapshot: &BufferSnapshot,
248 editable_range: Range<usize>,
249 inputs: ZetaPromptInput,
250 buffer_snapshotted_at: Instant,
251 received_response_at: Instant,
252 cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254 let snapshot = snapshot.clone();
255 let request_id = prediction_response.request_id;
256 let output_excerpt = prediction_response.output_excerpt;
257 cx.spawn(async move |cx| {
258 let output_excerpt: Arc<str> = output_excerpt.into();
259
260 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261 .background_spawn({
262 let output_excerpt = output_excerpt.clone();
263 let editable_range = editable_range.clone();
264 let snapshot = snapshot.clone();
265 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
266 })
267 .await?
268 .into();
269
270 let id = EditPredictionId(request_id.into());
271 Ok(EditPredictionResult::new(
272 id,
273 &buffer,
274 &snapshot,
275 edits,
276 buffer_snapshotted_at,
277 received_response_at,
278 inputs,
279 cx,
280 )
281 .await)
282 })
283}
284
285fn parse_edits(
286 output_excerpt: Arc<str>,
287 editable_range: Range<usize>,
288 snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290 let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292 let start_markers = content
293 .match_indices(EDITABLE_REGION_START_MARKER)
294 .collect::<Vec<_>>();
295 anyhow::ensure!(
296 start_markers.len() == 1,
297 "expected exactly one start marker, found {}",
298 start_markers.len()
299 );
300
301 let end_markers = content
302 .match_indices(EDITABLE_REGION_END_MARKER)
303 .collect::<Vec<_>>();
304 anyhow::ensure!(
305 end_markers.len() == 1,
306 "expected exactly one end marker, found {}",
307 end_markers.len()
308 );
309
310 let sof_markers = content
311 .match_indices(START_OF_FILE_MARKER)
312 .collect::<Vec<_>>();
313 anyhow::ensure!(
314 sof_markers.len() <= 1,
315 "expected at most one start-of-file marker, found {}",
316 sof_markers.len()
317 );
318
319 let codefence_start = start_markers[0].0;
320 let content = &content[codefence_start..];
321
322 let newline_ix = content.find('\n').context("could not find newline")?;
323 let content = &content[newline_ix + 1..];
324
325 let codefence_end = content
326 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
327 .context("could not find end marker")?;
328 let new_text = &content[..codefence_end];
329
330 let old_text = snapshot
331 .text_for_range(editable_range.clone())
332 .collect::<String>();
333
334 Ok(compute_edits(
335 old_text,
336 new_text,
337 editable_range.start,
338 snapshot,
339 ))
340}
341
342pub fn compute_edits(
343 old_text: String,
344 new_text: &str,
345 offset: usize,
346 snapshot: &BufferSnapshot,
347) -> Vec<(Range<Anchor>, Arc<str>)> {
348 text_diff(&old_text, new_text)
349 .into_iter()
350 .map(|(mut old_range, new_text)| {
351 old_range.start += offset;
352 old_range.end += offset;
353
354 let prefix_len = common_prefix(
355 snapshot.chars_for_range(old_range.clone()),
356 new_text.chars(),
357 );
358 old_range.start += prefix_len;
359
360 let suffix_len = common_prefix(
361 snapshot.reversed_chars_for_range(old_range.clone()),
362 new_text[prefix_len..].chars().rev(),
363 );
364 old_range.end = old_range.end.saturating_sub(suffix_len);
365
366 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
367 let range = if old_range.is_empty() {
368 let anchor = snapshot.anchor_after(old_range.start);
369 anchor..anchor
370 } else {
371 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
372 };
373 (range, new_text)
374 })
375 .collect()
376}
377
378fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
379 a.zip(b)
380 .take_while(|(a, b)| a == b)
381 .map(|(a, _)| a.len_utf8())
382 .sum()
383}
384
385fn git_info_for_file(
386 project: &Entity<Project>,
387 project_path: &ProjectPath,
388 cx: &App,
389) -> Option<PredictEditsGitInfo> {
390 let git_store = project.read(cx).git_store().read(cx);
391 if let Some((repository, _repo_path)) =
392 git_store.repository_and_path_for_project_path(project_path, cx)
393 {
394 let repository = repository.read(cx);
395 let head_sha = repository
396 .head_commit
397 .as_ref()
398 .map(|head_commit| head_commit.sha.to_string());
399 let remote_origin_url = repository.remote_origin_url.clone();
400 let remote_upstream_url = repository.remote_upstream_url.clone();
401 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
402 return None;
403 }
404 Some(PredictEditsGitInfo {
405 head_sha,
406 remote_origin_url,
407 remote_upstream_url,
408 })
409 } else {
410 None
411 }
412}
413
414pub struct GatherContextOutput {
415 pub body: PredictEditsBody,
416 pub context_range: Range<Point>,
417 pub editable_range: Range<usize>,
418 pub included_events_count: usize,
419}
420
421pub fn gather_context(
422 full_path_str: String,
423 snapshot: &BufferSnapshot,
424 cursor_point: language::Point,
425 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
426 trigger: PredictEditsRequestTrigger,
427 cx: &App,
428) -> Task<Result<GatherContextOutput>> {
429 cx.background_spawn({
430 let snapshot = snapshot.clone();
431 async move {
432 let input_excerpt = excerpt_for_cursor_position(
433 cursor_point,
434 &full_path_str,
435 &snapshot,
436 MAX_REWRITE_TOKENS,
437 MAX_CONTEXT_TOKENS,
438 );
439 let (input_events, included_events_count) = prompt_for_events();
440 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
441
442 let body = PredictEditsBody {
443 input_events,
444 input_excerpt: input_excerpt.prompt,
445 can_collect_data: false,
446 diagnostic_groups: None,
447 git_info: None,
448 outline: None,
449 speculated_output: None,
450 trigger,
451 };
452
453 Ok(GatherContextOutput {
454 body,
455 context_range: input_excerpt.context_range,
456 editable_range,
457 included_events_count,
458 })
459 }
460 })
461}
462
463fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
464 let mut result = String::new();
465 for (ix, event) in events.iter().rev().enumerate() {
466 let event_string = format_event(event.as_ref());
467 let event_tokens = guess_token_count(event_string.len());
468 if event_tokens > remaining_tokens {
469 return (result, ix);
470 }
471
472 if !result.is_empty() {
473 result.insert_str(0, "\n\n");
474 }
475 result.insert_str(0, &event_string);
476 remaining_tokens -= event_tokens;
477 }
478 return (result, events.len());
479}
480
481pub fn format_event(event: &Event) -> String {
482 match event {
483 Event::BufferChange {
484 path,
485 old_path,
486 diff,
487 ..
488 } => {
489 let mut prompt = String::new();
490
491 if old_path != path {
492 writeln!(
493 prompt,
494 "User renamed {} to {}\n",
495 old_path.display(),
496 path.display()
497 )
498 .unwrap();
499 }
500
501 if !diff.is_empty() {
502 write!(
503 prompt,
504 "User edited {}:\n```diff\n{}\n```",
505 path.display(),
506 diff
507 )
508 .unwrap();
509 }
510
511 prompt
512 }
513 }
514}
515
516#[derive(Debug)]
517pub struct InputExcerpt {
518 pub context_range: Range<Point>,
519 pub editable_range: Range<Point>,
520 pub prompt: String,
521}
522
523pub fn excerpt_for_cursor_position(
524 position: Point,
525 path: &str,
526 snapshot: &BufferSnapshot,
527 editable_region_token_limit: usize,
528 context_token_limit: usize,
529) -> InputExcerpt {
530 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
531 position,
532 snapshot,
533 editable_region_token_limit,
534 context_token_limit,
535 );
536
537 let mut prompt = String::new();
538
539 writeln!(&mut prompt, "```{path}").unwrap();
540 if context_range.start == Point::zero() {
541 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
542 }
543
544 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
545 prompt.push_str(chunk.text);
546 }
547
548 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
549
550 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
551 prompt.push_str(chunk.text);
552 }
553 write!(prompt, "\n```").unwrap();
554
555 InputExcerpt {
556 context_range,
557 editable_range,
558 prompt,
559 }
560}
561
562fn push_editable_range(
563 cursor_position: Point,
564 snapshot: &BufferSnapshot,
565 editable_range: Range<Point>,
566 prompt: &mut String,
567) {
568 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
569 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
570 prompt.push_str(chunk.text);
571 }
572 prompt.push_str(CURSOR_MARKER);
573 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
574 prompt.push_str(chunk.text);
575 }
576 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use gpui::{App, AppContext};
583 use indoc::indoc;
584 use language::Buffer;
585
586 #[gpui::test]
587 fn test_excerpt_for_cursor_position(cx: &mut App) {
588 let text = indoc! {r#"
589 fn foo() {
590 let x = 42;
591 println!("Hello, world!");
592 }
593
594 fn bar() {
595 let x = 42;
596 let mut sum = 0;
597 for i in 0..x {
598 sum += i;
599 }
600 println!("Sum: {}", sum);
601 return sum;
602 }
603
604 fn generate_random_numbers() -> Vec<i32> {
605 let mut rng = rand::thread_rng();
606 let mut numbers = Vec::new();
607 for _ in 0..5 {
608 numbers.push(rng.random_range(1..101));
609 }
610 numbers
611 }
612 "#};
613 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
614 let snapshot = buffer.read(cx).snapshot();
615
616 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
617 // when a larger scope doesn't fit the editable region.
618 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
619 assert_eq!(
620 excerpt.prompt,
621 indoc! {r#"
622 ```main.rs
623 let x = 42;
624 println!("Hello, world!");
625 <|editable_region_start|>
626 }
627
628 fn bar() {
629 let x = 42;
630 let mut sum = 0;
631 for i in 0..x {
632 sum += i;
633 }
634 println!("Sum: {}", sum);
635 r<|user_cursor_is_here|>eturn sum;
636 }
637
638 fn generate_random_numbers() -> Vec<i32> {
639 <|editable_region_end|>
640 let mut rng = rand::thread_rng();
641 let mut numbers = Vec::new();
642 ```"#}
643 );
644
645 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
646 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
647 assert_eq!(
648 excerpt.prompt,
649 indoc! {r#"
650 ```main.rs
651 fn bar() {
652 let x = 42;
653 let mut sum = 0;
654 <|editable_region_start|>
655 for i in 0..x {
656 sum += i;
657 }
658 println!("Sum: {}", sum);
659 r<|user_cursor_is_here|>eturn sum;
660 }
661
662 fn generate_random_numbers() -> Vec<i32> {
663 let mut rng = rand::thread_rng();
664 <|editable_region_end|>
665 let mut numbers = Vec::new();
666 for _ in 0..5 {
667 numbers.push(rng.random_range(1..101));
668 ```"#}
669 );
670 }
671}