1use crate::prediction::EditPredictionResult;
2use crate::zeta1::compute_edits_and_cursor_position;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
5 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
6};
7use anyhow::{Result, anyhow};
8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use gpui::{App, Task, prelude::*};
11use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
12use release_channel::AppVersion;
13
14use std::env;
15use std::{path::Path, sync::Arc, time::Instant};
16use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt};
17
18pub const MAX_CONTEXT_TOKENS: usize = 350;
19
20pub fn max_editable_tokens(format: ZetaFormat) -> usize {
21 match format {
22 ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
23 ZetaFormat::V0114180EditableRegion => 180,
24 ZetaFormat::V0120GitMergeMarkers => 180,
25 ZetaFormat::V0131GitMergeMarkersPrefix => 180,
26 }
27}
28
29pub fn request_prediction_with_zeta2(
30 store: &mut EditPredictionStore,
31 EditPredictionModelInput {
32 buffer,
33 snapshot,
34 position,
35 related_files,
36 events,
37 debug_tx,
38 trigger,
39 ..
40 }: EditPredictionModelInput,
41 cx: &mut Context<EditPredictionStore>,
42) -> Task<Result<Option<EditPredictionResult>>> {
43 let buffer_snapshotted_at = Instant::now();
44 let raw_config = store.zeta2_raw_config().cloned();
45
46 let Some(excerpt_path) = snapshot
47 .file()
48 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
49 else {
50 return Task::ready(Err(anyhow!("No file path for excerpt")));
51 };
52
53 let client = store.client.clone();
54 let llm_token = store.llm_token.clone();
55 let app_version = AppVersion::global(cx);
56
57 let request_task = cx.background_spawn({
58 async move {
59 let zeta_version = raw_config
60 .as_ref()
61 .map(|config| config.format)
62 .unwrap_or(ZetaFormat::default());
63
64 let cursor_offset = position.to_offset(&snapshot);
65 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
66 &snapshot,
67 related_files,
68 events,
69 excerpt_path,
70 cursor_offset,
71 zeta_version,
72 );
73
74 if let Some(debug_tx) = &debug_tx {
75 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
76 debug_tx
77 .unbounded_send(DebugEvent::EditPredictionStarted(
78 EditPredictionStartedDebugEvent {
79 buffer: buffer.downgrade(),
80 prompt: Some(prompt),
81 position,
82 },
83 ))
84 .ok();
85 }
86
87 log::trace!("Sending edit prediction request");
88
89 let (request_id, output_text, usage) = if let Some(config) = &raw_config {
90 let prompt = format_zeta_prompt(&prompt_input, config.format);
91 let request = RawCompletionRequest {
92 model: config.model_id.clone().unwrap_or_default(),
93 prompt,
94 temperature: None,
95 stop: vec![],
96 max_tokens: Some(2048),
97 environment: Some(config.format.to_string().to_lowercase()),
98 };
99
100 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
101 request,
102 client,
103 None,
104 llm_token,
105 app_version,
106 )
107 .await?;
108
109 let request_id = EditPredictionId(response.id.clone().into());
110 let output_text = response.choices.pop().map(|choice| {
111 clean_zeta2_model_output(&choice.text, config.format).to_string()
112 });
113
114 (request_id, output_text, usage)
115 } else {
116 // Use V3 endpoint - server handles model/version selection and suffix stripping
117 let (response, usage) = EditPredictionStore::send_v3_request(
118 prompt_input.clone(),
119 client,
120 llm_token,
121 app_version,
122 trigger,
123 )
124 .await?;
125
126 let request_id = EditPredictionId(response.request_id.into());
127 let output_text = if response.output.is_empty() {
128 None
129 } else {
130 Some(response.output)
131 };
132 (request_id, output_text, usage)
133 };
134
135 let received_response_at = Instant::now();
136
137 log::trace!("Got edit prediction response");
138
139 let Some(mut output_text) = output_text else {
140 return Ok((Some((request_id, None)), usage));
141 };
142
143 // Client-side cursor marker processing (applies to both raw and v3 responses)
144 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
145 if let Some(offset) = cursor_offset_in_output {
146 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
147 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
148 }
149
150 if let Some(debug_tx) = &debug_tx {
151 debug_tx
152 .unbounded_send(DebugEvent::EditPredictionFinished(
153 EditPredictionFinishedDebugEvent {
154 buffer: buffer.downgrade(),
155 position,
156 model_output: Some(output_text.clone()),
157 },
158 ))
159 .ok();
160 }
161
162 let mut old_text = snapshot
163 .text_for_range(editable_offset_range.clone())
164 .collect::<String>();
165
166 if !output_text.is_empty() && !output_text.ends_with('\n') {
167 output_text.push('\n');
168 }
169 if !old_text.is_empty() && !old_text.ends_with('\n') {
170 old_text.push('\n');
171 }
172
173 let (edits, cursor_position) = compute_edits_and_cursor_position(
174 old_text,
175 &output_text,
176 editable_offset_range.start,
177 cursor_offset_in_output,
178 &snapshot,
179 );
180
181 anyhow::Ok((
182 Some((
183 request_id,
184 Some((
185 prompt_input,
186 buffer,
187 snapshot.clone(),
188 edits,
189 cursor_position,
190 received_response_at,
191 )),
192 )),
193 usage,
194 ))
195 }
196 });
197
198 cx.spawn(async move |this, cx| {
199 let Some((id, prediction)) =
200 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
201 else {
202 return Ok(None);
203 };
204
205 let Some((
206 inputs,
207 edited_buffer,
208 edited_buffer_snapshot,
209 edits,
210 cursor_position,
211 received_response_at,
212 )) = prediction
213 else {
214 return Ok(Some(EditPredictionResult {
215 id,
216 prediction: Err(EditPredictionRejectReason::Empty),
217 }));
218 };
219
220 Ok(Some(
221 EditPredictionResult::new(
222 id,
223 &edited_buffer,
224 &edited_buffer_snapshot,
225 edits.into(),
226 cursor_position,
227 buffer_snapshotted_at,
228 received_response_at,
229 inputs,
230 cx,
231 )
232 .await,
233 ))
234 })
235}
236
237pub fn zeta2_prompt_input(
238 snapshot: &language::BufferSnapshot,
239 related_files: Vec<zeta_prompt::RelatedFile>,
240 events: Vec<Arc<zeta_prompt::Event>>,
241 excerpt_path: Arc<Path>,
242 cursor_offset: usize,
243 zeta_format: ZetaFormat,
244) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
245 let cursor_point = cursor_offset.to_point(snapshot);
246
247 let (editable_range, context_range) =
248 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
249 cursor_point,
250 snapshot,
251 max_editable_tokens(zeta_format),
252 MAX_CONTEXT_TOKENS,
253 );
254
255 let related_files = crate::filter_redundant_excerpts(
256 related_files,
257 excerpt_path.as_ref(),
258 context_range.start.row..context_range.end.row,
259 );
260
261 let context_start_offset = context_range.start.to_offset(snapshot);
262 let context_start_row = context_range.start.row;
263 let editable_offset_range = editable_range.to_offset(snapshot);
264 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
265 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
266 ..(editable_offset_range.end - context_start_offset);
267
268 let prompt_input = zeta_prompt::ZetaPromptInput {
269 cursor_path: excerpt_path,
270 cursor_excerpt: snapshot
271 .text_for_range(context_range)
272 .collect::<String>()
273 .into(),
274 editable_range_in_excerpt,
275 cursor_offset_in_excerpt,
276 excerpt_start_row: Some(context_start_row),
277 events,
278 related_files,
279 };
280 (editable_offset_range, prompt_input)
281}
282
283pub(crate) fn edit_prediction_accepted(
284 store: &EditPredictionStore,
285 current_prediction: CurrentEditPrediction,
286 cx: &App,
287) {
288 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
289 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
290 return;
291 }
292
293 let request_id = current_prediction.prediction.id.to_string();
294 let require_auth = custom_accept_url.is_none();
295 let client = store.client.clone();
296 let llm_token = store.llm_token.clone();
297 let app_version = AppVersion::global(cx);
298
299 cx.background_spawn(async move {
300 let url = if let Some(accept_edits_url) = custom_accept_url {
301 gpui::http_client::Url::parse(&accept_edits_url)?
302 } else {
303 client
304 .http_client()
305 .build_zed_llm_url("/predict_edits/accept", &[])?
306 };
307
308 let response = EditPredictionStore::send_api_request::<()>(
309 move |builder| {
310 let req = builder.uri(url.as_ref()).body(
311 serde_json::to_string(&AcceptEditPredictionBody {
312 request_id: request_id.clone(),
313 })?
314 .into(),
315 );
316 Ok(req?)
317 },
318 client,
319 llm_token,
320 app_version,
321 require_auth,
322 )
323 .await;
324
325 response?;
326 anyhow::Ok(())
327 })
328 .detach_and_log_err(cx);
329}