zeta1.rs

  1mod input_excerpt;
  2
  3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  4
  5use crate::{
  6    EditPredictionId, ZedUpdateRequiredError, Zeta,
  7    prediction::{EditPredictionInputs, EditPredictionResult},
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12    predict_edits_v3::Event,
 13};
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 15use input_excerpt::excerpt_for_cursor_position;
 16use language::{
 17    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
 18};
 19use project::{Project, ProjectPath};
 20use release_channel::AppVersion;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22
 23const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 24const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 25const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 26const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 27
 28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 31
 32pub(crate) fn request_prediction_with_zeta1(
 33    zeta: &mut Zeta,
 34    project: &Entity<Project>,
 35    buffer: &Entity<Buffer>,
 36    snapshot: BufferSnapshot,
 37    position: language::Anchor,
 38    events: Vec<Arc<Event>>,
 39    trigger: PredictEditsRequestTrigger,
 40    cx: &mut Context<Zeta>,
 41) -> Task<Result<Option<EditPredictionResult>>> {
 42    let buffer = buffer.clone();
 43    let buffer_snapshotted_at = Instant::now();
 44    let client = zeta.client.clone();
 45    let llm_token = zeta.llm_token.clone();
 46    let app_version = AppVersion::global(cx);
 47
 48    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 49        let can_collect_file = zeta.can_collect_file(project, file, cx);
 50        let git_info = if can_collect_file {
 51            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 52        } else {
 53            None
 54        };
 55        (git_info, can_collect_file)
 56    } else {
 57        (None, false)
 58    };
 59
 60    let full_path: Arc<Path> = snapshot
 61        .file()
 62        .map(|f| Arc::from(f.full_path(cx).as_path()))
 63        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 64    let full_path_str = full_path.to_string_lossy().into_owned();
 65    let cursor_point = position.to_point(&snapshot);
 66    let prompt_for_events = {
 67        let events = events.clone();
 68        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 69    };
 70    let gather_task = gather_context(
 71        full_path_str,
 72        &snapshot,
 73        cursor_point,
 74        prompt_for_events,
 75        trigger,
 76        cx,
 77    );
 78
 79    cx.spawn(async move |this, cx| {
 80        let GatherContextOutput {
 81            mut body,
 82            context_range,
 83            editable_range,
 84            included_events_count,
 85        } = gather_task.await?;
 86        let done_gathering_context_at = Instant::now();
 87
 88        let included_events = &events[events.len() - included_events_count..events.len()];
 89        body.can_collect_data = can_collect_file
 90            && this
 91                .read_with(cx, |this, _| this.can_collect_events(included_events))
 92                .unwrap_or(false);
 93        if body.can_collect_data {
 94            body.git_info = git_info;
 95        }
 96
 97        log::debug!(
 98            "Events:\n{}\nExcerpt:\n{:?}",
 99            body.input_events,
100            body.input_excerpt
101        );
102
103        let http_client = client.http_client();
104
105        let response = Zeta::send_api_request::<PredictEditsResponse>(
106            |request| {
107                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
108                    predict_edits_url
109                } else {
110                    http_client
111                        .build_zed_llm_url("/predict_edits/v2", &[])?
112                        .as_str()
113                        .into()
114                };
115                Ok(request
116                    .uri(uri)
117                    .body(serde_json::to_string(&body)?.into())?)
118            },
119            client,
120            llm_token,
121            app_version,
122        )
123        .await;
124
125        let inputs = EditPredictionInputs {
126            events: included_events.into(),
127            included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
128                path: full_path.clone(),
129                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
130                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
131                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
132                    text: snapshot
133                        .text_for_range(context_range)
134                        .collect::<String>()
135                        .into(),
136                }],
137            }],
138            cursor_point: cloud_llm_client::predict_edits_v3::Point {
139                column: cursor_point.column,
140                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
141            },
142            cursor_path: full_path,
143        };
144
145        // let response = perform_predict_edits(PerformPredictEditsParams {
146        //     client,
147        //     llm_token,
148        //     app_version,
149        //     body,
150        // })
151        // .await;
152
153        let (response, usage) = match response {
154            Ok(response) => response,
155            Err(err) => {
156                if err.is::<ZedUpdateRequiredError>() {
157                    cx.update(|cx| {
158                        this.update(cx, |zeta, _cx| {
159                            zeta.update_required = true;
160                        })
161                        .ok();
162
163                        let error_message: SharedString = err.to_string().into();
164                        show_app_notification(
165                            NotificationId::unique::<ZedUpdateRequiredError>(),
166                            cx,
167                            move |cx| {
168                                cx.new(|cx| {
169                                    ErrorMessagePrompt::new(error_message.clone(), cx)
170                                        .with_link_button("Update Zed", "https://zed.dev/releases")
171                                })
172                            },
173                        );
174                    })
175                    .ok();
176                }
177
178                return Err(err);
179            }
180        };
181
182        let received_response_at = Instant::now();
183        log::debug!("completion response: {}", &response.output_excerpt);
184
185        if let Some(usage) = usage {
186            this.update(cx, |this, cx| {
187                this.user_store.update(cx, |user_store, cx| {
188                    user_store.update_edit_prediction_usage(usage, cx);
189                });
190            })
191            .ok();
192        }
193
194        let edit_prediction = process_completion_response(
195            response,
196            buffer,
197            &snapshot,
198            editable_range,
199            inputs,
200            buffer_snapshotted_at,
201            received_response_at,
202            cx,
203        )
204        .await;
205
206        let finished_at = Instant::now();
207
208        // record latency for ~1% of requests
209        if rand::random::<u8>() <= 2 {
210            telemetry::event!(
211                "Edit Prediction Request",
212                context_latency = done_gathering_context_at
213                    .duration_since(buffer_snapshotted_at)
214                    .as_millis(),
215                request_latency = received_response_at
216                    .duration_since(done_gathering_context_at)
217                    .as_millis(),
218                process_latency = finished_at.duration_since(received_response_at).as_millis()
219            );
220        }
221
222        edit_prediction.map(Some)
223    })
224}
225
226fn process_completion_response(
227    prediction_response: PredictEditsResponse,
228    buffer: Entity<Buffer>,
229    snapshot: &BufferSnapshot,
230    editable_range: Range<usize>,
231    inputs: EditPredictionInputs,
232    buffer_snapshotted_at: Instant,
233    received_response_at: Instant,
234    cx: &AsyncApp,
235) -> Task<Result<EditPredictionResult>> {
236    let snapshot = snapshot.clone();
237    let request_id = prediction_response.request_id;
238    let output_excerpt = prediction_response.output_excerpt;
239    cx.spawn(async move |cx| {
240        let output_excerpt: Arc<str> = output_excerpt.into();
241
242        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
243            .background_spawn({
244                let output_excerpt = output_excerpt.clone();
245                let editable_range = editable_range.clone();
246                let snapshot = snapshot.clone();
247                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
248            })
249            .await?
250            .into();
251
252        let id = EditPredictionId(request_id.into());
253        Ok(EditPredictionResult::new(
254            id,
255            &buffer,
256            &snapshot,
257            edits,
258            buffer_snapshotted_at,
259            received_response_at,
260            inputs,
261            cx,
262        )
263        .await)
264    })
265}
266
267fn parse_edits(
268    output_excerpt: Arc<str>,
269    editable_range: Range<usize>,
270    snapshot: &BufferSnapshot,
271) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
272    let content = output_excerpt.replace(CURSOR_MARKER, "");
273
274    let start_markers = content
275        .match_indices(EDITABLE_REGION_START_MARKER)
276        .collect::<Vec<_>>();
277    anyhow::ensure!(
278        start_markers.len() == 1,
279        "expected exactly one start marker, found {}",
280        start_markers.len()
281    );
282
283    let end_markers = content
284        .match_indices(EDITABLE_REGION_END_MARKER)
285        .collect::<Vec<_>>();
286    anyhow::ensure!(
287        end_markers.len() == 1,
288        "expected exactly one end marker, found {}",
289        end_markers.len()
290    );
291
292    let sof_markers = content
293        .match_indices(START_OF_FILE_MARKER)
294        .collect::<Vec<_>>();
295    anyhow::ensure!(
296        sof_markers.len() <= 1,
297        "expected at most one start-of-file marker, found {}",
298        sof_markers.len()
299    );
300
301    let codefence_start = start_markers[0].0;
302    let content = &content[codefence_start..];
303
304    let newline_ix = content.find('\n').context("could not find newline")?;
305    let content = &content[newline_ix + 1..];
306
307    let codefence_end = content
308        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
309        .context("could not find end marker")?;
310    let new_text = &content[..codefence_end];
311
312    let old_text = snapshot
313        .text_for_range(editable_range.clone())
314        .collect::<String>();
315
316    Ok(compute_edits(
317        old_text,
318        new_text,
319        editable_range.start,
320        snapshot,
321    ))
322}
323
324pub fn compute_edits(
325    old_text: String,
326    new_text: &str,
327    offset: usize,
328    snapshot: &BufferSnapshot,
329) -> Vec<(Range<Anchor>, Arc<str>)> {
330    text_diff(&old_text, new_text)
331        .into_iter()
332        .map(|(mut old_range, new_text)| {
333            old_range.start += offset;
334            old_range.end += offset;
335
336            let prefix_len = common_prefix(
337                snapshot.chars_for_range(old_range.clone()),
338                new_text.chars(),
339            );
340            old_range.start += prefix_len;
341
342            let suffix_len = common_prefix(
343                snapshot.reversed_chars_for_range(old_range.clone()),
344                new_text[prefix_len..].chars().rev(),
345            );
346            old_range.end = old_range.end.saturating_sub(suffix_len);
347
348            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
349            let range = if old_range.is_empty() {
350                let anchor = snapshot.anchor_after(old_range.start);
351                anchor..anchor
352            } else {
353                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
354            };
355            (range, new_text)
356        })
357        .collect()
358}
359
360fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
361    a.zip(b)
362        .take_while(|(a, b)| a == b)
363        .map(|(a, _)| a.len_utf8())
364        .sum()
365}
366
367fn git_info_for_file(
368    project: &Entity<Project>,
369    project_path: &ProjectPath,
370    cx: &App,
371) -> Option<PredictEditsGitInfo> {
372    let git_store = project.read(cx).git_store().read(cx);
373    if let Some((repository, _repo_path)) =
374        git_store.repository_and_path_for_project_path(project_path, cx)
375    {
376        let repository = repository.read(cx);
377        let head_sha = repository
378            .head_commit
379            .as_ref()
380            .map(|head_commit| head_commit.sha.to_string());
381        let remote_origin_url = repository.remote_origin_url.clone();
382        let remote_upstream_url = repository.remote_upstream_url.clone();
383        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
384            return None;
385        }
386        Some(PredictEditsGitInfo {
387            head_sha,
388            remote_origin_url,
389            remote_upstream_url,
390        })
391    } else {
392        None
393    }
394}
395
396pub struct GatherContextOutput {
397    pub body: PredictEditsBody,
398    pub context_range: Range<Point>,
399    pub editable_range: Range<usize>,
400    pub included_events_count: usize,
401}
402
403pub fn gather_context(
404    full_path_str: String,
405    snapshot: &BufferSnapshot,
406    cursor_point: language::Point,
407    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
408    trigger: PredictEditsRequestTrigger,
409    cx: &App,
410) -> Task<Result<GatherContextOutput>> {
411    cx.background_spawn({
412        let snapshot = snapshot.clone();
413        async move {
414            let input_excerpt = excerpt_for_cursor_position(
415                cursor_point,
416                &full_path_str,
417                &snapshot,
418                MAX_REWRITE_TOKENS,
419                MAX_CONTEXT_TOKENS,
420            );
421            let (input_events, included_events_count) = prompt_for_events();
422            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
423
424            let body = PredictEditsBody {
425                input_events,
426                input_excerpt: input_excerpt.prompt,
427                can_collect_data: false,
428                diagnostic_groups: None,
429                git_info: None,
430                outline: None,
431                speculated_output: None,
432                trigger,
433            };
434
435            Ok(GatherContextOutput {
436                body,
437                context_range: input_excerpt.context_range,
438                editable_range,
439                included_events_count,
440            })
441        }
442    })
443}
444
445fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
446    let mut result = String::new();
447    for (ix, event) in events.iter().rev().enumerate() {
448        let event_string = format_event(event.as_ref());
449        let event_tokens = guess_token_count(event_string.len());
450        if event_tokens > remaining_tokens {
451            return (result, ix);
452        }
453
454        if !result.is_empty() {
455            result.insert_str(0, "\n\n");
456        }
457        result.insert_str(0, &event_string);
458        remaining_tokens -= event_tokens;
459    }
460    return (result, events.len());
461}
462
463pub fn format_event(event: &Event) -> String {
464    match event {
465        Event::BufferChange {
466            path,
467            old_path,
468            diff,
469            ..
470        } => {
471            let mut prompt = String::new();
472
473            if old_path != path {
474                writeln!(
475                    prompt,
476                    "User renamed {} to {}\n",
477                    old_path.display(),
478                    path.display()
479                )
480                .unwrap();
481            }
482
483            if !diff.is_empty() {
484                write!(
485                    prompt,
486                    "User edited {}:\n```diff\n{}\n```",
487                    path.display(),
488                    diff
489                )
490                .unwrap();
491            }
492
493            prompt
494        }
495    }
496}
497
498/// Typical number of string bytes per token for the purposes of limiting model input. This is
499/// intentionally low to err on the side of underestimating limits.
500pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
501
502fn guess_token_count(bytes: usize) -> usize {
503    bytes / BYTES_PER_TOKEN_GUESS
504}