zeta2.rs

  1#[cfg(feature = "cli-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::open_ai_response::text_from_response;
  4use crate::prediction::EditPredictionResult;
  5use crate::{
  6    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  7    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  8    EditPredictionStore,
  9};
 10use anyhow::{Result, anyhow};
 11use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 12use gpui::{App, Task, prelude::*};
 13use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 14use release_channel::AppVersion;
 15
 16use std::env;
 17use std::{path::Path, sync::Arc, time::Instant};
 18use zeta_prompt::CURSOR_MARKER;
 19use zeta_prompt::format_zeta_prompt;
 20
 21pub const MAX_CONTEXT_TOKENS: usize = 350;
 22pub const MAX_EDITABLE_TOKENS: usize = 150;
 23
 24pub fn request_prediction_with_zeta2(
 25    store: &mut EditPredictionStore,
 26    EditPredictionModelInput {
 27        buffer,
 28        snapshot,
 29        position,
 30        related_files,
 31        events,
 32        debug_tx,
 33        ..
 34    }: EditPredictionModelInput,
 35    cx: &mut Context<EditPredictionStore>,
 36) -> Task<Result<Option<EditPredictionResult>>> {
 37    let buffer_snapshotted_at = Instant::now();
 38
 39    let Some(excerpt_path) = snapshot
 40        .file()
 41        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 42    else {
 43        return Task::ready(Err(anyhow!("No file path for excerpt")));
 44    };
 45
 46    let client = store.client.clone();
 47    let llm_token = store.llm_token.clone();
 48    let app_version = AppVersion::global(cx);
 49
 50    #[cfg(feature = "cli-support")]
 51    let eval_cache = store.eval_cache.clone();
 52
 53    let request_task = cx.background_spawn({
 54        async move {
 55            let cursor_offset = position.to_offset(&snapshot);
 56            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 57                &snapshot,
 58                related_files,
 59                events,
 60                excerpt_path,
 61                cursor_offset,
 62            );
 63
 64            let prompt = format_zeta_prompt(&prompt_input);
 65
 66            if let Some(debug_tx) = &debug_tx {
 67                debug_tx
 68                    .unbounded_send(DebugEvent::EditPredictionStarted(
 69                        EditPredictionStartedDebugEvent {
 70                            buffer: buffer.downgrade(),
 71                            prompt: Some(prompt.clone()),
 72                            position,
 73                        },
 74                    ))
 75                    .ok();
 76            }
 77
 78            let request = open_ai::Request {
 79                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 80                messages: vec![open_ai::RequestMessage::User {
 81                    content: open_ai::MessageContent::Plain(prompt),
 82                }],
 83                stream: false,
 84                max_completion_tokens: None,
 85                stop: Default::default(),
 86                temperature: Default::default(),
 87                tool_choice: None,
 88                parallel_tool_calls: None,
 89                tools: vec![],
 90                prompt_cache_key: None,
 91                reasoning_effort: None,
 92            };
 93
 94            log::trace!("Sending edit prediction request");
 95
 96            let response = EditPredictionStore::send_raw_llm_request(
 97                request,
 98                client,
 99                llm_token,
100                app_version,
101                #[cfg(feature = "cli-support")]
102                eval_cache,
103                #[cfg(feature = "cli-support")]
104                EvalCacheEntryKind::Prediction,
105            )
106            .await;
107            let received_response_at = Instant::now();
108
109            log::trace!("Got edit prediction response");
110
111            let (res, usage) = response?;
112            let request_id = EditPredictionId(res.id.clone().into());
113            let Some(mut output_text) = text_from_response(res) else {
114                return Ok((Some((request_id, None)), usage));
115            };
116
117            if let Some(debug_tx) = &debug_tx {
118                debug_tx
119                    .unbounded_send(DebugEvent::EditPredictionFinished(
120                        EditPredictionFinishedDebugEvent {
121                            buffer: buffer.downgrade(),
122                            position,
123                            model_output: Some(output_text.clone()),
124                        },
125                    ))
126                    .ok();
127            }
128
129            if output_text.contains(CURSOR_MARKER) {
130                log::trace!("Stripping out {CURSOR_MARKER} from response");
131                output_text = output_text.replace(CURSOR_MARKER, "");
132            }
133
134            let old_text = snapshot
135                .text_for_range(editable_offset_range.clone())
136                .collect::<String>();
137            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
138                .into_iter()
139                .map(|(range, text)| {
140                    (
141                        snapshot.anchor_after(editable_offset_range.start + range.start)
142                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
143                        text,
144                    )
145                })
146                .collect();
147
148            anyhow::Ok((
149                Some((
150                    request_id,
151                    Some((
152                        prompt_input,
153                        buffer,
154                        snapshot.clone(),
155                        edits,
156                        received_response_at,
157                    )),
158                )),
159                usage,
160            ))
161        }
162    });
163
164    cx.spawn(async move |this, cx| {
165        let Some((id, prediction)) =
166            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
167        else {
168            return Ok(None);
169        };
170
171        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
172            prediction
173        else {
174            return Ok(Some(EditPredictionResult {
175                id,
176                prediction: Err(EditPredictionRejectReason::Empty),
177            }));
178        };
179
180        Ok(Some(
181            EditPredictionResult::new(
182                id,
183                &edited_buffer,
184                &edited_buffer_snapshot,
185                edits.into(),
186                buffer_snapshotted_at,
187                received_response_at,
188                inputs,
189                cx,
190            )
191            .await,
192        ))
193    })
194}
195
196pub fn zeta2_prompt_input(
197    snapshot: &language::BufferSnapshot,
198    related_files: Arc<[zeta_prompt::RelatedFile]>,
199    events: Vec<Arc<zeta_prompt::Event>>,
200    excerpt_path: Arc<Path>,
201    cursor_offset: usize,
202) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
203    let cursor_point = cursor_offset.to_point(snapshot);
204
205    let (editable_range, context_range) =
206        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
207            cursor_point,
208            snapshot,
209            MAX_EDITABLE_TOKENS,
210            MAX_CONTEXT_TOKENS,
211        );
212
213    let context_start_offset = context_range.start.to_offset(snapshot);
214    let editable_offset_range = editable_range.to_offset(snapshot);
215    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
216    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
217        ..(editable_offset_range.end - context_start_offset);
218
219    let prompt_input = zeta_prompt::ZetaPromptInput {
220        cursor_path: excerpt_path,
221        cursor_excerpt: snapshot
222            .text_for_range(context_range)
223            .collect::<String>()
224            .into(),
225        editable_range_in_excerpt,
226        cursor_offset_in_excerpt,
227        events,
228        related_files,
229    };
230    (editable_offset_range, prompt_input)
231}
232
233pub(crate) fn edit_prediction_accepted(
234    store: &EditPredictionStore,
235    current_prediction: CurrentEditPrediction,
236    cx: &App,
237) {
238    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
239    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
240        return;
241    }
242
243    let request_id = current_prediction.prediction.id.to_string();
244    let require_auth = custom_accept_url.is_none();
245    let client = store.client.clone();
246    let llm_token = store.llm_token.clone();
247    let app_version = AppVersion::global(cx);
248
249    cx.background_spawn(async move {
250        let url = if let Some(accept_edits_url) = custom_accept_url {
251            gpui::http_client::Url::parse(&accept_edits_url)?
252        } else {
253            client
254                .http_client()
255                .build_zed_llm_url("/predict_edits/accept", &[])?
256        };
257
258        let response = EditPredictionStore::send_api_request::<()>(
259            move |builder| {
260                let req = builder.uri(url.as_ref()).body(
261                    serde_json::to_string(&AcceptEditPredictionBody {
262                        request_id: request_id.clone(),
263                    })?
264                    .into(),
265                );
266                Ok(req?)
267            },
268            client,
269            llm_token,
270            app_version,
271            require_auth,
272        )
273        .await;
274
275        response?;
276        anyhow::Ok(())
277    })
278    .detach_and_log_err(cx);
279}
280
281#[cfg(feature = "cli-support")]
282pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
283    let old_editable_region = &input.cursor_excerpt[input.editable_range_in_excerpt.clone()];
284    let new_editable_region = crate::udiff::apply_diff_to_string(patch, old_editable_region)?;
285    Ok(new_editable_region)
286}