1mod input_excerpt;
2
3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
4
5use crate::{
6 EditPredictionId, ZedUpdateRequiredError, Zeta,
7 prediction::{EditPrediction, EditPredictionInputs},
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use input_excerpt::excerpt_for_cursor_position;
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 zeta: &mut Zeta,
33 project: &Entity<Project>,
34 buffer: &Entity<Buffer>,
35 snapshot: BufferSnapshot,
36 position: language::Anchor,
37 events: Vec<Arc<Event>>,
38 cx: &mut Context<Zeta>,
39) -> Task<Result<Option<EditPrediction>>> {
40 let buffer = buffer.clone();
41 let buffer_snapshotted_at = Instant::now();
42 let client = zeta.client.clone();
43 let llm_token = zeta.llm_token.clone();
44 let app_version = AppVersion::global(cx);
45
46 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
47 let can_collect_file = zeta.can_collect_file(project, file, cx);
48 let git_info = if can_collect_file {
49 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
50 } else {
51 None
52 };
53 (git_info, can_collect_file)
54 } else {
55 (None, false)
56 };
57
58 let full_path: Arc<Path> = snapshot
59 .file()
60 .map(|f| Arc::from(f.full_path(cx).as_path()))
61 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
62 let full_path_str = full_path.to_string_lossy().into_owned();
63 let cursor_point = position.to_point(&snapshot);
64 let prompt_for_events = {
65 let events = events.clone();
66 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
67 };
68 let gather_task = gather_context(
69 full_path_str,
70 &snapshot,
71 cursor_point,
72 prompt_for_events,
73 cx,
74 );
75
76 cx.spawn(async move |this, cx| {
77 let GatherContextOutput {
78 mut body,
79 context_range,
80 editable_range,
81 included_events_count,
82 } = gather_task.await?;
83 let done_gathering_context_at = Instant::now();
84
85 let included_events = &events[events.len() - included_events_count..events.len()];
86 body.can_collect_data = can_collect_file
87 && this
88 .read_with(cx, |this, _| this.can_collect_events(included_events))
89 .unwrap_or(false);
90 if body.can_collect_data {
91 body.git_info = git_info;
92 }
93
94 log::debug!(
95 "Events:\n{}\nExcerpt:\n{:?}",
96 body.input_events,
97 body.input_excerpt
98 );
99
100 let http_client = client.http_client();
101
102 let response = Zeta::send_api_request::<PredictEditsResponse>(
103 |request| {
104 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
105 predict_edits_url
106 } else {
107 http_client
108 .build_zed_llm_url("/predict_edits/v2", &[])?
109 .as_str()
110 .into()
111 };
112 Ok(request
113 .uri(uri)
114 .body(serde_json::to_string(&body)?.into())?)
115 },
116 client,
117 llm_token,
118 app_version,
119 )
120 .await;
121
122 let inputs = EditPredictionInputs {
123 events: included_events.into(),
124 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
125 path: full_path.clone(),
126 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
127 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
128 start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
129 text: snapshot
130 .text_for_range(context_range)
131 .collect::<String>()
132 .into(),
133 }],
134 }],
135 cursor_point: cloud_llm_client::predict_edits_v3::Point {
136 column: cursor_point.column,
137 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
138 },
139 cursor_path: full_path,
140 };
141
142 // let response = perform_predict_edits(PerformPredictEditsParams {
143 // client,
144 // llm_token,
145 // app_version,
146 // body,
147 // })
148 // .await;
149
150 let (response, usage) = match response {
151 Ok(response) => response,
152 Err(err) => {
153 if err.is::<ZedUpdateRequiredError>() {
154 cx.update(|cx| {
155 this.update(cx, |zeta, _cx| {
156 zeta.update_required = true;
157 })
158 .ok();
159
160 let error_message: SharedString = err.to_string().into();
161 show_app_notification(
162 NotificationId::unique::<ZedUpdateRequiredError>(),
163 cx,
164 move |cx| {
165 cx.new(|cx| {
166 ErrorMessagePrompt::new(error_message.clone(), cx)
167 .with_link_button("Update Zed", "https://zed.dev/releases")
168 })
169 },
170 );
171 })
172 .ok();
173 }
174
175 return Err(err);
176 }
177 };
178
179 let received_response_at = Instant::now();
180 log::debug!("completion response: {}", &response.output_excerpt);
181
182 if let Some(usage) = usage {
183 this.update(cx, |this, cx| {
184 this.user_store.update(cx, |user_store, cx| {
185 user_store.update_edit_prediction_usage(usage, cx);
186 });
187 })
188 .ok();
189 }
190
191 let edit_prediction = process_completion_response(
192 response,
193 buffer,
194 &snapshot,
195 editable_range,
196 inputs,
197 buffer_snapshotted_at,
198 received_response_at,
199 cx,
200 )
201 .await;
202
203 let finished_at = Instant::now();
204
205 // record latency for ~1% of requests
206 if rand::random::<u8>() <= 2 {
207 telemetry::event!(
208 "Edit Prediction Request",
209 context_latency = done_gathering_context_at
210 .duration_since(buffer_snapshotted_at)
211 .as_millis(),
212 request_latency = received_response_at
213 .duration_since(done_gathering_context_at)
214 .as_millis(),
215 process_latency = finished_at.duration_since(received_response_at).as_millis()
216 );
217 }
218
219 edit_prediction
220 })
221}
222
223fn process_completion_response(
224 prediction_response: PredictEditsResponse,
225 buffer: Entity<Buffer>,
226 snapshot: &BufferSnapshot,
227 editable_range: Range<usize>,
228 inputs: EditPredictionInputs,
229 buffer_snapshotted_at: Instant,
230 received_response_at: Instant,
231 cx: &AsyncApp,
232) -> Task<Result<Option<EditPrediction>>> {
233 let snapshot = snapshot.clone();
234 let request_id = prediction_response.request_id;
235 let output_excerpt = prediction_response.output_excerpt;
236 cx.spawn(async move |cx| {
237 let output_excerpt: Arc<str> = output_excerpt.into();
238
239 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
240 .background_spawn({
241 let output_excerpt = output_excerpt.clone();
242 let editable_range = editable_range.clone();
243 let snapshot = snapshot.clone();
244 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
245 })
246 .await?
247 .into();
248
249 Ok(EditPrediction::new(
250 EditPredictionId(request_id.into()),
251 &buffer,
252 &snapshot,
253 edits,
254 buffer_snapshotted_at,
255 received_response_at,
256 inputs,
257 cx,
258 )
259 .await)
260 })
261}
262
263fn parse_edits(
264 output_excerpt: Arc<str>,
265 editable_range: Range<usize>,
266 snapshot: &BufferSnapshot,
267) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
268 let content = output_excerpt.replace(CURSOR_MARKER, "");
269
270 let start_markers = content
271 .match_indices(EDITABLE_REGION_START_MARKER)
272 .collect::<Vec<_>>();
273 anyhow::ensure!(
274 start_markers.len() == 1,
275 "expected exactly one start marker, found {}",
276 start_markers.len()
277 );
278
279 let end_markers = content
280 .match_indices(EDITABLE_REGION_END_MARKER)
281 .collect::<Vec<_>>();
282 anyhow::ensure!(
283 end_markers.len() == 1,
284 "expected exactly one end marker, found {}",
285 end_markers.len()
286 );
287
288 let sof_markers = content
289 .match_indices(START_OF_FILE_MARKER)
290 .collect::<Vec<_>>();
291 anyhow::ensure!(
292 sof_markers.len() <= 1,
293 "expected at most one start-of-file marker, found {}",
294 sof_markers.len()
295 );
296
297 let codefence_start = start_markers[0].0;
298 let content = &content[codefence_start..];
299
300 let newline_ix = content.find('\n').context("could not find newline")?;
301 let content = &content[newline_ix + 1..];
302
303 let codefence_end = content
304 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
305 .context("could not find end marker")?;
306 let new_text = &content[..codefence_end];
307
308 let old_text = snapshot
309 .text_for_range(editable_range.clone())
310 .collect::<String>();
311
312 Ok(compute_edits(
313 old_text,
314 new_text,
315 editable_range.start,
316 snapshot,
317 ))
318}
319
320pub fn compute_edits(
321 old_text: String,
322 new_text: &str,
323 offset: usize,
324 snapshot: &BufferSnapshot,
325) -> Vec<(Range<Anchor>, Arc<str>)> {
326 text_diff(&old_text, new_text)
327 .into_iter()
328 .map(|(mut old_range, new_text)| {
329 old_range.start += offset;
330 old_range.end += offset;
331
332 let prefix_len = common_prefix(
333 snapshot.chars_for_range(old_range.clone()),
334 new_text.chars(),
335 );
336 old_range.start += prefix_len;
337
338 let suffix_len = common_prefix(
339 snapshot.reversed_chars_for_range(old_range.clone()),
340 new_text[prefix_len..].chars().rev(),
341 );
342 old_range.end = old_range.end.saturating_sub(suffix_len);
343
344 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
345 let range = if old_range.is_empty() {
346 let anchor = snapshot.anchor_after(old_range.start);
347 anchor..anchor
348 } else {
349 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
350 };
351 (range, new_text)
352 })
353 .collect()
354}
355
356fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
357 a.zip(b)
358 .take_while(|(a, b)| a == b)
359 .map(|(a, _)| a.len_utf8())
360 .sum()
361}
362
363fn git_info_for_file(
364 project: &Entity<Project>,
365 project_path: &ProjectPath,
366 cx: &App,
367) -> Option<PredictEditsGitInfo> {
368 let git_store = project.read(cx).git_store().read(cx);
369 if let Some((repository, _repo_path)) =
370 git_store.repository_and_path_for_project_path(project_path, cx)
371 {
372 let repository = repository.read(cx);
373 let head_sha = repository
374 .head_commit
375 .as_ref()
376 .map(|head_commit| head_commit.sha.to_string());
377 let remote_origin_url = repository.remote_origin_url.clone();
378 let remote_upstream_url = repository.remote_upstream_url.clone();
379 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
380 return None;
381 }
382 Some(PredictEditsGitInfo {
383 head_sha,
384 remote_origin_url,
385 remote_upstream_url,
386 })
387 } else {
388 None
389 }
390}
391
392pub struct GatherContextOutput {
393 pub body: PredictEditsBody,
394 pub context_range: Range<Point>,
395 pub editable_range: Range<usize>,
396 pub included_events_count: usize,
397}
398
399pub fn gather_context(
400 full_path_str: String,
401 snapshot: &BufferSnapshot,
402 cursor_point: language::Point,
403 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
404 cx: &App,
405) -> Task<Result<GatherContextOutput>> {
406 cx.background_spawn({
407 let snapshot = snapshot.clone();
408 async move {
409 let input_excerpt = excerpt_for_cursor_position(
410 cursor_point,
411 &full_path_str,
412 &snapshot,
413 MAX_REWRITE_TOKENS,
414 MAX_CONTEXT_TOKENS,
415 );
416 let (input_events, included_events_count) = prompt_for_events();
417 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
418
419 let body = PredictEditsBody {
420 input_events,
421 input_excerpt: input_excerpt.prompt,
422 can_collect_data: false,
423 diagnostic_groups: None,
424 git_info: None,
425 outline: None,
426 speculated_output: None,
427 };
428
429 Ok(GatherContextOutput {
430 body,
431 context_range: input_excerpt.context_range,
432 editable_range,
433 included_events_count,
434 })
435 }
436 })
437}
438
439fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
440 let mut result = String::new();
441 for (ix, event) in events.iter().rev().enumerate() {
442 let event_string = format_event(event.as_ref());
443 let event_tokens = guess_token_count(event_string.len());
444 if event_tokens > remaining_tokens {
445 return (result, ix);
446 }
447
448 if !result.is_empty() {
449 result.insert_str(0, "\n\n");
450 }
451 result.insert_str(0, &event_string);
452 remaining_tokens -= event_tokens;
453 }
454 return (result, events.len());
455}
456
457pub fn format_event(event: &Event) -> String {
458 match event {
459 Event::BufferChange {
460 path,
461 old_path,
462 diff,
463 ..
464 } => {
465 let mut prompt = String::new();
466
467 if old_path != path {
468 writeln!(
469 prompt,
470 "User renamed {} to {}\n",
471 old_path.display(),
472 path.display()
473 )
474 .unwrap();
475 }
476
477 if !diff.is_empty() {
478 write!(
479 prompt,
480 "User edited {}:\n```diff\n{}\n```",
481 path.display(),
482 diff
483 )
484 .unwrap();
485 }
486
487 prompt
488 }
489 }
490}
491
492/// Typical number of string bytes per token for the purposes of limiting model input. This is
493/// intentionally low to err on the side of underestimating limits.
494pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
495
496fn guess_token_count(bytes: usize) -> usize {
497 bytes / BYTES_PER_TOKEN_GUESS
498}