1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{Event, ZetaPromptInput};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 store: &mut EditPredictionStore,
33 EditPredictionModelInput {
34 project,
35 buffer,
36 snapshot,
37 position,
38 events,
39 trigger,
40 debug_tx,
41 ..
42 }: EditPredictionModelInput,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let buffer_snapshotted_at = Instant::now();
46 let client = store.client.clone();
47 let llm_token = store.llm_token.clone();
48 let app_version = AppVersion::global(cx);
49
50 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
51 let can_collect_file = store.can_collect_file(&project, file, cx);
52 let git_info = if can_collect_file {
53 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
54 } else {
55 None
56 };
57 (git_info, can_collect_file)
58 } else {
59 (None, false)
60 };
61
62 let full_path: Arc<Path> = snapshot
63 .file()
64 .map(|f| Arc::from(f.full_path(cx).as_path()))
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66 let full_path_str = full_path.to_string_lossy().into_owned();
67 let cursor_point = position.to_point(&snapshot);
68 let prompt_for_events = {
69 let events = events.clone();
70 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
71 };
72 let gather_task = gather_context(
73 full_path_str,
74 &snapshot,
75 cursor_point,
76 prompt_for_events,
77 trigger,
78 cx,
79 );
80
81 cx.spawn(async move |this, cx| {
82 let GatherContextOutput {
83 mut body,
84 context_range,
85 editable_range,
86 included_events_count,
87 } = gather_task.await?;
88 let done_gathering_context_at = Instant::now();
89
90 let included_events = &events[events.len() - included_events_count..events.len()];
91 body.can_collect_data = can_collect_file
92 && this
93 .read_with(cx, |this, _| this.can_collect_events(included_events))
94 .unwrap_or(false);
95 if body.can_collect_data {
96 body.git_info = git_info;
97 }
98
99 log::debug!(
100 "Events:\n{}\nExcerpt:\n{:?}",
101 body.input_events,
102 body.input_excerpt
103 );
104
105 let http_client = client.http_client();
106
107 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
108 |request| {
109 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
110 predict_edits_url
111 } else {
112 http_client
113 .build_zed_llm_url("/predict_edits/v2", &[])?
114 .as_str()
115 .into()
116 };
117 Ok(request
118 .uri(uri)
119 .body(serde_json::to_string(&body)?.into())?)
120 },
121 client,
122 llm_token,
123 app_version,
124 )
125 .await;
126
127 let context_start_offset = context_range.start.to_offset(&snapshot);
128 let editable_offset_range = editable_range.to_offset(&snapshot);
129
130 let inputs = ZetaPromptInput {
131 events: included_events.into(),
132 related_files: vec![].into(),
133 cursor_path: full_path,
134 cursor_excerpt: snapshot
135 .text_for_range(context_range)
136 .collect::<String>()
137 .into(),
138 editable_range_in_excerpt: (editable_range.start - context_start_offset)
139 ..(editable_offset_range.end - context_start_offset),
140 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
141 };
142
143 if let Some(debug_tx) = &debug_tx {
144 debug_tx
145 .unbounded_send(DebugEvent::EditPredictionStarted(
146 EditPredictionStartedDebugEvent {
147 buffer: buffer.downgrade(),
148 prompt: Some(serde_json::to_string(&inputs).unwrap()),
149 position,
150 },
151 ))
152 .ok();
153 }
154
155 let (response, usage) = match response {
156 Ok(response) => response,
157 Err(err) => {
158 if err.is::<ZedUpdateRequiredError>() {
159 cx.update(|cx| {
160 this.update(cx, |ep_store, _cx| {
161 ep_store.update_required = true;
162 })
163 .ok();
164
165 let error_message: SharedString = err.to_string().into();
166 show_app_notification(
167 NotificationId::unique::<ZedUpdateRequiredError>(),
168 cx,
169 move |cx| {
170 cx.new(|cx| {
171 ErrorMessagePrompt::new(error_message.clone(), cx)
172 .with_link_button("Update Zed", "https://zed.dev/releases")
173 })
174 },
175 );
176 })
177 .ok();
178 }
179
180 return Err(err);
181 }
182 };
183
184 let received_response_at = Instant::now();
185 log::debug!("completion response: {}", &response.output_excerpt);
186
187 if let Some(usage) = usage {
188 this.update(cx, |this, cx| {
189 this.user_store.update(cx, |user_store, cx| {
190 user_store.update_edit_prediction_usage(usage, cx);
191 });
192 })
193 .ok();
194 }
195
196 if let Some(debug_tx) = &debug_tx {
197 debug_tx
198 .unbounded_send(DebugEvent::EditPredictionFinished(
199 EditPredictionFinishedDebugEvent {
200 buffer: buffer.downgrade(),
201 model_output: Some(response.output_excerpt.clone()),
202 position,
203 },
204 ))
205 .ok();
206 }
207
208 let edit_prediction = process_completion_response(
209 response,
210 buffer,
211 &snapshot,
212 editable_range,
213 inputs,
214 buffer_snapshotted_at,
215 received_response_at,
216 cx,
217 )
218 .await;
219
220 let finished_at = Instant::now();
221
222 // record latency for ~1% of requests
223 if rand::random::<u8>() <= 2 {
224 telemetry::event!(
225 "Edit Prediction Request",
226 context_latency = done_gathering_context_at
227 .duration_since(buffer_snapshotted_at)
228 .as_millis(),
229 request_latency = received_response_at
230 .duration_since(done_gathering_context_at)
231 .as_millis(),
232 process_latency = finished_at.duration_since(received_response_at).as_millis()
233 );
234 }
235
236 edit_prediction.map(Some)
237 })
238}
239
240fn process_completion_response(
241 prediction_response: PredictEditsResponse,
242 buffer: Entity<Buffer>,
243 snapshot: &BufferSnapshot,
244 editable_range: Range<usize>,
245 inputs: ZetaPromptInput,
246 buffer_snapshotted_at: Instant,
247 received_response_at: Instant,
248 cx: &AsyncApp,
249) -> Task<Result<EditPredictionResult>> {
250 let snapshot = snapshot.clone();
251 let request_id = prediction_response.request_id;
252 let output_excerpt = prediction_response.output_excerpt;
253 cx.spawn(async move |cx| {
254 let output_excerpt: Arc<str> = output_excerpt.into();
255
256 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
257 .background_spawn({
258 let output_excerpt = output_excerpt.clone();
259 let editable_range = editable_range.clone();
260 let snapshot = snapshot.clone();
261 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
262 })
263 .await?
264 .into();
265
266 let id = EditPredictionId(request_id.into());
267 Ok(EditPredictionResult::new(
268 id,
269 &buffer,
270 &snapshot,
271 edits,
272 buffer_snapshotted_at,
273 received_response_at,
274 inputs,
275 cx,
276 )
277 .await)
278 })
279}
280
281fn parse_edits(
282 output_excerpt: Arc<str>,
283 editable_range: Range<usize>,
284 snapshot: &BufferSnapshot,
285) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
286 let content = output_excerpt.replace(CURSOR_MARKER, "");
287
288 let start_markers = content
289 .match_indices(EDITABLE_REGION_START_MARKER)
290 .collect::<Vec<_>>();
291 anyhow::ensure!(
292 start_markers.len() == 1,
293 "expected exactly one start marker, found {}",
294 start_markers.len()
295 );
296
297 let end_markers = content
298 .match_indices(EDITABLE_REGION_END_MARKER)
299 .collect::<Vec<_>>();
300 anyhow::ensure!(
301 end_markers.len() == 1,
302 "expected exactly one end marker, found {}",
303 end_markers.len()
304 );
305
306 let sof_markers = content
307 .match_indices(START_OF_FILE_MARKER)
308 .collect::<Vec<_>>();
309 anyhow::ensure!(
310 sof_markers.len() <= 1,
311 "expected at most one start-of-file marker, found {}",
312 sof_markers.len()
313 );
314
315 let codefence_start = start_markers[0].0;
316 let content = &content[codefence_start..];
317
318 let newline_ix = content.find('\n').context("could not find newline")?;
319 let content = &content[newline_ix + 1..];
320
321 let codefence_end = content
322 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
323 .context("could not find end marker")?;
324 let new_text = &content[..codefence_end];
325
326 let old_text = snapshot
327 .text_for_range(editable_range.clone())
328 .collect::<String>();
329
330 Ok(compute_edits(
331 old_text,
332 new_text,
333 editable_range.start,
334 snapshot,
335 ))
336}
337
338pub fn compute_edits(
339 old_text: String,
340 new_text: &str,
341 offset: usize,
342 snapshot: &BufferSnapshot,
343) -> Vec<(Range<Anchor>, Arc<str>)> {
344 text_diff(&old_text, new_text)
345 .into_iter()
346 .map(|(mut old_range, new_text)| {
347 old_range.start += offset;
348 old_range.end += offset;
349
350 let prefix_len = common_prefix(
351 snapshot.chars_for_range(old_range.clone()),
352 new_text.chars(),
353 );
354 old_range.start += prefix_len;
355
356 let suffix_len = common_prefix(
357 snapshot.reversed_chars_for_range(old_range.clone()),
358 new_text[prefix_len..].chars().rev(),
359 );
360 old_range.end = old_range.end.saturating_sub(suffix_len);
361
362 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
363 let range = if old_range.is_empty() {
364 let anchor = snapshot.anchor_after(old_range.start);
365 anchor..anchor
366 } else {
367 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
368 };
369 (range, new_text)
370 })
371 .collect()
372}
373
374fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
375 a.zip(b)
376 .take_while(|(a, b)| a == b)
377 .map(|(a, _)| a.len_utf8())
378 .sum()
379}
380
381fn git_info_for_file(
382 project: &Entity<Project>,
383 project_path: &ProjectPath,
384 cx: &App,
385) -> Option<PredictEditsGitInfo> {
386 let git_store = project.read(cx).git_store().read(cx);
387 if let Some((repository, _repo_path)) =
388 git_store.repository_and_path_for_project_path(project_path, cx)
389 {
390 let repository = repository.read(cx);
391 let head_sha = repository
392 .head_commit
393 .as_ref()
394 .map(|head_commit| head_commit.sha.to_string());
395 let remote_origin_url = repository.remote_origin_url.clone();
396 let remote_upstream_url = repository.remote_upstream_url.clone();
397 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
398 return None;
399 }
400 Some(PredictEditsGitInfo {
401 head_sha,
402 remote_origin_url,
403 remote_upstream_url,
404 })
405 } else {
406 None
407 }
408}
409
410pub struct GatherContextOutput {
411 pub body: PredictEditsBody,
412 pub context_range: Range<Point>,
413 pub editable_range: Range<usize>,
414 pub included_events_count: usize,
415}
416
417pub fn gather_context(
418 full_path_str: String,
419 snapshot: &BufferSnapshot,
420 cursor_point: language::Point,
421 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
422 trigger: PredictEditsRequestTrigger,
423 cx: &App,
424) -> Task<Result<GatherContextOutput>> {
425 cx.background_spawn({
426 let snapshot = snapshot.clone();
427 async move {
428 let input_excerpt = excerpt_for_cursor_position(
429 cursor_point,
430 &full_path_str,
431 &snapshot,
432 MAX_REWRITE_TOKENS,
433 MAX_CONTEXT_TOKENS,
434 );
435 let (input_events, included_events_count) = prompt_for_events();
436 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
437
438 let body = PredictEditsBody {
439 input_events,
440 input_excerpt: input_excerpt.prompt,
441 can_collect_data: false,
442 diagnostic_groups: None,
443 git_info: None,
444 outline: None,
445 speculated_output: None,
446 trigger,
447 };
448
449 Ok(GatherContextOutput {
450 body,
451 context_range: input_excerpt.context_range,
452 editable_range,
453 included_events_count,
454 })
455 }
456 })
457}
458
459fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
460 let mut result = String::new();
461 for (ix, event) in events.iter().rev().enumerate() {
462 let event_string = format_event(event.as_ref());
463 let event_tokens = guess_token_count(event_string.len());
464 if event_tokens > remaining_tokens {
465 return (result, ix);
466 }
467
468 if !result.is_empty() {
469 result.insert_str(0, "\n\n");
470 }
471 result.insert_str(0, &event_string);
472 remaining_tokens -= event_tokens;
473 }
474 return (result, events.len());
475}
476
477pub fn format_event(event: &Event) -> String {
478 match event {
479 Event::BufferChange {
480 path,
481 old_path,
482 diff,
483 ..
484 } => {
485 let mut prompt = String::new();
486
487 if old_path != path {
488 writeln!(
489 prompt,
490 "User renamed {} to {}\n",
491 old_path.display(),
492 path.display()
493 )
494 .unwrap();
495 }
496
497 if !diff.is_empty() {
498 write!(
499 prompt,
500 "User edited {}:\n```diff\n{}\n```",
501 path.display(),
502 diff
503 )
504 .unwrap();
505 }
506
507 prompt
508 }
509 }
510}
511
512#[derive(Debug)]
513pub struct InputExcerpt {
514 pub context_range: Range<Point>,
515 pub editable_range: Range<Point>,
516 pub prompt: String,
517}
518
519pub fn excerpt_for_cursor_position(
520 position: Point,
521 path: &str,
522 snapshot: &BufferSnapshot,
523 editable_region_token_limit: usize,
524 context_token_limit: usize,
525) -> InputExcerpt {
526 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
527 position,
528 snapshot,
529 editable_region_token_limit,
530 context_token_limit,
531 );
532
533 let mut prompt = String::new();
534
535 writeln!(&mut prompt, "```{path}").unwrap();
536 if context_range.start == Point::zero() {
537 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
538 }
539
540 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
541 prompt.push_str(chunk.text);
542 }
543
544 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
545
546 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
547 prompt.push_str(chunk.text);
548 }
549 write!(prompt, "\n```").unwrap();
550
551 InputExcerpt {
552 context_range,
553 editable_range,
554 prompt,
555 }
556}
557
558fn push_editable_range(
559 cursor_position: Point,
560 snapshot: &BufferSnapshot,
561 editable_range: Range<Point>,
562 prompt: &mut String,
563) {
564 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
565 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
566 prompt.push_str(chunk.text);
567 }
568 prompt.push_str(CURSOR_MARKER);
569 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
570 prompt.push_str(chunk.text);
571 }
572 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use gpui::{App, AppContext};
579 use indoc::indoc;
580 use language::Buffer;
581
582 #[gpui::test]
583 fn test_excerpt_for_cursor_position(cx: &mut App) {
584 let text = indoc! {r#"
585 fn foo() {
586 let x = 42;
587 println!("Hello, world!");
588 }
589
590 fn bar() {
591 let x = 42;
592 let mut sum = 0;
593 for i in 0..x {
594 sum += i;
595 }
596 println!("Sum: {}", sum);
597 return sum;
598 }
599
600 fn generate_random_numbers() -> Vec<i32> {
601 let mut rng = rand::thread_rng();
602 let mut numbers = Vec::new();
603 for _ in 0..5 {
604 numbers.push(rng.random_range(1..101));
605 }
606 numbers
607 }
608 "#};
609 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
610 let snapshot = buffer.read(cx).snapshot();
611
612 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
613 // when a larger scope doesn't fit the editable region.
614 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
615 assert_eq!(
616 excerpt.prompt,
617 indoc! {r#"
618 ```main.rs
619 let x = 42;
620 println!("Hello, world!");
621 <|editable_region_start|>
622 }
623
624 fn bar() {
625 let x = 42;
626 let mut sum = 0;
627 for i in 0..x {
628 sum += i;
629 }
630 println!("Sum: {}", sum);
631 r<|user_cursor_is_here|>eturn sum;
632 }
633
634 fn generate_random_numbers() -> Vec<i32> {
635 <|editable_region_end|>
636 let mut rng = rand::thread_rng();
637 let mut numbers = Vec::new();
638 ```"#}
639 );
640
641 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
642 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
643 assert_eq!(
644 excerpt.prompt,
645 indoc! {r#"
646 ```main.rs
647 fn bar() {
648 let x = 42;
649 let mut sum = 0;
650 <|editable_region_start|>
651 for i in 0..x {
652 sum += i;
653 }
654 println!("Sum: {}", sum);
655 r<|user_cursor_is_here|>eturn sum;
656 }
657
658 fn generate_random_numbers() -> Vec<i32> {
659 let mut rng = rand::thread_rng();
660 <|editable_region_end|>
661 let mut numbers = Vec::new();
662 for _ in 0..5 {
663 numbers.push(rng.random_range(1..101));
664 ```"#}
665 );
666 }
667}