1#[cfg(feature = "cli-support")]
2use crate::EvalCacheEntryKind;
3use crate::prediction::EditPredictionResult;
4use crate::{
5 CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
6 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
7 EditPredictionStore,
8};
9use anyhow::{Result, anyhow};
10use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
11use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
12use gpui::{App, Task, prelude::*};
13use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
14use release_channel::AppVersion;
15
16use std::env;
17use std::{path::Path, sync::Arc, time::Instant};
18use zeta_prompt::CURSOR_MARKER;
19use zeta_prompt::format_zeta_prompt;
20
21pub const MAX_CONTEXT_TOKENS: usize = 350;
22pub const MAX_EDITABLE_TOKENS: usize = 150;
23
24pub fn request_prediction_with_zeta2(
25 store: &mut EditPredictionStore,
26 EditPredictionModelInput {
27 buffer,
28 snapshot,
29 position,
30 related_files,
31 events,
32 debug_tx,
33 ..
34 }: EditPredictionModelInput,
35 cx: &mut Context<EditPredictionStore>,
36) -> Task<Result<Option<EditPredictionResult>>> {
37 let buffer_snapshotted_at = Instant::now();
38 let url = store.custom_predict_edits_url.clone();
39
40 let Some(excerpt_path) = snapshot
41 .file()
42 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
43 else {
44 return Task::ready(Err(anyhow!("No file path for excerpt")));
45 };
46
47 let client = store.client.clone();
48 let llm_token = store.llm_token.clone();
49 let app_version = AppVersion::global(cx);
50
51 #[cfg(feature = "cli-support")]
52 let eval_cache = store.eval_cache.clone();
53
54 let request_task = cx.background_spawn({
55 async move {
56 let cursor_offset = position.to_offset(&snapshot);
57 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
58 &snapshot,
59 related_files,
60 events,
61 excerpt_path,
62 cursor_offset,
63 );
64
65 let prompt = format_zeta_prompt(&prompt_input);
66
67 if let Some(debug_tx) = &debug_tx {
68 debug_tx
69 .unbounded_send(DebugEvent::EditPredictionStarted(
70 EditPredictionStartedDebugEvent {
71 buffer: buffer.downgrade(),
72 prompt: Some(prompt.clone()),
73 position,
74 },
75 ))
76 .ok();
77 }
78
79 let request = RawCompletionRequest {
80 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
81 prompt,
82 temperature: None,
83 stop: vec![],
84 max_tokens: 1024,
85 };
86
87 log::trace!("Sending edit prediction request");
88
89 let response = EditPredictionStore::send_raw_llm_request(
90 request,
91 client,
92 url,
93 llm_token,
94 app_version,
95 #[cfg(feature = "cli-support")]
96 eval_cache,
97 #[cfg(feature = "cli-support")]
98 EvalCacheEntryKind::Prediction,
99 )
100 .await;
101 let received_response_at = Instant::now();
102
103 log::trace!("Got edit prediction response");
104
105 let (mut res, usage) = response?;
106 let request_id = EditPredictionId(res.id.clone().into());
107 let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
108 return Ok((Some((request_id, None)), usage));
109 };
110
111 if let Some(debug_tx) = &debug_tx {
112 debug_tx
113 .unbounded_send(DebugEvent::EditPredictionFinished(
114 EditPredictionFinishedDebugEvent {
115 buffer: buffer.downgrade(),
116 position,
117 model_output: Some(output_text.clone()),
118 },
119 ))
120 .ok();
121 }
122
123 if output_text.contains(CURSOR_MARKER) {
124 log::trace!("Stripping out {CURSOR_MARKER} from response");
125 output_text = output_text.replace(CURSOR_MARKER, "");
126 }
127
128 let old_text = snapshot
129 .text_for_range(editable_offset_range.clone())
130 .collect::<String>();
131 let edits: Vec<_> = language::text_diff(&old_text, &output_text)
132 .into_iter()
133 .map(|(range, text)| {
134 (
135 snapshot.anchor_after(editable_offset_range.start + range.start)
136 ..snapshot.anchor_before(editable_offset_range.start + range.end),
137 text,
138 )
139 })
140 .collect();
141
142 anyhow::Ok((
143 Some((
144 request_id,
145 Some((
146 prompt_input,
147 buffer,
148 snapshot.clone(),
149 edits,
150 received_response_at,
151 )),
152 )),
153 usage,
154 ))
155 }
156 });
157
158 cx.spawn(async move |this, cx| {
159 let Some((id, prediction)) =
160 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
161 else {
162 return Ok(None);
163 };
164
165 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
166 prediction
167 else {
168 return Ok(Some(EditPredictionResult {
169 id,
170 prediction: Err(EditPredictionRejectReason::Empty),
171 }));
172 };
173
174 Ok(Some(
175 EditPredictionResult::new(
176 id,
177 &edited_buffer,
178 &edited_buffer_snapshot,
179 edits.into(),
180 buffer_snapshotted_at,
181 received_response_at,
182 inputs,
183 cx,
184 )
185 .await,
186 ))
187 })
188}
189
190pub fn zeta2_prompt_input(
191 snapshot: &language::BufferSnapshot,
192 related_files: Arc<[zeta_prompt::RelatedFile]>,
193 events: Vec<Arc<zeta_prompt::Event>>,
194 excerpt_path: Arc<Path>,
195 cursor_offset: usize,
196) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
197 let cursor_point = cursor_offset.to_point(snapshot);
198
199 let (editable_range, context_range) =
200 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
201 cursor_point,
202 snapshot,
203 MAX_EDITABLE_TOKENS,
204 MAX_CONTEXT_TOKENS,
205 );
206
207 let context_start_offset = context_range.start.to_offset(snapshot);
208 let editable_offset_range = editable_range.to_offset(snapshot);
209 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
210 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
211 ..(editable_offset_range.end - context_start_offset);
212
213 let prompt_input = zeta_prompt::ZetaPromptInput {
214 cursor_path: excerpt_path,
215 cursor_excerpt: snapshot
216 .text_for_range(context_range)
217 .collect::<String>()
218 .into(),
219 editable_range_in_excerpt,
220 cursor_offset_in_excerpt,
221 events,
222 related_files,
223 };
224 (editable_offset_range, prompt_input)
225}
226
227pub(crate) fn edit_prediction_accepted(
228 store: &EditPredictionStore,
229 current_prediction: CurrentEditPrediction,
230 cx: &App,
231) {
232 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
233 if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
234 return;
235 }
236
237 let request_id = current_prediction.prediction.id.to_string();
238 let require_auth = custom_accept_url.is_none();
239 let client = store.client.clone();
240 let llm_token = store.llm_token.clone();
241 let app_version = AppVersion::global(cx);
242
243 cx.background_spawn(async move {
244 let url = if let Some(accept_edits_url) = custom_accept_url {
245 gpui::http_client::Url::parse(&accept_edits_url)?
246 } else {
247 client
248 .http_client()
249 .build_zed_llm_url("/predict_edits/accept", &[])?
250 };
251
252 let response = EditPredictionStore::send_api_request::<()>(
253 move |builder| {
254 let req = builder.uri(url.as_ref()).body(
255 serde_json::to_string(&AcceptEditPredictionBody {
256 request_id: request_id.clone(),
257 })?
258 .into(),
259 );
260 Ok(req?)
261 },
262 client,
263 llm_token,
264 app_version,
265 require_auth,
266 )
267 .await;
268
269 response?;
270 anyhow::Ok(())
271 })
272 .detach_and_log_err(cx);
273}