1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
2
3use crate::{
4 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
5 EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
6 cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
7 prediction::EditPredictionResult,
8};
9use anyhow::Result;
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use language::{
15 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
16};
17use project::{Project, ProjectPath};
18use release_channel::AppVersion;
19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
20use zeta_prompt::{
21 Event, ZetaPromptInput,
22 zeta1::{
23 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
24 START_OF_FILE_MARKER,
25 },
26};
27
28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
31
32pub(crate) fn request_prediction_with_zeta1(
33 store: &mut EditPredictionStore,
34 EditPredictionModelInput {
35 project,
36 buffer,
37 snapshot,
38 position,
39 events,
40 trigger,
41 debug_tx,
42 ..
43 }: EditPredictionModelInput,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let buffer_snapshotted_at = Instant::now();
47 let client = store.client.clone();
48 let llm_token = store.llm_token.clone();
49 let app_version = AppVersion::global(cx);
50
51 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
52 let can_collect_file = store.can_collect_file(&project, file, cx);
53 let git_info = if can_collect_file {
54 git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
55 } else {
56 None
57 };
58 (git_info, can_collect_file)
59 } else {
60 (None, false)
61 };
62
63 let full_path: Arc<Path> = snapshot
64 .file()
65 .map(|f| Arc::from(f.full_path(cx).as_path()))
66 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
67 let full_path_str = full_path.to_string_lossy().into_owned();
68 let cursor_point = position.to_point(&snapshot);
69 let prompt_for_events = {
70 let events = events.clone();
71 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
72 };
73 let gather_task = gather_context(
74 full_path_str,
75 &snapshot,
76 cursor_point,
77 prompt_for_events,
78 trigger,
79 cx,
80 );
81
82 let (uri, require_auth) = match &store.custom_predict_edits_url {
83 Some(custom_url) => (custom_url.clone(), false),
84 None => {
85 match client
86 .http_client()
87 .build_zed_llm_url("/predict_edits/v2", &[])
88 {
89 Ok(url) => (url.into(), true),
90 Err(err) => return Task::ready(Err(err)),
91 }
92 }
93 };
94
95 cx.spawn(async move |this, cx| {
96 let GatherContextOutput {
97 mut body,
98 context_range,
99 editable_range,
100 included_events_count,
101 } = gather_task.await?;
102 let done_gathering_context_at = Instant::now();
103
104 let included_events = &events[events.len() - included_events_count..events.len()];
105 body.can_collect_data = can_collect_file
106 && this
107 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
108 .unwrap_or(false);
109 if body.can_collect_data {
110 body.git_info = git_info;
111 }
112
113 log::debug!(
114 "Events:\n{}\nExcerpt:\n{:?}",
115 body.input_events,
116 body.input_excerpt
117 );
118
119 let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
120 |request| {
121 Ok(request
122 .uri(uri.as_str())
123 .body(serde_json::to_string(&body)?.into())?)
124 },
125 client,
126 llm_token,
127 app_version,
128 require_auth,
129 )
130 .await;
131
132 let context_start_offset = context_range.start.to_offset(&snapshot);
133 let editable_offset_range = editable_range.to_offset(&snapshot);
134
135 let inputs = ZetaPromptInput {
136 events: included_events.into(),
137 related_files: vec![],
138 cursor_path: full_path,
139 cursor_excerpt: snapshot
140 .text_for_range(context_range)
141 .collect::<String>()
142 .into(),
143 editable_range_in_excerpt: (editable_range.start - context_start_offset)
144 ..(editable_offset_range.end - context_start_offset),
145 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
146 };
147
148 if let Some(debug_tx) = &debug_tx {
149 debug_tx
150 .unbounded_send(DebugEvent::EditPredictionStarted(
151 EditPredictionStartedDebugEvent {
152 buffer: buffer.downgrade(),
153 prompt: Some(serde_json::to_string(&inputs).unwrap()),
154 position,
155 },
156 ))
157 .ok();
158 }
159
160 let (response, usage) = match response {
161 Ok(response) => response,
162 Err(err) => {
163 if err.is::<ZedUpdateRequiredError>() {
164 cx.update(|cx| {
165 this.update(cx, |ep_store, _cx| {
166 ep_store.update_required = true;
167 })
168 .ok();
169
170 let error_message: SharedString = err.to_string().into();
171 show_app_notification(
172 NotificationId::unique::<ZedUpdateRequiredError>(),
173 cx,
174 move |cx| {
175 cx.new(|cx| {
176 ErrorMessagePrompt::new(error_message.clone(), cx)
177 .with_link_button("Update Zed", "https://zed.dev/releases")
178 })
179 },
180 );
181 });
182 }
183
184 return Err(err);
185 }
186 };
187
188 let received_response_at = Instant::now();
189 log::debug!("completion response: {}", &response.output_excerpt);
190
191 if let Some(usage) = usage {
192 this.update(cx, |this, cx| {
193 this.user_store.update(cx, |user_store, cx| {
194 user_store.update_edit_prediction_usage(usage, cx);
195 });
196 })
197 .ok();
198 }
199
200 if let Some(debug_tx) = &debug_tx {
201 debug_tx
202 .unbounded_send(DebugEvent::EditPredictionFinished(
203 EditPredictionFinishedDebugEvent {
204 buffer: buffer.downgrade(),
205 model_output: Some(response.output_excerpt.clone()),
206 position,
207 },
208 ))
209 .ok();
210 }
211
212 let edit_prediction = process_completion_response(
213 response,
214 buffer,
215 &snapshot,
216 editable_range,
217 inputs,
218 buffer_snapshotted_at,
219 received_response_at,
220 cx,
221 )
222 .await;
223
224 let finished_at = Instant::now();
225
226 // record latency for ~1% of requests
227 if rand::random::<u8>() <= 2 {
228 telemetry::event!(
229 "Edit Prediction Request",
230 context_latency = done_gathering_context_at
231 .duration_since(buffer_snapshotted_at)
232 .as_millis(),
233 request_latency = received_response_at
234 .duration_since(done_gathering_context_at)
235 .as_millis(),
236 process_latency = finished_at.duration_since(received_response_at).as_millis()
237 );
238 }
239
240 edit_prediction.map(Some)
241 })
242}
243
244fn process_completion_response(
245 prediction_response: PredictEditsResponse,
246 buffer: Entity<Buffer>,
247 snapshot: &BufferSnapshot,
248 editable_range: Range<usize>,
249 inputs: ZetaPromptInput,
250 buffer_snapshotted_at: Instant,
251 received_response_at: Instant,
252 cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254 let snapshot = snapshot.clone();
255 let request_id = prediction_response.request_id;
256 let output_excerpt = prediction_response.output_excerpt;
257 cx.spawn(async move |cx| {
258 let output_excerpt: Arc<str> = output_excerpt.into();
259
260 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261 .background_spawn({
262 let output_excerpt = output_excerpt.clone();
263 let editable_range = editable_range.clone();
264 let snapshot = snapshot.clone();
265 async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
266 })
267 .await?
268 .into();
269
270 let id = EditPredictionId(request_id.into());
271 Ok(EditPredictionResult::new(
272 id,
273 &buffer,
274 &snapshot,
275 edits,
276 buffer_snapshotted_at,
277 received_response_at,
278 inputs,
279 cx,
280 )
281 .await)
282 })
283}
284
285pub(crate) fn parse_edits(
286 output_excerpt: &str,
287 editable_range: Range<usize>,
288 snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290 let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292 let start_markers = content
293 .match_indices(EDITABLE_REGION_START_MARKER)
294 .collect::<Vec<_>>();
295 anyhow::ensure!(
296 start_markers.len() <= 1,
297 "expected at most one start marker, found {}",
298 start_markers.len()
299 );
300
301 let end_markers = content
302 .match_indices(EDITABLE_REGION_END_MARKER)
303 .collect::<Vec<_>>();
304 anyhow::ensure!(
305 end_markers.len() <= 1,
306 "expected at most one end marker, found {}",
307 end_markers.len()
308 );
309
310 let sof_markers = content
311 .match_indices(START_OF_FILE_MARKER)
312 .collect::<Vec<_>>();
313 anyhow::ensure!(
314 sof_markers.len() <= 1,
315 "expected at most one start-of-file marker, found {}",
316 sof_markers.len()
317 );
318
319 let content_start = start_markers
320 .first()
321 .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322 .unwrap_or(0);
323 let content_end = end_markers
324 .first()
325 .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326 .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328 let new_text = &content[content_start..content_end];
329
330 let old_text = snapshot
331 .text_for_range(editable_range.clone())
332 .collect::<String>();
333
334 Ok(compute_edits(
335 old_text,
336 new_text,
337 editable_range.start,
338 snapshot,
339 ))
340}
341
342pub fn compute_edits(
343 old_text: String,
344 new_text: &str,
345 offset: usize,
346 snapshot: &BufferSnapshot,
347) -> Vec<(Range<Anchor>, Arc<str>)> {
348 text_diff(&old_text, new_text)
349 .into_iter()
350 .map(|(mut old_range, new_text)| {
351 let old_slice = &old_text[old_range.clone()];
352
353 let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
354 let suffix_len = common_prefix(
355 old_slice[prefix_len..].chars().rev(),
356 new_text[prefix_len..].chars().rev(),
357 );
358
359 old_range.start += offset;
360 old_range.end += offset;
361 old_range.start += prefix_len;
362 old_range.end -= suffix_len;
363
364 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
365 let range = if old_range.is_empty() {
366 let anchor = snapshot.anchor_after(old_range.start);
367 anchor..anchor
368 } else {
369 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
370 };
371 (range, new_text)
372 })
373 .collect()
374}
375
376fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
377 a.zip(b)
378 .take_while(|(a, b)| a == b)
379 .map(|(a, _)| a.len_utf8())
380 .sum()
381}
382
383fn git_info_for_file(
384 project: &Entity<Project>,
385 project_path: &ProjectPath,
386 cx: &App,
387) -> Option<PredictEditsGitInfo> {
388 let git_store = project.read(cx).git_store().read(cx);
389 if let Some((repository, _repo_path)) =
390 git_store.repository_and_path_for_project_path(project_path, cx)
391 {
392 let repository = repository.read(cx);
393 let head_sha = repository
394 .head_commit
395 .as_ref()
396 .map(|head_commit| head_commit.sha.to_string());
397 let remote_origin_url = repository.remote_origin_url.clone();
398 let remote_upstream_url = repository.remote_upstream_url.clone();
399 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
400 return None;
401 }
402 Some(PredictEditsGitInfo {
403 head_sha,
404 remote_origin_url,
405 remote_upstream_url,
406 })
407 } else {
408 None
409 }
410}
411
412pub struct GatherContextOutput {
413 pub body: PredictEditsBody,
414 pub context_range: Range<Point>,
415 pub editable_range: Range<usize>,
416 pub included_events_count: usize,
417}
418
419pub fn gather_context(
420 full_path_str: String,
421 snapshot: &BufferSnapshot,
422 cursor_point: language::Point,
423 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
424 trigger: PredictEditsRequestTrigger,
425 cx: &App,
426) -> Task<Result<GatherContextOutput>> {
427 cx.background_spawn({
428 let snapshot = snapshot.clone();
429 async move {
430 let input_excerpt = excerpt_for_cursor_position(
431 cursor_point,
432 &full_path_str,
433 &snapshot,
434 MAX_REWRITE_TOKENS,
435 MAX_CONTEXT_TOKENS,
436 );
437 let (input_events, included_events_count) = prompt_for_events();
438 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
439
440 let body = PredictEditsBody {
441 input_events,
442 input_excerpt: input_excerpt.prompt,
443 can_collect_data: false,
444 diagnostic_groups: None,
445 git_info: None,
446 outline: None,
447 speculated_output: None,
448 trigger,
449 };
450
451 Ok(GatherContextOutput {
452 body,
453 context_range: input_excerpt.context_range,
454 editable_range,
455 included_events_count,
456 })
457 }
458 })
459}
460
461pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
462 prompt_for_events_impl(events, max_tokens).0
463}
464
465fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
466 let mut result = String::new();
467 for (ix, event) in events.iter().rev().enumerate() {
468 let event_string = format_event(event.as_ref());
469 let event_tokens = guess_token_count(event_string.len());
470 if event_tokens > remaining_tokens {
471 return (result, ix);
472 }
473
474 if !result.is_empty() {
475 result.insert_str(0, "\n\n");
476 }
477 result.insert_str(0, &event_string);
478 remaining_tokens -= event_tokens;
479 }
480 return (result, events.len());
481}
482
483pub fn format_event(event: &Event) -> String {
484 match event {
485 Event::BufferChange {
486 path,
487 old_path,
488 diff,
489 ..
490 } => {
491 let mut prompt = String::new();
492
493 if old_path != path {
494 writeln!(
495 prompt,
496 "User renamed {} to {}\n",
497 old_path.display(),
498 path.display()
499 )
500 .unwrap();
501 }
502
503 if !diff.is_empty() {
504 write!(
505 prompt,
506 "User edited {}:\n```diff\n{}\n```",
507 path.display(),
508 diff
509 )
510 .unwrap();
511 }
512
513 prompt
514 }
515 }
516}
517
518#[derive(Debug)]
519pub struct InputExcerpt {
520 pub context_range: Range<Point>,
521 pub editable_range: Range<Point>,
522 pub prompt: String,
523}
524
525pub fn excerpt_for_cursor_position(
526 position: Point,
527 path: &str,
528 snapshot: &BufferSnapshot,
529 editable_region_token_limit: usize,
530 context_token_limit: usize,
531) -> InputExcerpt {
532 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
533 position,
534 snapshot,
535 editable_region_token_limit,
536 context_token_limit,
537 );
538
539 let mut prompt = String::new();
540
541 writeln!(&mut prompt, "```{path}").unwrap();
542 if context_range.start == Point::zero() {
543 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
544 }
545
546 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
547 prompt.push_str(chunk.text);
548 }
549
550 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
551
552 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
553 prompt.push_str(chunk.text);
554 }
555 write!(prompt, "\n```").unwrap();
556
557 InputExcerpt {
558 context_range,
559 editable_range,
560 prompt,
561 }
562}
563
564fn push_editable_range(
565 cursor_position: Point,
566 snapshot: &BufferSnapshot,
567 editable_range: Range<Point>,
568 prompt: &mut String,
569) {
570 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
571 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
572 prompt.push_str(chunk.text);
573 }
574 prompt.push_str(CURSOR_MARKER);
575 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
576 prompt.push_str(chunk.text);
577 }
578 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use gpui::{App, AppContext};
585 use indoc::indoc;
586 use language::Buffer;
587
588 #[gpui::test]
589 fn test_excerpt_for_cursor_position(cx: &mut App) {
590 let text = indoc! {r#"
591 fn foo() {
592 let x = 42;
593 println!("Hello, world!");
594 }
595
596 fn bar() {
597 let x = 42;
598 let mut sum = 0;
599 for i in 0..x {
600 sum += i;
601 }
602 println!("Sum: {}", sum);
603 return sum;
604 }
605
606 fn generate_random_numbers() -> Vec<i32> {
607 let mut rng = rand::thread_rng();
608 let mut numbers = Vec::new();
609 for _ in 0..5 {
610 numbers.push(rng.random_range(1..101));
611 }
612 numbers
613 }
614 "#};
615 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
616 let snapshot = buffer.read(cx).snapshot();
617
618 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
619 // when a larger scope doesn't fit the editable region.
620 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
621 assert_eq!(
622 excerpt.prompt,
623 indoc! {r#"
624 ```main.rs
625 let x = 42;
626 println!("Hello, world!");
627 <|editable_region_start|>
628 }
629
630 fn bar() {
631 let x = 42;
632 let mut sum = 0;
633 for i in 0..x {
634 sum += i;
635 }
636 println!("Sum: {}", sum);
637 r<|user_cursor_is_here|>eturn sum;
638 }
639
640 fn generate_random_numbers() -> Vec<i32> {
641 <|editable_region_end|>
642 let mut rng = rand::thread_rng();
643 let mut numbers = Vec::new();
644 ```"#}
645 );
646
647 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
648 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
649 assert_eq!(
650 excerpt.prompt,
651 indoc! {r#"
652 ```main.rs
653 fn bar() {
654 let x = 42;
655 let mut sum = 0;
656 <|editable_region_start|>
657 for i in 0..x {
658 sum += i;
659 }
660 println!("Sum: {}", sum);
661 r<|user_cursor_is_here|>eturn sum;
662 }
663
664 fn generate_random_numbers() -> Vec<i32> {
665 let mut rng = rand::thread_rng();
666 <|editable_region_end|>
667 let mut numbers = Vec::new();
668 for _ in 0..5 {
669 numbers.push(rng.random_range(1..101));
670 ```"#}
671 );
672 }
673}