zeta2.rs

  1#[cfg(feature = "eval-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::prediction::EditPredictionResult;
  4use crate::{
  5    DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
  6    EditPredictionRequestedDebugEvent, EditPredictionStore,
  7};
  8use anyhow::{Result, anyhow, bail};
  9use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
 10use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
 11use cloud_zeta2_prompt::CURSOR_MARKER;
 12use edit_prediction_context::{EditPredictionExcerpt, Line};
 13use edit_prediction_context::{RelatedExcerpt, RelatedFile};
 14use futures::channel::oneshot;
 15use gpui::{Entity, Task, prelude::*};
 16use language::{Anchor, BufferSnapshot};
 17use language::{Buffer, Point, ToOffset as _, ToPoint};
 18use project::{Project, ProjectItem as _};
 19use release_channel::AppVersion;
 20use std::{
 21    env,
 22    path::Path,
 23    sync::Arc,
 24    time::{Duration, Instant},
 25};
 26
 27pub fn request_prediction_with_zeta2(
 28    store: &mut EditPredictionStore,
 29    project: &Entity<Project>,
 30    active_buffer: &Entity<Buffer>,
 31    active_snapshot: BufferSnapshot,
 32    position: Anchor,
 33    events: Vec<Arc<Event>>,
 34    mut included_files: Vec<RelatedFile>,
 35    trigger: PredictEditsRequestTrigger,
 36    cx: &mut Context<EditPredictionStore>,
 37) -> Task<Result<Option<EditPredictionResult>>> {
 38    let options = store.options.clone();
 39    let buffer_snapshotted_at = Instant::now();
 40
 41    let Some((excerpt_path, active_project_path)) = active_snapshot
 42        .file()
 43        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 44        .zip(active_buffer.read(cx).project_path(cx))
 45    else {
 46        return Task::ready(Err(anyhow!("No file path for excerpt")));
 47    };
 48
 49    let client = store.client.clone();
 50    let llm_token = store.llm_token.clone();
 51    let app_version = AppVersion::global(cx);
 52    let debug_tx = store.debug_tx.clone();
 53
 54    let file = active_buffer.read(cx).file();
 55
 56    let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
 57
 58    // TODO data collection
 59    let can_collect_data = file
 60        .as_ref()
 61        .map_or(false, |file| store.can_collect_file(project, file, cx));
 62
 63    #[cfg(feature = "eval-support")]
 64    let eval_cache = store.eval_cache.clone();
 65
 66    let request_task = cx.background_spawn({
 67        let active_buffer = active_buffer.clone();
 68        async move {
 69            let cursor_offset = position.to_offset(&active_snapshot);
 70            let cursor_point = cursor_offset.to_point(&active_snapshot);
 71
 72            let before_retrieval = Instant::now();
 73
 74            let excerpt_options = options.context;
 75
 76            let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 77                cursor_point,
 78                &active_snapshot,
 79                &excerpt_options,
 80            ) else {
 81                return Ok((None, None));
 82            };
 83
 84            let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 85                ..active_snapshot.anchor_before(excerpt.range.end);
 86            let related_excerpt = RelatedExcerpt {
 87                anchor_range: excerpt_anchor_range.clone(),
 88                point_range: Point::new(excerpt.line_range.start.0, 0)
 89                    ..Point::new(excerpt.line_range.end.0, 0),
 90                text: active_snapshot.as_rope().slice(excerpt.range),
 91            };
 92
 93            if let Some(buffer_ix) = included_files
 94                .iter()
 95                .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
 96            {
 97                let file = &mut included_files[buffer_ix];
 98                file.excerpts.push(related_excerpt);
 99                file.merge_excerpts();
100                let last_ix = included_files.len() - 1;
101                included_files.swap(buffer_ix, last_ix);
102            } else {
103                let active_file = RelatedFile {
104                    path: active_project_path,
105                    buffer: active_buffer.downgrade(),
106                    excerpts: vec![related_excerpt],
107                    max_row: active_snapshot.max_point().row,
108                };
109                included_files.push(active_file);
110            }
111
112            let included_files = included_files
113                .iter()
114                .map(|related_file| predict_edits_v3::RelatedFile {
115                    path: Arc::from(related_file.path.path.as_std_path()),
116                    max_row: Line(related_file.max_row),
117                    excerpts: related_file
118                        .excerpts
119                        .iter()
120                        .map(|excerpt| predict_edits_v3::Excerpt {
121                            start_line: Line(excerpt.point_range.start.row),
122                            text: excerpt.text.to_string().into(),
123                        })
124                        .collect(),
125                })
126                .collect::<Vec<_>>();
127
128            let cloud_request = predict_edits_v3::PredictEditsRequest {
129                excerpt_path,
130                excerpt: String::new(),
131                excerpt_line_range: Line(0)..Line(0),
132                excerpt_range: 0..0,
133                cursor_point: predict_edits_v3::Point {
134                    line: predict_edits_v3::Line(cursor_point.row),
135                    column: cursor_point.column,
136                },
137                related_files: included_files,
138                events,
139                can_collect_data,
140                debug_info: debug_tx.is_some(),
141                prompt_max_bytes: Some(options.max_prompt_bytes),
142                prompt_format: options.prompt_format,
143                excerpt_parent: None,
144                git_info: None,
145                trigger,
146            };
147
148            let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
149
150            let inputs = EditPredictionInputs {
151                included_files: cloud_request.related_files,
152                events: cloud_request.events,
153                cursor_point: cloud_request.cursor_point,
154                cursor_path: cloud_request.excerpt_path,
155            };
156
157            let retrieval_time = Instant::now() - before_retrieval;
158
159            let debug_response_tx = if let Some(debug_tx) = &debug_tx {
160                let (response_tx, response_rx) = oneshot::channel();
161
162                debug_tx
163                    .unbounded_send(DebugEvent::EditPredictionRequested(
164                        EditPredictionRequestedDebugEvent {
165                            inputs: inputs.clone(),
166                            retrieval_time,
167                            buffer: active_buffer.downgrade(),
168                            local_prompt: match prompt_result.as_ref() {
169                                Ok(prompt) => Ok(prompt.clone()),
170                                Err(err) => Err(err.to_string()),
171                            },
172                            position,
173                            response_rx,
174                        },
175                    ))
176                    .ok();
177                Some(response_tx)
178            } else {
179                None
180            };
181
182            if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
183                if let Some(debug_response_tx) = debug_response_tx {
184                    debug_response_tx
185                        .send((Err("Request skipped".to_string()), Duration::ZERO))
186                        .ok();
187                }
188                anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
189            }
190
191            let prompt = prompt_result?;
192            let generation_params =
193                cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
194            let request = open_ai::Request {
195                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
196                messages: vec![open_ai::RequestMessage::User {
197                    content: open_ai::MessageContent::Plain(prompt),
198                }],
199                stream: false,
200                max_completion_tokens: None,
201                stop: generation_params.stop.unwrap_or_default(),
202                temperature: generation_params.temperature.unwrap_or(0.7),
203                tool_choice: None,
204                parallel_tool_calls: None,
205                tools: vec![],
206                prompt_cache_key: None,
207                reasoning_effort: None,
208            };
209
210            log::trace!("Sending edit prediction request");
211
212            let before_request = Instant::now();
213            let response = EditPredictionStore::send_raw_llm_request(
214                request,
215                client,
216                llm_token,
217                app_version,
218                #[cfg(feature = "eval-support")]
219                eval_cache,
220                #[cfg(feature = "eval-support")]
221                EvalCacheEntryKind::Prediction,
222            )
223            .await;
224            let received_response_at = Instant::now();
225            let request_time = received_response_at - before_request;
226
227            log::trace!("Got edit prediction response");
228
229            if let Some(debug_response_tx) = debug_response_tx {
230                debug_response_tx
231                    .send((
232                        response
233                            .as_ref()
234                            .map_err(|err| err.to_string())
235                            .map(|response| response.0.clone()),
236                        request_time,
237                    ))
238                    .ok();
239            }
240
241            let (res, usage) = response?;
242            let request_id = EditPredictionId(res.id.clone().into());
243            let Some(mut output_text) = text_from_response(res) else {
244                return Ok((Some((request_id, None)), usage));
245            };
246
247            if output_text.contains(CURSOR_MARKER) {
248                log::trace!("Stripping out {CURSOR_MARKER} from response");
249                output_text = output_text.replace(CURSOR_MARKER, "");
250            }
251
252            let get_buffer_from_context = |path: &Path| {
253                if Some(path) == active_file_full_path.as_deref() {
254                    Some((
255                        &active_snapshot,
256                        std::slice::from_ref(&excerpt_anchor_range),
257                    ))
258                } else {
259                    None
260                }
261            };
262
263            let (_, edits) = match options.prompt_format {
264                PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
265                    if output_text.contains("--- a/\n+++ b/\nNo edits") {
266                        let edits = vec![];
267                        (&active_snapshot, edits)
268                    } else {
269                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
270                    }
271                }
272                PromptFormat::OldTextNewText => {
273                    crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
274                }
275                _ => {
276                    bail!("unsupported prompt format {}", options.prompt_format)
277                }
278            };
279
280            anyhow::Ok((
281                Some((
282                    request_id,
283                    Some((
284                        inputs,
285                        active_buffer,
286                        active_snapshot.clone(),
287                        edits,
288                        received_response_at,
289                    )),
290                )),
291                usage,
292            ))
293        }
294    });
295
296    cx.spawn(async move |this, cx| {
297        let Some((id, prediction)) =
298            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
299        else {
300            return Ok(None);
301        };
302
303        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
304            prediction
305        else {
306            return Ok(Some(EditPredictionResult {
307                id,
308                prediction: Err(EditPredictionRejectReason::Empty),
309            }));
310        };
311
312        Ok(Some(
313            EditPredictionResult::new(
314                id,
315                &edited_buffer,
316                &edited_buffer_snapshot,
317                edits.into(),
318                buffer_snapshotted_at,
319                received_response_at,
320                inputs,
321                cx,
322            )
323            .await,
324        ))
325    })
326}
327
328pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
329    let choice = res.choices.pop()?;
330    let output_text = match choice.message {
331        open_ai::RequestMessage::Assistant {
332            content: Some(open_ai::MessageContent::Plain(content)),
333            ..
334        } => content,
335        open_ai::RequestMessage::Assistant {
336            content: Some(open_ai::MessageContent::Multipart(mut content)),
337            ..
338        } => {
339            if content.is_empty() {
340                log::error!("No output from Baseten completion response");
341                return None;
342            }
343
344            match content.remove(0) {
345                open_ai::MessagePart::Text { text } => text,
346                open_ai::MessagePart::Image { .. } => {
347                    log::error!("Expected text, got an image");
348                    return None;
349                }
350            }
351        }
352        _ => {
353            log::error!("Invalid response message: {:?}", choice.message);
354            return None;
355        }
356    };
357    Some(output_text)
358}