zeta.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  3    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
  4    ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
  5    prediction::EditPredictionResult,
  6};
  7use anyhow::Result;
  8use cloud_llm_client::{
  9    AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
 10};
 11use edit_prediction_types::PredictedCursorPosition;
 12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
 13use language::{
 14    Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
 15    text_diff,
 16};
 17use release_channel::AppVersion;
 18use settings::EditPredictionPromptFormat;
 19use text::{Anchor, Bias};
 20use ui::SharedString;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::ZetaPromptInput;
 23
 24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 25use zeta_prompt::{
 26    CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
 27    prompt_input_contains_special_tokens,
 28    zeta1::{self, EDITABLE_REGION_END_MARKER},
 29};
 30
 31use crate::open_ai_compatible::{
 32    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
 33};
 34
 35pub fn request_prediction_with_zeta(
 36    store: &mut EditPredictionStore,
 37    EditPredictionModelInput {
 38        buffer,
 39        snapshot,
 40        position,
 41        related_files,
 42        events,
 43        debug_tx,
 44        trigger,
 45        project,
 46        can_collect_data,
 47        is_open_source,
 48        ..
 49    }: EditPredictionModelInput,
 50    capture_data: Option<Vec<StoredEvent>>,
 51    cx: &mut Context<EditPredictionStore>,
 52) -> Task<Result<Option<EditPredictionResult>>> {
 53    let settings = &all_language_settings(None, cx).edit_predictions;
 54    let provider = settings.provider;
 55    let custom_server_settings = match provider {
 56        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 57        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 58            settings.open_ai_compatible_api.clone()
 59        }
 60        _ => None,
 61    };
 62
 63    let http_client = cx.http_client();
 64    let buffer_snapshotted_at = Instant::now();
 65    let raw_config = store.zeta2_raw_config().cloned();
 66    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 67    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 68
 69    let excerpt_path: Arc<Path> = snapshot
 70        .file()
 71        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 72        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 73
 74    let repo_url = if can_collect_data {
 75        let buffer_id = buffer.read(cx).remote_id();
 76        project
 77            .read(cx)
 78            .git_store()
 79            .read(cx)
 80            .repository_and_path_for_buffer_id(buffer_id, cx)
 81            .and_then(|(repo, _)| repo.read(cx).default_remote_url())
 82    } else {
 83        None
 84    };
 85
 86    let client = store.client.clone();
 87    let llm_token = store.llm_token.clone();
 88    let organization_id = store
 89        .user_store
 90        .read(cx)
 91        .current_organization()
 92        .map(|organization| organization.id.clone());
 93    let app_version = AppVersion::global(cx);
 94
 95    struct Prediction {
 96        prompt_input: ZetaPromptInput,
 97        buffer: Entity<Buffer>,
 98        snapshot: BufferSnapshot,
 99        edits: Vec<(Range<Anchor>, Arc<str>)>,
100        cursor_position: Option<PredictedCursorPosition>,
101        received_response_at: Instant,
102        editable_range_in_buffer: Range<usize>,
103        model_version: Option<String>,
104    }
105
106    let request_task = cx.background_spawn({
107        async move {
108            let zeta_version = raw_config
109                .as_ref()
110                .map(|config| config.format)
111                .unwrap_or(ZetaFormat::default());
112
113            let cursor_offset = position.to_offset(&snapshot);
114            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115                &snapshot,
116                related_files,
117                events,
118                excerpt_path,
119                cursor_offset,
120                preferred_experiment,
121                is_open_source,
122                can_collect_data,
123                repo_url,
124            );
125
126            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127                return Err(anyhow::anyhow!("prompt contains special tokens"));
128            }
129
130            if let Some(debug_tx) = &debug_tx {
131                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132                debug_tx
133                    .unbounded_send(DebugEvent::EditPredictionStarted(
134                        EditPredictionStartedDebugEvent {
135                            buffer: buffer.downgrade(),
136                            prompt: Some(prompt),
137                            position,
138                        },
139                    ))
140                    .ok();
141            }
142
143            log::trace!("Sending edit prediction request");
144
145            let (request_id, output, model_version, usage) =
146                if let Some(custom_settings) = &custom_server_settings {
147                    let max_tokens = custom_settings.max_output_tokens * 4;
148
149                    match custom_settings.prompt_format {
150                        EditPredictionPromptFormat::Zeta => {
151                            let ranges = &prompt_input.excerpt_ranges;
152                            let editable_range_in_excerpt = ranges.editable_350.clone();
153                            let prompt = zeta1::format_zeta1_from_input(
154                                &prompt_input,
155                                editable_range_in_excerpt.clone(),
156                                ranges.editable_350_context_150.clone(),
157                            );
158                            let stop_tokens = vec![
159                                EDITABLE_REGION_END_MARKER.to_string(),
160                                format!("{EDITABLE_REGION_END_MARKER}\n"),
161                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163                            ];
164
165                            let (response_text, request_id) = send_custom_server_request(
166                                provider,
167                                custom_settings,
168                                prompt,
169                                max_tokens,
170                                stop_tokens,
171                                open_ai_compatible_api_key.clone(),
172                                &http_client,
173                            )
174                            .await?;
175
176                            let request_id = EditPredictionId(request_id.into());
177                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
178
179                            (
180                                request_id,
181                                Some(editable_range_in_excerpt).zip(output_text),
182                                None,
183                                None,
184                            )
185                        }
186                        EditPredictionPromptFormat::Zeta2 => {
187                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
188                            let prefill = get_prefill(&prompt_input, zeta_version);
189                            let prompt = format!("{prompt}{prefill}");
190
191                            let (response_text, request_id) = send_custom_server_request(
192                                provider,
193                                custom_settings,
194                                prompt,
195                                max_tokens,
196                                vec![],
197                                open_ai_compatible_api_key.clone(),
198                                &http_client,
199                            )
200                            .await?;
201
202                            let request_id = EditPredictionId(request_id.into());
203                            let output_text = if response_text.is_empty() {
204                                None
205                            } else {
206                                let output = format!("{prefill}{response_text}");
207                                Some(parse_zeta2_model_output(
208                                    &output,
209                                    zeta_version,
210                                    &prompt_input,
211                                )?)
212                            };
213
214                            (request_id, output_text, None, None)
215                        }
216                        _ => anyhow::bail!("unsupported prompt format"),
217                    }
218                } else if let Some(config) = &raw_config {
219                    let prompt = format_zeta_prompt(&prompt_input, config.format);
220                    let prefill = get_prefill(&prompt_input, config.format);
221                    let prompt = format!("{prompt}{prefill}");
222                    let environment = config
223                        .environment
224                        .clone()
225                        .or_else(|| Some(config.format.to_string().to_lowercase()));
226                    let request = RawCompletionRequest {
227                        model: config.model_id.clone().unwrap_or_default(),
228                        prompt,
229                        temperature: None,
230                        stop: vec![],
231                        max_tokens: Some(2048),
232                        environment,
233                    };
234
235                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
236                        request,
237                        client,
238                        None,
239                        llm_token,
240                        organization_id,
241                        app_version,
242                    )
243                    .await?;
244
245                    let request_id = EditPredictionId(response.id.clone().into());
246                    let output = if let Some(choice) = response.choices.pop() {
247                        let response = &choice.text;
248                        let output = format!("{prefill}{response}");
249                        Some(parse_zeta2_model_output(
250                            &output,
251                            config.format,
252                            &prompt_input,
253                        )?)
254                    } else {
255                        None
256                    };
257
258                    (request_id, output, None, usage)
259                } else {
260                    // Use V3 endpoint - server handles model/version selection and suffix stripping
261                    let (response, usage) = EditPredictionStore::send_v3_request(
262                        prompt_input.clone(),
263                        client,
264                        llm_token,
265                        organization_id,
266                        app_version,
267                        trigger,
268                    )
269                    .await?;
270
271                    let request_id = EditPredictionId(response.request_id.into());
272                    let output_text = Some(response.output).filter(|s| !s.is_empty());
273                    let model_version = response.model_version;
274
275                    (
276                        request_id,
277                        Some(response.editable_range).zip(output_text),
278                        model_version,
279                        usage,
280                    )
281                };
282
283            let received_response_at = Instant::now();
284
285            log::trace!("Got edit prediction response");
286
287            let Some((editable_range_in_excerpt, mut output_text)) = output else {
288                return Ok(((request_id, None), None));
289            };
290
291            let editable_range_in_buffer = editable_range_in_excerpt.start
292                + full_context_offset_range.start
293                ..editable_range_in_excerpt.end + full_context_offset_range.start;
294
295            let mut old_text = snapshot
296                .text_for_range(editable_range_in_buffer.clone())
297                .collect::<String>();
298
299            // Client-side cursor marker processing (applies to both raw and v3 responses)
300            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
301            if let Some(offset) = cursor_offset_in_output {
302                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
303                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
304            }
305
306            if let Some(debug_tx) = &debug_tx {
307                debug_tx
308                    .unbounded_send(DebugEvent::EditPredictionFinished(
309                        EditPredictionFinishedDebugEvent {
310                            buffer: buffer.downgrade(),
311                            position,
312                            model_output: Some(output_text.clone()),
313                        },
314                    ))
315                    .ok();
316            }
317
318            if !output_text.is_empty() && !output_text.ends_with('\n') {
319                output_text.push('\n');
320            }
321            if !old_text.is_empty() && !old_text.ends_with('\n') {
322                old_text.push('\n');
323            }
324
325            let (edits, cursor_position) = compute_edits_and_cursor_position(
326                old_text,
327                &output_text,
328                editable_range_in_buffer.start,
329                cursor_offset_in_output,
330                &snapshot,
331            );
332
333            anyhow::Ok((
334                (
335                    request_id,
336                    Some(Prediction {
337                        prompt_input,
338                        buffer,
339                        snapshot: snapshot.clone(),
340                        edits,
341                        cursor_position,
342                        received_response_at,
343                        editable_range_in_buffer,
344                        model_version,
345                    }),
346                ),
347                usage,
348            ))
349        }
350    });
351
352    cx.spawn(async move |this, cx| {
353        let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
354
355        let Some(Prediction {
356            prompt_input: inputs,
357            buffer: edited_buffer,
358            snapshot: edited_buffer_snapshot,
359            edits,
360            cursor_position,
361            received_response_at,
362            editable_range_in_buffer,
363            model_version,
364        }) = prediction
365        else {
366            return Ok(Some(EditPredictionResult {
367                id,
368                prediction: Err(EditPredictionRejectReason::Empty),
369            }));
370        };
371
372        if can_collect_data {
373            let weak_this = this.clone();
374            let id = id.clone();
375            let edited_buffer = edited_buffer.clone();
376            let edited_buffer_snapshot = edited_buffer_snapshot.clone();
377            let example_task = capture_data.and_then(|stored_events| {
378                cx.update(|cx| {
379                    crate::capture_example(
380                        project.clone(),
381                        edited_buffer.clone(),
382                        position,
383                        stored_events,
384                        false,
385                        cx,
386                    )
387                })
388            });
389            cx.spawn(async move |cx| {
390                let example_spec = if let Some(task) = example_task {
391                    task.await.ok()
392                } else {
393                    None
394                };
395
396                weak_this
397                    .update(cx, |this, cx| {
398                        this.enqueue_settled_prediction(
399                            id.clone(),
400                            &project,
401                            &edited_buffer,
402                            &edited_buffer_snapshot,
403                            editable_range_in_buffer,
404                            example_spec,
405                            cx,
406                        );
407                    })
408                    .ok();
409            })
410            .detach();
411        }
412
413        Ok(Some(
414            EditPredictionResult::new(
415                id,
416                &edited_buffer,
417                &edited_buffer_snapshot,
418                edits.into(),
419                cursor_position,
420                buffer_snapshotted_at,
421                received_response_at,
422                inputs,
423                model_version,
424                cx,
425            )
426            .await,
427        ))
428    })
429}
430
431fn handle_api_response<T>(
432    this: &WeakEntity<EditPredictionStore>,
433    response: Result<(T, Option<client::EditPredictionUsage>)>,
434    cx: &mut gpui::AsyncApp,
435) -> Result<T> {
436    match response {
437        Ok((data, usage)) => {
438            if let Some(usage) = usage {
439                this.update(cx, |this, cx| {
440                    this.user_store.update(cx, |user_store, cx| {
441                        user_store.update_edit_prediction_usage(usage, cx);
442                    });
443                })
444                .ok();
445            }
446            Ok(data)
447        }
448        Err(err) => {
449            if err.is::<ZedUpdateRequiredError>() {
450                cx.update(|cx| {
451                    this.update(cx, |this, _cx| {
452                        this.update_required = true;
453                    })
454                    .ok();
455
456                    let error_message: SharedString = err.to_string().into();
457                    show_app_notification(
458                        NotificationId::unique::<ZedUpdateRequiredError>(),
459                        cx,
460                        move |cx| {
461                            cx.new(|cx| {
462                                ErrorMessagePrompt::new(error_message.clone(), cx)
463                                    .with_link_button("Update Zed", "https://zed.dev/releases")
464                            })
465                        },
466                    );
467                });
468            }
469            Err(err)
470        }
471    }
472}
473
474pub fn zeta2_prompt_input(
475    snapshot: &language::BufferSnapshot,
476    related_files: Vec<zeta_prompt::RelatedFile>,
477    events: Vec<Arc<zeta_prompt::Event>>,
478    excerpt_path: Arc<Path>,
479    cursor_offset: usize,
480    preferred_experiment: Option<String>,
481    is_open_source: bool,
482    can_collect_data: bool,
483    repo_url: Option<String>,
484) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
485    let cursor_point = cursor_offset.to_point(snapshot);
486
487    let (full_context, full_context_offset_range, excerpt_ranges) =
488        compute_excerpt_ranges(cursor_point, snapshot);
489
490    let full_context_start_offset = full_context_offset_range.start;
491    let full_context_start_row = full_context.start.row;
492
493    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
494
495    let prompt_input = zeta_prompt::ZetaPromptInput {
496        cursor_path: excerpt_path,
497        cursor_excerpt: snapshot
498            .text_for_range(full_context)
499            .collect::<String>()
500            .into(),
501        cursor_offset_in_excerpt,
502        excerpt_start_row: Some(full_context_start_row),
503        events,
504        related_files,
505        excerpt_ranges,
506        experiment: preferred_experiment,
507        in_open_source_repo: is_open_source,
508        can_collect_data,
509        repo_url,
510    };
511    (full_context_offset_range, prompt_input)
512}
513
514pub(crate) fn edit_prediction_accepted(
515    store: &EditPredictionStore,
516    current_prediction: CurrentEditPrediction,
517    cx: &App,
518) {
519    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
520    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
521        return;
522    }
523
524    let request_id = current_prediction.prediction.id.to_string();
525    let model_version = current_prediction.prediction.model_version;
526    let require_auth = custom_accept_url.is_none();
527    let client = store.client.clone();
528    let llm_token = store.llm_token.clone();
529    let organization_id = store
530        .user_store
531        .read(cx)
532        .current_organization()
533        .map(|organization| organization.id.clone());
534    let app_version = AppVersion::global(cx);
535
536    cx.background_spawn(async move {
537        let url = if let Some(accept_edits_url) = custom_accept_url {
538            gpui::http_client::Url::parse(&accept_edits_url)?
539        } else {
540            client
541                .http_client()
542                .build_zed_llm_url("/predict_edits/accept", &[])?
543        };
544
545        let response = EditPredictionStore::send_api_request::<()>(
546            move |builder| {
547                let req = builder.uri(url.as_ref()).body(
548                    serde_json::to_string(&AcceptEditPredictionBody {
549                        request_id: request_id.clone(),
550                        model_version: model_version.clone(),
551                    })?
552                    .into(),
553                );
554                Ok(req?)
555            },
556            client,
557            llm_token,
558            organization_id,
559            app_version,
560            require_auth,
561        )
562        .await;
563
564        response?;
565        anyhow::Ok(())
566    })
567    .detach_and_log_err(cx);
568}
569
570pub fn compute_edits(
571    old_text: String,
572    new_text: &str,
573    offset: usize,
574    snapshot: &BufferSnapshot,
575) -> Vec<(Range<Anchor>, Arc<str>)> {
576    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
577}
578
579pub fn compute_edits_and_cursor_position(
580    old_text: String,
581    new_text: &str,
582    offset: usize,
583    cursor_offset_in_new_text: Option<usize>,
584    snapshot: &BufferSnapshot,
585) -> (
586    Vec<(Range<Anchor>, Arc<str>)>,
587    Option<PredictedCursorPosition>,
588) {
589    let diffs = text_diff(&old_text, new_text);
590
591    // Delta represents the cumulative change in byte count from all preceding edits.
592    // new_offset = old_offset + delta, so old_offset = new_offset - delta
593    let mut delta: isize = 0;
594    let mut cursor_position: Option<PredictedCursorPosition> = None;
595    let buffer_len = snapshot.len();
596
597    let edits = diffs
598        .iter()
599        .map(|(raw_old_range, new_text)| {
600            // Compute cursor position if it falls within or before this edit.
601            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
602                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
603                let edit_end_in_new = edit_start_in_new + new_text.len();
604
605                if cursor_offset < edit_start_in_new {
606                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
607                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
608                    cursor_position = Some(PredictedCursorPosition::at_anchor(
609                        snapshot.anchor_after(buffer_offset),
610                    ));
611                } else if cursor_offset < edit_end_in_new {
612                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
613                    let offset_within_insertion = cursor_offset - edit_start_in_new;
614                    cursor_position = Some(PredictedCursorPosition::new(
615                        snapshot.anchor_before(buffer_offset),
616                        offset_within_insertion,
617                    ));
618                }
619
620                delta += new_text.len() as isize - raw_old_range.len() as isize;
621            }
622
623            // Compute the edit with prefix/suffix trimming.
624            let mut old_range = raw_old_range.clone();
625            let old_slice = &old_text[old_range.clone()];
626
627            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
628            let suffix_len = common_prefix(
629                old_slice[prefix_len..].chars().rev(),
630                new_text[prefix_len..].chars().rev(),
631            );
632
633            old_range.start += offset;
634            old_range.end += offset;
635            old_range.start += prefix_len;
636            old_range.end -= suffix_len;
637
638            old_range.start = old_range.start.min(buffer_len);
639            old_range.end = old_range.end.min(buffer_len);
640
641            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
642            let range = if old_range.is_empty() {
643                let anchor = snapshot.anchor_after(old_range.start);
644                anchor..anchor
645            } else {
646                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
647            };
648            (range, new_text)
649        })
650        .collect();
651
652    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
653        let cursor_in_old = (cursor_offset as isize - delta) as usize;
654        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
655        cursor_position = Some(PredictedCursorPosition::at_anchor(
656            snapshot.anchor_after(buffer_offset),
657        ));
658    }
659
660    (edits, cursor_position)
661}
662
663fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
664    a.zip(b)
665        .take_while(|(a, b)| a == b)
666        .map(|(a, _)| a.len_utf8())
667        .sum()
668}