1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
10 progress::{ExampleProgress, InfoStyle, Step},
11 retrieve_context::run_context_retrieval,
12};
13use anyhow::Context as _;
14use edit_prediction::{DebugEvent, EditPredictionStore};
15use futures::{FutureExt as _, StreamExt as _, future::Shared};
16use gpui::{AppContext as _, AsyncApp, Task};
17use std::{
18 fs,
19 sync::{
20 Arc, Mutex, OnceLock,
21 atomic::{AtomicUsize, Ordering::SeqCst},
22 },
23};
24
25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
26static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
27
28pub async fn run_prediction(
29 example: &mut Example,
30 args: &PredictArgs,
31 app_state: Arc<EpAppState>,
32 example_progress: &ExampleProgress,
33 mut cx: AsyncApp,
34) -> anyhow::Result<()> {
35 let repetition_count = args.repetitions;
36
37 if let Some(existing_prediction) = example.predictions.first() {
38 let has_prediction = existing_prediction.actual_patch.is_some()
39 || !existing_prediction.actual_output.is_empty();
40 if has_prediction {
41 match args.provider {
42 None => return Ok(()),
43 Some(provider) if existing_prediction.provider == provider => return Ok(()),
44 Some(_) => example.predictions.clear(),
45 }
46 }
47 }
48
49 let Some(provider) = args.provider else {
50 anyhow::bail!(
51 "No existing predictions found. Use --provider to specify which model to use for prediction."
52 );
53 };
54
55 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
56
57 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
58 provider
59 {
60 let _step_progress = example_progress.start(Step::Predict);
61
62 run_format_prompt(
63 example,
64 &FormatPromptArgs { provider },
65 app_state.clone(),
66 example_progress,
67 cx,
68 )
69 .await?;
70
71 let batched = matches!(provider, PredictionProvider::Teacher(..));
72 return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
73 }
74
75 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
76
77 let step_progress = example_progress.start(Step::Predict);
78
79 if matches!(
80 provider,
81 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
82 ) {
83 step_progress.set_substatus("authenticating");
84 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
85 AUTHENTICATED
86 .get_or_init(|| {
87 let client = app_state.client.clone();
88 cx.spawn(async move |cx| {
89 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
90 eprintln!("Authentication failed: {}", e);
91 }
92 })
93 .shared()
94 })
95 .clone()
96 .await;
97 }
98
99 let ep_store = cx
100 .update(|cx| EditPredictionStore::try_global(cx))
101 .context("EditPredictionStore not initialized")?;
102
103 ep_store.update(&mut cx, |store, _cx| {
104 let model = match provider {
105 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
106 PredictionProvider::Zeta2(version) => {
107 edit_prediction::EditPredictionModel::Zeta2 { version }
108 }
109 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
110 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
111 PredictionProvider::Teacher(..)
112 | PredictionProvider::TeacherNonBatching(..)
113 | PredictionProvider::Repair => {
114 unreachable!()
115 }
116 };
117 store.set_edit_prediction_model(model);
118 });
119 step_progress.set_substatus("configuring model");
120 let state = example.state.as_ref().context("state must be set")?;
121 let run_dir = RUN_DIR.join(&example.spec.name);
122
123 let updated_example = Arc::new(Mutex::new(example.clone()));
124 let current_run_ix = Arc::new(AtomicUsize::new(0));
125
126 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
127 let debug_task = cx.background_spawn({
128 let updated_example = updated_example.clone();
129 let current_run_ix = current_run_ix.clone();
130 let run_dir = run_dir.clone();
131 async move {
132 while let Some(event) = debug_rx.next().await {
133 let run_ix = current_run_ix.load(SeqCst);
134 let mut updated_example = updated_example.lock().unwrap();
135
136 let run_dir = if repetition_count > 1 {
137 run_dir.join(format!("{:03}", run_ix))
138 } else {
139 run_dir.clone()
140 };
141
142 match event {
143 DebugEvent::EditPredictionStarted(request) => {
144 assert_eq!(updated_example.predictions.len(), run_ix + 1);
145
146 if let Some(prompt) = request.prompt {
147 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
148 if matches!(provider, PredictionProvider::Zeta2(_)) {
149 updated_example.prompt.get_or_insert(ExamplePrompt {
150 input: prompt,
151 expected_output: String::new(),
152 rejected_output: None,
153 provider,
154 });
155 }
156 }
157 }
158 DebugEvent::EditPredictionFinished(request) => {
159 assert_eq!(updated_example.predictions.len(), run_ix + 1);
160
161 if let Some(output) = request.model_output {
162 fs::write(run_dir.join("prediction_response.md"), &output)?;
163 updated_example
164 .predictions
165 .last_mut()
166 .unwrap()
167 .actual_output = output;
168 }
169 if run_ix >= repetition_count {
170 break;
171 }
172 }
173 _ => {}
174 }
175 }
176 anyhow::Ok(())
177 }
178 });
179
180 for ix in 0..repetition_count {
181 current_run_ix.store(ix, SeqCst);
182 let run_dir = if repetition_count > 1 {
183 run_dir.join(format!("{:03}", ix))
184 } else {
185 run_dir.clone()
186 };
187
188 fs::create_dir_all(&run_dir)?;
189 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
190 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
191 }
192 #[cfg(unix)]
193 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
194 #[cfg(windows)]
195 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
196
197 updated_example
198 .lock()
199 .unwrap()
200 .predictions
201 .push(ExamplePrediction {
202 actual_patch: None,
203 actual_output: String::new(),
204 actual_cursor_offset: None,
205 error: None,
206 provider,
207 });
208
209 step_progress.set_substatus("requesting prediction");
210 let prediction = ep_store
211 .update(&mut cx, |store, cx| {
212 store.request_prediction(
213 &state.project,
214 &state.buffer,
215 state.cursor_position,
216 cloud_llm_client::PredictEditsRequestTrigger::Cli,
217 cx,
218 )
219 })
220 .await?;
221
222 let actual_patch = prediction.and_then(|prediction| {
223 let prediction = prediction.prediction.ok()?;
224 prediction
225 .edit_preview
226 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
227 });
228
229 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
230
231 updated_example
232 .lock()
233 .unwrap()
234 .predictions
235 .last_mut()
236 .unwrap()
237 .actual_patch = actual_patch;
238
239 if ix == repetition_count - 1 {
240 let (info, style) = if has_prediction {
241 ("predicted", InfoStyle::Normal)
242 } else {
243 ("no prediction", InfoStyle::Warning)
244 };
245 step_progress.set_info(info, style);
246 }
247 }
248
249 ep_store.update(&mut cx, |store, _| {
250 store.remove_project(&state.project);
251 });
252 debug_task.await?;
253
254 *example = Arc::into_inner(updated_example)
255 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
256 .into_inner()
257 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
258 Ok(())
259}
260
261async fn predict_teacher(
262 example: &mut Example,
263 backend: TeacherBackend,
264 batched: bool,
265 repetition_count: usize,
266 cache_only: bool,
267) -> anyhow::Result<()> {
268 match backend {
269 TeacherBackend::Sonnet45 => {
270 predict_anthropic(example, backend, batched, repetition_count, cache_only).await
271 }
272 TeacherBackend::Gpt52 => {
273 predict_openai(example, backend, batched, repetition_count, cache_only).await
274 }
275 }
276}
277
278async fn predict_anthropic(
279 example: &mut Example,
280 backend: TeacherBackend,
281 batched: bool,
282 repetition_count: usize,
283 cache_only: bool,
284) -> anyhow::Result<()> {
285 let llm_model_name = backend.model_name();
286 let max_tokens = 16384;
287 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
288 let client = if batched {
289 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
290 } else {
291 AnthropicClient::plain()
292 };
293 client.expect("Failed to create Anthropic client")
294 });
295
296 let prompt = example.prompt.as_ref().context("Prompt is required")?;
297
298 for ix in 0..repetition_count {
299 let messages = vec![anthropic::Message {
300 role: anthropic::Role::User,
301 content: vec![anthropic::RequestContent::Text {
302 text: prompt.input.clone(),
303 cache_control: None,
304 }],
305 }];
306
307 let seed = if repetition_count > 1 { Some(ix) } else { None };
308 let Some(response) = llm_client
309 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
310 .await?
311 else {
312 // Request stashed for batched processing
313 return Ok(());
314 };
315
316 let actual_output = response
317 .content
318 .into_iter()
319 .filter_map(|content| match content {
320 anthropic::ResponseContent::Text { text } => Some(text),
321 _ => None,
322 })
323 .collect::<Vec<String>>()
324 .join("\n");
325
326 let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
327
328 let prediction = ExamplePrediction {
329 actual_patch: Some(actual_patch),
330 actual_output,
331 actual_cursor_offset,
332 error: None,
333 provider: if batched {
334 PredictionProvider::Teacher(backend)
335 } else {
336 PredictionProvider::TeacherNonBatching(backend)
337 },
338 };
339
340 example.predictions.push(prediction);
341 }
342 Ok(())
343}
344
345async fn predict_openai(
346 example: &mut Example,
347 backend: TeacherBackend,
348 batched: bool,
349 repetition_count: usize,
350 cache_only: bool,
351) -> anyhow::Result<()> {
352 let llm_model_name = backend.model_name();
353 let max_tokens = 16384;
354 let llm_client = OPENAI_CLIENT.get_or_init(|| {
355 let client = if batched {
356 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
357 } else {
358 OpenAiClient::plain()
359 };
360 client.expect("Failed to create OpenAI client")
361 });
362
363 let prompt = example.prompt.as_ref().context("Prompt is required")?;
364
365 for ix in 0..repetition_count {
366 let messages = vec![open_ai::RequestMessage::User {
367 content: open_ai::MessageContent::Plain(prompt.input.clone()),
368 }];
369
370 let seed = if repetition_count > 1 { Some(ix) } else { None };
371 let Some(response) = llm_client
372 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
373 .await?
374 else {
375 // Request stashed for batched processing
376 return Ok(());
377 };
378
379 let actual_output = response
380 .choices
381 .into_iter()
382 .filter_map(|choice| match choice.message {
383 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
384 open_ai::MessageContent::Plain(text) => text,
385 open_ai::MessageContent::Multipart(parts) => parts
386 .into_iter()
387 .filter_map(|p| match p {
388 open_ai::MessagePart::Text { text } => Some(text),
389 _ => None,
390 })
391 .collect::<Vec<_>>()
392 .join(""),
393 }),
394 _ => None,
395 })
396 .collect::<Vec<String>>()
397 .join("\n");
398
399 let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
400
401 let prediction = ExamplePrediction {
402 actual_patch: Some(actual_patch),
403 actual_output,
404 actual_cursor_offset,
405 error: None,
406 provider: if batched {
407 PredictionProvider::Teacher(backend)
408 } else {
409 PredictionProvider::TeacherNonBatching(backend)
410 },
411 };
412
413 example.predictions.push(prediction);
414 }
415 Ok(())
416}
417
418pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
419 match provider {
420 Some(PredictionProvider::Teacher(backend)) => match backend {
421 TeacherBackend::Sonnet45 => {
422 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
423 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
424 .expect("Failed to create Anthropic client")
425 });
426 llm_client
427 .sync_batches()
428 .await
429 .context("Failed to sync Anthropic batches")?;
430 }
431 TeacherBackend::Gpt52 => {
432 let llm_client = OPENAI_CLIENT.get_or_init(|| {
433 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
434 .expect("Failed to create OpenAI client")
435 });
436 llm_client
437 .sync_batches()
438 .await
439 .context("Failed to sync OpenAI batches")?;
440 }
441 },
442 _ => (),
443 };
444 Ok(())
445}