zeta.rs

  1use crate::cursor_excerpt::compute_excerpt_ranges;
  2use crate::prediction::EditPredictionResult;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
  6};
  7use anyhow::{Context as _, Result};
  8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use edit_prediction_types::PredictedCursorPosition;
 11use futures::AsyncReadExt as _;
 12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
 13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
 14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 15use release_channel::AppVersion;
 16use text::{Anchor, Bias};
 17
 18use std::env;
 19use std::ops::Range;
 20use std::{path::Path, sync::Arc, time::Instant};
 21use zeta_prompt::{
 22    CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
 23    format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
 24    zeta1::{self, EDITABLE_REGION_END_MARKER},
 25};
 26
 27pub fn request_prediction_with_zeta(
 28    store: &mut EditPredictionStore,
 29    EditPredictionModelInput {
 30        buffer,
 31        snapshot,
 32        position,
 33        related_files,
 34        events,
 35        debug_tx,
 36        trigger,
 37        project,
 38        can_collect_data,
 39        is_open_source,
 40        ..
 41    }: EditPredictionModelInput,
 42    preferred_model: Option<EditPredictionModelKind>,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let settings = &all_language_settings(None, cx).edit_predictions;
 46    let provider = settings.provider;
 47    let custom_server_settings = match provider {
 48        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 49        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 50            settings.open_ai_compatible_api.clone()
 51        }
 52        _ => None,
 53    };
 54
 55    let http_client = cx.http_client();
 56    let buffer_snapshotted_at = Instant::now();
 57    let raw_config = store.zeta2_raw_config().cloned();
 58
 59    let excerpt_path: Arc<Path> = snapshot
 60        .file()
 61        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 62        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 63
 64    let client = store.client.clone();
 65    let llm_token = store.llm_token.clone();
 66    let app_version = AppVersion::global(cx);
 67
 68    let request_task = cx.background_spawn({
 69        async move {
 70            let zeta_version = raw_config
 71                .as_ref()
 72                .map(|config| config.format)
 73                .unwrap_or(ZetaFormat::default());
 74
 75            let cursor_offset = position.to_offset(&snapshot);
 76            let editable_range_in_excerpt: Range<usize>;
 77            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
 78                &snapshot,
 79                related_files,
 80                events,
 81                excerpt_path,
 82                cursor_offset,
 83                zeta_version,
 84                preferred_model,
 85                is_open_source,
 86                can_collect_data,
 87            );
 88
 89            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
 90                return Ok((None, None));
 91            }
 92
 93            let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1);
 94            let excerpt_ranges = prompt_input
 95                .excerpt_ranges
 96                .as_ref()
 97                .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?;
 98
 99            if let Some(debug_tx) = &debug_tx {
100                let prompt = if is_zeta1 {
101                    zeta1::format_zeta1_from_input(
102                        &prompt_input,
103                        excerpt_ranges.editable_350.clone(),
104                        excerpt_ranges.editable_350_context_150.clone(),
105                    )
106                } else {
107                    format_zeta_prompt(&prompt_input, zeta_version)
108                };
109                debug_tx
110                    .unbounded_send(DebugEvent::EditPredictionStarted(
111                        EditPredictionStartedDebugEvent {
112                            buffer: buffer.downgrade(),
113                            prompt: Some(prompt),
114                            position,
115                        },
116                    ))
117                    .ok();
118            }
119
120            log::trace!("Sending edit prediction request");
121
122            let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
123                &custom_server_settings
124            {
125                let max_tokens = custom_settings.max_output_tokens * 4;
126
127                if is_zeta1 {
128                    let ranges = excerpt_ranges;
129                    let prompt = zeta1::format_zeta1_from_input(
130                        &prompt_input,
131                        ranges.editable_350.clone(),
132                        ranges.editable_350_context_150.clone(),
133                    );
134                    editable_range_in_excerpt = ranges.editable_350.clone();
135                    let stop_tokens = vec![
136                        EDITABLE_REGION_END_MARKER.to_string(),
137                        format!("{EDITABLE_REGION_END_MARKER}\n"),
138                        format!("{EDITABLE_REGION_END_MARKER}\n\n"),
139                        format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
140                    ];
141
142                    let (response_text, request_id) = send_custom_server_request(
143                        provider,
144                        custom_settings,
145                        prompt,
146                        max_tokens,
147                        stop_tokens,
148                        &http_client,
149                    )
150                    .await?;
151
152                    let request_id = EditPredictionId(request_id.into());
153                    let output_text = zeta1::clean_zeta1_model_output(&response_text);
154
155                    (request_id, output_text, None, None)
156                } else {
157                    let prompt = format_zeta_prompt(&prompt_input, zeta_version);
158                    let prefill = get_prefill(&prompt_input, zeta_version);
159                    let prompt = format!("{prompt}{prefill}");
160
161                    editable_range_in_excerpt = prompt_input
162                        .excerpt_ranges
163                        .as_ref()
164                        .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0)
165                        .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
166
167                    let (response_text, request_id) = send_custom_server_request(
168                        provider,
169                        custom_settings,
170                        prompt,
171                        max_tokens,
172                        vec![],
173                        &http_client,
174                    )
175                    .await?;
176
177                    let request_id = EditPredictionId(request_id.into());
178                    let output_text = if response_text.is_empty() {
179                        None
180                    } else {
181                        let output = format!("{prefill}{response_text}");
182                        Some(clean_zeta2_model_output(&output, zeta_version).to_string())
183                    };
184
185                    (request_id, output_text, None, None)
186                }
187            } else if let Some(config) = &raw_config {
188                let prompt = format_zeta_prompt(&prompt_input, config.format);
189                let prefill = get_prefill(&prompt_input, config.format);
190                let prompt = format!("{prompt}{prefill}");
191                let request = RawCompletionRequest {
192                    model: config.model_id.clone().unwrap_or_default(),
193                    prompt,
194                    temperature: None,
195                    stop: vec![],
196                    max_tokens: Some(2048),
197                    environment: Some(config.format.to_string().to_lowercase()),
198                };
199
200                editable_range_in_excerpt = prompt_input
201                    .excerpt_ranges
202                    .as_ref()
203                    .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1)
204                    .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
205
206                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
207                    request,
208                    client,
209                    None,
210                    llm_token,
211                    app_version,
212                )
213                .await?;
214
215                let request_id = EditPredictionId(response.id.clone().into());
216                let output_text = response.choices.pop().map(|choice| {
217                    let response = &choice.text;
218                    let output = format!("{prefill}{response}");
219                    clean_zeta2_model_output(&output, config.format).to_string()
220                });
221
222                (request_id, output_text, None, usage)
223            } else {
224                // Use V3 endpoint - server handles model/version selection and suffix stripping
225                let (response, usage) = EditPredictionStore::send_v3_request(
226                    prompt_input.clone(),
227                    client,
228                    llm_token,
229                    app_version,
230                    trigger,
231                )
232                .await?;
233
234                let request_id = EditPredictionId(response.request_id.into());
235                let output_text = if response.output.is_empty() {
236                    None
237                } else {
238                    Some(response.output)
239                };
240                editable_range_in_excerpt = response.editable_range;
241                let model_version = response.model_version;
242
243                (request_id, output_text, model_version, usage)
244            };
245
246            let received_response_at = Instant::now();
247
248            log::trace!("Got edit prediction response");
249
250            let Some(mut output_text) = output_text else {
251                return Ok((Some((request_id, None, model_version)), usage));
252            };
253
254            // Client-side cursor marker processing (applies to both raw and v3 responses)
255            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
256            if let Some(offset) = cursor_offset_in_output {
257                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
258                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
259            }
260
261            if let Some(debug_tx) = &debug_tx {
262                debug_tx
263                    .unbounded_send(DebugEvent::EditPredictionFinished(
264                        EditPredictionFinishedDebugEvent {
265                            buffer: buffer.downgrade(),
266                            position,
267                            model_output: Some(output_text.clone()),
268                        },
269                    ))
270                    .ok();
271            }
272
273            let editable_range_in_buffer = editable_range_in_excerpt.start
274                + full_context_offset_range.start
275                ..editable_range_in_excerpt.end + full_context_offset_range.start;
276
277            let mut old_text = snapshot
278                .text_for_range(editable_range_in_buffer.clone())
279                .collect::<String>();
280
281            if !output_text.is_empty() && !output_text.ends_with('\n') {
282                output_text.push('\n');
283            }
284            if !old_text.is_empty() && !old_text.ends_with('\n') {
285                old_text.push('\n');
286            }
287
288            let (edits, cursor_position) = compute_edits_and_cursor_position(
289                old_text,
290                &output_text,
291                editable_range_in_buffer.start,
292                cursor_offset_in_output,
293                &snapshot,
294            );
295
296            anyhow::Ok((
297                Some((
298                    request_id,
299                    Some((
300                        prompt_input,
301                        buffer,
302                        snapshot.clone(),
303                        edits,
304                        cursor_position,
305                        received_response_at,
306                        editable_range_in_buffer,
307                    )),
308                    model_version,
309                )),
310                usage,
311            ))
312        }
313    });
314
315    cx.spawn(async move |this, cx| {
316        let Some((id, prediction, model_version)) =
317            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
318        else {
319            return Ok(None);
320        };
321
322        let Some((
323            inputs,
324            edited_buffer,
325            edited_buffer_snapshot,
326            edits,
327            cursor_position,
328            received_response_at,
329            editable_range_in_buffer,
330        )) = prediction
331        else {
332            return Ok(Some(EditPredictionResult {
333                id,
334                prediction: Err(EditPredictionRejectReason::Empty),
335            }));
336        };
337
338        if can_collect_data {
339            this.update(cx, |this, cx| {
340                this.enqueue_settled_prediction(
341                    id.clone(),
342                    &project,
343                    &edited_buffer,
344                    &edited_buffer_snapshot,
345                    editable_range_in_buffer,
346                    cx,
347                );
348            })
349            .ok();
350        }
351
352        Ok(Some(
353            EditPredictionResult::new(
354                id,
355                &edited_buffer,
356                &edited_buffer_snapshot,
357                edits.into(),
358                cursor_position,
359                buffer_snapshotted_at,
360                received_response_at,
361                inputs,
362                model_version,
363                cx,
364            )
365            .await,
366        ))
367    })
368}
369
370pub fn zeta2_prompt_input(
371    snapshot: &language::BufferSnapshot,
372    related_files: Vec<zeta_prompt::RelatedFile>,
373    events: Vec<Arc<zeta_prompt::Event>>,
374    excerpt_path: Arc<Path>,
375    cursor_offset: usize,
376    zeta_format: ZetaFormat,
377    preferred_model: Option<EditPredictionModelKind>,
378    is_open_source: bool,
379    can_collect_data: bool,
380) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
381    let cursor_point = cursor_offset.to_point(snapshot);
382
383    let (full_context, full_context_offset_range, excerpt_ranges) =
384        compute_excerpt_ranges(cursor_point, snapshot);
385
386    let related_files = crate::filter_redundant_excerpts(
387        related_files,
388        excerpt_path.as_ref(),
389        full_context.start.row..full_context.end.row,
390    );
391
392    let full_context_start_offset = full_context_offset_range.start;
393    let full_context_start_row = full_context.start.row;
394
395    let editable_offset_range = match preferred_model {
396        Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(),
397        _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0,
398    };
399
400    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
401
402    let prompt_input = zeta_prompt::ZetaPromptInput {
403        cursor_path: excerpt_path,
404        cursor_excerpt: snapshot
405            .text_for_range(full_context)
406            .collect::<String>()
407            .into(),
408        editable_range_in_excerpt: editable_offset_range,
409        cursor_offset_in_excerpt,
410        excerpt_start_row: Some(full_context_start_row),
411        events,
412        related_files,
413        excerpt_ranges: Some(excerpt_ranges),
414        preferred_model,
415        in_open_source_repo: is_open_source,
416        can_collect_data,
417    };
418    (full_context_offset_range, prompt_input)
419}
420
421pub(crate) async fn send_custom_server_request(
422    provider: settings::EditPredictionProvider,
423    settings: &OpenAiCompatibleEditPredictionSettings,
424    prompt: String,
425    max_tokens: u32,
426    stop_tokens: Vec<String>,
427    http_client: &Arc<dyn http_client::HttpClient>,
428) -> Result<(String, String)> {
429    match provider {
430        settings::EditPredictionProvider::Ollama => {
431            let response =
432                ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
433                    .await?;
434            Ok((response.response, response.created_at))
435        }
436        _ => {
437            let request = RawCompletionRequest {
438                model: settings.model.clone(),
439                prompt,
440                max_tokens: Some(max_tokens),
441                temperature: None,
442                stop: stop_tokens
443                    .into_iter()
444                    .map(std::borrow::Cow::Owned)
445                    .collect(),
446                environment: None,
447            };
448
449            let request_body = serde_json::to_string(&request)?;
450            let http_request = http_client::Request::builder()
451                .method(http_client::Method::POST)
452                .uri(settings.api_url.as_ref())
453                .header("Content-Type", "application/json")
454                .body(http_client::AsyncBody::from(request_body))?;
455
456            let mut response = http_client.send(http_request).await?;
457            let status = response.status();
458
459            if !status.is_success() {
460                let mut body = String::new();
461                response.body_mut().read_to_string(&mut body).await?;
462                anyhow::bail!("custom server error: {} - {}", status, body);
463            }
464
465            let mut body = String::new();
466            response.body_mut().read_to_string(&mut body).await?;
467
468            let parsed: RawCompletionResponse =
469                serde_json::from_str(&body).context("Failed to parse completion response")?;
470            let text = parsed
471                .choices
472                .into_iter()
473                .next()
474                .map(|choice| choice.text)
475                .unwrap_or_default();
476            Ok((text, parsed.id))
477        }
478    }
479}
480
481pub(crate) fn edit_prediction_accepted(
482    store: &EditPredictionStore,
483    current_prediction: CurrentEditPrediction,
484    cx: &App,
485) {
486    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
487    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
488        return;
489    }
490
491    let request_id = current_prediction.prediction.id.to_string();
492    let model_version = current_prediction.prediction.model_version;
493    let require_auth = custom_accept_url.is_none();
494    let client = store.client.clone();
495    let llm_token = store.llm_token.clone();
496    let app_version = AppVersion::global(cx);
497
498    cx.background_spawn(async move {
499        let url = if let Some(accept_edits_url) = custom_accept_url {
500            gpui::http_client::Url::parse(&accept_edits_url)?
501        } else {
502            client
503                .http_client()
504                .build_zed_llm_url("/predict_edits/accept", &[])?
505        };
506
507        let response = EditPredictionStore::send_api_request::<()>(
508            move |builder| {
509                let req = builder.uri(url.as_ref()).body(
510                    serde_json::to_string(&AcceptEditPredictionBody {
511                        request_id: request_id.clone(),
512                        model_version: model_version.clone(),
513                    })?
514                    .into(),
515                );
516                Ok(req?)
517            },
518            client,
519            llm_token,
520            app_version,
521            require_auth,
522        )
523        .await;
524
525        response?;
526        anyhow::Ok(())
527    })
528    .detach_and_log_err(cx);
529}
530
531pub fn compute_edits(
532    old_text: String,
533    new_text: &str,
534    offset: usize,
535    snapshot: &BufferSnapshot,
536) -> Vec<(Range<Anchor>, Arc<str>)> {
537    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
538}
539
540pub fn compute_edits_and_cursor_position(
541    old_text: String,
542    new_text: &str,
543    offset: usize,
544    cursor_offset_in_new_text: Option<usize>,
545    snapshot: &BufferSnapshot,
546) -> (
547    Vec<(Range<Anchor>, Arc<str>)>,
548    Option<PredictedCursorPosition>,
549) {
550    let diffs = text_diff(&old_text, new_text);
551
552    // Delta represents the cumulative change in byte count from all preceding edits.
553    // new_offset = old_offset + delta, so old_offset = new_offset - delta
554    let mut delta: isize = 0;
555    let mut cursor_position: Option<PredictedCursorPosition> = None;
556    let buffer_len = snapshot.len();
557
558    let edits = diffs
559        .iter()
560        .map(|(raw_old_range, new_text)| {
561            // Compute cursor position if it falls within or before this edit.
562            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
563                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
564                let edit_end_in_new = edit_start_in_new + new_text.len();
565
566                if cursor_offset < edit_start_in_new {
567                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
568                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
569                    cursor_position = Some(PredictedCursorPosition::at_anchor(
570                        snapshot.anchor_after(buffer_offset),
571                    ));
572                } else if cursor_offset < edit_end_in_new {
573                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
574                    let offset_within_insertion = cursor_offset - edit_start_in_new;
575                    cursor_position = Some(PredictedCursorPosition::new(
576                        snapshot.anchor_before(buffer_offset),
577                        offset_within_insertion,
578                    ));
579                }
580
581                delta += new_text.len() as isize - raw_old_range.len() as isize;
582            }
583
584            // Compute the edit with prefix/suffix trimming.
585            let mut old_range = raw_old_range.clone();
586            let old_slice = &old_text[old_range.clone()];
587
588            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
589            let suffix_len = common_prefix(
590                old_slice[prefix_len..].chars().rev(),
591                new_text[prefix_len..].chars().rev(),
592            );
593
594            old_range.start += offset;
595            old_range.end += offset;
596            old_range.start += prefix_len;
597            old_range.end -= suffix_len;
598
599            old_range.start = old_range.start.min(buffer_len);
600            old_range.end = old_range.end.min(buffer_len);
601
602            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
603            let range = if old_range.is_empty() {
604                let anchor = snapshot.anchor_after(old_range.start);
605                anchor..anchor
606            } else {
607                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
608            };
609            (range, new_text)
610        })
611        .collect();
612
613    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
614        let cursor_in_old = (cursor_offset as isize - delta) as usize;
615        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
616        cursor_position = Some(PredictedCursorPosition::at_anchor(
617            snapshot.anchor_after(buffer_offset),
618        ));
619    }
620
621    (edits, cursor_position)
622}
623
624fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
625    a.zip(b)
626        .take_while(|(a, b)| a == b)
627        .map(|(a, _)| a.len_utf8())
628        .sum()
629}