1#[cfg(feature = "cli-support")]
2use crate::EvalCacheEntryKind;
3use crate::prediction::EditPredictionResult;
4use crate::{
5 CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
6 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
7 EditPredictionStore,
8};
9use anyhow::{Result, anyhow};
10use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
11use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
12use gpui::{App, Task, prelude::*};
13use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
14use release_channel::AppVersion;
15
16use std::env;
17use std::{path::Path, sync::Arc, time::Instant};
18use zeta_prompt::format_zeta_prompt;
19use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
20
21pub const MAX_CONTEXT_TOKENS: usize = 350;
22pub const MAX_EDITABLE_TOKENS: usize = 150;
23
24pub fn request_prediction_with_zeta2(
25 store: &mut EditPredictionStore,
26 EditPredictionModelInput {
27 buffer,
28 snapshot,
29 position,
30 related_files,
31 events,
32 debug_tx,
33 ..
34 }: EditPredictionModelInput,
35 zeta_version: ZetaVersion,
36 cx: &mut Context<EditPredictionStore>,
37) -> Task<Result<Option<EditPredictionResult>>> {
38 let buffer_snapshotted_at = Instant::now();
39 let url = store.custom_predict_edits_url.clone();
40
41 let Some(excerpt_path) = snapshot
42 .file()
43 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
44 else {
45 return Task::ready(Err(anyhow!("No file path for excerpt")));
46 };
47
48 let client = store.client.clone();
49 let llm_token = store.llm_token.clone();
50 let app_version = AppVersion::global(cx);
51
52 #[cfg(feature = "cli-support")]
53 let eval_cache = store.eval_cache.clone();
54
55 let request_task = cx.background_spawn({
56 async move {
57 let cursor_offset = position.to_offset(&snapshot);
58 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
59 &snapshot,
60 related_files,
61 events,
62 excerpt_path,
63 cursor_offset,
64 );
65
66 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
67
68 if let Some(debug_tx) = &debug_tx {
69 debug_tx
70 .unbounded_send(DebugEvent::EditPredictionStarted(
71 EditPredictionStartedDebugEvent {
72 buffer: buffer.downgrade(),
73 prompt: Some(prompt.clone()),
74 position,
75 },
76 ))
77 .ok();
78 }
79
80 let request = RawCompletionRequest {
81 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
82 prompt,
83 temperature: None,
84 stop: vec![],
85 max_tokens: Some(2048),
86 };
87
88 log::trace!("Sending edit prediction request");
89
90 let response = EditPredictionStore::send_raw_llm_request(
91 request,
92 client,
93 url,
94 llm_token,
95 app_version,
96 #[cfg(feature = "cli-support")]
97 eval_cache,
98 #[cfg(feature = "cli-support")]
99 EvalCacheEntryKind::Prediction,
100 )
101 .await;
102 let received_response_at = Instant::now();
103
104 log::trace!("Got edit prediction response");
105
106 let (mut res, usage) = response?;
107 let request_id = EditPredictionId(res.id.clone().into());
108 let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
109 return Ok((Some((request_id, None)), usage));
110 };
111
112 if let Some(debug_tx) = &debug_tx {
113 debug_tx
114 .unbounded_send(DebugEvent::EditPredictionFinished(
115 EditPredictionFinishedDebugEvent {
116 buffer: buffer.downgrade(),
117 position,
118 model_output: Some(output_text.clone()),
119 },
120 ))
121 .ok();
122 }
123
124 if output_text.contains(CURSOR_MARKER) {
125 log::trace!("Stripping out {CURSOR_MARKER} from response");
126 output_text = output_text.replace(CURSOR_MARKER, "");
127 }
128
129 let mut old_text = snapshot
130 .text_for_range(editable_offset_range.clone())
131 .collect::<String>();
132
133 if !output_text.is_empty() && !output_text.ends_with('\n') {
134 output_text.push('\n');
135 }
136 if !old_text.is_empty() && !old_text.ends_with('\n') {
137 old_text.push('\n');
138 }
139
140 let edits: Vec<_> = language::text_diff(&old_text, &output_text)
141 .into_iter()
142 .map(|(range, text)| {
143 (
144 snapshot.anchor_after(editable_offset_range.start + range.start)
145 ..snapshot.anchor_before(editable_offset_range.start + range.end),
146 text,
147 )
148 })
149 .collect();
150
151 anyhow::Ok((
152 Some((
153 request_id,
154 Some((
155 prompt_input,
156 buffer,
157 snapshot.clone(),
158 edits,
159 received_response_at,
160 )),
161 )),
162 usage,
163 ))
164 }
165 });
166
167 cx.spawn(async move |this, cx| {
168 let Some((id, prediction)) =
169 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
170 else {
171 return Ok(None);
172 };
173
174 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
175 prediction
176 else {
177 return Ok(Some(EditPredictionResult {
178 id,
179 prediction: Err(EditPredictionRejectReason::Empty),
180 }));
181 };
182
183 Ok(Some(
184 EditPredictionResult::new(
185 id,
186 &edited_buffer,
187 &edited_buffer_snapshot,
188 edits.into(),
189 buffer_snapshotted_at,
190 received_response_at,
191 inputs,
192 cx,
193 )
194 .await,
195 ))
196 })
197}
198
199pub fn zeta2_prompt_input(
200 snapshot: &language::BufferSnapshot,
201 related_files: Vec<zeta_prompt::RelatedFile>,
202 events: Vec<Arc<zeta_prompt::Event>>,
203 excerpt_path: Arc<Path>,
204 cursor_offset: usize,
205) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
206 let cursor_point = cursor_offset.to_point(snapshot);
207
208 let (editable_range, context_range) =
209 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
210 cursor_point,
211 snapshot,
212 MAX_EDITABLE_TOKENS,
213 MAX_CONTEXT_TOKENS,
214 );
215
216 let related_files = crate::filter_redundant_excerpts(
217 related_files,
218 excerpt_path.as_ref(),
219 context_range.start.row..context_range.end.row,
220 );
221
222 let context_start_offset = context_range.start.to_offset(snapshot);
223 let editable_offset_range = editable_range.to_offset(snapshot);
224 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
225 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
226 ..(editable_offset_range.end - context_start_offset);
227
228 let prompt_input = zeta_prompt::ZetaPromptInput {
229 cursor_path: excerpt_path,
230 cursor_excerpt: snapshot
231 .text_for_range(context_range)
232 .collect::<String>()
233 .into(),
234 editable_range_in_excerpt,
235 cursor_offset_in_excerpt,
236 events,
237 related_files,
238 };
239 (editable_offset_range, prompt_input)
240}
241
242pub(crate) fn edit_prediction_accepted(
243 store: &EditPredictionStore,
244 current_prediction: CurrentEditPrediction,
245 cx: &App,
246) {
247 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
248 if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
249 return;
250 }
251
252 let request_id = current_prediction.prediction.id.to_string();
253 let require_auth = custom_accept_url.is_none();
254 let client = store.client.clone();
255 let llm_token = store.llm_token.clone();
256 let app_version = AppVersion::global(cx);
257
258 cx.background_spawn(async move {
259 let url = if let Some(accept_edits_url) = custom_accept_url {
260 gpui::http_client::Url::parse(&accept_edits_url)?
261 } else {
262 client
263 .http_client()
264 .build_zed_llm_url("/predict_edits/accept", &[])?
265 };
266
267 let response = EditPredictionStore::send_api_request::<()>(
268 move |builder| {
269 let req = builder.uri(url.as_ref()).body(
270 serde_json::to_string(&AcceptEditPredictionBody {
271 request_id: request_id.clone(),
272 })?
273 .into(),
274 );
275 Ok(req?)
276 },
277 client,
278 llm_token,
279 app_version,
280 require_auth,
281 )
282 .await;
283
284 response?;
285 anyhow::Ok(())
286 })
287 .detach_and_log_err(cx);
288}