zeta2.rs

  1use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
  2use crate::prediction::EditPredictionResult;
  3use crate::zeta1::compute_edits_and_cursor_position;
  4use crate::{
  5    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  6    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  7};
  8use anyhow::Result;
  9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 11use gpui::{App, Task, prelude::*};
 12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 13use release_channel::AppVersion;
 14
 15use std::env;
 16use std::{path::Path, sync::Arc, time::Instant};
 17use zeta_prompt::{
 18    CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
 19    format_zeta_prompt, get_prefill,
 20};
 21
 22pub const MAX_CONTEXT_TOKENS: usize = 350;
 23
 24pub fn max_editable_tokens(format: ZetaFormat) -> usize {
 25    match format {
 26        ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
 27        ZetaFormat::V0114180EditableRegion => 180,
 28        ZetaFormat::V0120GitMergeMarkers => 180,
 29        ZetaFormat::V0131GitMergeMarkersPrefix => 180,
 30        ZetaFormat::V0211Prefill => 180,
 31        ZetaFormat::V0211SeedCoder => 180,
 32    }
 33}
 34
 35pub fn request_prediction_with_zeta2(
 36    store: &mut EditPredictionStore,
 37    EditPredictionModelInput {
 38        buffer,
 39        snapshot,
 40        position,
 41        related_files,
 42        events,
 43        debug_tx,
 44        trigger,
 45        project,
 46        ..
 47    }: EditPredictionModelInput,
 48    preferred_model: Option<EditPredictionModelKind>,
 49    cx: &mut Context<EditPredictionStore>,
 50) -> Task<Result<Option<EditPredictionResult>>> {
 51    let buffer_snapshotted_at = Instant::now();
 52    let raw_config = store.zeta2_raw_config().cloned();
 53
 54    let excerpt_path: Arc<Path> = snapshot
 55        .file()
 56        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 57        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 58
 59    let client = store.client.clone();
 60    let llm_token = store.llm_token.clone();
 61    let app_version = AppVersion::global(cx);
 62
 63    let is_open_source = snapshot
 64        .file()
 65        .map_or(false, |file| store.is_file_open_source(&project, file, cx))
 66        && events.iter().all(|event| event.in_open_source_repo())
 67        && related_files.iter().all(|file| file.in_open_source_repo);
 68
 69    let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
 70
 71    let request_task = cx.background_spawn({
 72        async move {
 73            let zeta_version = raw_config
 74                .as_ref()
 75                .map(|config| config.format)
 76                .unwrap_or(ZetaFormat::default());
 77
 78            let cursor_offset = position.to_offset(&snapshot);
 79            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 80                &snapshot,
 81                related_files,
 82                events,
 83                excerpt_path,
 84                cursor_offset,
 85                zeta_version,
 86                preferred_model,
 87                is_open_source,
 88                can_collect_data,
 89            );
 90
 91            if let Some(debug_tx) = &debug_tx {
 92                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 93                debug_tx
 94                    .unbounded_send(DebugEvent::EditPredictionStarted(
 95                        EditPredictionStartedDebugEvent {
 96                            buffer: buffer.downgrade(),
 97                            prompt: Some(prompt),
 98                            position,
 99                        },
100                    ))
101                    .ok();
102            }
103
104            log::trace!("Sending edit prediction request");
105
106            let (request_id, output_text, usage) = if let Some(config) = &raw_config {
107                let prompt = format_zeta_prompt(&prompt_input, config.format);
108                let prefill = get_prefill(&prompt_input, config.format);
109                let prompt = format!("{prompt}{prefill}");
110                let request = RawCompletionRequest {
111                    model: config.model_id.clone().unwrap_or_default(),
112                    prompt,
113                    temperature: None,
114                    stop: vec![],
115                    max_tokens: Some(2048),
116                    environment: Some(config.format.to_string().to_lowercase()),
117                };
118
119                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
120                    request,
121                    client,
122                    None,
123                    llm_token,
124                    app_version,
125                )
126                .await?;
127
128                let request_id = EditPredictionId(response.id.clone().into());
129                let output_text = response.choices.pop().map(|choice| {
130                    let response = &choice.text;
131                    let output = format!("{prefill}{response}");
132                    clean_zeta2_model_output(&output, config.format).to_string()
133                });
134
135                (request_id, output_text, usage)
136            } else {
137                // Use V3 endpoint - server handles model/version selection and suffix stripping
138                let (response, usage) = EditPredictionStore::send_v3_request(
139                    prompt_input.clone(),
140                    client,
141                    llm_token,
142                    app_version,
143                    trigger,
144                )
145                .await?;
146
147                let request_id = EditPredictionId(response.request_id.into());
148                let output_text = if response.output.is_empty() {
149                    None
150                } else {
151                    Some(response.output)
152                };
153                (request_id, output_text, usage)
154            };
155
156            let received_response_at = Instant::now();
157
158            log::trace!("Got edit prediction response");
159
160            let Some(mut output_text) = output_text else {
161                return Ok((Some((request_id, None)), usage));
162            };
163
164            // Client-side cursor marker processing (applies to both raw and v3 responses)
165            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
166            if let Some(offset) = cursor_offset_in_output {
167                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
168                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
169            }
170
171            if let Some(debug_tx) = &debug_tx {
172                debug_tx
173                    .unbounded_send(DebugEvent::EditPredictionFinished(
174                        EditPredictionFinishedDebugEvent {
175                            buffer: buffer.downgrade(),
176                            position,
177                            model_output: Some(output_text.clone()),
178                        },
179                    ))
180                    .ok();
181            }
182
183            let mut old_text = snapshot
184                .text_for_range(editable_offset_range.clone())
185                .collect::<String>();
186
187            if !output_text.is_empty() && !output_text.ends_with('\n') {
188                output_text.push('\n');
189            }
190            if !old_text.is_empty() && !old_text.ends_with('\n') {
191                old_text.push('\n');
192            }
193
194            let (edits, cursor_position) = compute_edits_and_cursor_position(
195                old_text,
196                &output_text,
197                editable_offset_range.start,
198                cursor_offset_in_output,
199                &snapshot,
200            );
201
202            anyhow::Ok((
203                Some((
204                    request_id,
205                    Some((
206                        prompt_input,
207                        buffer,
208                        snapshot.clone(),
209                        edits,
210                        cursor_position,
211                        received_response_at,
212                    )),
213                )),
214                usage,
215            ))
216        }
217    });
218
219    cx.spawn(async move |this, cx| {
220        let Some((id, prediction)) =
221            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
222        else {
223            return Ok(None);
224        };
225
226        let Some((
227            inputs,
228            edited_buffer,
229            edited_buffer_snapshot,
230            edits,
231            cursor_position,
232            received_response_at,
233        )) = prediction
234        else {
235            return Ok(Some(EditPredictionResult {
236                id,
237                prediction: Err(EditPredictionRejectReason::Empty),
238            }));
239        };
240
241        Ok(Some(
242            EditPredictionResult::new(
243                id,
244                &edited_buffer,
245                &edited_buffer_snapshot,
246                edits.into(),
247                cursor_position,
248                buffer_snapshotted_at,
249                received_response_at,
250                inputs,
251                cx,
252            )
253            .await,
254        ))
255    })
256}
257
258pub fn zeta2_prompt_input(
259    snapshot: &language::BufferSnapshot,
260    related_files: Vec<zeta_prompt::RelatedFile>,
261    events: Vec<Arc<zeta_prompt::Event>>,
262    excerpt_path: Arc<Path>,
263    cursor_offset: usize,
264    zeta_format: ZetaFormat,
265    preferred_model: Option<EditPredictionModelKind>,
266    is_open_source: bool,
267    can_collect_data: bool,
268) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
269    let cursor_point = cursor_offset.to_point(snapshot);
270
271    let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
272
273    let related_files = crate::filter_redundant_excerpts(
274        related_files,
275        excerpt_path.as_ref(),
276        full_context.start.row..full_context.end.row,
277    );
278
279    let full_context_start_offset = full_context.start.to_offset(snapshot);
280    let full_context_start_row = full_context.start.row;
281
282    let excerpt_ranges =
283        excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
284
285    let editable_range = match preferred_model {
286        Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
287        _ => match zeta_format {
288            ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
289            _ => &range_points.editable_180,
290        },
291    };
292
293    let editable_offset_range = editable_range.to_offset(snapshot);
294    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
295    let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
296        ..(editable_offset_range.end - full_context_start_offset);
297
298    let prompt_input = zeta_prompt::ZetaPromptInput {
299        cursor_path: excerpt_path,
300        cursor_excerpt: snapshot
301            .text_for_range(full_context)
302            .collect::<String>()
303            .into(),
304        editable_range_in_excerpt,
305        cursor_offset_in_excerpt,
306        excerpt_start_row: Some(full_context_start_row),
307        events,
308        related_files,
309        excerpt_ranges: Some(excerpt_ranges),
310        preferred_model,
311        in_open_source_repo: is_open_source,
312        can_collect_data,
313    };
314    (editable_offset_range, prompt_input)
315}
316
317pub(crate) fn edit_prediction_accepted(
318    store: &EditPredictionStore,
319    current_prediction: CurrentEditPrediction,
320    cx: &App,
321) {
322    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
323    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
324        return;
325    }
326
327    let request_id = current_prediction.prediction.id.to_string();
328    let require_auth = custom_accept_url.is_none();
329    let client = store.client.clone();
330    let llm_token = store.llm_token.clone();
331    let app_version = AppVersion::global(cx);
332
333    cx.background_spawn(async move {
334        let url = if let Some(accept_edits_url) = custom_accept_url {
335            gpui::http_client::Url::parse(&accept_edits_url)?
336        } else {
337            client
338                .http_client()
339                .build_zed_llm_url("/predict_edits/accept", &[])?
340        };
341
342        let response = EditPredictionStore::send_api_request::<()>(
343            move |builder| {
344                let req = builder.uri(url.as_ref()).body(
345                    serde_json::to_string(&AcceptEditPredictionBody {
346                        request_id: request_id.clone(),
347                    })?
348                    .into(),
349                );
350                Ok(req?)
351            },
352            client,
353            llm_token,
354            app_version,
355            require_auth,
356        )
357        .await;
358
359        response?;
360        anyhow::Ok(())
361    })
362    .detach_and_log_err(cx);
363}