1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
10 progress::{ExampleProgress, InfoStyle, Step},
11 retrieve_context::run_context_retrieval,
12};
13use anyhow::Context as _;
14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
15use futures::{FutureExt as _, StreamExt as _, future::Shared};
16use gpui::{AppContext as _, AsyncApp, Task};
17use std::{
18 fs,
19 sync::{
20 Arc, Mutex, OnceLock,
21 atomic::{AtomicUsize, Ordering::SeqCst},
22 },
23};
24use zeta_prompt::ZetaFormat;
25
26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
28
29pub async fn run_prediction(
30 example: &mut Example,
31 args: &PredictArgs,
32 app_state: Arc<EpAppState>,
33 example_progress: &ExampleProgress,
34 mut cx: AsyncApp,
35) -> anyhow::Result<()> {
36 let repetition_count = args.repetitions;
37
38 if let Some(existing_prediction) = example.predictions.first() {
39 let has_prediction = existing_prediction.actual_patch.is_some()
40 || !existing_prediction.actual_output.is_empty();
41 if has_prediction {
42 match args.provider {
43 None => return Ok(()),
44 Some(provider) if existing_prediction.provider == provider => return Ok(()),
45 Some(_) => example.predictions.clear(),
46 }
47 }
48 }
49
50 let Some(provider) = args.provider else {
51 anyhow::bail!(
52 "No existing predictions found. Use --provider to specify which model to use for prediction."
53 );
54 };
55
56 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
57
58 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
59 provider
60 {
61 let _step_progress = example_progress.start(Step::Predict);
62
63 run_format_prompt(
64 example,
65 &FormatPromptArgs { provider },
66 app_state.clone(),
67 example_progress,
68 cx,
69 )
70 .await?;
71
72 let batched = matches!(provider, PredictionProvider::Teacher(..));
73 return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
74 }
75
76 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
77
78 let step_progress = example_progress.start(Step::Predict);
79
80 if matches!(
81 provider,
82 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
83 ) {
84 step_progress.set_substatus("authenticating");
85 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
86 AUTHENTICATED
87 .get_or_init(|| {
88 let client = app_state.client.clone();
89 cx.spawn(async move |cx| {
90 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
91 eprintln!("Authentication failed: {}", e);
92 }
93 })
94 .shared()
95 })
96 .clone()
97 .await;
98 }
99
100 let ep_store = cx
101 .update(|cx| EditPredictionStore::try_global(cx))
102 .context("EditPredictionStore not initialized")?;
103
104 ep_store.update(&mut cx, |store, _cx| {
105 let model = match provider {
106 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
107 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
108 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110 PredictionProvider::Teacher(..)
111 | PredictionProvider::TeacherNonBatching(..)
112 | PredictionProvider::Repair => {
113 unreachable!()
114 }
115 };
116 store.set_edit_prediction_model(model);
117
118 // If user specified a non-default Zeta2 version, configure raw endpoint.
119 // ZED_ZETA_MODEL env var is optional.
120 if let PredictionProvider::Zeta2(format) = provider {
121 if format != ZetaFormat::default() {
122 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
123 store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
124 }
125 }
126 });
127 step_progress.set_substatus("configuring model");
128 let state = example.state.as_ref().context("state must be set")?;
129 let run_dir = RUN_DIR.join(&example.spec.name);
130
131 let updated_example = Arc::new(Mutex::new(example.clone()));
132 let current_run_ix = Arc::new(AtomicUsize::new(0));
133
134 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
135 let debug_task = cx.background_spawn({
136 let updated_example = updated_example.clone();
137 let current_run_ix = current_run_ix.clone();
138 let run_dir = run_dir.clone();
139 async move {
140 while let Some(event) = debug_rx.next().await {
141 let run_ix = current_run_ix.load(SeqCst);
142 let mut updated_example = updated_example.lock().unwrap();
143
144 let run_dir = if repetition_count > 1 {
145 run_dir.join(format!("{:03}", run_ix))
146 } else {
147 run_dir.clone()
148 };
149
150 match event {
151 DebugEvent::EditPredictionStarted(request) => {
152 assert_eq!(updated_example.predictions.len(), run_ix + 1);
153
154 if let Some(prompt) = request.prompt {
155 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
156 if matches!(provider, PredictionProvider::Zeta2(_)) {
157 updated_example.prompt.get_or_insert(ExamplePrompt {
158 input: prompt,
159 expected_output: String::new(),
160 rejected_output: None,
161 provider,
162 });
163 }
164 }
165 }
166 DebugEvent::EditPredictionFinished(request) => {
167 assert_eq!(updated_example.predictions.len(), run_ix + 1);
168
169 if let Some(output) = request.model_output {
170 fs::write(run_dir.join("prediction_response.md"), &output)?;
171 updated_example
172 .predictions
173 .last_mut()
174 .unwrap()
175 .actual_output = output;
176 }
177 if run_ix >= repetition_count {
178 break;
179 }
180 }
181 _ => {}
182 }
183 }
184 anyhow::Ok(())
185 }
186 });
187
188 for ix in 0..repetition_count {
189 current_run_ix.store(ix, SeqCst);
190 let run_dir = if repetition_count > 1 {
191 run_dir.join(format!("{:03}", ix))
192 } else {
193 run_dir.clone()
194 };
195
196 fs::create_dir_all(&run_dir)?;
197 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
198 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
199 }
200 #[cfg(unix)]
201 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
202 #[cfg(windows)]
203 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
204
205 updated_example
206 .lock()
207 .unwrap()
208 .predictions
209 .push(ExamplePrediction {
210 actual_patch: None,
211 actual_output: String::new(),
212 actual_cursor: None,
213 error: None,
214 provider,
215 });
216
217 step_progress.set_substatus("requesting prediction");
218 let prediction = ep_store
219 .update(&mut cx, |store, cx| {
220 store.request_prediction(
221 &state.project,
222 &state.buffer,
223 state.cursor_position,
224 cloud_llm_client::PredictEditsRequestTrigger::Cli,
225 cx,
226 )
227 })
228 .await?;
229
230 let actual_patch = prediction.and_then(|prediction| {
231 let prediction = prediction.prediction.ok()?;
232 prediction
233 .edit_preview
234 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
235 });
236
237 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
238
239 updated_example
240 .lock()
241 .unwrap()
242 .predictions
243 .last_mut()
244 .unwrap()
245 .actual_patch = actual_patch;
246
247 if ix == repetition_count - 1 {
248 let (info, style) = if has_prediction {
249 ("predicted", InfoStyle::Normal)
250 } else {
251 ("no prediction", InfoStyle::Warning)
252 };
253 step_progress.set_info(info, style);
254 }
255 }
256
257 ep_store.update(&mut cx, |store, _| {
258 store.remove_project(&state.project);
259 });
260 debug_task.await?;
261
262 *example = Arc::into_inner(updated_example)
263 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
264 .into_inner()
265 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
266 Ok(())
267}
268
269async fn predict_teacher(
270 example: &mut Example,
271 backend: TeacherBackend,
272 batched: bool,
273 repetition_count: usize,
274 cache_only: bool,
275) -> anyhow::Result<()> {
276 match backend {
277 TeacherBackend::Sonnet45 => {
278 predict_anthropic(example, backend, batched, repetition_count, cache_only).await
279 }
280 TeacherBackend::Gpt52 => {
281 predict_openai(example, backend, batched, repetition_count, cache_only).await
282 }
283 }
284}
285
286async fn predict_anthropic(
287 example: &mut Example,
288 backend: TeacherBackend,
289 batched: bool,
290 repetition_count: usize,
291 cache_only: bool,
292) -> anyhow::Result<()> {
293 let llm_model_name = backend.model_name();
294 let max_tokens = 16384;
295 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
296 let client = if batched {
297 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
298 } else {
299 AnthropicClient::plain()
300 };
301 client.expect("Failed to create Anthropic client")
302 });
303
304 let prompt = example.prompt.as_ref().context("Prompt is required")?;
305
306 for ix in 0..repetition_count {
307 let messages = vec![anthropic::Message {
308 role: anthropic::Role::User,
309 content: vec![anthropic::RequestContent::Text {
310 text: prompt.input.clone(),
311 cache_control: None,
312 }],
313 }];
314
315 let seed = if repetition_count > 1 { Some(ix) } else { None };
316 let Some(response) = llm_client
317 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
318 .await?
319 else {
320 // Request stashed for batched processing
321 return Ok(());
322 };
323
324 let actual_output = response
325 .content
326 .into_iter()
327 .filter_map(|content| match content {
328 anthropic::ResponseContent::Text { text } => Some(text),
329 _ => None,
330 })
331 .collect::<Vec<String>>()
332 .join("\n");
333
334 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
335
336 let prediction = ExamplePrediction {
337 actual_patch: Some(actual_patch),
338 actual_output,
339 actual_cursor,
340 error: None,
341 provider: if batched {
342 PredictionProvider::Teacher(backend)
343 } else {
344 PredictionProvider::TeacherNonBatching(backend)
345 },
346 };
347
348 example.predictions.push(prediction);
349 }
350 Ok(())
351}
352
353async fn predict_openai(
354 example: &mut Example,
355 backend: TeacherBackend,
356 batched: bool,
357 repetition_count: usize,
358 cache_only: bool,
359) -> anyhow::Result<()> {
360 let llm_model_name = backend.model_name();
361 let max_tokens = 16384;
362 let llm_client = OPENAI_CLIENT.get_or_init(|| {
363 let client = if batched {
364 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
365 } else {
366 OpenAiClient::plain()
367 };
368 client.expect("Failed to create OpenAI client")
369 });
370
371 let prompt = example.prompt.as_ref().context("Prompt is required")?;
372
373 for ix in 0..repetition_count {
374 let messages = vec![open_ai::RequestMessage::User {
375 content: open_ai::MessageContent::Plain(prompt.input.clone()),
376 }];
377
378 let seed = if repetition_count > 1 { Some(ix) } else { None };
379 let Some(response) = llm_client
380 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
381 .await?
382 else {
383 // Request stashed for batched processing
384 return Ok(());
385 };
386
387 let actual_output = response
388 .choices
389 .into_iter()
390 .filter_map(|choice| match choice.message {
391 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
392 open_ai::MessageContent::Plain(text) => text,
393 open_ai::MessageContent::Multipart(parts) => parts
394 .into_iter()
395 .filter_map(|p| match p {
396 open_ai::MessagePart::Text { text } => Some(text),
397 _ => None,
398 })
399 .collect::<Vec<_>>()
400 .join(""),
401 }),
402 _ => None,
403 })
404 .collect::<Vec<String>>()
405 .join("\n");
406
407 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
408
409 let prediction = ExamplePrediction {
410 actual_patch: Some(actual_patch),
411 actual_output,
412 actual_cursor,
413 error: None,
414 provider: if batched {
415 PredictionProvider::Teacher(backend)
416 } else {
417 PredictionProvider::TeacherNonBatching(backend)
418 },
419 };
420
421 example.predictions.push(prediction);
422 }
423 Ok(())
424}
425
426pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
427 match provider {
428 Some(PredictionProvider::Teacher(backend)) => match backend {
429 TeacherBackend::Sonnet45 => {
430 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
431 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
432 .expect("Failed to create Anthropic client")
433 });
434 llm_client
435 .sync_batches()
436 .await
437 .context("Failed to sync Anthropic batches")?;
438 }
439 TeacherBackend::Gpt52 => {
440 let llm_client = OPENAI_CLIENT.get_or_init(|| {
441 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
442 .expect("Failed to create OpenAI client")
443 });
444 llm_client
445 .sync_batches()
446 .await
447 .context("Failed to sync OpenAI batches")?;
448 }
449 },
450 _ => (),
451 };
452 Ok(())
453}