1mod input_excerpt;
2
3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
4
5use crate::{
6 EditPredictionId, ZedUpdateRequiredError, Zeta,
7 prediction::{EditPredictionInputs, EditPredictionResult},
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use input_excerpt::excerpt_for_cursor_position;
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 zeta: &mut Zeta,
33 project: &Entity<Project>,
34 buffer: &Entity<Buffer>,
35 snapshot: BufferSnapshot,
36 position: language::Anchor,
37 events: Vec<Arc<Event>>,
38 cx: &mut Context<Zeta>,
39) -> Task<Result<Option<EditPredictionResult>>> {
40 let buffer = buffer.clone();
41 let buffer_snapshotted_at = Instant::now();
42 let client = zeta.client.clone();
43 let llm_token = zeta.llm_token.clone();
44 let app_version = AppVersion::global(cx);
45
46 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
47 let can_collect_file = zeta.can_collect_file(project, file, cx);
48 let git_info = if can_collect_file {
49 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
50 } else {
51 None
52 };
53 (git_info, can_collect_file)
54 } else {
55 (None, false)
56 };
57
58 let full_path: Arc<Path> = snapshot
59 .file()
60 .map(|f| Arc::from(f.full_path(cx).as_path()))
61 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
62 let full_path_str = full_path.to_string_lossy().into_owned();
63 let cursor_point = position.to_point(&snapshot);
64 let prompt_for_events = {
65 let events = events.clone();
66 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
67 };
68 let gather_task = gather_context(
69 full_path_str,
70 &snapshot,
71 cursor_point,
72 prompt_for_events,
73 cx,
74 );
75
76 cx.spawn(async move |this, cx| {
77 let GatherContextOutput {
78 mut body,
79 context_range,
80 editable_range,
81 included_events_count,
82 } = gather_task.await?;
83 let done_gathering_context_at = Instant::now();
84
85 let included_events = &events[events.len() - included_events_count..events.len()];
86 body.can_collect_data = can_collect_file
87 && this
88 .read_with(cx, |this, _| this.can_collect_events(included_events))
89 .unwrap_or(false);
90 if body.can_collect_data {
91 body.git_info = git_info;
92 }
93
94 log::debug!(
95 "Events:\n{}\nExcerpt:\n{:?}",
96 body.input_events,
97 body.input_excerpt
98 );
99
100 let http_client = client.http_client();
101
102 let response = Zeta::send_api_request::<PredictEditsResponse>(
103 |request| {
104 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
105 predict_edits_url
106 } else {
107 http_client
108 .build_zed_llm_url("/predict_edits/v2", &[])?
109 .as_str()
110 .into()
111 };
112 Ok(request
113 .uri(uri)
114 .body(serde_json::to_string(&body)?.into())?)
115 },
116 client,
117 llm_token,
118 app_version,
119 )
120 .await;
121
122 let inputs = EditPredictionInputs {
123 events: included_events.into(),
124 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
125 path: full_path.clone(),
126 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
127 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
128 start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
129 text: snapshot
130 .text_for_range(context_range)
131 .collect::<String>()
132 .into(),
133 }],
134 }],
135 cursor_point: cloud_llm_client::predict_edits_v3::Point {
136 column: cursor_point.column,
137 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
138 },
139 cursor_path: full_path,
140 };
141
142 // let response = perform_predict_edits(PerformPredictEditsParams {
143 // client,
144 // llm_token,
145 // app_version,
146 // body,
147 // })
148 // .await;
149
150 let (response, usage) = match response {
151 Ok(response) => response,
152 Err(err) => {
153 if err.is::<ZedUpdateRequiredError>() {
154 cx.update(|cx| {
155 this.update(cx, |zeta, _cx| {
156 zeta.update_required = true;
157 })
158 .ok();
159
160 let error_message: SharedString = err.to_string().into();
161 show_app_notification(
162 NotificationId::unique::<ZedUpdateRequiredError>(),
163 cx,
164 move |cx| {
165 cx.new(|cx| {
166 ErrorMessagePrompt::new(error_message.clone(), cx)
167 .with_link_button("Update Zed", "https://zed.dev/releases")
168 })
169 },
170 );
171 })
172 .ok();
173 }
174
175 return Err(err);
176 }
177 };
178
179 let received_response_at = Instant::now();
180 log::debug!("completion response: {}", &response.output_excerpt);
181
182 if let Some(usage) = usage {
183 this.update(cx, |this, cx| {
184 this.user_store.update(cx, |user_store, cx| {
185 user_store.update_edit_prediction_usage(usage, cx);
186 });
187 })
188 .ok();
189 }
190
191 let edit_prediction = process_completion_response(
192 response,
193 buffer,
194 &snapshot,
195 editable_range,
196 inputs,
197 buffer_snapshotted_at,
198 received_response_at,
199 cx,
200 )
201 .await;
202
203 let finished_at = Instant::now();
204
205 // record latency for ~1% of requests
206 if rand::random::<u8>() <= 2 {
207 telemetry::event!(
208 "Edit Prediction Request",
209 context_latency = done_gathering_context_at
210 .duration_since(buffer_snapshotted_at)
211 .as_millis(),
212 request_latency = received_response_at
213 .duration_since(done_gathering_context_at)
214 .as_millis(),
215 process_latency = finished_at.duration_since(received_response_at).as_millis()
216 );
217 }
218
219 edit_prediction.map(Some)
220 })
221}
222
223fn process_completion_response(
224 prediction_response: PredictEditsResponse,
225 buffer: Entity<Buffer>,
226 snapshot: &BufferSnapshot,
227 editable_range: Range<usize>,
228 inputs: EditPredictionInputs,
229 buffer_snapshotted_at: Instant,
230 received_response_at: Instant,
231 cx: &AsyncApp,
232) -> Task<Result<EditPredictionResult>> {
233 let snapshot = snapshot.clone();
234 let request_id = prediction_response.request_id;
235 let output_excerpt = prediction_response.output_excerpt;
236 cx.spawn(async move |cx| {
237 let output_excerpt: Arc<str> = output_excerpt.into();
238
239 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
240 .background_spawn({
241 let output_excerpt = output_excerpt.clone();
242 let editable_range = editable_range.clone();
243 let snapshot = snapshot.clone();
244 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
245 })
246 .await?
247 .into();
248
249 let id = EditPredictionId(request_id.into());
250 Ok(EditPredictionResult::new(
251 id,
252 &buffer,
253 &snapshot,
254 edits,
255 buffer_snapshotted_at,
256 received_response_at,
257 inputs,
258 cx,
259 )
260 .await)
261 })
262}
263
264fn parse_edits(
265 output_excerpt: Arc<str>,
266 editable_range: Range<usize>,
267 snapshot: &BufferSnapshot,
268) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
269 let content = output_excerpt.replace(CURSOR_MARKER, "");
270
271 let start_markers = content
272 .match_indices(EDITABLE_REGION_START_MARKER)
273 .collect::<Vec<_>>();
274 anyhow::ensure!(
275 start_markers.len() == 1,
276 "expected exactly one start marker, found {}",
277 start_markers.len()
278 );
279
280 let end_markers = content
281 .match_indices(EDITABLE_REGION_END_MARKER)
282 .collect::<Vec<_>>();
283 anyhow::ensure!(
284 end_markers.len() == 1,
285 "expected exactly one end marker, found {}",
286 end_markers.len()
287 );
288
289 let sof_markers = content
290 .match_indices(START_OF_FILE_MARKER)
291 .collect::<Vec<_>>();
292 anyhow::ensure!(
293 sof_markers.len() <= 1,
294 "expected at most one start-of-file marker, found {}",
295 sof_markers.len()
296 );
297
298 let codefence_start = start_markers[0].0;
299 let content = &content[codefence_start..];
300
301 let newline_ix = content.find('\n').context("could not find newline")?;
302 let content = &content[newline_ix + 1..];
303
304 let codefence_end = content
305 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
306 .context("could not find end marker")?;
307 let new_text = &content[..codefence_end];
308
309 let old_text = snapshot
310 .text_for_range(editable_range.clone())
311 .collect::<String>();
312
313 Ok(compute_edits(
314 old_text,
315 new_text,
316 editable_range.start,
317 snapshot,
318 ))
319}
320
321pub fn compute_edits(
322 old_text: String,
323 new_text: &str,
324 offset: usize,
325 snapshot: &BufferSnapshot,
326) -> Vec<(Range<Anchor>, Arc<str>)> {
327 text_diff(&old_text, new_text)
328 .into_iter()
329 .map(|(mut old_range, new_text)| {
330 old_range.start += offset;
331 old_range.end += offset;
332
333 let prefix_len = common_prefix(
334 snapshot.chars_for_range(old_range.clone()),
335 new_text.chars(),
336 );
337 old_range.start += prefix_len;
338
339 let suffix_len = common_prefix(
340 snapshot.reversed_chars_for_range(old_range.clone()),
341 new_text[prefix_len..].chars().rev(),
342 );
343 old_range.end = old_range.end.saturating_sub(suffix_len);
344
345 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
346 let range = if old_range.is_empty() {
347 let anchor = snapshot.anchor_after(old_range.start);
348 anchor..anchor
349 } else {
350 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
351 };
352 (range, new_text)
353 })
354 .collect()
355}
356
357fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
358 a.zip(b)
359 .take_while(|(a, b)| a == b)
360 .map(|(a, _)| a.len_utf8())
361 .sum()
362}
363
364fn git_info_for_file(
365 project: &Entity<Project>,
366 project_path: &ProjectPath,
367 cx: &App,
368) -> Option<PredictEditsGitInfo> {
369 let git_store = project.read(cx).git_store().read(cx);
370 if let Some((repository, _repo_path)) =
371 git_store.repository_and_path_for_project_path(project_path, cx)
372 {
373 let repository = repository.read(cx);
374 let head_sha = repository
375 .head_commit
376 .as_ref()
377 .map(|head_commit| head_commit.sha.to_string());
378 let remote_origin_url = repository.remote_origin_url.clone();
379 let remote_upstream_url = repository.remote_upstream_url.clone();
380 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
381 return None;
382 }
383 Some(PredictEditsGitInfo {
384 head_sha,
385 remote_origin_url,
386 remote_upstream_url,
387 })
388 } else {
389 None
390 }
391}
392
393pub struct GatherContextOutput {
394 pub body: PredictEditsBody,
395 pub context_range: Range<Point>,
396 pub editable_range: Range<usize>,
397 pub included_events_count: usize,
398}
399
400pub fn gather_context(
401 full_path_str: String,
402 snapshot: &BufferSnapshot,
403 cursor_point: language::Point,
404 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
405 cx: &App,
406) -> Task<Result<GatherContextOutput>> {
407 cx.background_spawn({
408 let snapshot = snapshot.clone();
409 async move {
410 let input_excerpt = excerpt_for_cursor_position(
411 cursor_point,
412 &full_path_str,
413 &snapshot,
414 MAX_REWRITE_TOKENS,
415 MAX_CONTEXT_TOKENS,
416 );
417 let (input_events, included_events_count) = prompt_for_events();
418 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
419
420 let body = PredictEditsBody {
421 input_events,
422 input_excerpt: input_excerpt.prompt,
423 can_collect_data: false,
424 diagnostic_groups: None,
425 git_info: None,
426 outline: None,
427 speculated_output: None,
428 };
429
430 Ok(GatherContextOutput {
431 body,
432 context_range: input_excerpt.context_range,
433 editable_range,
434 included_events_count,
435 })
436 }
437 })
438}
439
440fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
441 let mut result = String::new();
442 for (ix, event) in events.iter().rev().enumerate() {
443 let event_string = format_event(event.as_ref());
444 let event_tokens = guess_token_count(event_string.len());
445 if event_tokens > remaining_tokens {
446 return (result, ix);
447 }
448
449 if !result.is_empty() {
450 result.insert_str(0, "\n\n");
451 }
452 result.insert_str(0, &event_string);
453 remaining_tokens -= event_tokens;
454 }
455 return (result, events.len());
456}
457
458pub fn format_event(event: &Event) -> String {
459 match event {
460 Event::BufferChange {
461 path,
462 old_path,
463 diff,
464 ..
465 } => {
466 let mut prompt = String::new();
467
468 if old_path != path {
469 writeln!(
470 prompt,
471 "User renamed {} to {}\n",
472 old_path.display(),
473 path.display()
474 )
475 .unwrap();
476 }
477
478 if !diff.is_empty() {
479 write!(
480 prompt,
481 "User edited {}:\n```diff\n{}\n```",
482 path.display(),
483 diff
484 )
485 .unwrap();
486 }
487
488 prompt
489 }
490 }
491}
492
493/// Typical number of string bytes per token for the purposes of limiting model input. This is
494/// intentionally low to err on the side of underestimating limits.
495pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
496
497fn guess_token_count(bytes: usize) -> usize {
498 bytes / BYTES_PER_TOKEN_GUESS
499}