1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{Event, ZetaPromptInput};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 store: &mut EditPredictionStore,
33 EditPredictionModelInput {
34 project,
35 buffer,
36 snapshot,
37 position,
38 events,
39 trigger,
40 debug_tx,
41 ..
42 }: EditPredictionModelInput,
43 cx: &mut Context<EditPredictionStore>,
44) -> Task<Result<Option<EditPredictionResult>>> {
45 let buffer_snapshotted_at = Instant::now();
46 let client = store.client.clone();
47 let llm_token = store.llm_token.clone();
48 let app_version = AppVersion::global(cx);
49
50 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
51 let can_collect_file = store.can_collect_file(&project, file, cx);
52 let git_info = if can_collect_file {
53 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
54 } else {
55 None
56 };
57 (git_info, can_collect_file)
58 } else {
59 (None, false)
60 };
61
62 let full_path: Arc<Path> = snapshot
63 .file()
64 .map(|f| Arc::from(f.full_path(cx).as_path()))
65 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
66 let full_path_str = full_path.to_string_lossy().into_owned();
67 let cursor_point = position.to_point(&snapshot);
68 let prompt_for_events = {
69 let events = events.clone();
70 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
71 };
72 let gather_task = gather_context(
73 full_path_str,
74 &snapshot,
75 cursor_point,
76 prompt_for_events,
77 trigger,
78 cx,
79 );
80
81 let (uri, require_auth) = match &store.custom_predict_edits_url {
82 Some(custom_url) => (custom_url.clone(), false),
83 None => {
84 match client
85 .http_client()
86 .build_zed_llm_url("/predict_edits/v2", &[])
87 {
88 Ok(url) => (url.into(), true),
89 Err(err) => return Task::ready(Err(err)),
90 }
91 }
92 };
93
94 cx.spawn(async move |this, cx| {
95 let GatherContextOutput {
96 mut body,
97 context_range,
98 editable_range,
99 included_events_count,
100 } = gather_task.await?;
101 let done_gathering_context_at = Instant::now();
102
103 let included_events = &events[events.len() - included_events_count..events.len()];
104 body.can_collect_data = can_collect_file
105 && this
106 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
107 .unwrap_or(false);
108 if body.can_collect_data {
109 body.git_info = git_info;
110 }
111
112 log::debug!(
113 "Events:\n{}\nExcerpt:\n{:?}",
114 body.input_events,
115 body.input_excerpt
116 );
117
118 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119 |request| {
120 Ok(request
121 .uri(uri.as_str())
122 .body(serde_json::to_string(&body)?.into())?)
123 },
124 client,
125 llm_token,
126 app_version,
127 require_auth,
128 )
129 .await;
130
131 let context_start_offset = context_range.start.to_offset(&snapshot);
132 let editable_offset_range = editable_range.to_offset(&snapshot);
133
134 let inputs = ZetaPromptInput {
135 events: included_events.into(),
136 related_files: vec![],
137 cursor_path: full_path,
138 cursor_excerpt: snapshot
139 .text_for_range(context_range)
140 .collect::<String>()
141 .into(),
142 editable_range_in_excerpt: (editable_range.start - context_start_offset)
143 ..(editable_offset_range.end - context_start_offset),
144 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145 };
146
147 if let Some(debug_tx) = &debug_tx {
148 debug_tx
149 .unbounded_send(DebugEvent::EditPredictionStarted(
150 EditPredictionStartedDebugEvent {
151 buffer: buffer.downgrade(),
152 prompt: Some(serde_json::to_string(&inputs).unwrap()),
153 position,
154 },
155 ))
156 .ok();
157 }
158
159 let (response, usage) = match response {
160 Ok(response) => response,
161 Err(err) => {
162 if err.is::<ZedUpdateRequiredError>() {
163 cx.update(|cx| {
164 this.update(cx, |ep_store, _cx| {
165 ep_store.update_required = true;
166 })
167 .ok();
168
169 let error_message: SharedString = err.to_string().into();
170 show_app_notification(
171 NotificationId::unique::<ZedUpdateRequiredError>(),
172 cx,
173 move |cx| {
174 cx.new(|cx| {
175 ErrorMessagePrompt::new(error_message.clone(), cx)
176 .with_link_button("Update Zed", "https://zed.dev/releases")
177 })
178 },
179 );
180 });
181 }
182
183 return Err(err);
184 }
185 };
186
187 let received_response_at = Instant::now();
188 log::debug!("completion response: {}", &response.output_excerpt);
189
190 if let Some(usage) = usage {
191 this.update(cx, |this, cx| {
192 this.user_store.update(cx, |user_store, cx| {
193 user_store.update_edit_prediction_usage(usage, cx);
194 });
195 })
196 .ok();
197 }
198
199 if let Some(debug_tx) = &debug_tx {
200 debug_tx
201 .unbounded_send(DebugEvent::EditPredictionFinished(
202 EditPredictionFinishedDebugEvent {
203 buffer: buffer.downgrade(),
204 model_output: Some(response.output_excerpt.clone()),
205 position,
206 },
207 ))
208 .ok();
209 }
210
211 let edit_prediction = process_completion_response(
212 response,
213 buffer,
214 &snapshot,
215 editable_range,
216 inputs,
217 buffer_snapshotted_at,
218 received_response_at,
219 cx,
220 )
221 .await;
222
223 let finished_at = Instant::now();
224
225 // record latency for ~1% of requests
226 if rand::random::<u8>() <= 2 {
227 telemetry::event!(
228 "Edit Prediction Request",
229 context_latency = done_gathering_context_at
230 .duration_since(buffer_snapshotted_at)
231 .as_millis(),
232 request_latency = received_response_at
233 .duration_since(done_gathering_context_at)
234 .as_millis(),
235 process_latency = finished_at.duration_since(received_response_at).as_millis()
236 );
237 }
238
239 edit_prediction.map(Some)
240 })
241}
242
243fn process_completion_response(
244 prediction_response: PredictEditsResponse,
245 buffer: Entity<Buffer>,
246 snapshot: &BufferSnapshot,
247 editable_range: Range<usize>,
248 inputs: ZetaPromptInput,
249 buffer_snapshotted_at: Instant,
250 received_response_at: Instant,
251 cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253 let snapshot = snapshot.clone();
254 let request_id = prediction_response.request_id;
255 let output_excerpt = prediction_response.output_excerpt;
256 cx.spawn(async move |cx| {
257 let output_excerpt: Arc<str> = output_excerpt.into();
258
259 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260 .background_spawn({
261 let output_excerpt = output_excerpt.clone();
262 let editable_range = editable_range.clone();
263 let snapshot = snapshot.clone();
264 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
265 })
266 .await?
267 .into();
268
269 let id = EditPredictionId(request_id.into());
270 Ok(EditPredictionResult::new(
271 id,
272 &buffer,
273 &snapshot,
274 edits,
275 buffer_snapshotted_at,
276 received_response_at,
277 inputs,
278 cx,
279 )
280 .await)
281 })
282}
283
284fn parse_edits(
285 output_excerpt: Arc<str>,
286 editable_range: Range<usize>,
287 snapshot: &BufferSnapshot,
288) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
289 let content = output_excerpt.replace(CURSOR_MARKER, "");
290
291 let start_markers = content
292 .match_indices(EDITABLE_REGION_START_MARKER)
293 .collect::<Vec<_>>();
294 anyhow::ensure!(
295 start_markers.len() == 1,
296 "expected exactly one start marker, found {}",
297 start_markers.len()
298 );
299
300 let end_markers = content
301 .match_indices(EDITABLE_REGION_END_MARKER)
302 .collect::<Vec<_>>();
303 anyhow::ensure!(
304 end_markers.len() == 1,
305 "expected exactly one end marker, found {}",
306 end_markers.len()
307 );
308
309 let sof_markers = content
310 .match_indices(START_OF_FILE_MARKER)
311 .collect::<Vec<_>>();
312 anyhow::ensure!(
313 sof_markers.len() <= 1,
314 "expected at most one start-of-file marker, found {}",
315 sof_markers.len()
316 );
317
318 let codefence_start = start_markers[0].0;
319 let content = &content[codefence_start..];
320
321 let newline_ix = content.find('\n').context("could not find newline")?;
322 let content = &content[newline_ix + 1..];
323
324 let codefence_end = content
325 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
326 .context("could not find end marker")?;
327 let new_text = &content[..codefence_end];
328
329 let old_text = snapshot
330 .text_for_range(editable_range.clone())
331 .collect::<String>();
332
333 Ok(compute_edits(
334 old_text,
335 new_text,
336 editable_range.start,
337 snapshot,
338 ))
339}
340
341pub fn compute_edits(
342 old_text: String,
343 new_text: &str,
344 offset: usize,
345 snapshot: &BufferSnapshot,
346) -> Vec<(Range<Anchor>, Arc<str>)> {
347 text_diff(&old_text, new_text)
348 .into_iter()
349 .map(|(mut old_range, new_text)| {
350 let old_slice = &old_text[old_range.clone()];
351
352 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
353 let suffix_len = common_prefix(
354 old_slice[prefix_len..].chars().rev(),
355 new_text[prefix_len..].chars().rev(),
356 );
357
358 old_range.start += offset;
359 old_range.end += offset;
360 old_range.start += prefix_len;
361 old_range.end -= suffix_len;
362
363 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
364 let range = if old_range.is_empty() {
365 let anchor = snapshot.anchor_after(old_range.start);
366 anchor..anchor
367 } else {
368 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
369 };
370 (range, new_text)
371 })
372 .collect()
373}
374
375fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
376 a.zip(b)
377 .take_while(|(a, b)| a == b)
378 .map(|(a, _)| a.len_utf8())
379 .sum()
380}
381
382fn git_info_for_file(
383 project: &Entity<Project>,
384 project_path: &ProjectPath,
385 cx: &App,
386) -> Option<PredictEditsGitInfo> {
387 let git_store = project.read(cx).git_store().read(cx);
388 if let Some((repository, _repo_path)) =
389 git_store.repository_and_path_for_project_path(project_path, cx)
390 {
391 let repository = repository.read(cx);
392 let head_sha = repository
393 .head_commit
394 .as_ref()
395 .map(|head_commit| head_commit.sha.to_string());
396 let remote_origin_url = repository.remote_origin_url.clone();
397 let remote_upstream_url = repository.remote_upstream_url.clone();
398 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
399 return None;
400 }
401 Some(PredictEditsGitInfo {
402 head_sha,
403 remote_origin_url,
404 remote_upstream_url,
405 })
406 } else {
407 None
408 }
409}
410
411pub struct GatherContextOutput {
412 pub body: PredictEditsBody,
413 pub context_range: Range<Point>,
414 pub editable_range: Range<usize>,
415 pub included_events_count: usize,
416}
417
418pub fn gather_context(
419 full_path_str: String,
420 snapshot: &BufferSnapshot,
421 cursor_point: language::Point,
422 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
423 trigger: PredictEditsRequestTrigger,
424 cx: &App,
425) -> Task<Result<GatherContextOutput>> {
426 cx.background_spawn({
427 let snapshot = snapshot.clone();
428 async move {
429 let input_excerpt = excerpt_for_cursor_position(
430 cursor_point,
431 &full_path_str,
432 &snapshot,
433 MAX_REWRITE_TOKENS,
434 MAX_CONTEXT_TOKENS,
435 );
436 let (input_events, included_events_count) = prompt_for_events();
437 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
438
439 let body = PredictEditsBody {
440 input_events,
441 input_excerpt: input_excerpt.prompt,
442 can_collect_data: false,
443 diagnostic_groups: None,
444 git_info: None,
445 outline: None,
446 speculated_output: None,
447 trigger,
448 };
449
450 Ok(GatherContextOutput {
451 body,
452 context_range: input_excerpt.context_range,
453 editable_range,
454 included_events_count,
455 })
456 }
457 })
458}
459
460fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
461 let mut result = String::new();
462 for (ix, event) in events.iter().rev().enumerate() {
463 let event_string = format_event(event.as_ref());
464 let event_tokens = guess_token_count(event_string.len());
465 if event_tokens > remaining_tokens {
466 return (result, ix);
467 }
468
469 if !result.is_empty() {
470 result.insert_str(0, "\n\n");
471 }
472 result.insert_str(0, &event_string);
473 remaining_tokens -= event_tokens;
474 }
475 return (result, events.len());
476}
477
478pub fn format_event(event: &Event) -> String {
479 match event {
480 Event::BufferChange {
481 path,
482 old_path,
483 diff,
484 ..
485 } => {
486 let mut prompt = String::new();
487
488 if old_path != path {
489 writeln!(
490 prompt,
491 "User renamed {} to {}\n",
492 old_path.display(),
493 path.display()
494 )
495 .unwrap();
496 }
497
498 if !diff.is_empty() {
499 write!(
500 prompt,
501 "User edited {}:\n```diff\n{}\n```",
502 path.display(),
503 diff
504 )
505 .unwrap();
506 }
507
508 prompt
509 }
510 }
511}
512
513#[derive(Debug)]
514pub struct InputExcerpt {
515 pub context_range: Range<Point>,
516 pub editable_range: Range<Point>,
517 pub prompt: String,
518}
519
520pub fn excerpt_for_cursor_position(
521 position: Point,
522 path: &str,
523 snapshot: &BufferSnapshot,
524 editable_region_token_limit: usize,
525 context_token_limit: usize,
526) -> InputExcerpt {
527 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
528 position,
529 snapshot,
530 editable_region_token_limit,
531 context_token_limit,
532 );
533
534 let mut prompt = String::new();
535
536 writeln!(&mut prompt, "```{path}").unwrap();
537 if context_range.start == Point::zero() {
538 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
539 }
540
541 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
542 prompt.push_str(chunk.text);
543 }
544
545 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
546
547 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
548 prompt.push_str(chunk.text);
549 }
550 write!(prompt, "\n```").unwrap();
551
552 InputExcerpt {
553 context_range,
554 editable_range,
555 prompt,
556 }
557}
558
559fn push_editable_range(
560 cursor_position: Point,
561 snapshot: &BufferSnapshot,
562 editable_range: Range<Point>,
563 prompt: &mut String,
564) {
565 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
566 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
567 prompt.push_str(chunk.text);
568 }
569 prompt.push_str(CURSOR_MARKER);
570 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
571 prompt.push_str(chunk.text);
572 }
573 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use gpui::{App, AppContext};
580 use indoc::indoc;
581 use language::Buffer;
582
583 #[gpui::test]
584 fn test_excerpt_for_cursor_position(cx: &mut App) {
585 let text = indoc! {r#"
586 fn foo() {
587 let x = 42;
588 println!("Hello, world!");
589 }
590
591 fn bar() {
592 let x = 42;
593 let mut sum = 0;
594 for i in 0..x {
595 sum += i;
596 }
597 println!("Sum: {}", sum);
598 return sum;
599 }
600
601 fn generate_random_numbers() -> Vec<i32> {
602 let mut rng = rand::thread_rng();
603 let mut numbers = Vec::new();
604 for _ in 0..5 {
605 numbers.push(rng.random_range(1..101));
606 }
607 numbers
608 }
609 "#};
610 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
611 let snapshot = buffer.read(cx).snapshot();
612
613 // The excerpt expands to syntax boundaries.
614 // With 50 token editable limit, we get a region that expands to syntax nodes.
615 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
616 assert_eq!(
617 excerpt.prompt,
618 indoc! {r#"
619 ```main.rs
620
621 fn bar() {
622 let x = 42;
623 <|editable_region_start|>
624 let mut sum = 0;
625 for i in 0..x {
626 sum += i;
627 }
628 println!("Sum: {}", sum);
629 r<|user_cursor_is_here|>eturn sum;
630 }
631
632 fn generate_random_numbers() -> Vec<i32> {
633 <|editable_region_end|>
634 let mut rng = rand::thread_rng();
635 let mut numbers = Vec::new();
636 ```"#}
637 );
638
639 // With smaller budget, the region expands to syntax boundaries but is tighter.
640 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
641 assert_eq!(
642 excerpt.prompt,
643 indoc! {r#"
644 ```main.rs
645 fn bar() {
646 let x = 42;
647 let mut sum = 0;
648 for i in 0..x {
649 <|editable_region_start|>
650 sum += i;
651 }
652 println!("Sum: {}", sum);
653 r<|user_cursor_is_here|>eturn sum;
654 }
655
656 fn generate_random_numbers() -> Vec<i32> {
657 <|editable_region_end|>
658 let mut rng = rand::thread_rng();
659 ```"#}
660 );
661 }
662}