1use crate::prediction::EditPredictionResult;
2use crate::zeta1::compute_edits;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
5 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
6 EditPredictionStore,
7};
8use anyhow::{Result, anyhow};
9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
11use gpui::{App, Task, prelude::*};
12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
13use release_channel::AppVersion;
14
15use std::env;
16use std::{path::Path, sync::Arc, time::Instant};
17use zeta_prompt::format_zeta_prompt;
18use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
19
20pub const MAX_CONTEXT_TOKENS: usize = 350;
21
22pub fn max_editable_tokens(version: ZetaVersion) -> usize {
23 match version {
24 ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
25 ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
26 }
27}
28
29pub fn request_prediction_with_zeta2(
30 store: &mut EditPredictionStore,
31 EditPredictionModelInput {
32 buffer,
33 snapshot,
34 position,
35 related_files,
36 events,
37 debug_tx,
38 trigger,
39 ..
40 }: EditPredictionModelInput,
41 zeta_version: ZetaVersion,
42 cx: &mut Context<EditPredictionStore>,
43) -> Task<Result<Option<EditPredictionResult>>> {
44 let buffer_snapshotted_at = Instant::now();
45 let custom_url = store.custom_predict_edits_url.clone();
46
47 let Some(excerpt_path) = snapshot
48 .file()
49 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
50 else {
51 return Task::ready(Err(anyhow!("No file path for excerpt")));
52 };
53
54 let client = store.client.clone();
55 let llm_token = store.llm_token.clone();
56 let app_version = AppVersion::global(cx);
57
58 let request_task = cx.background_spawn({
59 async move {
60 let cursor_offset = position.to_offset(&snapshot);
61 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
62 &snapshot,
63 related_files,
64 events,
65 excerpt_path,
66 cursor_offset,
67 zeta_version,
68 );
69
70 if let Some(debug_tx) = &debug_tx {
71 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
72 debug_tx
73 .unbounded_send(DebugEvent::EditPredictionStarted(
74 EditPredictionStartedDebugEvent {
75 buffer: buffer.downgrade(),
76 prompt: Some(prompt),
77 position,
78 },
79 ))
80 .ok();
81 }
82
83 log::trace!("Sending edit prediction request");
84
85 let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
86 // Use raw endpoint with custom URL
87 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
88 let request = RawCompletionRequest {
89 model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
90 prompt,
91 temperature: None,
92 stop: vec![],
93 max_tokens: Some(2048),
94 };
95
96 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
97 request,
98 client,
99 Some(custom_url),
100 llm_token,
101 app_version,
102 )
103 .await?;
104
105 let request_id = EditPredictionId(response.id.clone().into());
106 let output_text = response.choices.pop().map(|choice| choice.text);
107 (request_id, output_text, usage)
108 } else {
109 let (response, usage) = EditPredictionStore::send_v3_request(
110 prompt_input.clone(),
111 zeta_version,
112 client,
113 llm_token,
114 app_version,
115 trigger,
116 )
117 .await?;
118
119 let request_id = EditPredictionId(response.request_id.into());
120 let output_text = if response.output.is_empty() {
121 None
122 } else {
123 Some(response.output)
124 };
125 (request_id, output_text, usage)
126 };
127
128 let received_response_at = Instant::now();
129
130 log::trace!("Got edit prediction response");
131
132 let Some(mut output_text) = output_text else {
133 return Ok((Some((request_id, None)), usage));
134 };
135
136 if let Some(debug_tx) = &debug_tx {
137 debug_tx
138 .unbounded_send(DebugEvent::EditPredictionFinished(
139 EditPredictionFinishedDebugEvent {
140 buffer: buffer.downgrade(),
141 position,
142 model_output: Some(output_text.clone()),
143 },
144 ))
145 .ok();
146 }
147
148 if output_text.contains(CURSOR_MARKER) {
149 log::trace!("Stripping out {CURSOR_MARKER} from response");
150 output_text = output_text.replace(CURSOR_MARKER, "");
151 }
152
153 if zeta_version == ZetaVersion::V0120GitMergeMarkers {
154 if let Some(stripped) =
155 output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
156 {
157 output_text = stripped.to_string();
158 }
159 }
160
161 let mut old_text = snapshot
162 .text_for_range(editable_offset_range.clone())
163 .collect::<String>();
164
165 if !output_text.is_empty() && !output_text.ends_with('\n') {
166 output_text.push('\n');
167 }
168 if !old_text.is_empty() && !old_text.ends_with('\n') {
169 old_text.push('\n');
170 }
171
172 let edits = compute_edits(
173 old_text,
174 &output_text,
175 editable_offset_range.start,
176 &snapshot,
177 );
178
179 anyhow::Ok((
180 Some((
181 request_id,
182 Some((
183 prompt_input,
184 buffer,
185 snapshot.clone(),
186 edits,
187 received_response_at,
188 )),
189 )),
190 usage,
191 ))
192 }
193 });
194
195 cx.spawn(async move |this, cx| {
196 let Some((id, prediction)) =
197 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
198 else {
199 return Ok(None);
200 };
201
202 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
203 prediction
204 else {
205 return Ok(Some(EditPredictionResult {
206 id,
207 prediction: Err(EditPredictionRejectReason::Empty),
208 }));
209 };
210
211 Ok(Some(
212 EditPredictionResult::new(
213 id,
214 &edited_buffer,
215 &edited_buffer_snapshot,
216 edits.into(),
217 buffer_snapshotted_at,
218 received_response_at,
219 inputs,
220 cx,
221 )
222 .await,
223 ))
224 })
225}
226
227pub fn zeta2_prompt_input(
228 snapshot: &language::BufferSnapshot,
229 related_files: Vec<zeta_prompt::RelatedFile>,
230 events: Vec<Arc<zeta_prompt::Event>>,
231 excerpt_path: Arc<Path>,
232 cursor_offset: usize,
233 zeta_version: ZetaVersion,
234) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
235 let cursor_point = cursor_offset.to_point(snapshot);
236
237 let (editable_range, context_range) =
238 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
239 cursor_point,
240 snapshot,
241 max_editable_tokens(zeta_version),
242 MAX_CONTEXT_TOKENS,
243 );
244
245 let related_files = crate::filter_redundant_excerpts(
246 related_files,
247 excerpt_path.as_ref(),
248 context_range.start.row..context_range.end.row,
249 );
250
251 let context_start_offset = context_range.start.to_offset(snapshot);
252 let editable_offset_range = editable_range.to_offset(snapshot);
253 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
254 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
255 ..(editable_offset_range.end - context_start_offset);
256
257 let prompt_input = zeta_prompt::ZetaPromptInput {
258 cursor_path: excerpt_path,
259 cursor_excerpt: snapshot
260 .text_for_range(context_range)
261 .collect::<String>()
262 .into(),
263 editable_range_in_excerpt,
264 cursor_offset_in_excerpt,
265 events,
266 related_files,
267 };
268 (editable_offset_range, prompt_input)
269}
270
271pub(crate) fn edit_prediction_accepted(
272 store: &EditPredictionStore,
273 current_prediction: CurrentEditPrediction,
274 cx: &App,
275) {
276 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
277 if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
278 return;
279 }
280
281 let request_id = current_prediction.prediction.id.to_string();
282 let require_auth = custom_accept_url.is_none();
283 let client = store.client.clone();
284 let llm_token = store.llm_token.clone();
285 let app_version = AppVersion::global(cx);
286
287 cx.background_spawn(async move {
288 let url = if let Some(accept_edits_url) = custom_accept_url {
289 gpui::http_client::Url::parse(&accept_edits_url)?
290 } else {
291 client
292 .http_client()
293 .build_zed_llm_url("/predict_edits/accept", &[])?
294 };
295
296 let response = EditPredictionStore::send_api_request::<()>(
297 move |builder| {
298 let req = builder.uri(url.as_ref()).body(
299 serde_json::to_string(&AcceptEditPredictionBody {
300 request_id: request_id.clone(),
301 })?
302 .into(),
303 );
304 Ok(req?)
305 },
306 client,
307 llm_token,
308 app_version,
309 require_auth,
310 )
311 .await;
312
313 response?;
314 anyhow::Ok(())
315 })
316 .detach_and_log_err(cx);
317}