1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
5 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
6 prediction::{EditPredictionInputs, EditPredictionResult},
7};
8use anyhow::{Context as _, Result};
9use cloud_llm_client::{
10 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
11 predict_edits_v3::Event,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20
21const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
22const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
23const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
24const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
25
26pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
27pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
28pub(crate) const MAX_EVENT_TOKENS: usize = 500;
29
30pub(crate) fn request_prediction_with_zeta1(
31 store: &mut EditPredictionStore,
32 project: &Entity<Project>,
33 buffer: &Entity<Buffer>,
34 snapshot: BufferSnapshot,
35 position: language::Anchor,
36 events: Vec<Arc<Event>>,
37 trigger: PredictEditsRequestTrigger,
38 cx: &mut Context<EditPredictionStore>,
39) -> Task<Result<Option<EditPredictionResult>>> {
40 let buffer = buffer.clone();
41 let buffer_snapshotted_at = Instant::now();
42 let client = store.client.clone();
43 let llm_token = store.llm_token.clone();
44 let app_version = AppVersion::global(cx);
45
46 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
47 let can_collect_file = store.can_collect_file(project, file, cx);
48 let git_info = if can_collect_file {
49 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
50 } else {
51 None
52 };
53 (git_info, can_collect_file)
54 } else {
55 (None, false)
56 };
57
58 let full_path: Arc<Path> = snapshot
59 .file()
60 .map(|f| Arc::from(f.full_path(cx).as_path()))
61 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
62 let full_path_str = full_path.to_string_lossy().into_owned();
63 let cursor_point = position.to_point(&snapshot);
64 let prompt_for_events = {
65 let events = events.clone();
66 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
67 };
68 let gather_task = gather_context(
69 full_path_str,
70 &snapshot,
71 cursor_point,
72 prompt_for_events,
73 trigger,
74 cx,
75 );
76
77 cx.spawn(async move |this, cx| {
78 let GatherContextOutput {
79 mut body,
80 context_range,
81 editable_range,
82 included_events_count,
83 } = gather_task.await?;
84 let done_gathering_context_at = Instant::now();
85
86 let included_events = &events[events.len() - included_events_count..events.len()];
87 body.can_collect_data = can_collect_file
88 && this
89 .read_with(cx, |this, _| this.can_collect_events(included_events))
90 .unwrap_or(false);
91 if body.can_collect_data {
92 body.git_info = git_info;
93 }
94
95 log::debug!(
96 "Events:\n{}\nExcerpt:\n{:?}",
97 body.input_events,
98 body.input_excerpt
99 );
100
101 let http_client = client.http_client();
102
103 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
104 |request| {
105 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
106 predict_edits_url
107 } else {
108 http_client
109 .build_zed_llm_url("/predict_edits/v2", &[])?
110 .as_str()
111 .into()
112 };
113 Ok(request
114 .uri(uri)
115 .body(serde_json::to_string(&body)?.into())?)
116 },
117 client,
118 llm_token,
119 app_version,
120 )
121 .await;
122
123 let inputs = EditPredictionInputs {
124 events: included_events.into(),
125 included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
126 path: full_path.clone(),
127 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
128 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
129 start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
130 text: snapshot
131 .text_for_range(context_range)
132 .collect::<String>()
133 .into(),
134 }],
135 }],
136 cursor_point: cloud_llm_client::predict_edits_v3::Point {
137 column: cursor_point.column,
138 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
139 },
140 cursor_path: full_path,
141 };
142
143 // let response = perform_predict_edits(PerformPredictEditsParams {
144 // client,
145 // llm_token,
146 // app_version,
147 // body,
148 // })
149 // .await;
150
151 let (response, usage) = match response {
152 Ok(response) => response,
153 Err(err) => {
154 if err.is::<ZedUpdateRequiredError>() {
155 cx.update(|cx| {
156 this.update(cx, |ep_store, _cx| {
157 ep_store.update_required = true;
158 })
159 .ok();
160
161 let error_message: SharedString = err.to_string().into();
162 show_app_notification(
163 NotificationId::unique::<ZedUpdateRequiredError>(),
164 cx,
165 move |cx| {
166 cx.new(|cx| {
167 ErrorMessagePrompt::new(error_message.clone(), cx)
168 .with_link_button("Update Zed", "https://zed.dev/releases")
169 })
170 },
171 );
172 })
173 .ok();
174 }
175
176 return Err(err);
177 }
178 };
179
180 let received_response_at = Instant::now();
181 log::debug!("completion response: {}", &response.output_excerpt);
182
183 if let Some(usage) = usage {
184 this.update(cx, |this, cx| {
185 this.user_store.update(cx, |user_store, cx| {
186 user_store.update_edit_prediction_usage(usage, cx);
187 });
188 })
189 .ok();
190 }
191
192 let edit_prediction = process_completion_response(
193 response,
194 buffer,
195 &snapshot,
196 editable_range,
197 inputs,
198 buffer_snapshotted_at,
199 received_response_at,
200 cx,
201 )
202 .await;
203
204 let finished_at = Instant::now();
205
206 // record latency for ~1% of requests
207 if rand::random::<u8>() <= 2 {
208 telemetry::event!(
209 "Edit Prediction Request",
210 context_latency = done_gathering_context_at
211 .duration_since(buffer_snapshotted_at)
212 .as_millis(),
213 request_latency = received_response_at
214 .duration_since(done_gathering_context_at)
215 .as_millis(),
216 process_latency = finished_at.duration_since(received_response_at).as_millis()
217 );
218 }
219
220 edit_prediction.map(Some)
221 })
222}
223
224fn process_completion_response(
225 prediction_response: PredictEditsResponse,
226 buffer: Entity<Buffer>,
227 snapshot: &BufferSnapshot,
228 editable_range: Range<usize>,
229 inputs: EditPredictionInputs,
230 buffer_snapshotted_at: Instant,
231 received_response_at: Instant,
232 cx: &AsyncApp,
233) -> Task<Result<EditPredictionResult>> {
234 let snapshot = snapshot.clone();
235 let request_id = prediction_response.request_id;
236 let output_excerpt = prediction_response.output_excerpt;
237 cx.spawn(async move |cx| {
238 let output_excerpt: Arc<str> = output_excerpt.into();
239
240 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
241 .background_spawn({
242 let output_excerpt = output_excerpt.clone();
243 let editable_range = editable_range.clone();
244 let snapshot = snapshot.clone();
245 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
246 })
247 .await?
248 .into();
249
250 let id = EditPredictionId(request_id.into());
251 Ok(EditPredictionResult::new(
252 id,
253 &buffer,
254 &snapshot,
255 edits,
256 buffer_snapshotted_at,
257 received_response_at,
258 inputs,
259 cx,
260 )
261 .await)
262 })
263}
264
265fn parse_edits(
266 output_excerpt: Arc<str>,
267 editable_range: Range<usize>,
268 snapshot: &BufferSnapshot,
269) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
270 let content = output_excerpt.replace(CURSOR_MARKER, "");
271
272 let start_markers = content
273 .match_indices(EDITABLE_REGION_START_MARKER)
274 .collect::<Vec<_>>();
275 anyhow::ensure!(
276 start_markers.len() == 1,
277 "expected exactly one start marker, found {}",
278 start_markers.len()
279 );
280
281 let end_markers = content
282 .match_indices(EDITABLE_REGION_END_MARKER)
283 .collect::<Vec<_>>();
284 anyhow::ensure!(
285 end_markers.len() == 1,
286 "expected exactly one end marker, found {}",
287 end_markers.len()
288 );
289
290 let sof_markers = content
291 .match_indices(START_OF_FILE_MARKER)
292 .collect::<Vec<_>>();
293 anyhow::ensure!(
294 sof_markers.len() <= 1,
295 "expected at most one start-of-file marker, found {}",
296 sof_markers.len()
297 );
298
299 let codefence_start = start_markers[0].0;
300 let content = &content[codefence_start..];
301
302 let newline_ix = content.find('\n').context("could not find newline")?;
303 let content = &content[newline_ix + 1..];
304
305 let codefence_end = content
306 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
307 .context("could not find end marker")?;
308 let new_text = &content[..codefence_end];
309
310 let old_text = snapshot
311 .text_for_range(editable_range.clone())
312 .collect::<String>();
313
314 Ok(compute_edits(
315 old_text,
316 new_text,
317 editable_range.start,
318 snapshot,
319 ))
320}
321
322pub fn compute_edits(
323 old_text: String,
324 new_text: &str,
325 offset: usize,
326 snapshot: &BufferSnapshot,
327) -> Vec<(Range<Anchor>, Arc<str>)> {
328 text_diff(&old_text, new_text)
329 .into_iter()
330 .map(|(mut old_range, new_text)| {
331 old_range.start += offset;
332 old_range.end += offset;
333
334 let prefix_len = common_prefix(
335 snapshot.chars_for_range(old_range.clone()),
336 new_text.chars(),
337 );
338 old_range.start += prefix_len;
339
340 let suffix_len = common_prefix(
341 snapshot.reversed_chars_for_range(old_range.clone()),
342 new_text[prefix_len..].chars().rev(),
343 );
344 old_range.end = old_range.end.saturating_sub(suffix_len);
345
346 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
347 let range = if old_range.is_empty() {
348 let anchor = snapshot.anchor_after(old_range.start);
349 anchor..anchor
350 } else {
351 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
352 };
353 (range, new_text)
354 })
355 .collect()
356}
357
358fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
359 a.zip(b)
360 .take_while(|(a, b)| a == b)
361 .map(|(a, _)| a.len_utf8())
362 .sum()
363}
364
365fn git_info_for_file(
366 project: &Entity<Project>,
367 project_path: &ProjectPath,
368 cx: &App,
369) -> Option<PredictEditsGitInfo> {
370 let git_store = project.read(cx).git_store().read(cx);
371 if let Some((repository, _repo_path)) =
372 git_store.repository_and_path_for_project_path(project_path, cx)
373 {
374 let repository = repository.read(cx);
375 let head_sha = repository
376 .head_commit
377 .as_ref()
378 .map(|head_commit| head_commit.sha.to_string());
379 let remote_origin_url = repository.remote_origin_url.clone();
380 let remote_upstream_url = repository.remote_upstream_url.clone();
381 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
382 return None;
383 }
384 Some(PredictEditsGitInfo {
385 head_sha,
386 remote_origin_url,
387 remote_upstream_url,
388 })
389 } else {
390 None
391 }
392}
393
394pub struct GatherContextOutput {
395 pub body: PredictEditsBody,
396 pub context_range: Range<Point>,
397 pub editable_range: Range<usize>,
398 pub included_events_count: usize,
399}
400
401pub fn gather_context(
402 full_path_str: String,
403 snapshot: &BufferSnapshot,
404 cursor_point: language::Point,
405 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
406 trigger: PredictEditsRequestTrigger,
407 cx: &App,
408) -> Task<Result<GatherContextOutput>> {
409 cx.background_spawn({
410 let snapshot = snapshot.clone();
411 async move {
412 let input_excerpt = excerpt_for_cursor_position(
413 cursor_point,
414 &full_path_str,
415 &snapshot,
416 MAX_REWRITE_TOKENS,
417 MAX_CONTEXT_TOKENS,
418 );
419 let (input_events, included_events_count) = prompt_for_events();
420 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
421
422 let body = PredictEditsBody {
423 input_events,
424 input_excerpt: input_excerpt.prompt,
425 can_collect_data: false,
426 diagnostic_groups: None,
427 git_info: None,
428 outline: None,
429 speculated_output: None,
430 trigger,
431 };
432
433 Ok(GatherContextOutput {
434 body,
435 context_range: input_excerpt.context_range,
436 editable_range,
437 included_events_count,
438 })
439 }
440 })
441}
442
443fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
444 let mut result = String::new();
445 for (ix, event) in events.iter().rev().enumerate() {
446 let event_string = format_event(event.as_ref());
447 let event_tokens = guess_token_count(event_string.len());
448 if event_tokens > remaining_tokens {
449 return (result, ix);
450 }
451
452 if !result.is_empty() {
453 result.insert_str(0, "\n\n");
454 }
455 result.insert_str(0, &event_string);
456 remaining_tokens -= event_tokens;
457 }
458 return (result, events.len());
459}
460
461pub fn format_event(event: &Event) -> String {
462 match event {
463 Event::BufferChange {
464 path,
465 old_path,
466 diff,
467 ..
468 } => {
469 let mut prompt = String::new();
470
471 if old_path != path {
472 writeln!(
473 prompt,
474 "User renamed {} to {}\n",
475 old_path.display(),
476 path.display()
477 )
478 .unwrap();
479 }
480
481 if !diff.is_empty() {
482 write!(
483 prompt,
484 "User edited {}:\n```diff\n{}\n```",
485 path.display(),
486 diff
487 )
488 .unwrap();
489 }
490
491 prompt
492 }
493 }
494}
495
496#[derive(Debug)]
497pub struct InputExcerpt {
498 pub context_range: Range<Point>,
499 pub editable_range: Range<Point>,
500 pub prompt: String,
501}
502
503pub fn excerpt_for_cursor_position(
504 position: Point,
505 path: &str,
506 snapshot: &BufferSnapshot,
507 editable_region_token_limit: usize,
508 context_token_limit: usize,
509) -> InputExcerpt {
510 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
511 position,
512 snapshot,
513 editable_region_token_limit,
514 context_token_limit,
515 );
516
517 let mut prompt = String::new();
518
519 writeln!(&mut prompt, "```{path}").unwrap();
520 if context_range.start == Point::zero() {
521 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
522 }
523
524 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
525 prompt.push_str(chunk.text);
526 }
527
528 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
529
530 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
531 prompt.push_str(chunk.text);
532 }
533 write!(prompt, "\n```").unwrap();
534
535 InputExcerpt {
536 context_range,
537 editable_range,
538 prompt,
539 }
540}
541
542fn push_editable_range(
543 cursor_position: Point,
544 snapshot: &BufferSnapshot,
545 editable_range: Range<Point>,
546 prompt: &mut String,
547) {
548 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
549 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
550 prompt.push_str(chunk.text);
551 }
552 prompt.push_str(CURSOR_MARKER);
553 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
554 prompt.push_str(chunk.text);
555 }
556 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use gpui::{App, AppContext};
563 use indoc::indoc;
564 use language::Buffer;
565
566 #[gpui::test]
567 fn test_excerpt_for_cursor_position(cx: &mut App) {
568 let text = indoc! {r#"
569 fn foo() {
570 let x = 42;
571 println!("Hello, world!");
572 }
573
574 fn bar() {
575 let x = 42;
576 let mut sum = 0;
577 for i in 0..x {
578 sum += i;
579 }
580 println!("Sum: {}", sum);
581 return sum;
582 }
583
584 fn generate_random_numbers() -> Vec<i32> {
585 let mut rng = rand::thread_rng();
586 let mut numbers = Vec::new();
587 for _ in 0..5 {
588 numbers.push(rng.random_range(1..101));
589 }
590 numbers
591 }
592 "#};
593 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
594 let snapshot = buffer.read(cx).snapshot();
595
596 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
597 // when a larger scope doesn't fit the editable region.
598 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
599 assert_eq!(
600 excerpt.prompt,
601 indoc! {r#"
602 ```main.rs
603 let x = 42;
604 println!("Hello, world!");
605 <|editable_region_start|>
606 }
607
608 fn bar() {
609 let x = 42;
610 let mut sum = 0;
611 for i in 0..x {
612 sum += i;
613 }
614 println!("Sum: {}", sum);
615 r<|user_cursor_is_here|>eturn sum;
616 }
617
618 fn generate_random_numbers() -> Vec<i32> {
619 <|editable_region_end|>
620 let mut rng = rand::thread_rng();
621 let mut numbers = Vec::new();
622 ```"#}
623 );
624
625 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
626 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
627 assert_eq!(
628 excerpt.prompt,
629 indoc! {r#"
630 ```main.rs
631 fn bar() {
632 let x = 42;
633 let mut sum = 0;
634 <|editable_region_start|>
635 for i in 0..x {
636 sum += i;
637 }
638 println!("Sum: {}", sum);
639 r<|user_cursor_is_here|>eturn sum;
640 }
641
642 fn generate_random_numbers() -> Vec<i32> {
643 let mut rng = rand::thread_rng();
644 <|editable_region_end|>
645 let mut numbers = Vec::new();
646 for _ in 0..5 {
647 numbers.push(rng.random_range(1..101));
648 ```"#}
649 );
650 }
651}