zeta.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  3    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
  4    ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
  5    prediction::EditPredictionResult,
  6};
  7use anyhow::Result;
  8use cloud_llm_client::{
  9    AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
 10};
 11use edit_prediction_types::PredictedCursorPosition;
 12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
 13use language::{
 14    Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
 15    text_diff,
 16};
 17use release_channel::AppVersion;
 18use settings::EditPredictionPromptFormat;
 19use text::{Anchor, Bias};
 20use ui::SharedString;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::{ParsedOutput, ZetaPromptInput};
 23
 24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 25use zeta_prompt::{
 26    CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
 27    prompt_input_contains_special_tokens,
 28    zeta1::{self, EDITABLE_REGION_END_MARKER},
 29};
 30
 31use crate::open_ai_compatible::{
 32    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
 33};
 34
 35pub fn request_prediction_with_zeta(
 36    store: &mut EditPredictionStore,
 37    EditPredictionModelInput {
 38        buffer,
 39        snapshot,
 40        position,
 41        related_files,
 42        events,
 43        debug_tx,
 44        trigger,
 45        project,
 46        can_collect_data,
 47        is_open_source,
 48        ..
 49    }: EditPredictionModelInput,
 50    capture_data: Option<Vec<StoredEvent>>,
 51    cx: &mut Context<EditPredictionStore>,
 52) -> Task<Result<Option<EditPredictionResult>>> {
 53    let settings = &all_language_settings(None, cx).edit_predictions;
 54    let provider = settings.provider;
 55    let custom_server_settings = match provider {
 56        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 57        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 58            settings.open_ai_compatible_api.clone()
 59        }
 60        _ => None,
 61    };
 62
 63    let http_client = cx.http_client();
 64    let buffer_snapshotted_at = Instant::now();
 65    let raw_config = store.zeta2_raw_config().cloned();
 66    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 67    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 68
 69    let excerpt_path: Arc<Path> = snapshot
 70        .file()
 71        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 72        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 73
 74    let repo_url = if can_collect_data {
 75        let buffer_id = buffer.read(cx).remote_id();
 76        project
 77            .read(cx)
 78            .git_store()
 79            .read(cx)
 80            .repository_and_path_for_buffer_id(buffer_id, cx)
 81            .and_then(|(repo, _)| repo.read(cx).default_remote_url())
 82    } else {
 83        None
 84    };
 85
 86    let client = store.client.clone();
 87    let llm_token = store.llm_token.clone();
 88    let organization_id = store
 89        .user_store
 90        .read(cx)
 91        .current_organization()
 92        .map(|organization| organization.id.clone());
 93    let app_version = AppVersion::global(cx);
 94
 95    struct Prediction {
 96        prompt_input: ZetaPromptInput,
 97        buffer: Entity<Buffer>,
 98        snapshot: BufferSnapshot,
 99        edits: Vec<(Range<Anchor>, Arc<str>)>,
100        cursor_position: Option<PredictedCursorPosition>,
101        received_response_at: Instant,
102        editable_range_in_buffer: Range<usize>,
103        model_version: Option<String>,
104    }
105
106    let request_task = cx.background_spawn({
107        async move {
108            let zeta_version = raw_config
109                .as_ref()
110                .map(|config| config.format)
111                .unwrap_or(ZetaFormat::default());
112
113            let cursor_offset = position.to_offset(&snapshot);
114            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115                &snapshot,
116                related_files,
117                events,
118                excerpt_path,
119                cursor_offset,
120                preferred_experiment,
121                is_open_source,
122                can_collect_data,
123                repo_url,
124            );
125
126            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127                return Err(anyhow::anyhow!("prompt contains special tokens"));
128            }
129
130            if let Some(debug_tx) = &debug_tx {
131                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132                debug_tx
133                    .unbounded_send(DebugEvent::EditPredictionStarted(
134                        EditPredictionStartedDebugEvent {
135                            buffer: buffer.downgrade(),
136                            prompt: Some(prompt),
137                            position,
138                        },
139                    ))
140                    .ok();
141            }
142
143            log::trace!("Sending edit prediction request");
144
145            let (request_id, output, model_version, usage) =
146                if let Some(custom_settings) = &custom_server_settings {
147                    let max_tokens = custom_settings.max_output_tokens * 4;
148
149                    match custom_settings.prompt_format {
150                        EditPredictionPromptFormat::Zeta => {
151                            let ranges = &prompt_input.excerpt_ranges;
152                            let editable_range_in_excerpt = ranges.editable_350.clone();
153                            let prompt = zeta1::format_zeta1_from_input(
154                                &prompt_input,
155                                editable_range_in_excerpt.clone(),
156                                ranges.editable_350_context_150.clone(),
157                            );
158                            let stop_tokens = vec![
159                                EDITABLE_REGION_END_MARKER.to_string(),
160                                format!("{EDITABLE_REGION_END_MARKER}\n"),
161                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163                            ];
164
165                            let (response_text, request_id) = send_custom_server_request(
166                                provider,
167                                custom_settings,
168                                prompt,
169                                max_tokens,
170                                stop_tokens,
171                                open_ai_compatible_api_key.clone(),
172                                &http_client,
173                            )
174                            .await?;
175
176                            let request_id = EditPredictionId(request_id.into());
177                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
178                            let parsed_output = output_text.map(|text| ParsedOutput {
179                                new_editable_region: text,
180                                range_in_excerpt: editable_range_in_excerpt,
181                            });
182
183                            (request_id, parsed_output, None, None)
184                        }
185                        EditPredictionPromptFormat::Zeta2 => {
186                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
187                            let prefill = get_prefill(&prompt_input, zeta_version);
188                            let prompt = format!("{prompt}{prefill}");
189
190                            let (response_text, request_id) = send_custom_server_request(
191                                provider,
192                                custom_settings,
193                                prompt,
194                                max_tokens,
195                                vec![],
196                                open_ai_compatible_api_key.clone(),
197                                &http_client,
198                            )
199                            .await?;
200
201                            let request_id = EditPredictionId(request_id.into());
202                            let output_text = if response_text.is_empty() {
203                                None
204                            } else {
205                                let output = format!("{prefill}{response_text}");
206                                Some(parse_zeta2_model_output(
207                                    &output,
208                                    zeta_version,
209                                    &prompt_input,
210                                )?)
211                            };
212
213                            (request_id, output_text, None, None)
214                        }
215                        _ => anyhow::bail!("unsupported prompt format"),
216                    }
217                } else if let Some(config) = &raw_config {
218                    let prompt = format_zeta_prompt(&prompt_input, config.format);
219                    let prefill = get_prefill(&prompt_input, config.format);
220                    let prompt = format!("{prompt}{prefill}");
221                    let environment = config
222                        .environment
223                        .clone()
224                        .or_else(|| Some(config.format.to_string().to_lowercase()));
225                    let request = RawCompletionRequest {
226                        model: config.model_id.clone().unwrap_or_default(),
227                        prompt,
228                        temperature: None,
229                        stop: vec![],
230                        max_tokens: Some(2048),
231                        environment,
232                    };
233
234                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
235                        request,
236                        client,
237                        None,
238                        llm_token,
239                        organization_id,
240                        app_version,
241                    )
242                    .await?;
243
244                    let request_id = EditPredictionId(response.id.clone().into());
245                    let output = if let Some(choice) = response.choices.pop() {
246                        let response = &choice.text;
247                        let output = format!("{prefill}{response}");
248                        Some(parse_zeta2_model_output(
249                            &output,
250                            config.format,
251                            &prompt_input,
252                        )?)
253                    } else {
254                        None
255                    };
256
257                    (request_id, output, None, usage)
258                } else {
259                    // Use V3 endpoint - server handles model/version selection and suffix stripping
260                    let (response, usage) = EditPredictionStore::send_v3_request(
261                        prompt_input.clone(),
262                        client,
263                        llm_token,
264                        organization_id,
265                        app_version,
266                        trigger,
267                    )
268                    .await?;
269
270                    let request_id = EditPredictionId(response.request_id.into());
271                    let output_text = Some(response.output).filter(|s| !s.is_empty());
272                    let model_version = response.model_version;
273                    let parsed_output = ParsedOutput {
274                        new_editable_region: output_text.unwrap_or_default(),
275                        range_in_excerpt: response.editable_range,
276                    };
277
278                    (request_id, Some(parsed_output), model_version, usage)
279                };
280
281            let received_response_at = Instant::now();
282
283            log::trace!("Got edit prediction response");
284
285            let Some(ParsedOutput {
286                new_editable_region: mut output_text,
287                range_in_excerpt: editable_range_in_excerpt,
288            }) = output
289            else {
290                return Ok(((request_id, None), None));
291            };
292
293            let editable_range_in_buffer = editable_range_in_excerpt.start
294                + full_context_offset_range.start
295                ..editable_range_in_excerpt.end + full_context_offset_range.start;
296
297            let mut old_text = snapshot
298                .text_for_range(editable_range_in_buffer.clone())
299                .collect::<String>();
300
301            // Client-side cursor marker processing (applies to both raw and v3 responses)
302            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
303            if let Some(offset) = cursor_offset_in_output {
304                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
305                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
306            }
307
308            if let Some(debug_tx) = &debug_tx {
309                debug_tx
310                    .unbounded_send(DebugEvent::EditPredictionFinished(
311                        EditPredictionFinishedDebugEvent {
312                            buffer: buffer.downgrade(),
313                            position,
314                            model_output: Some(output_text.clone()),
315                        },
316                    ))
317                    .ok();
318            }
319
320            if !output_text.is_empty() && !output_text.ends_with('\n') {
321                output_text.push('\n');
322            }
323            if !old_text.is_empty() && !old_text.ends_with('\n') {
324                old_text.push('\n');
325            }
326
327            let (edits, cursor_position) = compute_edits_and_cursor_position(
328                old_text,
329                &output_text,
330                editable_range_in_buffer.start,
331                cursor_offset_in_output,
332                &snapshot,
333            );
334
335            anyhow::Ok((
336                (
337                    request_id,
338                    Some(Prediction {
339                        prompt_input,
340                        buffer,
341                        snapshot: snapshot.clone(),
342                        edits,
343                        cursor_position,
344                        received_response_at,
345                        editable_range_in_buffer,
346                        model_version,
347                    }),
348                ),
349                usage,
350            ))
351        }
352    });
353
354    cx.spawn(async move |this, cx| {
355        let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
356
357        let Some(Prediction {
358            prompt_input: inputs,
359            buffer: edited_buffer,
360            snapshot: edited_buffer_snapshot,
361            edits,
362            cursor_position,
363            received_response_at,
364            editable_range_in_buffer,
365            model_version,
366        }) = prediction
367        else {
368            return Ok(Some(EditPredictionResult {
369                id,
370                prediction: Err(EditPredictionRejectReason::Empty),
371            }));
372        };
373
374        if can_collect_data {
375            let weak_this = this.clone();
376            let id = id.clone();
377            let edited_buffer = edited_buffer.clone();
378            let edited_buffer_snapshot = edited_buffer_snapshot.clone();
379            let example_task = capture_data.and_then(|stored_events| {
380                cx.update(|cx| {
381                    crate::capture_example(
382                        project.clone(),
383                        edited_buffer.clone(),
384                        position,
385                        stored_events,
386                        false,
387                        cx,
388                    )
389                })
390            });
391            cx.spawn(async move |cx| {
392                let example_spec = if let Some(task) = example_task {
393                    task.await.ok()
394                } else {
395                    None
396                };
397
398                weak_this
399                    .update(cx, |this, cx| {
400                        this.enqueue_settled_prediction(
401                            id.clone(),
402                            &project,
403                            &edited_buffer,
404                            &edited_buffer_snapshot,
405                            editable_range_in_buffer,
406                            example_spec,
407                            cx,
408                        );
409                    })
410                    .ok();
411            })
412            .detach();
413        }
414
415        Ok(Some(
416            EditPredictionResult::new(
417                id,
418                &edited_buffer,
419                &edited_buffer_snapshot,
420                edits.into(),
421                cursor_position,
422                buffer_snapshotted_at,
423                received_response_at,
424                inputs,
425                model_version,
426                cx,
427            )
428            .await,
429        ))
430    })
431}
432
433fn handle_api_response<T>(
434    this: &WeakEntity<EditPredictionStore>,
435    response: Result<(T, Option<client::EditPredictionUsage>)>,
436    cx: &mut gpui::AsyncApp,
437) -> Result<T> {
438    match response {
439        Ok((data, usage)) => {
440            if let Some(usage) = usage {
441                this.update(cx, |this, cx| {
442                    this.user_store.update(cx, |user_store, cx| {
443                        user_store.update_edit_prediction_usage(usage, cx);
444                    });
445                })
446                .ok();
447            }
448            Ok(data)
449        }
450        Err(err) => {
451            if err.is::<ZedUpdateRequiredError>() {
452                cx.update(|cx| {
453                    this.update(cx, |this, _cx| {
454                        this.update_required = true;
455                    })
456                    .ok();
457
458                    let error_message: SharedString = err.to_string().into();
459                    show_app_notification(
460                        NotificationId::unique::<ZedUpdateRequiredError>(),
461                        cx,
462                        move |cx| {
463                            cx.new(|cx| {
464                                ErrorMessagePrompt::new(error_message.clone(), cx)
465                                    .with_link_button("Update Zed", "https://zed.dev/releases")
466                            })
467                        },
468                    );
469                });
470            }
471            Err(err)
472        }
473    }
474}
475
476pub fn zeta2_prompt_input(
477    snapshot: &language::BufferSnapshot,
478    related_files: Vec<zeta_prompt::RelatedFile>,
479    events: Vec<Arc<zeta_prompt::Event>>,
480    excerpt_path: Arc<Path>,
481    cursor_offset: usize,
482    preferred_experiment: Option<String>,
483    is_open_source: bool,
484    can_collect_data: bool,
485    repo_url: Option<String>,
486) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
487    let cursor_point = cursor_offset.to_point(snapshot);
488
489    let (full_context, full_context_offset_range, excerpt_ranges) =
490        compute_excerpt_ranges(cursor_point, snapshot);
491
492    let full_context_start_offset = full_context_offset_range.start;
493    let full_context_start_row = full_context.start.row;
494
495    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
496
497    let prompt_input = zeta_prompt::ZetaPromptInput {
498        cursor_path: excerpt_path,
499        cursor_excerpt: snapshot
500            .text_for_range(full_context)
501            .collect::<String>()
502            .into(),
503        cursor_offset_in_excerpt,
504        excerpt_start_row: Some(full_context_start_row),
505        events,
506        related_files,
507        excerpt_ranges,
508        experiment: preferred_experiment,
509        in_open_source_repo: is_open_source,
510        can_collect_data,
511        repo_url,
512    };
513    (full_context_offset_range, prompt_input)
514}
515
516pub(crate) fn edit_prediction_accepted(
517    store: &EditPredictionStore,
518    current_prediction: CurrentEditPrediction,
519    cx: &App,
520) {
521    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
522    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
523        return;
524    }
525
526    let request_id = current_prediction.prediction.id.to_string();
527    let model_version = current_prediction.prediction.model_version;
528    let require_auth = custom_accept_url.is_none();
529    let client = store.client.clone();
530    let llm_token = store.llm_token.clone();
531    let organization_id = store
532        .user_store
533        .read(cx)
534        .current_organization()
535        .map(|organization| organization.id.clone());
536    let app_version = AppVersion::global(cx);
537
538    cx.background_spawn(async move {
539        let url = if let Some(accept_edits_url) = custom_accept_url {
540            gpui::http_client::Url::parse(&accept_edits_url)?
541        } else {
542            client
543                .http_client()
544                .build_zed_llm_url("/predict_edits/accept", &[])?
545        };
546
547        let response = EditPredictionStore::send_api_request::<()>(
548            move |builder| {
549                let req = builder.uri(url.as_ref()).body(
550                    serde_json::to_string(&AcceptEditPredictionBody {
551                        request_id: request_id.clone(),
552                        model_version: model_version.clone(),
553                    })?
554                    .into(),
555                );
556                Ok(req?)
557            },
558            client,
559            llm_token,
560            organization_id,
561            app_version,
562            require_auth,
563        )
564        .await;
565
566        response?;
567        anyhow::Ok(())
568    })
569    .detach_and_log_err(cx);
570}
571
572pub fn compute_edits(
573    old_text: String,
574    new_text: &str,
575    offset: usize,
576    snapshot: &BufferSnapshot,
577) -> Vec<(Range<Anchor>, Arc<str>)> {
578    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
579}
580
581pub fn compute_edits_and_cursor_position(
582    old_text: String,
583    new_text: &str,
584    offset: usize,
585    cursor_offset_in_new_text: Option<usize>,
586    snapshot: &BufferSnapshot,
587) -> (
588    Vec<(Range<Anchor>, Arc<str>)>,
589    Option<PredictedCursorPosition>,
590) {
591    let diffs = text_diff(&old_text, new_text);
592
593    // Delta represents the cumulative change in byte count from all preceding edits.
594    // new_offset = old_offset + delta, so old_offset = new_offset - delta
595    let mut delta: isize = 0;
596    let mut cursor_position: Option<PredictedCursorPosition> = None;
597    let buffer_len = snapshot.len();
598
599    let edits = diffs
600        .iter()
601        .map(|(raw_old_range, new_text)| {
602            // Compute cursor position if it falls within or before this edit.
603            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
604                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
605                let edit_end_in_new = edit_start_in_new + new_text.len();
606
607                if cursor_offset < edit_start_in_new {
608                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
609                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
610                    cursor_position = Some(PredictedCursorPosition::at_anchor(
611                        snapshot.anchor_after(buffer_offset),
612                    ));
613                } else if cursor_offset < edit_end_in_new {
614                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
615                    let offset_within_insertion = cursor_offset - edit_start_in_new;
616                    cursor_position = Some(PredictedCursorPosition::new(
617                        snapshot.anchor_before(buffer_offset),
618                        offset_within_insertion,
619                    ));
620                }
621
622                delta += new_text.len() as isize - raw_old_range.len() as isize;
623            }
624
625            // Compute the edit with prefix/suffix trimming.
626            let mut old_range = raw_old_range.clone();
627            let old_slice = &old_text[old_range.clone()];
628
629            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
630            let suffix_len = common_prefix(
631                old_slice[prefix_len..].chars().rev(),
632                new_text[prefix_len..].chars().rev(),
633            );
634
635            old_range.start += offset;
636            old_range.end += offset;
637            old_range.start += prefix_len;
638            old_range.end -= suffix_len;
639
640            old_range.start = old_range.start.min(buffer_len);
641            old_range.end = old_range.end.min(buffer_len);
642
643            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
644            let range = if old_range.is_empty() {
645                let anchor = snapshot.anchor_after(old_range.start);
646                anchor..anchor
647            } else {
648                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
649            };
650            (range, new_text)
651        })
652        .collect();
653
654    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
655        let cursor_in_old = (cursor_offset as isize - delta) as usize;
656        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
657        cursor_position = Some(PredictedCursorPosition::at_anchor(
658            snapshot.anchor_after(buffer_offset),
659        ));
660    }
661
662    (edits, cursor_position)
663}
664
665fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
666    a.zip(b)
667        .take_while(|(a, b)| a == b)
668        .map(|(a, _)| a.len_utf8())
669        .sum()
670}