1mod input_excerpt;
2
3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
4
5use crate::{
6 EditPredictionId, ZedUpdateRequiredError, Zeta,
7 prediction::{EditPredictionInputs, EditPredictionResult},
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
12 predict_edits_v3::Event,
13};
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
15use input_excerpt::excerpt_for_cursor_position;
16use language::{
17 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
18};
19use project::{Project, ProjectPath};
20use release_channel::AppVersion;
21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
22
23const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
24const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
25const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
26const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
27
28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
31
32pub(crate) fn request_prediction_with_zeta1(
33 zeta: &mut Zeta,
34 project: &Entity<Project>,
35 buffer: &Entity<Buffer>,
36 snapshot: BufferSnapshot,
37 position: language::Anchor,
38 events: Vec<Arc<Event>>,
39 trigger: PredictEditsRequestTrigger,
40 cx: &mut Context<Zeta>,
41) -> Task<Result<Option<EditPredictionResult>>> {
42 let buffer = buffer.clone();
43 let buffer_snapshotted_at = Instant::now();
44 let client = zeta.client.clone();
45 let llm_token = zeta.llm_token.clone();
46 let app_version = AppVersion::global(cx);
47
48 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
49 let can_collect_file = zeta.can_collect_file(project, file, cx);
50 let git_info = if can_collect_file {
51 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
52 } else {
53 None
54 };
55 (git_info, can_collect_file)
56 } else {
57 (None, false)
58 };
59
60 let full_path: Arc<Path> = snapshot
61 .file()
62 .map(|f| Arc::from(f.full_path(cx).as_path()))
63 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
64 let full_path_str = full_path.to_string_lossy().into_owned();
65 let cursor_point = position.to_point(&snapshot);
66 let prompt_for_events = {
67 let events = events.clone();
68 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
69 };
70 let gather_task = gather_context(
71 full_path_str,
72 &snapshot,
73 cursor_point,
74 prompt_for_events,
75 trigger,
76 cx,
77 );
78
79 cx.spawn(async move |this, cx| {
80 let GatherContextOutput {
81 mut body,
82 context_range,
83 editable_range,
84 included_events_count,
85 } = gather_task.await?;
86 let done_gathering_context_at = Instant::now();
87
88 let included_events = &events[events.len() - included_events_count..events.len()];
89 body.can_collect_data = can_collect_file
90 && this
91 .read_with(cx, |this, _| this.can_collect_events(included_events))
92 .unwrap_or(false);
93 if body.can_collect_data {
94 body.git_info = git_info;
95 }
96
97 log::debug!(
98 "Events:\n{}\nExcerpt:\n{:?}",
99 body.input_events,
100 body.input_excerpt
101 );
102
103 let http_client = client.http_client();
104
105 let response = Zeta::send_api_request::<PredictEditsResponse>(
106 |request| {
107 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
108 predict_edits_url
109 } else {
110 http_client
111 .build_zed_llm_url("/predict_edits/v2", &[])?
112 .as_str()
113 .into()
114 };
115 Ok(request
116 .uri(uri)
117 .body(serde_json::to_string(&body)?.into())?)
118 },
119 client,
120 llm_token,
121 app_version,
122 )
123 .await;
124
125 let inputs = EditPredictionInputs {
126 events: included_events.into(),
127 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
128 path: full_path.clone(),
129 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
130 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
131 start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
132 text: snapshot
133 .text_for_range(context_range)
134 .collect::<String>()
135 .into(),
136 }],
137 }],
138 cursor_point: cloud_llm_client::predict_edits_v3::Point {
139 column: cursor_point.column,
140 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
141 },
142 cursor_path: full_path,
143 };
144
145 // let response = perform_predict_edits(PerformPredictEditsParams {
146 // client,
147 // llm_token,
148 // app_version,
149 // body,
150 // })
151 // .await;
152
153 let (response, usage) = match response {
154 Ok(response) => response,
155 Err(err) => {
156 if err.is::<ZedUpdateRequiredError>() {
157 cx.update(|cx| {
158 this.update(cx, |zeta, _cx| {
159 zeta.update_required = true;
160 })
161 .ok();
162
163 let error_message: SharedString = err.to_string().into();
164 show_app_notification(
165 NotificationId::unique::<ZedUpdateRequiredError>(),
166 cx,
167 move |cx| {
168 cx.new(|cx| {
169 ErrorMessagePrompt::new(error_message.clone(), cx)
170 .with_link_button("Update Zed", "https://zed.dev/releases")
171 })
172 },
173 );
174 })
175 .ok();
176 }
177
178 return Err(err);
179 }
180 };
181
182 let received_response_at = Instant::now();
183 log::debug!("completion response: {}", &response.output_excerpt);
184
185 if let Some(usage) = usage {
186 this.update(cx, |this, cx| {
187 this.user_store.update(cx, |user_store, cx| {
188 user_store.update_edit_prediction_usage(usage, cx);
189 });
190 })
191 .ok();
192 }
193
194 let edit_prediction = process_completion_response(
195 response,
196 buffer,
197 &snapshot,
198 editable_range,
199 inputs,
200 buffer_snapshotted_at,
201 received_response_at,
202 cx,
203 )
204 .await;
205
206 let finished_at = Instant::now();
207
208 // record latency for ~1% of requests
209 if rand::random::<u8>() <= 2 {
210 telemetry::event!(
211 "Edit Prediction Request",
212 context_latency = done_gathering_context_at
213 .duration_since(buffer_snapshotted_at)
214 .as_millis(),
215 request_latency = received_response_at
216 .duration_since(done_gathering_context_at)
217 .as_millis(),
218 process_latency = finished_at.duration_since(received_response_at).as_millis()
219 );
220 }
221
222 edit_prediction.map(Some)
223 })
224}
225
226fn process_completion_response(
227 prediction_response: PredictEditsResponse,
228 buffer: Entity<Buffer>,
229 snapshot: &BufferSnapshot,
230 editable_range: Range<usize>,
231 inputs: EditPredictionInputs,
232 buffer_snapshotted_at: Instant,
233 received_response_at: Instant,
234 cx: &AsyncApp,
235) -> Task<Result<EditPredictionResult>> {
236 let snapshot = snapshot.clone();
237 let request_id = prediction_response.request_id;
238 let output_excerpt = prediction_response.output_excerpt;
239 cx.spawn(async move |cx| {
240 let output_excerpt: Arc<str> = output_excerpt.into();
241
242 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
243 .background_spawn({
244 let output_excerpt = output_excerpt.clone();
245 let editable_range = editable_range.clone();
246 let snapshot = snapshot.clone();
247 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
248 })
249 .await?
250 .into();
251
252 let id = EditPredictionId(request_id.into());
253 Ok(EditPredictionResult::new(
254 id,
255 &buffer,
256 &snapshot,
257 edits,
258 buffer_snapshotted_at,
259 received_response_at,
260 inputs,
261 cx,
262 )
263 .await)
264 })
265}
266
267fn parse_edits(
268 output_excerpt: Arc<str>,
269 editable_range: Range<usize>,
270 snapshot: &BufferSnapshot,
271) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
272 let content = output_excerpt.replace(CURSOR_MARKER, "");
273
274 let start_markers = content
275 .match_indices(EDITABLE_REGION_START_MARKER)
276 .collect::<Vec<_>>();
277 anyhow::ensure!(
278 start_markers.len() == 1,
279 "expected exactly one start marker, found {}",
280 start_markers.len()
281 );
282
283 let end_markers = content
284 .match_indices(EDITABLE_REGION_END_MARKER)
285 .collect::<Vec<_>>();
286 anyhow::ensure!(
287 end_markers.len() == 1,
288 "expected exactly one end marker, found {}",
289 end_markers.len()
290 );
291
292 let sof_markers = content
293 .match_indices(START_OF_FILE_MARKER)
294 .collect::<Vec<_>>();
295 anyhow::ensure!(
296 sof_markers.len() <= 1,
297 "expected at most one start-of-file marker, found {}",
298 sof_markers.len()
299 );
300
301 let codefence_start = start_markers[0].0;
302 let content = &content[codefence_start..];
303
304 let newline_ix = content.find('\n').context("could not find newline")?;
305 let content = &content[newline_ix + 1..];
306
307 let codefence_end = content
308 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
309 .context("could not find end marker")?;
310 let new_text = &content[..codefence_end];
311
312 let old_text = snapshot
313 .text_for_range(editable_range.clone())
314 .collect::<String>();
315
316 Ok(compute_edits(
317 old_text,
318 new_text,
319 editable_range.start,
320 snapshot,
321 ))
322}
323
324pub fn compute_edits(
325 old_text: String,
326 new_text: &str,
327 offset: usize,
328 snapshot: &BufferSnapshot,
329) -> Vec<(Range<Anchor>, Arc<str>)> {
330 text_diff(&old_text, new_text)
331 .into_iter()
332 .map(|(mut old_range, new_text)| {
333 old_range.start += offset;
334 old_range.end += offset;
335
336 let prefix_len = common_prefix(
337 snapshot.chars_for_range(old_range.clone()),
338 new_text.chars(),
339 );
340 old_range.start += prefix_len;
341
342 let suffix_len = common_prefix(
343 snapshot.reversed_chars_for_range(old_range.clone()),
344 new_text[prefix_len..].chars().rev(),
345 );
346 old_range.end = old_range.end.saturating_sub(suffix_len);
347
348 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
349 let range = if old_range.is_empty() {
350 let anchor = snapshot.anchor_after(old_range.start);
351 anchor..anchor
352 } else {
353 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
354 };
355 (range, new_text)
356 })
357 .collect()
358}
359
360fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
361 a.zip(b)
362 .take_while(|(a, b)| a == b)
363 .map(|(a, _)| a.len_utf8())
364 .sum()
365}
366
367fn git_info_for_file(
368 project: &Entity<Project>,
369 project_path: &ProjectPath,
370 cx: &App,
371) -> Option<PredictEditsGitInfo> {
372 let git_store = project.read(cx).git_store().read(cx);
373 if let Some((repository, _repo_path)) =
374 git_store.repository_and_path_for_project_path(project_path, cx)
375 {
376 let repository = repository.read(cx);
377 let head_sha = repository
378 .head_commit
379 .as_ref()
380 .map(|head_commit| head_commit.sha.to_string());
381 let remote_origin_url = repository.remote_origin_url.clone();
382 let remote_upstream_url = repository.remote_upstream_url.clone();
383 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
384 return None;
385 }
386 Some(PredictEditsGitInfo {
387 head_sha,
388 remote_origin_url,
389 remote_upstream_url,
390 })
391 } else {
392 None
393 }
394}
395
396pub struct GatherContextOutput {
397 pub body: PredictEditsBody,
398 pub context_range: Range<Point>,
399 pub editable_range: Range<usize>,
400 pub included_events_count: usize,
401}
402
403pub fn gather_context(
404 full_path_str: String,
405 snapshot: &BufferSnapshot,
406 cursor_point: language::Point,
407 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
408 trigger: PredictEditsRequestTrigger,
409 cx: &App,
410) -> Task<Result<GatherContextOutput>> {
411 cx.background_spawn({
412 let snapshot = snapshot.clone();
413 async move {
414 let input_excerpt = excerpt_for_cursor_position(
415 cursor_point,
416 &full_path_str,
417 &snapshot,
418 MAX_REWRITE_TOKENS,
419 MAX_CONTEXT_TOKENS,
420 );
421 let (input_events, included_events_count) = prompt_for_events();
422 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
423
424 let body = PredictEditsBody {
425 input_events,
426 input_excerpt: input_excerpt.prompt,
427 can_collect_data: false,
428 diagnostic_groups: None,
429 git_info: None,
430 outline: None,
431 speculated_output: None,
432 trigger,
433 };
434
435 Ok(GatherContextOutput {
436 body,
437 context_range: input_excerpt.context_range,
438 editable_range,
439 included_events_count,
440 })
441 }
442 })
443}
444
445fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
446 let mut result = String::new();
447 for (ix, event) in events.iter().rev().enumerate() {
448 let event_string = format_event(event.as_ref());
449 let event_tokens = guess_token_count(event_string.len());
450 if event_tokens > remaining_tokens {
451 return (result, ix);
452 }
453
454 if !result.is_empty() {
455 result.insert_str(0, "\n\n");
456 }
457 result.insert_str(0, &event_string);
458 remaining_tokens -= event_tokens;
459 }
460 return (result, events.len());
461}
462
463pub fn format_event(event: &Event) -> String {
464 match event {
465 Event::BufferChange {
466 path,
467 old_path,
468 diff,
469 ..
470 } => {
471 let mut prompt = String::new();
472
473 if old_path != path {
474 writeln!(
475 prompt,
476 "User renamed {} to {}\n",
477 old_path.display(),
478 path.display()
479 )
480 .unwrap();
481 }
482
483 if !diff.is_empty() {
484 write!(
485 prompt,
486 "User edited {}:\n```diff\n{}\n```",
487 path.display(),
488 diff
489 )
490 .unwrap();
491 }
492
493 prompt
494 }
495 }
496}
497
498/// Typical number of string bytes per token for the purposes of limiting model input. This is
499/// intentionally low to err on the side of underestimating limits.
500pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
501
502fn guess_token_count(bytes: usize) -> usize {
503 bytes / BYTES_PER_TOKEN_GUESS
504}