zeta2.rs

  1use crate::prediction::EditPredictionResult;
  2use crate::zeta1::compute_edits_and_cursor_position;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  5    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  6    EditPredictionStore,
  7};
  8use anyhow::{Result, anyhow};
  9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 11use gpui::{App, Task, prelude::*};
 12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 13use release_channel::AppVersion;
 14
 15use std::env;
 16use std::{path::Path, sync::Arc, time::Instant};
 17use zeta_prompt::format_zeta_prompt;
 18use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
 19
 20pub const MAX_CONTEXT_TOKENS: usize = 350;
 21
 22pub fn max_editable_tokens(version: ZetaVersion) -> usize {
 23    match version {
 24        ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
 25        ZetaVersion::V0114180EditableRegion => 180,
 26        ZetaVersion::V0120GitMergeMarkers => 180,
 27        ZetaVersion::V0131GitMergeMarkersPrefix => 180,
 28    }
 29}
 30
 31pub fn request_prediction_with_zeta2(
 32    store: &mut EditPredictionStore,
 33    EditPredictionModelInput {
 34        buffer,
 35        snapshot,
 36        position,
 37        related_files,
 38        events,
 39        debug_tx,
 40        trigger,
 41        ..
 42    }: EditPredictionModelInput,
 43    zeta_version: ZetaVersion,
 44    cx: &mut Context<EditPredictionStore>,
 45) -> Task<Result<Option<EditPredictionResult>>> {
 46    let buffer_snapshotted_at = Instant::now();
 47    let custom_url = store.custom_predict_edits_url.clone();
 48
 49    let Some(excerpt_path) = snapshot
 50        .file()
 51        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 52    else {
 53        return Task::ready(Err(anyhow!("No file path for excerpt")));
 54    };
 55
 56    let client = store.client.clone();
 57    let llm_token = store.llm_token.clone();
 58    let app_version = AppVersion::global(cx);
 59
 60    let request_task = cx.background_spawn({
 61        async move {
 62            let cursor_offset = position.to_offset(&snapshot);
 63            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 64                &snapshot,
 65                related_files,
 66                events,
 67                excerpt_path,
 68                cursor_offset,
 69                zeta_version,
 70            );
 71
 72            if let Some(debug_tx) = &debug_tx {
 73                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 74                debug_tx
 75                    .unbounded_send(DebugEvent::EditPredictionStarted(
 76                        EditPredictionStartedDebugEvent {
 77                            buffer: buffer.downgrade(),
 78                            prompt: Some(prompt),
 79                            position,
 80                        },
 81                    ))
 82                    .ok();
 83            }
 84
 85            log::trace!("Sending edit prediction request");
 86
 87            let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
 88                // Use raw endpoint with custom URL
 89                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 90                let request = RawCompletionRequest {
 91                    model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
 92                    prompt,
 93                    temperature: None,
 94                    stop: vec![],
 95                    max_tokens: Some(2048),
 96                };
 97
 98                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
 99                    request,
100                    client,
101                    Some(custom_url),
102                    llm_token,
103                    app_version,
104                )
105                .await?;
106
107                let request_id = EditPredictionId(response.id.clone().into());
108                let output_text = response.choices.pop().map(|choice| choice.text);
109                (request_id, output_text, usage)
110            } else {
111                let (response, usage) = EditPredictionStore::send_v3_request(
112                    prompt_input.clone(),
113                    zeta_version,
114                    client,
115                    llm_token,
116                    app_version,
117                    trigger,
118                )
119                .await?;
120
121                let request_id = EditPredictionId(response.request_id.into());
122                let output_text = if response.output.is_empty() {
123                    None
124                } else {
125                    Some(response.output)
126                };
127                (request_id, output_text, usage)
128            };
129
130            let received_response_at = Instant::now();
131
132            log::trace!("Got edit prediction response");
133
134            let Some(mut output_text) = output_text else {
135                return Ok((Some((request_id, None)), usage));
136            };
137
138            if let Some(debug_tx) = &debug_tx {
139                debug_tx
140                    .unbounded_send(DebugEvent::EditPredictionFinished(
141                        EditPredictionFinishedDebugEvent {
142                            buffer: buffer.downgrade(),
143                            position,
144                            model_output: Some(output_text.clone()),
145                        },
146                    ))
147                    .ok();
148            }
149
150            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
151            if let Some(offset) = cursor_offset_in_output {
152                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
153                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
154            }
155
156            if zeta_version == ZetaVersion::V0120GitMergeMarkers {
157                if let Some(stripped) =
158                    output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
159                {
160                    output_text = stripped.to_string();
161                }
162            }
163
164            let mut old_text = snapshot
165                .text_for_range(editable_offset_range.clone())
166                .collect::<String>();
167
168            if !output_text.is_empty() && !output_text.ends_with('\n') {
169                output_text.push('\n');
170            }
171            if !old_text.is_empty() && !old_text.ends_with('\n') {
172                old_text.push('\n');
173            }
174
175            let (edits, cursor_position) = compute_edits_and_cursor_position(
176                old_text,
177                &output_text,
178                editable_offset_range.start,
179                cursor_offset_in_output,
180                &snapshot,
181            );
182
183            anyhow::Ok((
184                Some((
185                    request_id,
186                    Some((
187                        prompt_input,
188                        buffer,
189                        snapshot.clone(),
190                        edits,
191                        cursor_position,
192                        received_response_at,
193                    )),
194                )),
195                usage,
196            ))
197        }
198    });
199
200    cx.spawn(async move |this, cx| {
201        let Some((id, prediction)) =
202            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
203        else {
204            return Ok(None);
205        };
206
207        let Some((
208            inputs,
209            edited_buffer,
210            edited_buffer_snapshot,
211            edits,
212            cursor_position,
213            received_response_at,
214        )) = prediction
215        else {
216            return Ok(Some(EditPredictionResult {
217                id,
218                prediction: Err(EditPredictionRejectReason::Empty),
219            }));
220        };
221
222        Ok(Some(
223            EditPredictionResult::new(
224                id,
225                &edited_buffer,
226                &edited_buffer_snapshot,
227                edits.into(),
228                cursor_position,
229                buffer_snapshotted_at,
230                received_response_at,
231                inputs,
232                cx,
233            )
234            .await,
235        ))
236    })
237}
238
239pub fn zeta2_prompt_input(
240    snapshot: &language::BufferSnapshot,
241    related_files: Vec<zeta_prompt::RelatedFile>,
242    events: Vec<Arc<zeta_prompt::Event>>,
243    excerpt_path: Arc<Path>,
244    cursor_offset: usize,
245    zeta_version: ZetaVersion,
246) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
247    let cursor_point = cursor_offset.to_point(snapshot);
248
249    let (editable_range, context_range) =
250        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
251            cursor_point,
252            snapshot,
253            max_editable_tokens(zeta_version),
254            MAX_CONTEXT_TOKENS,
255        );
256
257    let related_files = crate::filter_redundant_excerpts(
258        related_files,
259        excerpt_path.as_ref(),
260        context_range.start.row..context_range.end.row,
261    );
262
263    let context_start_offset = context_range.start.to_offset(snapshot);
264    let context_start_row = context_range.start.row;
265    let editable_offset_range = editable_range.to_offset(snapshot);
266    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
267    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
268        ..(editable_offset_range.end - context_start_offset);
269
270    let prompt_input = zeta_prompt::ZetaPromptInput {
271        cursor_path: excerpt_path,
272        cursor_excerpt: snapshot
273            .text_for_range(context_range)
274            .collect::<String>()
275            .into(),
276        editable_range_in_excerpt,
277        cursor_offset_in_excerpt,
278        excerpt_start_row: Some(context_start_row),
279        events,
280        related_files,
281    };
282    (editable_offset_range, prompt_input)
283}
284
285pub(crate) fn edit_prediction_accepted(
286    store: &EditPredictionStore,
287    current_prediction: CurrentEditPrediction,
288    cx: &App,
289) {
290    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
291    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
292        return;
293    }
294
295    let request_id = current_prediction.prediction.id.to_string();
296    let require_auth = custom_accept_url.is_none();
297    let client = store.client.clone();
298    let llm_token = store.llm_token.clone();
299    let app_version = AppVersion::global(cx);
300
301    cx.background_spawn(async move {
302        let url = if let Some(accept_edits_url) = custom_accept_url {
303            gpui::http_client::Url::parse(&accept_edits_url)?
304        } else {
305            client
306                .http_client()
307                .build_zed_llm_url("/predict_edits/accept", &[])?
308        };
309
310        let response = EditPredictionStore::send_api_request::<()>(
311            move |builder| {
312                let req = builder.uri(url.as_ref()).body(
313                    serde_json::to_string(&AcceptEditPredictionBody {
314                        request_id: request_id.clone(),
315                    })?
316                    .into(),
317                );
318                Ok(req?)
319            },
320            client,
321            llm_token,
322            app_version,
323            require_auth,
324        )
325        .await;
326
327        response?;
328        anyhow::Ok(())
329    })
330    .detach_and_log_err(cx);
331}