zeta.rs

  1use crate::cursor_excerpt::compute_excerpt_ranges;
  2use crate::prediction::EditPredictionResult;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
  6};
  7use anyhow::{Context as _, Result};
  8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use edit_prediction_types::PredictedCursorPosition;
 11use futures::AsyncReadExt as _;
 12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
 13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
 14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 15use release_channel::AppVersion;
 16use settings::EditPredictionPromptFormat;
 17use text::{Anchor, Bias};
 18
 19use std::env;
 20use std::ops::Range;
 21use std::{path::Path, sync::Arc, time::Instant};
 22use zeta_prompt::{
 23    CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
 24    prompt_input_contains_special_tokens,
 25    zeta1::{self, EDITABLE_REGION_END_MARKER},
 26};
 27
 28pub fn request_prediction_with_zeta(
 29    store: &mut EditPredictionStore,
 30    EditPredictionModelInput {
 31        buffer,
 32        snapshot,
 33        position,
 34        related_files,
 35        events,
 36        debug_tx,
 37        trigger,
 38        project,
 39        can_collect_data,
 40        is_open_source,
 41        ..
 42    }: EditPredictionModelInput,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let settings = &all_language_settings(None, cx).edit_predictions;
 46    let provider = settings.provider;
 47    let custom_server_settings = match provider {
 48        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 49        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 50            settings.open_ai_compatible_api.clone()
 51        }
 52        _ => None,
 53    };
 54
 55    let http_client = cx.http_client();
 56    let buffer_snapshotted_at = Instant::now();
 57    let raw_config = store.zeta2_raw_config().cloned();
 58    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 59
 60    let excerpt_path: Arc<Path> = snapshot
 61        .file()
 62        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 63        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 64
 65    let client = store.client.clone();
 66    let llm_token = store.llm_token.clone();
 67    let app_version = AppVersion::global(cx);
 68
 69    let request_task = cx.background_spawn({
 70        async move {
 71            let zeta_version = raw_config
 72                .as_ref()
 73                .map(|config| config.format)
 74                .unwrap_or(ZetaFormat::default());
 75
 76            let cursor_offset = position.to_offset(&snapshot);
 77            let editable_range_in_excerpt: Range<usize>;
 78            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
 79                &snapshot,
 80                related_files,
 81                events,
 82                excerpt_path,
 83                cursor_offset,
 84                preferred_experiment,
 85                is_open_source,
 86                can_collect_data,
 87            );
 88
 89            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
 90                return Ok((None, None));
 91            }
 92
 93            if let Some(debug_tx) = &debug_tx {
 94                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 95                debug_tx
 96                    .unbounded_send(DebugEvent::EditPredictionStarted(
 97                        EditPredictionStartedDebugEvent {
 98                            buffer: buffer.downgrade(),
 99                            prompt: Some(prompt),
100                            position,
101                        },
102                    ))
103                    .ok();
104            }
105
106            log::trace!("Sending edit prediction request");
107
108            let (request_id, output_text, model_version, usage) =
109                if let Some(custom_settings) = &custom_server_settings {
110                    let max_tokens = custom_settings.max_output_tokens * 4;
111
112                    match custom_settings.prompt_format {
113                        EditPredictionPromptFormat::Zeta => {
114                            let ranges = &prompt_input.excerpt_ranges;
115                            let prompt = zeta1::format_zeta1_from_input(
116                                &prompt_input,
117                                ranges.editable_350.clone(),
118                                ranges.editable_350_context_150.clone(),
119                            );
120                            editable_range_in_excerpt = ranges.editable_350.clone();
121                            let stop_tokens = vec![
122                                EDITABLE_REGION_END_MARKER.to_string(),
123                                format!("{EDITABLE_REGION_END_MARKER}\n"),
124                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
125                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
126                            ];
127
128                            let (response_text, request_id) = send_custom_server_request(
129                                provider,
130                                custom_settings,
131                                prompt,
132                                max_tokens,
133                                stop_tokens,
134                                &http_client,
135                            )
136                            .await?;
137
138                            let request_id = EditPredictionId(request_id.into());
139                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
140
141                            (request_id, output_text, None, None)
142                        }
143                        EditPredictionPromptFormat::Zeta2 => {
144                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
145                            let prefill = get_prefill(&prompt_input, zeta_version);
146                            let prompt = format!("{prompt}{prefill}");
147
148                            editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
149                                zeta_version,
150                                &prompt_input.excerpt_ranges,
151                            )
152                            .0;
153
154                            let (response_text, request_id) = send_custom_server_request(
155                                provider,
156                                custom_settings,
157                                prompt,
158                                max_tokens,
159                                vec![],
160                                &http_client,
161                            )
162                            .await?;
163
164                            let request_id = EditPredictionId(request_id.into());
165                            let output_text = if response_text.is_empty() {
166                                None
167                            } else {
168                                let output = format!("{prefill}{response_text}");
169                                Some(clean_zeta2_model_output(&output, zeta_version).to_string())
170                            };
171
172                            (request_id, output_text, None, None)
173                        }
174                        _ => anyhow::bail!("unsupported prompt format"),
175                    }
176                } else if let Some(config) = &raw_config {
177                    let prompt = format_zeta_prompt(&prompt_input, config.format);
178                    let prefill = get_prefill(&prompt_input, config.format);
179                    let prompt = format!("{prompt}{prefill}");
180                    let request = RawCompletionRequest {
181                        model: config.model_id.clone().unwrap_or_default(),
182                        prompt,
183                        temperature: None,
184                        stop: vec![],
185                        max_tokens: Some(2048),
186                        environment: Some(config.format.to_string().to_lowercase()),
187                    };
188
189                    editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
190                        config.format,
191                        &prompt_input.excerpt_ranges,
192                    )
193                    .1;
194
195                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
196                        request,
197                        client,
198                        None,
199                        llm_token,
200                        app_version,
201                    )
202                    .await?;
203
204                    let request_id = EditPredictionId(response.id.clone().into());
205                    let output_text = response.choices.pop().map(|choice| {
206                        let response = &choice.text;
207                        let output = format!("{prefill}{response}");
208                        clean_zeta2_model_output(&output, config.format).to_string()
209                    });
210
211                    (request_id, output_text, None, usage)
212                } else {
213                    // Use V3 endpoint - server handles model/version selection and suffix stripping
214                    let (response, usage) = EditPredictionStore::send_v3_request(
215                        prompt_input.clone(),
216                        client,
217                        llm_token,
218                        app_version,
219                        trigger,
220                    )
221                    .await?;
222
223                    let request_id = EditPredictionId(response.request_id.into());
224                    let output_text = if response.output.is_empty() {
225                        None
226                    } else {
227                        Some(response.output)
228                    };
229                    editable_range_in_excerpt = response.editable_range;
230                    let model_version = response.model_version;
231
232                    (request_id, output_text, model_version, usage)
233                };
234
235            let received_response_at = Instant::now();
236
237            log::trace!("Got edit prediction response");
238
239            let Some(mut output_text) = output_text else {
240                return Ok((Some((request_id, None, model_version)), usage));
241            };
242
243            // Client-side cursor marker processing (applies to both raw and v3 responses)
244            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
245            if let Some(offset) = cursor_offset_in_output {
246                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
247                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
248            }
249
250            if let Some(debug_tx) = &debug_tx {
251                debug_tx
252                    .unbounded_send(DebugEvent::EditPredictionFinished(
253                        EditPredictionFinishedDebugEvent {
254                            buffer: buffer.downgrade(),
255                            position,
256                            model_output: Some(output_text.clone()),
257                        },
258                    ))
259                    .ok();
260            }
261
262            let editable_range_in_buffer = editable_range_in_excerpt.start
263                + full_context_offset_range.start
264                ..editable_range_in_excerpt.end + full_context_offset_range.start;
265
266            let mut old_text = snapshot
267                .text_for_range(editable_range_in_buffer.clone())
268                .collect::<String>();
269
270            if !output_text.is_empty() && !output_text.ends_with('\n') {
271                output_text.push('\n');
272            }
273            if !old_text.is_empty() && !old_text.ends_with('\n') {
274                old_text.push('\n');
275            }
276
277            let (edits, cursor_position) = compute_edits_and_cursor_position(
278                old_text,
279                &output_text,
280                editable_range_in_buffer.start,
281                cursor_offset_in_output,
282                &snapshot,
283            );
284
285            anyhow::Ok((
286                Some((
287                    request_id,
288                    Some((
289                        prompt_input,
290                        buffer,
291                        snapshot.clone(),
292                        edits,
293                        cursor_position,
294                        received_response_at,
295                        editable_range_in_buffer,
296                    )),
297                    model_version,
298                )),
299                usage,
300            ))
301        }
302    });
303
304    cx.spawn(async move |this, cx| {
305        let Some((id, prediction, model_version)) =
306            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
307        else {
308            return Ok(None);
309        };
310
311        let Some((
312            inputs,
313            edited_buffer,
314            edited_buffer_snapshot,
315            edits,
316            cursor_position,
317            received_response_at,
318            editable_range_in_buffer,
319        )) = prediction
320        else {
321            return Ok(Some(EditPredictionResult {
322                id,
323                prediction: Err(EditPredictionRejectReason::Empty),
324            }));
325        };
326
327        if can_collect_data {
328            this.update(cx, |this, cx| {
329                this.enqueue_settled_prediction(
330                    id.clone(),
331                    &project,
332                    &edited_buffer,
333                    &edited_buffer_snapshot,
334                    editable_range_in_buffer,
335                    cx,
336                );
337            })
338            .ok();
339        }
340
341        Ok(Some(
342            EditPredictionResult::new(
343                id,
344                &edited_buffer,
345                &edited_buffer_snapshot,
346                edits.into(),
347                cursor_position,
348                buffer_snapshotted_at,
349                received_response_at,
350                inputs,
351                model_version,
352                cx,
353            )
354            .await,
355        ))
356    })
357}
358
359pub fn zeta2_prompt_input(
360    snapshot: &language::BufferSnapshot,
361    related_files: Vec<zeta_prompt::RelatedFile>,
362    events: Vec<Arc<zeta_prompt::Event>>,
363    excerpt_path: Arc<Path>,
364    cursor_offset: usize,
365    preferred_experiment: Option<String>,
366    is_open_source: bool,
367    can_collect_data: bool,
368) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
369    let cursor_point = cursor_offset.to_point(snapshot);
370
371    let (full_context, full_context_offset_range, excerpt_ranges) =
372        compute_excerpt_ranges(cursor_point, snapshot);
373
374    let related_files = crate::filter_redundant_excerpts(
375        related_files,
376        excerpt_path.as_ref(),
377        full_context.start.row..full_context.end.row,
378    );
379
380    let full_context_start_offset = full_context_offset_range.start;
381    let full_context_start_row = full_context.start.row;
382
383    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
384
385    let prompt_input = zeta_prompt::ZetaPromptInput {
386        cursor_path: excerpt_path,
387        cursor_excerpt: snapshot
388            .text_for_range(full_context)
389            .collect::<String>()
390            .into(),
391        cursor_offset_in_excerpt,
392        excerpt_start_row: Some(full_context_start_row),
393        events,
394        related_files,
395        excerpt_ranges,
396        experiment: preferred_experiment,
397        in_open_source_repo: is_open_source,
398        can_collect_data,
399    };
400    (full_context_offset_range, prompt_input)
401}
402
403pub(crate) async fn send_custom_server_request(
404    provider: settings::EditPredictionProvider,
405    settings: &OpenAiCompatibleEditPredictionSettings,
406    prompt: String,
407    max_tokens: u32,
408    stop_tokens: Vec<String>,
409    http_client: &Arc<dyn http_client::HttpClient>,
410) -> Result<(String, String)> {
411    match provider {
412        settings::EditPredictionProvider::Ollama => {
413            let response =
414                ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
415                    .await?;
416            Ok((response.response, response.created_at))
417        }
418        _ => {
419            let request = RawCompletionRequest {
420                model: settings.model.clone(),
421                prompt,
422                max_tokens: Some(max_tokens),
423                temperature: None,
424                stop: stop_tokens
425                    .into_iter()
426                    .map(std::borrow::Cow::Owned)
427                    .collect(),
428                environment: None,
429            };
430
431            let request_body = serde_json::to_string(&request)?;
432            let http_request = http_client::Request::builder()
433                .method(http_client::Method::POST)
434                .uri(settings.api_url.as_ref())
435                .header("Content-Type", "application/json")
436                .body(http_client::AsyncBody::from(request_body))?;
437
438            let mut response = http_client.send(http_request).await?;
439            let status = response.status();
440
441            if !status.is_success() {
442                let mut body = String::new();
443                response.body_mut().read_to_string(&mut body).await?;
444                anyhow::bail!("custom server error: {} - {}", status, body);
445            }
446
447            let mut body = String::new();
448            response.body_mut().read_to_string(&mut body).await?;
449
450            let parsed: RawCompletionResponse =
451                serde_json::from_str(&body).context("Failed to parse completion response")?;
452            let text = parsed
453                .choices
454                .into_iter()
455                .next()
456                .map(|choice| choice.text)
457                .unwrap_or_default();
458            Ok((text, parsed.id))
459        }
460    }
461}
462
463pub(crate) fn edit_prediction_accepted(
464    store: &EditPredictionStore,
465    current_prediction: CurrentEditPrediction,
466    cx: &App,
467) {
468    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
469    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
470        return;
471    }
472
473    let request_id = current_prediction.prediction.id.to_string();
474    let model_version = current_prediction.prediction.model_version;
475    let require_auth = custom_accept_url.is_none();
476    let client = store.client.clone();
477    let llm_token = store.llm_token.clone();
478    let app_version = AppVersion::global(cx);
479
480    cx.background_spawn(async move {
481        let url = if let Some(accept_edits_url) = custom_accept_url {
482            gpui::http_client::Url::parse(&accept_edits_url)?
483        } else {
484            client
485                .http_client()
486                .build_zed_llm_url("/predict_edits/accept", &[])?
487        };
488
489        let response = EditPredictionStore::send_api_request::<()>(
490            move |builder| {
491                let req = builder.uri(url.as_ref()).body(
492                    serde_json::to_string(&AcceptEditPredictionBody {
493                        request_id: request_id.clone(),
494                        model_version: model_version.clone(),
495                    })?
496                    .into(),
497                );
498                Ok(req?)
499            },
500            client,
501            llm_token,
502            app_version,
503            require_auth,
504        )
505        .await;
506
507        response?;
508        anyhow::Ok(())
509    })
510    .detach_and_log_err(cx);
511}
512
513pub fn compute_edits(
514    old_text: String,
515    new_text: &str,
516    offset: usize,
517    snapshot: &BufferSnapshot,
518) -> Vec<(Range<Anchor>, Arc<str>)> {
519    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
520}
521
522pub fn compute_edits_and_cursor_position(
523    old_text: String,
524    new_text: &str,
525    offset: usize,
526    cursor_offset_in_new_text: Option<usize>,
527    snapshot: &BufferSnapshot,
528) -> (
529    Vec<(Range<Anchor>, Arc<str>)>,
530    Option<PredictedCursorPosition>,
531) {
532    let diffs = text_diff(&old_text, new_text);
533
534    // Delta represents the cumulative change in byte count from all preceding edits.
535    // new_offset = old_offset + delta, so old_offset = new_offset - delta
536    let mut delta: isize = 0;
537    let mut cursor_position: Option<PredictedCursorPosition> = None;
538    let buffer_len = snapshot.len();
539
540    let edits = diffs
541        .iter()
542        .map(|(raw_old_range, new_text)| {
543            // Compute cursor position if it falls within or before this edit.
544            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
545                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
546                let edit_end_in_new = edit_start_in_new + new_text.len();
547
548                if cursor_offset < edit_start_in_new {
549                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
550                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
551                    cursor_position = Some(PredictedCursorPosition::at_anchor(
552                        snapshot.anchor_after(buffer_offset),
553                    ));
554                } else if cursor_offset < edit_end_in_new {
555                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
556                    let offset_within_insertion = cursor_offset - edit_start_in_new;
557                    cursor_position = Some(PredictedCursorPosition::new(
558                        snapshot.anchor_before(buffer_offset),
559                        offset_within_insertion,
560                    ));
561                }
562
563                delta += new_text.len() as isize - raw_old_range.len() as isize;
564            }
565
566            // Compute the edit with prefix/suffix trimming.
567            let mut old_range = raw_old_range.clone();
568            let old_slice = &old_text[old_range.clone()];
569
570            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
571            let suffix_len = common_prefix(
572                old_slice[prefix_len..].chars().rev(),
573                new_text[prefix_len..].chars().rev(),
574            );
575
576            old_range.start += offset;
577            old_range.end += offset;
578            old_range.start += prefix_len;
579            old_range.end -= suffix_len;
580
581            old_range.start = old_range.start.min(buffer_len);
582            old_range.end = old_range.end.min(buffer_len);
583
584            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
585            let range = if old_range.is_empty() {
586                let anchor = snapshot.anchor_after(old_range.start);
587                anchor..anchor
588            } else {
589                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
590            };
591            (range, new_text)
592        })
593        .collect();
594
595    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
596        let cursor_in_old = (cursor_offset as isize - delta) as usize;
597        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
598        cursor_position = Some(PredictedCursorPosition::at_anchor(
599            snapshot.anchor_after(buffer_offset),
600        ));
601    }
602
603    (edits, cursor_position)
604}
605
606fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
607    a.zip(b)
608        .take_while(|(a, b)| a == b)
609        .map(|(a, _)| a.len_utf8())
610        .sum()
611}