zeta2.rs

  1use crate::prediction::EditPredictionResult;
  2use crate::zeta1::compute_edits;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  5    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  6    EditPredictionStore,
  7};
  8use anyhow::{Result, anyhow};
  9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 11use gpui::{App, Task, prelude::*};
 12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 13use release_channel::AppVersion;
 14
 15use std::env;
 16use std::{path::Path, sync::Arc, time::Instant};
 17use zeta_prompt::format_zeta_prompt;
 18use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
 19
 20pub const MAX_CONTEXT_TOKENS: usize = 350;
 21
 22pub fn max_editable_tokens(version: ZetaVersion) -> usize {
 23    match version {
 24        ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
 25        ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
 26    }
 27}
 28
 29pub fn request_prediction_with_zeta2(
 30    store: &mut EditPredictionStore,
 31    EditPredictionModelInput {
 32        buffer,
 33        snapshot,
 34        position,
 35        related_files,
 36        events,
 37        debug_tx,
 38        trigger,
 39        ..
 40    }: EditPredictionModelInput,
 41    zeta_version: ZetaVersion,
 42    cx: &mut Context<EditPredictionStore>,
 43) -> Task<Result<Option<EditPredictionResult>>> {
 44    let buffer_snapshotted_at = Instant::now();
 45    let custom_url = store.custom_predict_edits_url.clone();
 46
 47    let Some(excerpt_path) = snapshot
 48        .file()
 49        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 50    else {
 51        return Task::ready(Err(anyhow!("No file path for excerpt")));
 52    };
 53
 54    let client = store.client.clone();
 55    let llm_token = store.llm_token.clone();
 56    let app_version = AppVersion::global(cx);
 57
 58    let request_task = cx.background_spawn({
 59        async move {
 60            let cursor_offset = position.to_offset(&snapshot);
 61            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 62                &snapshot,
 63                related_files,
 64                events,
 65                excerpt_path,
 66                cursor_offset,
 67                zeta_version,
 68            );
 69
 70            if let Some(debug_tx) = &debug_tx {
 71                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 72                debug_tx
 73                    .unbounded_send(DebugEvent::EditPredictionStarted(
 74                        EditPredictionStartedDebugEvent {
 75                            buffer: buffer.downgrade(),
 76                            prompt: Some(prompt),
 77                            position,
 78                        },
 79                    ))
 80                    .ok();
 81            }
 82
 83            log::trace!("Sending edit prediction request");
 84
 85            let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
 86                // Use raw endpoint with custom URL
 87                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 88                let request = RawCompletionRequest {
 89                    model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
 90                    prompt,
 91                    temperature: None,
 92                    stop: vec![],
 93                    max_tokens: Some(2048),
 94                };
 95
 96                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
 97                    request,
 98                    client,
 99                    Some(custom_url),
100                    llm_token,
101                    app_version,
102                )
103                .await?;
104
105                let request_id = EditPredictionId(response.id.clone().into());
106                let output_text = response.choices.pop().map(|choice| choice.text);
107                (request_id, output_text, usage)
108            } else {
109                let (response, usage) = EditPredictionStore::send_v3_request(
110                    prompt_input.clone(),
111                    zeta_version,
112                    client,
113                    llm_token,
114                    app_version,
115                    trigger,
116                )
117                .await?;
118
119                let request_id = EditPredictionId(response.request_id.into());
120                let output_text = if response.output.is_empty() {
121                    None
122                } else {
123                    Some(response.output)
124                };
125                (request_id, output_text, usage)
126            };
127
128            let received_response_at = Instant::now();
129
130            log::trace!("Got edit prediction response");
131
132            let Some(mut output_text) = output_text else {
133                return Ok((Some((request_id, None)), usage));
134            };
135
136            if let Some(debug_tx) = &debug_tx {
137                debug_tx
138                    .unbounded_send(DebugEvent::EditPredictionFinished(
139                        EditPredictionFinishedDebugEvent {
140                            buffer: buffer.downgrade(),
141                            position,
142                            model_output: Some(output_text.clone()),
143                        },
144                    ))
145                    .ok();
146            }
147
148            if output_text.contains(CURSOR_MARKER) {
149                log::trace!("Stripping out {CURSOR_MARKER} from response");
150                output_text = output_text.replace(CURSOR_MARKER, "");
151            }
152
153            if zeta_version == ZetaVersion::V0120GitMergeMarkers {
154                if let Some(stripped) =
155                    output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
156                {
157                    output_text = stripped.to_string();
158                }
159            }
160
161            let mut old_text = snapshot
162                .text_for_range(editable_offset_range.clone())
163                .collect::<String>();
164
165            if !output_text.is_empty() && !output_text.ends_with('\n') {
166                output_text.push('\n');
167            }
168            if !old_text.is_empty() && !old_text.ends_with('\n') {
169                old_text.push('\n');
170            }
171
172            let edits = compute_edits(
173                old_text,
174                &output_text,
175                editable_offset_range.start,
176                &snapshot,
177            );
178
179            anyhow::Ok((
180                Some((
181                    request_id,
182                    Some((
183                        prompt_input,
184                        buffer,
185                        snapshot.clone(),
186                        edits,
187                        received_response_at,
188                    )),
189                )),
190                usage,
191            ))
192        }
193    });
194
195    cx.spawn(async move |this, cx| {
196        let Some((id, prediction)) =
197            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
198        else {
199            return Ok(None);
200        };
201
202        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
203            prediction
204        else {
205            return Ok(Some(EditPredictionResult {
206                id,
207                prediction: Err(EditPredictionRejectReason::Empty),
208            }));
209        };
210
211        Ok(Some(
212            EditPredictionResult::new(
213                id,
214                &edited_buffer,
215                &edited_buffer_snapshot,
216                edits.into(),
217                buffer_snapshotted_at,
218                received_response_at,
219                inputs,
220                cx,
221            )
222            .await,
223        ))
224    })
225}
226
227pub fn zeta2_prompt_input(
228    snapshot: &language::BufferSnapshot,
229    related_files: Vec<zeta_prompt::RelatedFile>,
230    events: Vec<Arc<zeta_prompt::Event>>,
231    excerpt_path: Arc<Path>,
232    cursor_offset: usize,
233    zeta_version: ZetaVersion,
234) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
235    let cursor_point = cursor_offset.to_point(snapshot);
236
237    let (editable_range, context_range) =
238        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
239            cursor_point,
240            snapshot,
241            max_editable_tokens(zeta_version),
242            MAX_CONTEXT_TOKENS,
243        );
244
245    let related_files = crate::filter_redundant_excerpts(
246        related_files,
247        excerpt_path.as_ref(),
248        context_range.start.row..context_range.end.row,
249    );
250
251    let context_start_offset = context_range.start.to_offset(snapshot);
252    let editable_offset_range = editable_range.to_offset(snapshot);
253    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
254    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
255        ..(editable_offset_range.end - context_start_offset);
256
257    let prompt_input = zeta_prompt::ZetaPromptInput {
258        cursor_path: excerpt_path,
259        cursor_excerpt: snapshot
260            .text_for_range(context_range)
261            .collect::<String>()
262            .into(),
263        editable_range_in_excerpt,
264        cursor_offset_in_excerpt,
265        events,
266        related_files,
267    };
268    (editable_offset_range, prompt_input)
269}
270
271pub(crate) fn edit_prediction_accepted(
272    store: &EditPredictionStore,
273    current_prediction: CurrentEditPrediction,
274    cx: &App,
275) {
276    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
277    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
278        return;
279    }
280
281    let request_id = current_prediction.prediction.id.to_string();
282    let require_auth = custom_accept_url.is_none();
283    let client = store.client.clone();
284    let llm_token = store.llm_token.clone();
285    let app_version = AppVersion::global(cx);
286
287    cx.background_spawn(async move {
288        let url = if let Some(accept_edits_url) = custom_accept_url {
289            gpui::http_client::Url::parse(&accept_edits_url)?
290        } else {
291            client
292                .http_client()
293                .build_zed_llm_url("/predict_edits/accept", &[])?
294        };
295
296        let response = EditPredictionStore::send_api_request::<()>(
297            move |builder| {
298                let req = builder.uri(url.as_ref()).body(
299                    serde_json::to_string(&AcceptEditPredictionBody {
300                        request_id: request_id.clone(),
301                    })?
302                    .into(),
303                );
304                Ok(req?)
305            },
306            client,
307            llm_token,
308            app_version,
309            require_auth,
310        )
311        .await;
312
313        response?;
314        anyhow::Ok(())
315    })
316    .detach_and_log_err(cx);
317}