1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{Event, ZetaPromptInput};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 store: &mut EditPredictionStore,
33 EditPredictionModelInput {
34 project,
35 buffer,
36 snapshot,
37 position,
38 events,
39 trigger,
40 debug_tx,
41 ..
42 }: EditPredictionModelInput,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let buffer_snapshotted_at = Instant::now();
46 let client = store.client.clone();
47 let llm_token = store.llm_token.clone();
48 let app_version = AppVersion::global(cx);
49
50 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
51 let can_collect_file = store.can_collect_file(&project, file, cx);
52 let git_info = if can_collect_file {
53 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
54 } else {
55 None
56 };
57 (git_info, can_collect_file)
58 } else {
59 (None, false)
60 };
61
62 let full_path: Arc<Path> = snapshot
63 .file()
64 .map(|f| Arc::from(f.full_path(cx).as_path()))
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66 let full_path_str = full_path.to_string_lossy().into_owned();
67 let cursor_point = position.to_point(&snapshot);
68 let prompt_for_events = {
69 let events = events.clone();
70 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
71 };
72 let gather_task = gather_context(
73 full_path_str,
74 &snapshot,
75 cursor_point,
76 prompt_for_events,
77 trigger,
78 cx,
79 );
80
81 let (uri, require_auth) = match &store.custom_predict_edits_url {
82 Some(custom_url) => (custom_url.clone(), false),
83 None => {
84 match client
85 .http_client()
86 .build_zed_llm_url("/predict_edits/v2", &[])
87 {
88 Ok(url) => (url.into(), true),
89 Err(err) => return Task::ready(Err(err)),
90 }
91 }
92 };
93
94 cx.spawn(async move |this, cx| {
95 let GatherContextOutput {
96 mut body,
97 context_range,
98 editable_range,
99 included_events_count,
100 } = gather_task.await?;
101 let done_gathering_context_at = Instant::now();
102
103 let included_events = &events[events.len() - included_events_count..events.len()];
104 body.can_collect_data = can_collect_file
105 && this
106 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
107 .unwrap_or(false);
108 if body.can_collect_data {
109 body.git_info = git_info;
110 }
111
112 log::debug!(
113 "Events:\n{}\nExcerpt:\n{:?}",
114 body.input_events,
115 body.input_excerpt
116 );
117
118 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119 |request| {
120 Ok(request
121 .uri(uri.as_str())
122 .body(serde_json::to_string(&body)?.into())?)
123 },
124 client,
125 llm_token,
126 app_version,
127 require_auth,
128 )
129 .await;
130
131 let context_start_offset = context_range.start.to_offset(&snapshot);
132 let editable_offset_range = editable_range.to_offset(&snapshot);
133
134 let inputs = ZetaPromptInput {
135 events: included_events.into(),
136 related_files: vec![].into(),
137 cursor_path: full_path,
138 cursor_excerpt: snapshot
139 .text_for_range(context_range)
140 .collect::<String>()
141 .into(),
142 editable_range_in_excerpt: (editable_range.start - context_start_offset)
143 ..(editable_offset_range.end - context_start_offset),
144 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145 };
146
147 if let Some(debug_tx) = &debug_tx {
148 debug_tx
149 .unbounded_send(DebugEvent::EditPredictionStarted(
150 EditPredictionStartedDebugEvent {
151 buffer: buffer.downgrade(),
152 prompt: Some(serde_json::to_string(&inputs).unwrap()),
153 position,
154 },
155 ))
156 .ok();
157 }
158
159 let (response, usage) = match response {
160 Ok(response) => response,
161 Err(err) => {
162 if err.is::<ZedUpdateRequiredError>() {
163 cx.update(|cx| {
164 this.update(cx, |ep_store, _cx| {
165 ep_store.update_required = true;
166 })
167 .ok();
168
169 let error_message: SharedString = err.to_string().into();
170 show_app_notification(
171 NotificationId::unique::<ZedUpdateRequiredError>(),
172 cx,
173 move |cx| {
174 cx.new(|cx| {
175 ErrorMessagePrompt::new(error_message.clone(), cx)
176 .with_link_button("Update Zed", "https://zed.dev/releases")
177 })
178 },
179 );
180 });
181 }
182
183 return Err(err);
184 }
185 };
186
187 let received_response_at = Instant::now();
188 log::debug!("completion response: {}", &response.output_excerpt);
189
190 if let Some(usage) = usage {
191 this.update(cx, |this, cx| {
192 this.user_store.update(cx, |user_store, cx| {
193 user_store.update_edit_prediction_usage(usage, cx);
194 });
195 })
196 .ok();
197 }
198
199 if let Some(debug_tx) = &debug_tx {
200 debug_tx
201 .unbounded_send(DebugEvent::EditPredictionFinished(
202 EditPredictionFinishedDebugEvent {
203 buffer: buffer.downgrade(),
204 model_output: Some(response.output_excerpt.clone()),
205 position,
206 },
207 ))
208 .ok();
209 }
210
211 let edit_prediction = process_completion_response(
212 response,
213 buffer,
214 &snapshot,
215 editable_range,
216 inputs,
217 buffer_snapshotted_at,
218 received_response_at,
219 cx,
220 )
221 .await;
222
223 let finished_at = Instant::now();
224
225 // record latency for ~1% of requests
226 if rand::random::<u8>() <= 2 {
227 telemetry::event!(
228 "Edit Prediction Request",
229 context_latency = done_gathering_context_at
230 .duration_since(buffer_snapshotted_at)
231 .as_millis(),
232 request_latency = received_response_at
233 .duration_since(done_gathering_context_at)
234 .as_millis(),
235 process_latency = finished_at.duration_since(received_response_at).as_millis()
236 );
237 }
238
239 edit_prediction.map(Some)
240 })
241}
242
243fn process_completion_response(
244 prediction_response: PredictEditsResponse,
245 buffer: Entity<Buffer>,
246 snapshot: &BufferSnapshot,
247 editable_range: Range<usize>,
248 inputs: ZetaPromptInput,
249 buffer_snapshotted_at: Instant,
250 received_response_at: Instant,
251 cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253 let snapshot = snapshot.clone();
254 let request_id = prediction_response.request_id;
255 let output_excerpt = prediction_response.output_excerpt;
256 cx.spawn(async move |cx| {
257 let output_excerpt: Arc<str> = output_excerpt.into();
258
259 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260 .background_spawn({
261 let output_excerpt = output_excerpt.clone();
262 let editable_range = editable_range.clone();
263 let snapshot = snapshot.clone();
264 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
265 })
266 .await?
267 .into();
268
269 let id = EditPredictionId(request_id.into());
270 Ok(EditPredictionResult::new(
271 id,
272 &buffer,
273 &snapshot,
274 edits,
275 buffer_snapshotted_at,
276 received_response_at,
277 inputs,
278 cx,
279 )
280 .await)
281 })
282}
283
284fn parse_edits(
285 output_excerpt: Arc<str>,
286 editable_range: Range<usize>,
287 snapshot: &BufferSnapshot,
288) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
289 let content = output_excerpt.replace(CURSOR_MARKER, "");
290
291 let start_markers = content
292 .match_indices(EDITABLE_REGION_START_MARKER)
293 .collect::<Vec<_>>();
294 anyhow::ensure!(
295 start_markers.len() == 1,
296 "expected exactly one start marker, found {}",
297 start_markers.len()
298 );
299
300 let end_markers = content
301 .match_indices(EDITABLE_REGION_END_MARKER)
302 .collect::<Vec<_>>();
303 anyhow::ensure!(
304 end_markers.len() == 1,
305 "expected exactly one end marker, found {}",
306 end_markers.len()
307 );
308
309 let sof_markers = content
310 .match_indices(START_OF_FILE_MARKER)
311 .collect::<Vec<_>>();
312 anyhow::ensure!(
313 sof_markers.len() <= 1,
314 "expected at most one start-of-file marker, found {}",
315 sof_markers.len()
316 );
317
318 let codefence_start = start_markers[0].0;
319 let content = &content[codefence_start..];
320
321 let newline_ix = content.find('\n').context("could not find newline")?;
322 let content = &content[newline_ix + 1..];
323
324 let codefence_end = content
325 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
326 .context("could not find end marker")?;
327 let new_text = &content[..codefence_end];
328
329 let old_text = snapshot
330 .text_for_range(editable_range.clone())
331 .collect::<String>();
332
333 Ok(compute_edits(
334 old_text,
335 new_text,
336 editable_range.start,
337 snapshot,
338 ))
339}
340
341pub fn compute_edits(
342 old_text: String,
343 new_text: &str,
344 offset: usize,
345 snapshot: &BufferSnapshot,
346) -> Vec<(Range<Anchor>, Arc<str>)> {
347 text_diff(&old_text, new_text)
348 .into_iter()
349 .map(|(mut old_range, new_text)| {
350 old_range.start += offset;
351 old_range.end += offset;
352
353 let prefix_len = common_prefix(
354 snapshot.chars_for_range(old_range.clone()),
355 new_text.chars(),
356 );
357 old_range.start += prefix_len;
358
359 let suffix_len = common_prefix(
360 snapshot.reversed_chars_for_range(old_range.clone()),
361 new_text[prefix_len..].chars().rev(),
362 );
363 old_range.end = old_range.end.saturating_sub(suffix_len);
364
365 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
366 let range = if old_range.is_empty() {
367 let anchor = snapshot.anchor_after(old_range.start);
368 anchor..anchor
369 } else {
370 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
371 };
372 (range, new_text)
373 })
374 .collect()
375}
376
377fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
378 a.zip(b)
379 .take_while(|(a, b)| a == b)
380 .map(|(a, _)| a.len_utf8())
381 .sum()
382}
383
384fn git_info_for_file(
385 project: &Entity<Project>,
386 project_path: &ProjectPath,
387 cx: &App,
388) -> Option<PredictEditsGitInfo> {
389 let git_store = project.read(cx).git_store().read(cx);
390 if let Some((repository, _repo_path)) =
391 git_store.repository_and_path_for_project_path(project_path, cx)
392 {
393 let repository = repository.read(cx);
394 let head_sha = repository
395 .head_commit
396 .as_ref()
397 .map(|head_commit| head_commit.sha.to_string());
398 let remote_origin_url = repository.remote_origin_url.clone();
399 let remote_upstream_url = repository.remote_upstream_url.clone();
400 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
401 return None;
402 }
403 Some(PredictEditsGitInfo {
404 head_sha,
405 remote_origin_url,
406 remote_upstream_url,
407 })
408 } else {
409 None
410 }
411}
412
413pub struct GatherContextOutput {
414 pub body: PredictEditsBody,
415 pub context_range: Range<Point>,
416 pub editable_range: Range<usize>,
417 pub included_events_count: usize,
418}
419
420pub fn gather_context(
421 full_path_str: String,
422 snapshot: &BufferSnapshot,
423 cursor_point: language::Point,
424 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
425 trigger: PredictEditsRequestTrigger,
426 cx: &App,
427) -> Task<Result<GatherContextOutput>> {
428 cx.background_spawn({
429 let snapshot = snapshot.clone();
430 async move {
431 let input_excerpt = excerpt_for_cursor_position(
432 cursor_point,
433 &full_path_str,
434 &snapshot,
435 MAX_REWRITE_TOKENS,
436 MAX_CONTEXT_TOKENS,
437 );
438 let (input_events, included_events_count) = prompt_for_events();
439 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
440
441 let body = PredictEditsBody {
442 input_events,
443 input_excerpt: input_excerpt.prompt,
444 can_collect_data: false,
445 diagnostic_groups: None,
446 git_info: None,
447 outline: None,
448 speculated_output: None,
449 trigger,
450 };
451
452 Ok(GatherContextOutput {
453 body,
454 context_range: input_excerpt.context_range,
455 editable_range,
456 included_events_count,
457 })
458 }
459 })
460}
461
462fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
463 let mut result = String::new();
464 for (ix, event) in events.iter().rev().enumerate() {
465 let event_string = format_event(event.as_ref());
466 let event_tokens = guess_token_count(event_string.len());
467 if event_tokens > remaining_tokens {
468 return (result, ix);
469 }
470
471 if !result.is_empty() {
472 result.insert_str(0, "\n\n");
473 }
474 result.insert_str(0, &event_string);
475 remaining_tokens -= event_tokens;
476 }
477 return (result, events.len());
478}
479
480pub fn format_event(event: &Event) -> String {
481 match event {
482 Event::BufferChange {
483 path,
484 old_path,
485 diff,
486 ..
487 } => {
488 let mut prompt = String::new();
489
490 if old_path != path {
491 writeln!(
492 prompt,
493 "User renamed {} to {}\n",
494 old_path.display(),
495 path.display()
496 )
497 .unwrap();
498 }
499
500 if !diff.is_empty() {
501 write!(
502 prompt,
503 "User edited {}:\n```diff\n{}\n```",
504 path.display(),
505 diff
506 )
507 .unwrap();
508 }
509
510 prompt
511 }
512 }
513}
514
515#[derive(Debug)]
516pub struct InputExcerpt {
517 pub context_range: Range<Point>,
518 pub editable_range: Range<Point>,
519 pub prompt: String,
520}
521
522pub fn excerpt_for_cursor_position(
523 position: Point,
524 path: &str,
525 snapshot: &BufferSnapshot,
526 editable_region_token_limit: usize,
527 context_token_limit: usize,
528) -> InputExcerpt {
529 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
530 position,
531 snapshot,
532 editable_region_token_limit,
533 context_token_limit,
534 );
535
536 let mut prompt = String::new();
537
538 writeln!(&mut prompt, "```{path}").unwrap();
539 if context_range.start == Point::zero() {
540 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
541 }
542
543 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
544 prompt.push_str(chunk.text);
545 }
546
547 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
548
549 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
550 prompt.push_str(chunk.text);
551 }
552 write!(prompt, "\n```").unwrap();
553
554 InputExcerpt {
555 context_range,
556 editable_range,
557 prompt,
558 }
559}
560
561fn push_editable_range(
562 cursor_position: Point,
563 snapshot: &BufferSnapshot,
564 editable_range: Range<Point>,
565 prompt: &mut String,
566) {
567 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
568 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
569 prompt.push_str(chunk.text);
570 }
571 prompt.push_str(CURSOR_MARKER);
572 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
573 prompt.push_str(chunk.text);
574 }
575 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use gpui::{App, AppContext};
582 use indoc::indoc;
583 use language::Buffer;
584
585 #[gpui::test]
586 fn test_excerpt_for_cursor_position(cx: &mut App) {
587 let text = indoc! {r#"
588 fn foo() {
589 let x = 42;
590 println!("Hello, world!");
591 }
592
593 fn bar() {
594 let x = 42;
595 let mut sum = 0;
596 for i in 0..x {
597 sum += i;
598 }
599 println!("Sum: {}", sum);
600 return sum;
601 }
602
603 fn generate_random_numbers() -> Vec<i32> {
604 let mut rng = rand::thread_rng();
605 let mut numbers = Vec::new();
606 for _ in 0..5 {
607 numbers.push(rng.random_range(1..101));
608 }
609 numbers
610 }
611 "#};
612 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
613 let snapshot = buffer.read(cx).snapshot();
614
615 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
616 // when a larger scope doesn't fit the editable region.
617 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
618 assert_eq!(
619 excerpt.prompt,
620 indoc! {r#"
621 ```main.rs
622 let x = 42;
623 println!("Hello, world!");
624 <|editable_region_start|>
625 }
626
627 fn bar() {
628 let x = 42;
629 let mut sum = 0;
630 for i in 0..x {
631 sum += i;
632 }
633 println!("Sum: {}", sum);
634 r<|user_cursor_is_here|>eturn sum;
635 }
636
637 fn generate_random_numbers() -> Vec<i32> {
638 <|editable_region_end|>
639 let mut rng = rand::thread_rng();
640 let mut numbers = Vec::new();
641 ```"#}
642 );
643
644 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
645 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
646 assert_eq!(
647 excerpt.prompt,
648 indoc! {r#"
649 ```main.rs
650 fn bar() {
651 let x = 42;
652 let mut sum = 0;
653 <|editable_region_start|>
654 for i in 0..x {
655 sum += i;
656 }
657 println!("Sum: {}", sum);
658 r<|user_cursor_is_here|>eturn sum;
659 }
660
661 fn generate_random_numbers() -> Vec<i32> {
662 let mut rng = rand::thread_rng();
663 <|editable_region_end|>
664 let mut numbers = Vec::new();
665 for _ in 0..5 {
666 numbers.push(rng.random_range(1..101));
667 ```"#}
668 );
669 }
670}