zeta2.rs

  1use crate::prediction::EditPredictionResult;
  2use crate::zeta1::compute_edits_and_cursor_position;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  6};
  7use anyhow::{Result, anyhow};
  8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use gpui::{App, Task, prelude::*};
 11use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 12use release_channel::AppVersion;
 13
 14use std::env;
 15use std::{path::Path, sync::Arc, time::Instant};
 16use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt};
 17
 18pub const MAX_CONTEXT_TOKENS: usize = 350;
 19
 20pub fn max_editable_tokens(format: ZetaFormat) -> usize {
 21    match format {
 22        ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
 23        ZetaFormat::V0114180EditableRegion => 180,
 24        ZetaFormat::V0120GitMergeMarkers => 180,
 25        ZetaFormat::V0131GitMergeMarkersPrefix => 180,
 26    }
 27}
 28
 29pub fn request_prediction_with_zeta2(
 30    store: &mut EditPredictionStore,
 31    EditPredictionModelInput {
 32        buffer,
 33        snapshot,
 34        position,
 35        related_files,
 36        events,
 37        debug_tx,
 38        trigger,
 39        ..
 40    }: EditPredictionModelInput,
 41    cx: &mut Context<EditPredictionStore>,
 42) -> Task<Result<Option<EditPredictionResult>>> {
 43    let buffer_snapshotted_at = Instant::now();
 44    let raw_config = store.zeta2_raw_config().cloned();
 45
 46    let Some(excerpt_path) = snapshot
 47        .file()
 48        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 49    else {
 50        return Task::ready(Err(anyhow!("No file path for excerpt")));
 51    };
 52
 53    let client = store.client.clone();
 54    let llm_token = store.llm_token.clone();
 55    let app_version = AppVersion::global(cx);
 56
 57    let request_task = cx.background_spawn({
 58        async move {
 59            let zeta_version = raw_config
 60                .as_ref()
 61                .map(|config| config.format)
 62                .unwrap_or(ZetaFormat::default());
 63
 64            let cursor_offset = position.to_offset(&snapshot);
 65            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 66                &snapshot,
 67                related_files,
 68                events,
 69                excerpt_path,
 70                cursor_offset,
 71                zeta_version,
 72            );
 73
 74            if let Some(debug_tx) = &debug_tx {
 75                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 76                debug_tx
 77                    .unbounded_send(DebugEvent::EditPredictionStarted(
 78                        EditPredictionStartedDebugEvent {
 79                            buffer: buffer.downgrade(),
 80                            prompt: Some(prompt),
 81                            position,
 82                        },
 83                    ))
 84                    .ok();
 85            }
 86
 87            log::trace!("Sending edit prediction request");
 88
 89            let (request_id, output_text, usage) = if let Some(config) = &raw_config {
 90                let prompt = format_zeta_prompt(&prompt_input, config.format);
 91                let request = RawCompletionRequest {
 92                    model: config.model_id.clone().unwrap_or_default(),
 93                    prompt,
 94                    temperature: None,
 95                    stop: vec![],
 96                    max_tokens: Some(2048),
 97                    environment: Some(config.format.to_string().to_lowercase()),
 98                };
 99
100                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
101                    request,
102                    client,
103                    None,
104                    llm_token,
105                    app_version,
106                )
107                .await?;
108
109                let request_id = EditPredictionId(response.id.clone().into());
110                let output_text = response.choices.pop().map(|choice| {
111                    clean_zeta2_model_output(&choice.text, config.format).to_string()
112                });
113
114                (request_id, output_text, usage)
115            } else {
116                // Use V3 endpoint - server handles model/version selection and suffix stripping
117                let (response, usage) = EditPredictionStore::send_v3_request(
118                    prompt_input.clone(),
119                    client,
120                    llm_token,
121                    app_version,
122                    trigger,
123                )
124                .await?;
125
126                let request_id = EditPredictionId(response.request_id.into());
127                let output_text = if response.output.is_empty() {
128                    None
129                } else {
130                    Some(response.output)
131                };
132                (request_id, output_text, usage)
133            };
134
135            let received_response_at = Instant::now();
136
137            log::trace!("Got edit prediction response");
138
139            let Some(mut output_text) = output_text else {
140                return Ok((Some((request_id, None)), usage));
141            };
142
143            // Client-side cursor marker processing (applies to both raw and v3 responses)
144            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
145            if let Some(offset) = cursor_offset_in_output {
146                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
147                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
148            }
149
150            if let Some(debug_tx) = &debug_tx {
151                debug_tx
152                    .unbounded_send(DebugEvent::EditPredictionFinished(
153                        EditPredictionFinishedDebugEvent {
154                            buffer: buffer.downgrade(),
155                            position,
156                            model_output: Some(output_text.clone()),
157                        },
158                    ))
159                    .ok();
160            }
161
162            let mut old_text = snapshot
163                .text_for_range(editable_offset_range.clone())
164                .collect::<String>();
165
166            if !output_text.is_empty() && !output_text.ends_with('\n') {
167                output_text.push('\n');
168            }
169            if !old_text.is_empty() && !old_text.ends_with('\n') {
170                old_text.push('\n');
171            }
172
173            let (edits, cursor_position) = compute_edits_and_cursor_position(
174                old_text,
175                &output_text,
176                editable_offset_range.start,
177                cursor_offset_in_output,
178                &snapshot,
179            );
180
181            anyhow::Ok((
182                Some((
183                    request_id,
184                    Some((
185                        prompt_input,
186                        buffer,
187                        snapshot.clone(),
188                        edits,
189                        cursor_position,
190                        received_response_at,
191                    )),
192                )),
193                usage,
194            ))
195        }
196    });
197
198    cx.spawn(async move |this, cx| {
199        let Some((id, prediction)) =
200            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
201        else {
202            return Ok(None);
203        };
204
205        let Some((
206            inputs,
207            edited_buffer,
208            edited_buffer_snapshot,
209            edits,
210            cursor_position,
211            received_response_at,
212        )) = prediction
213        else {
214            return Ok(Some(EditPredictionResult {
215                id,
216                prediction: Err(EditPredictionRejectReason::Empty),
217            }));
218        };
219
220        Ok(Some(
221            EditPredictionResult::new(
222                id,
223                &edited_buffer,
224                &edited_buffer_snapshot,
225                edits.into(),
226                cursor_position,
227                buffer_snapshotted_at,
228                received_response_at,
229                inputs,
230                cx,
231            )
232            .await,
233        ))
234    })
235}
236
237pub fn zeta2_prompt_input(
238    snapshot: &language::BufferSnapshot,
239    related_files: Vec<zeta_prompt::RelatedFile>,
240    events: Vec<Arc<zeta_prompt::Event>>,
241    excerpt_path: Arc<Path>,
242    cursor_offset: usize,
243    zeta_format: ZetaFormat,
244) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
245    let cursor_point = cursor_offset.to_point(snapshot);
246
247    let (editable_range, context_range) =
248        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
249            cursor_point,
250            snapshot,
251            max_editable_tokens(zeta_format),
252            MAX_CONTEXT_TOKENS,
253        );
254
255    let related_files = crate::filter_redundant_excerpts(
256        related_files,
257        excerpt_path.as_ref(),
258        context_range.start.row..context_range.end.row,
259    );
260
261    let context_start_offset = context_range.start.to_offset(snapshot);
262    let context_start_row = context_range.start.row;
263    let editable_offset_range = editable_range.to_offset(snapshot);
264    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
265    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
266        ..(editable_offset_range.end - context_start_offset);
267
268    let prompt_input = zeta_prompt::ZetaPromptInput {
269        cursor_path: excerpt_path,
270        cursor_excerpt: snapshot
271            .text_for_range(context_range)
272            .collect::<String>()
273            .into(),
274        editable_range_in_excerpt,
275        cursor_offset_in_excerpt,
276        excerpt_start_row: Some(context_start_row),
277        events,
278        related_files,
279    };
280    (editable_offset_range, prompt_input)
281}
282
283pub(crate) fn edit_prediction_accepted(
284    store: &EditPredictionStore,
285    current_prediction: CurrentEditPrediction,
286    cx: &App,
287) {
288    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
289    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
290        return;
291    }
292
293    let request_id = current_prediction.prediction.id.to_string();
294    let require_auth = custom_accept_url.is_none();
295    let client = store.client.clone();
296    let llm_token = store.llm_token.clone();
297    let app_version = AppVersion::global(cx);
298
299    cx.background_spawn(async move {
300        let url = if let Some(accept_edits_url) = custom_accept_url {
301            gpui::http_client::Url::parse(&accept_edits_url)?
302        } else {
303            client
304                .http_client()
305                .build_zed_llm_url("/predict_edits/accept", &[])?
306        };
307
308        let response = EditPredictionStore::send_api_request::<()>(
309            move |builder| {
310                let req = builder.uri(url.as_ref()).body(
311                    serde_json::to_string(&AcceptEditPredictionBody {
312                        request_id: request_id.clone(),
313                    })?
314                    .into(),
315                );
316                Ok(req?)
317            },
318            client,
319            llm_token,
320            app_version,
321            require_auth,
322        )
323        .await;
324
325        response?;
326        anyhow::Ok(())
327    })
328    .detach_and_log_err(cx);
329}