1mod input_excerpt;
2
3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
4
5use crate::{
6 EditPredictionId, ZedUpdateRequiredError, Zeta,
7 prediction::{EditPrediction, EditPredictionInputs},
8};
9use anyhow::{Context as _, Result};
10use cloud_llm_client::{
11 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
12};
13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
14use input_excerpt::excerpt_for_cursor_position;
15use language::{
16 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
17};
18use project::{Project, ProjectPath};
19use release_channel::AppVersion;
20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
21
22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
26
27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
30
31pub(crate) fn request_prediction_with_zeta1(
32 zeta: &mut Zeta,
33 project: &Entity<Project>,
34 buffer: &Entity<Buffer>,
35 position: language::Anchor,
36 cx: &mut Context<Zeta>,
37) -> Task<Result<Option<EditPrediction>>> {
38 let buffer = buffer.clone();
39 let buffer_snapshotted_at = Instant::now();
40 let snapshot = buffer.read(cx).snapshot();
41 let client = zeta.client.clone();
42 let llm_token = zeta.llm_token.clone();
43 let app_version = AppVersion::global(cx);
44
45 let zeta_project = zeta.get_or_init_zeta_project(project, cx);
46 let events = Arc::new(zeta_project.events(cx));
47
48 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
49 let can_collect_file = zeta.can_collect_file(project, file, cx);
50 let git_info = if can_collect_file {
51 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
52 } else {
53 None
54 };
55 (git_info, can_collect_file)
56 } else {
57 (None, false)
58 };
59
60 let full_path: Arc<Path> = snapshot
61 .file()
62 .map(|f| Arc::from(f.full_path(cx).as_path()))
63 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
64 let full_path_str = full_path.to_string_lossy().into_owned();
65 let cursor_point = position.to_point(&snapshot);
66 let prompt_for_events = {
67 let events = events.clone();
68 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
69 };
70 let gather_task = gather_context(
71 full_path_str,
72 &snapshot,
73 cursor_point,
74 prompt_for_events,
75 cx,
76 );
77
78 cx.spawn(async move |this, cx| {
79 let GatherContextOutput {
80 mut body,
81 context_range,
82 editable_range,
83 included_events_count,
84 } = gather_task.await?;
85 let done_gathering_context_at = Instant::now();
86
87 let included_events = &events[events.len() - included_events_count..events.len()];
88 body.can_collect_data = can_collect_file
89 && this
90 .read_with(cx, |this, _| this.can_collect_events(included_events))
91 .unwrap_or(false);
92 if body.can_collect_data {
93 body.git_info = git_info;
94 }
95
96 log::debug!(
97 "Events:\n{}\nExcerpt:\n{:?}",
98 body.input_events,
99 body.input_excerpt
100 );
101
102 let http_client = client.http_client();
103
104 let response = Zeta::send_api_request::<PredictEditsResponse>(
105 |request| {
106 let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
107 predict_edits_url
108 } else {
109 http_client
110 .build_zed_llm_url("/predict_edits/v2", &[])?
111 .as_str()
112 .into()
113 };
114 Ok(request
115 .uri(uri)
116 .body(serde_json::to_string(&body)?.into())?)
117 },
118 client,
119 llm_token,
120 app_version,
121 )
122 .await;
123
124 let inputs = EditPredictionInputs {
125 events: included_events.into(),
126 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
127 path: full_path.clone(),
128 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
129 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
130 start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
131 text: snapshot
132 .text_for_range(context_range)
133 .collect::<String>()
134 .into(),
135 }],
136 }],
137 cursor_point: cloud_llm_client::predict_edits_v3::Point {
138 column: cursor_point.column,
139 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
140 },
141 cursor_path: full_path,
142 };
143
144 // let response = perform_predict_edits(PerformPredictEditsParams {
145 // client,
146 // llm_token,
147 // app_version,
148 // body,
149 // })
150 // .await;
151
152 let (response, usage) = match response {
153 Ok(response) => response,
154 Err(err) => {
155 if err.is::<ZedUpdateRequiredError>() {
156 cx.update(|cx| {
157 this.update(cx, |zeta, _cx| {
158 zeta.update_required = true;
159 })
160 .ok();
161
162 let error_message: SharedString = err.to_string().into();
163 show_app_notification(
164 NotificationId::unique::<ZedUpdateRequiredError>(),
165 cx,
166 move |cx| {
167 cx.new(|cx| {
168 ErrorMessagePrompt::new(error_message.clone(), cx)
169 .with_link_button("Update Zed", "https://zed.dev/releases")
170 })
171 },
172 );
173 })
174 .ok();
175 }
176
177 return Err(err);
178 }
179 };
180
181 let received_response_at = Instant::now();
182 log::debug!("completion response: {}", &response.output_excerpt);
183
184 if let Some(usage) = usage {
185 this.update(cx, |this, cx| {
186 this.user_store.update(cx, |user_store, cx| {
187 user_store.update_edit_prediction_usage(usage, cx);
188 });
189 })
190 .ok();
191 }
192
193 let edit_prediction = process_completion_response(
194 response,
195 buffer,
196 &snapshot,
197 editable_range,
198 inputs,
199 buffer_snapshotted_at,
200 received_response_at,
201 cx,
202 )
203 .await;
204
205 let finished_at = Instant::now();
206
207 // record latency for ~1% of requests
208 if rand::random::<u8>() <= 2 {
209 telemetry::event!(
210 "Edit Prediction Request",
211 context_latency = done_gathering_context_at
212 .duration_since(buffer_snapshotted_at)
213 .as_millis(),
214 request_latency = received_response_at
215 .duration_since(done_gathering_context_at)
216 .as_millis(),
217 process_latency = finished_at.duration_since(received_response_at).as_millis()
218 );
219 }
220
221 edit_prediction
222 })
223}
224
225fn process_completion_response(
226 prediction_response: PredictEditsResponse,
227 buffer: Entity<Buffer>,
228 snapshot: &BufferSnapshot,
229 editable_range: Range<usize>,
230 inputs: EditPredictionInputs,
231 buffer_snapshotted_at: Instant,
232 received_response_at: Instant,
233 cx: &AsyncApp,
234) -> Task<Result<Option<EditPrediction>>> {
235 let snapshot = snapshot.clone();
236 let request_id = prediction_response.request_id;
237 let output_excerpt = prediction_response.output_excerpt;
238 cx.spawn(async move |cx| {
239 let output_excerpt: Arc<str> = output_excerpt.into();
240
241 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
242 .background_spawn({
243 let output_excerpt = output_excerpt.clone();
244 let editable_range = editable_range.clone();
245 let snapshot = snapshot.clone();
246 async move { parse_edits(output_excerpt, editable_range, &snapshot) }
247 })
248 .await?
249 .into();
250
251 Ok(EditPrediction::new(
252 EditPredictionId(request_id.into()),
253 &buffer,
254 &snapshot,
255 edits,
256 buffer_snapshotted_at,
257 received_response_at,
258 inputs,
259 cx,
260 )
261 .await)
262 })
263}
264
265fn parse_edits(
266 output_excerpt: Arc<str>,
267 editable_range: Range<usize>,
268 snapshot: &BufferSnapshot,
269) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
270 let content = output_excerpt.replace(CURSOR_MARKER, "");
271
272 let start_markers = content
273 .match_indices(EDITABLE_REGION_START_MARKER)
274 .collect::<Vec<_>>();
275 anyhow::ensure!(
276 start_markers.len() == 1,
277 "expected exactly one start marker, found {}",
278 start_markers.len()
279 );
280
281 let end_markers = content
282 .match_indices(EDITABLE_REGION_END_MARKER)
283 .collect::<Vec<_>>();
284 anyhow::ensure!(
285 end_markers.len() == 1,
286 "expected exactly one end marker, found {}",
287 end_markers.len()
288 );
289
290 let sof_markers = content
291 .match_indices(START_OF_FILE_MARKER)
292 .collect::<Vec<_>>();
293 anyhow::ensure!(
294 sof_markers.len() <= 1,
295 "expected at most one start-of-file marker, found {}",
296 sof_markers.len()
297 );
298
299 let codefence_start = start_markers[0].0;
300 let content = &content[codefence_start..];
301
302 let newline_ix = content.find('\n').context("could not find newline")?;
303 let content = &content[newline_ix + 1..];
304
305 let codefence_end = content
306 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
307 .context("could not find end marker")?;
308 let new_text = &content[..codefence_end];
309
310 let old_text = snapshot
311 .text_for_range(editable_range.clone())
312 .collect::<String>();
313
314 Ok(compute_edits(
315 old_text,
316 new_text,
317 editable_range.start,
318 snapshot,
319 ))
320}
321
322pub fn compute_edits(
323 old_text: String,
324 new_text: &str,
325 offset: usize,
326 snapshot: &BufferSnapshot,
327) -> Vec<(Range<Anchor>, Arc<str>)> {
328 text_diff(&old_text, new_text)
329 .into_iter()
330 .map(|(mut old_range, new_text)| {
331 old_range.start += offset;
332 old_range.end += offset;
333
334 let prefix_len = common_prefix(
335 snapshot.chars_for_range(old_range.clone()),
336 new_text.chars(),
337 );
338 old_range.start += prefix_len;
339
340 let suffix_len = common_prefix(
341 snapshot.reversed_chars_for_range(old_range.clone()),
342 new_text[prefix_len..].chars().rev(),
343 );
344 old_range.end = old_range.end.saturating_sub(suffix_len);
345
346 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
347 let range = if old_range.is_empty() {
348 let anchor = snapshot.anchor_after(old_range.start);
349 anchor..anchor
350 } else {
351 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
352 };
353 (range, new_text)
354 })
355 .collect()
356}
357
358fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
359 a.zip(b)
360 .take_while(|(a, b)| a == b)
361 .map(|(a, _)| a.len_utf8())
362 .sum()
363}
364
365fn git_info_for_file(
366 project: &Entity<Project>,
367 project_path: &ProjectPath,
368 cx: &App,
369) -> Option<PredictEditsGitInfo> {
370 let git_store = project.read(cx).git_store().read(cx);
371 if let Some((repository, _repo_path)) =
372 git_store.repository_and_path_for_project_path(project_path, cx)
373 {
374 let repository = repository.read(cx);
375 let head_sha = repository
376 .head_commit
377 .as_ref()
378 .map(|head_commit| head_commit.sha.to_string());
379 let remote_origin_url = repository.remote_origin_url.clone();
380 let remote_upstream_url = repository.remote_upstream_url.clone();
381 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
382 return None;
383 }
384 Some(PredictEditsGitInfo {
385 head_sha,
386 remote_origin_url,
387 remote_upstream_url,
388 })
389 } else {
390 None
391 }
392}
393
394pub struct GatherContextOutput {
395 pub body: PredictEditsBody,
396 pub context_range: Range<Point>,
397 pub editable_range: Range<usize>,
398 pub included_events_count: usize,
399}
400
401pub fn gather_context(
402 full_path_str: String,
403 snapshot: &BufferSnapshot,
404 cursor_point: language::Point,
405 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
406 cx: &App,
407) -> Task<Result<GatherContextOutput>> {
408 cx.background_spawn({
409 let snapshot = snapshot.clone();
410 async move {
411 let input_excerpt = excerpt_for_cursor_position(
412 cursor_point,
413 &full_path_str,
414 &snapshot,
415 MAX_REWRITE_TOKENS,
416 MAX_CONTEXT_TOKENS,
417 );
418 let (input_events, included_events_count) = prompt_for_events();
419 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
420
421 let body = PredictEditsBody {
422 input_events,
423 input_excerpt: input_excerpt.prompt,
424 can_collect_data: false,
425 diagnostic_groups: None,
426 git_info: None,
427 outline: None,
428 speculated_output: None,
429 };
430
431 Ok(GatherContextOutput {
432 body,
433 context_range: input_excerpt.context_range,
434 editable_range,
435 included_events_count,
436 })
437 }
438 })
439}
440
441fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
442 let mut result = String::new();
443 for (ix, event) in events.iter().rev().enumerate() {
444 let event_string = format_event(event.as_ref());
445 let event_tokens = guess_token_count(event_string.len());
446 if event_tokens > remaining_tokens {
447 return (result, ix);
448 }
449
450 if !result.is_empty() {
451 result.insert_str(0, "\n\n");
452 }
453 result.insert_str(0, &event_string);
454 remaining_tokens -= event_tokens;
455 }
456 return (result, events.len());
457}
458
459pub fn format_event(event: &Event) -> String {
460 match event {
461 Event::BufferChange {
462 path,
463 old_path,
464 diff,
465 ..
466 } => {
467 let mut prompt = String::new();
468
469 if old_path != path {
470 writeln!(
471 prompt,
472 "User renamed {} to {}\n",
473 old_path.display(),
474 path.display()
475 )
476 .unwrap();
477 }
478
479 if !diff.is_empty() {
480 write!(
481 prompt,
482 "User edited {}:\n```diff\n{}\n```",
483 path.display(),
484 diff
485 )
486 .unwrap();
487 }
488
489 prompt
490 }
491 }
492}
493
494/// Typical number of string bytes per token for the purposes of limiting model input. This is
495/// intentionally low to err on the side of underestimating limits.
496pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
497
498fn guess_token_count(bytes: usize) -> usize {
499 bytes / BYTES_PER_TOKEN_GUESS
500}