1use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
2use crate::prediction::EditPredictionResult;
3use crate::zeta1::compute_edits_and_cursor_position;
4use crate::{
5 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
6 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
7};
8use anyhow::Result;
9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
11use gpui::{App, Task, prelude::*};
12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
13use release_channel::AppVersion;
14
15use std::env;
16use std::{path::Path, sync::Arc, time::Instant};
17use zeta_prompt::{
18 CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
19 format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
20};
21
22pub const MAX_CONTEXT_TOKENS: usize = 350;
23
24pub fn max_editable_tokens(format: ZetaFormat) -> usize {
25 match format {
26 ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
27 ZetaFormat::V0114180EditableRegion => 180,
28 ZetaFormat::V0120GitMergeMarkers => 180,
29 ZetaFormat::V0131GitMergeMarkersPrefix => 180,
30 ZetaFormat::V0211Prefill => 180,
31 ZetaFormat::V0211SeedCoder => 180,
32 }
33}
34
35pub fn request_prediction_with_zeta2(
36 store: &mut EditPredictionStore,
37 EditPredictionModelInput {
38 buffer,
39 snapshot,
40 position,
41 related_files,
42 events,
43 debug_tx,
44 trigger,
45 project,
46 ..
47 }: EditPredictionModelInput,
48 preferred_model: Option<EditPredictionModelKind>,
49 cx: &mut Context<EditPredictionStore>,
50) -> Task<Result<Option<EditPredictionResult>>> {
51 let buffer_snapshotted_at = Instant::now();
52 let raw_config = store.zeta2_raw_config().cloned();
53
54 let excerpt_path: Arc<Path> = snapshot
55 .file()
56 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
57 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
58
59 let client = store.client.clone();
60 let llm_token = store.llm_token.clone();
61 let app_version = AppVersion::global(cx);
62
63 let is_open_source = snapshot
64 .file()
65 .map_or(false, |file| store.is_file_open_source(&project, file, cx))
66 && events.iter().all(|event| event.in_open_source_repo())
67 && related_files.iter().all(|file| file.in_open_source_repo);
68
69 let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
70
71 let request_task = cx.background_spawn({
72 async move {
73 let zeta_version = raw_config
74 .as_ref()
75 .map(|config| config.format)
76 .unwrap_or(ZetaFormat::default());
77
78 let cursor_offset = position.to_offset(&snapshot);
79 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
80 &snapshot,
81 related_files,
82 events,
83 excerpt_path,
84 cursor_offset,
85 zeta_version,
86 preferred_model,
87 is_open_source,
88 can_collect_data,
89 );
90
91 if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
92 return Ok((None, None));
93 }
94
95 if let Some(debug_tx) = &debug_tx {
96 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
97 debug_tx
98 .unbounded_send(DebugEvent::EditPredictionStarted(
99 EditPredictionStartedDebugEvent {
100 buffer: buffer.downgrade(),
101 prompt: Some(prompt),
102 position,
103 },
104 ))
105 .ok();
106 }
107
108 log::trace!("Sending edit prediction request");
109
110 let (request_id, output_text, usage) = if let Some(config) = &raw_config {
111 let prompt = format_zeta_prompt(&prompt_input, config.format);
112 let prefill = get_prefill(&prompt_input, config.format);
113 let prompt = format!("{prompt}{prefill}");
114 let request = RawCompletionRequest {
115 model: config.model_id.clone().unwrap_or_default(),
116 prompt,
117 temperature: None,
118 stop: vec![],
119 max_tokens: Some(2048),
120 environment: Some(config.format.to_string().to_lowercase()),
121 };
122
123 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
124 request,
125 client,
126 None,
127 llm_token,
128 app_version,
129 )
130 .await?;
131
132 let request_id = EditPredictionId(response.id.clone().into());
133 let output_text = response.choices.pop().map(|choice| {
134 let response = &choice.text;
135 let output = format!("{prefill}{response}");
136 clean_zeta2_model_output(&output, config.format).to_string()
137 });
138
139 (request_id, output_text, usage)
140 } else {
141 // Use V3 endpoint - server handles model/version selection and suffix stripping
142 let (response, usage) = EditPredictionStore::send_v3_request(
143 prompt_input.clone(),
144 client,
145 llm_token,
146 app_version,
147 trigger,
148 )
149 .await?;
150
151 let request_id = EditPredictionId(response.request_id.into());
152 let output_text = if response.output.is_empty() {
153 None
154 } else {
155 Some(response.output)
156 };
157 (request_id, output_text, usage)
158 };
159
160 let received_response_at = Instant::now();
161
162 log::trace!("Got edit prediction response");
163
164 let Some(mut output_text) = output_text else {
165 return Ok((Some((request_id, None)), usage));
166 };
167
168 // Client-side cursor marker processing (applies to both raw and v3 responses)
169 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
170 if let Some(offset) = cursor_offset_in_output {
171 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
172 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
173 }
174
175 if let Some(debug_tx) = &debug_tx {
176 debug_tx
177 .unbounded_send(DebugEvent::EditPredictionFinished(
178 EditPredictionFinishedDebugEvent {
179 buffer: buffer.downgrade(),
180 position,
181 model_output: Some(output_text.clone()),
182 },
183 ))
184 .ok();
185 }
186
187 let mut old_text = snapshot
188 .text_for_range(editable_offset_range.clone())
189 .collect::<String>();
190
191 if !output_text.is_empty() && !output_text.ends_with('\n') {
192 output_text.push('\n');
193 }
194 if !old_text.is_empty() && !old_text.ends_with('\n') {
195 old_text.push('\n');
196 }
197
198 let (edits, cursor_position) = compute_edits_and_cursor_position(
199 old_text,
200 &output_text,
201 editable_offset_range.start,
202 cursor_offset_in_output,
203 &snapshot,
204 );
205
206 anyhow::Ok((
207 Some((
208 request_id,
209 Some((
210 prompt_input,
211 buffer,
212 snapshot.clone(),
213 edits,
214 cursor_position,
215 received_response_at,
216 )),
217 )),
218 usage,
219 ))
220 }
221 });
222
223 cx.spawn(async move |this, cx| {
224 let Some((id, prediction)) =
225 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
226 else {
227 return Ok(None);
228 };
229
230 let Some((
231 inputs,
232 edited_buffer,
233 edited_buffer_snapshot,
234 edits,
235 cursor_position,
236 received_response_at,
237 )) = prediction
238 else {
239 return Ok(Some(EditPredictionResult {
240 id,
241 prediction: Err(EditPredictionRejectReason::Empty),
242 }));
243 };
244
245 Ok(Some(
246 EditPredictionResult::new(
247 id,
248 &edited_buffer,
249 &edited_buffer_snapshot,
250 edits.into(),
251 cursor_position,
252 buffer_snapshotted_at,
253 received_response_at,
254 inputs,
255 cx,
256 )
257 .await,
258 ))
259 })
260}
261
262pub fn zeta2_prompt_input(
263 snapshot: &language::BufferSnapshot,
264 related_files: Vec<zeta_prompt::RelatedFile>,
265 events: Vec<Arc<zeta_prompt::Event>>,
266 excerpt_path: Arc<Path>,
267 cursor_offset: usize,
268 zeta_format: ZetaFormat,
269 preferred_model: Option<EditPredictionModelKind>,
270 is_open_source: bool,
271 can_collect_data: bool,
272) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
273 let cursor_point = cursor_offset.to_point(snapshot);
274
275 let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
276
277 let related_files = crate::filter_redundant_excerpts(
278 related_files,
279 excerpt_path.as_ref(),
280 full_context.start.row..full_context.end.row,
281 );
282
283 let full_context_start_offset = full_context.start.to_offset(snapshot);
284 let full_context_start_row = full_context.start.row;
285
286 let excerpt_ranges =
287 excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
288
289 let editable_range = match preferred_model {
290 Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
291 _ => match zeta_format {
292 ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
293 _ => &range_points.editable_180,
294 },
295 };
296
297 let editable_offset_range = editable_range.to_offset(snapshot);
298 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
299 let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
300 ..(editable_offset_range.end - full_context_start_offset);
301
302 let prompt_input = zeta_prompt::ZetaPromptInput {
303 cursor_path: excerpt_path,
304 cursor_excerpt: snapshot
305 .text_for_range(full_context)
306 .collect::<String>()
307 .into(),
308 editable_range_in_excerpt,
309 cursor_offset_in_excerpt,
310 excerpt_start_row: Some(full_context_start_row),
311 events,
312 related_files,
313 excerpt_ranges: Some(excerpt_ranges),
314 preferred_model,
315 in_open_source_repo: is_open_source,
316 can_collect_data,
317 };
318 (editable_offset_range, prompt_input)
319}
320
321pub(crate) fn edit_prediction_accepted(
322 store: &EditPredictionStore,
323 current_prediction: CurrentEditPrediction,
324 cx: &App,
325) {
326 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
327 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
328 return;
329 }
330
331 let request_id = current_prediction.prediction.id.to_string();
332 let require_auth = custom_accept_url.is_none();
333 let client = store.client.clone();
334 let llm_token = store.llm_token.clone();
335 let app_version = AppVersion::global(cx);
336
337 cx.background_spawn(async move {
338 let url = if let Some(accept_edits_url) = custom_accept_url {
339 gpui::http_client::Url::parse(&accept_edits_url)?
340 } else {
341 client
342 .http_client()
343 .build_zed_llm_url("/predict_edits/accept", &[])?
344 };
345
346 let response = EditPredictionStore::send_api_request::<()>(
347 move |builder| {
348 let req = builder.uri(url.as_ref()).body(
349 serde_json::to_string(&AcceptEditPredictionBody {
350 request_id: request_id.clone(),
351 })?
352 .into(),
353 );
354 Ok(req?)
355 },
356 client,
357 llm_token,
358 app_version,
359 require_auth,
360 )
361 .await;
362
363 response?;
364 anyhow::Ok(())
365 })
366 .detach_and_log_err(cx);
367}