1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23
24pub async fn run_prediction(
25 example: &mut Example,
26 provider: Option<PredictionProvider>,
27 repetition_count: usize,
28 app_state: Arc<EpAppState>,
29 mut cx: AsyncApp,
30) -> anyhow::Result<()> {
31 if !example.predictions.is_empty() {
32 return Ok(());
33 }
34
35 let provider = provider.context("provider is required")?;
36
37 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
38
39 if matches!(
40 provider,
41 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
42 ) {
43 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
44
45 if example.prompt.is_none() {
46 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
47 }
48
49 let batched = matches!(provider, PredictionProvider::Teacher);
50 return predict_anthropic(example, repetition_count, batched).await;
51 }
52
53 run_load_project(example, app_state.clone(), cx.clone()).await?;
54
55 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
56
57 if matches!(
58 provider,
59 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
60 ) {
61 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
62 AUTHENTICATED
63 .get_or_init(|| {
64 let client = app_state.client.clone();
65 cx.spawn(async move |cx| {
66 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
67 eprintln!("Authentication failed: {}", e);
68 }
69 })
70 .shared()
71 })
72 .clone()
73 .await;
74 }
75
76 let ep_store = cx.update(|cx| {
77 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
78 })??;
79
80 ep_store.update(&mut cx, |store, _cx| {
81 let model = match provider {
82 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
83 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
84 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
85 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
86 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
87 unreachable!()
88 }
89 };
90 store.set_edit_prediction_model(model);
91 })?;
92 let state = example.state.as_ref().context("state must be set")?;
93 let run_dir = RUN_DIR.join(&example.spec.name);
94
95 let updated_example = Arc::new(Mutex::new(example.clone()));
96 let current_run_ix = Arc::new(AtomicUsize::new(0));
97
98 let mut debug_rx =
99 ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
100 let debug_task = cx.background_spawn({
101 let updated_example = updated_example.clone();
102 let current_run_ix = current_run_ix.clone();
103 let run_dir = run_dir.clone();
104 async move {
105 while let Some(event) = debug_rx.next().await {
106 let run_ix = current_run_ix.load(SeqCst);
107 let mut updated_example = updated_example.lock().unwrap();
108
109 let run_dir = if repetition_count > 1 {
110 run_dir.join(format!("{:03}", run_ix))
111 } else {
112 run_dir.clone()
113 };
114
115 match event {
116 DebugEvent::EditPredictionStarted(request) => {
117 assert_eq!(updated_example.predictions.len(), run_ix + 1);
118
119 if let Some(prompt) = request.prompt {
120 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
121 }
122 }
123 DebugEvent::EditPredictionFinished(request) => {
124 assert_eq!(updated_example.predictions.len(), run_ix + 1);
125
126 if let Some(output) = request.model_output {
127 fs::write(run_dir.join("prediction_response.md"), &output)?;
128 updated_example
129 .predictions
130 .last_mut()
131 .unwrap()
132 .actual_output = output;
133 }
134 if run_ix >= repetition_count {
135 break;
136 }
137 }
138 _ => {}
139 }
140 }
141 anyhow::Ok(())
142 }
143 });
144
145 for ix in 0..repetition_count {
146 current_run_ix.store(ix, SeqCst);
147 let run_dir = if repetition_count > 1 {
148 run_dir.join(format!("{:03}", ix))
149 } else {
150 run_dir.clone()
151 };
152
153 fs::create_dir_all(&run_dir)?;
154 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
155 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
156 }
157 #[cfg(unix)]
158 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
159 #[cfg(windows)]
160 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
161
162 updated_example
163 .lock()
164 .unwrap()
165 .predictions
166 .push(ExamplePrediction {
167 actual_patch: String::new(),
168 actual_output: String::new(),
169 provider,
170 });
171
172 let prediction = ep_store
173 .update(&mut cx, |store, cx| {
174 store.request_prediction(
175 &state.project,
176 &state.buffer,
177 state.cursor_position,
178 cloud_llm_client::PredictEditsRequestTrigger::Cli,
179 cx,
180 )
181 })?
182 .await?;
183
184 let actual_patch = prediction
185 .and_then(|prediction| {
186 let prediction = prediction.prediction.ok()?;
187 prediction
188 .edit_preview
189 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
190 })
191 .unwrap_or_default();
192
193 let has_prediction = !actual_patch.is_empty();
194
195 updated_example
196 .lock()
197 .unwrap()
198 .predictions
199 .last_mut()
200 .unwrap()
201 .actual_patch = actual_patch;
202
203 if ix == repetition_count - 1 {
204 let (info, style) = if has_prediction {
205 ("predicted", InfoStyle::Normal)
206 } else {
207 ("no prediction", InfoStyle::Warning)
208 };
209 _step_progress.set_info(info, style);
210 }
211 }
212
213 ep_store.update(&mut cx, |store, _| {
214 store.remove_project(&state.project);
215 })?;
216 debug_task.await?;
217
218 *example = Arc::into_inner(updated_example)
219 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
220 .into_inner()
221 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
222 Ok(())
223}
224
225async fn predict_anthropic(
226 example: &mut Example,
227 _repetition_count: usize,
228 batched: bool,
229) -> anyhow::Result<()> {
230 let llm_model_name = "claude-sonnet-4-5";
231 let max_tokens = 16384;
232 let llm_client = if batched {
233 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
234 } else {
235 AnthropicClient::plain()
236 };
237 let llm_client = llm_client.context("Failed to create LLM client")?;
238
239 let prompt = example.prompt.as_ref().context("Prompt is required")?;
240
241 let messages = vec![anthropic::Message {
242 role: anthropic::Role::User,
243 content: vec![anthropic::RequestContent::Text {
244 text: prompt.input.clone(),
245 cache_control: None,
246 }],
247 }];
248
249 let Some(response) = llm_client
250 .generate(llm_model_name, max_tokens, messages)
251 .await?
252 else {
253 // Request stashed for batched processing
254 return Ok(());
255 };
256
257 let actual_output = response
258 .content
259 .into_iter()
260 .filter_map(|content| match content {
261 anthropic::ResponseContent::Text { text } => Some(text),
262 _ => None,
263 })
264 .collect::<Vec<String>>()
265 .join("\n");
266
267 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
268
269 let prediction = ExamplePrediction {
270 actual_patch,
271 actual_output,
272 provider: PredictionProvider::Teacher,
273 };
274
275 example.predictions.push(prediction);
276 Ok(())
277}
278
279pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
280 match provider {
281 PredictionProvider::Teacher => {
282 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
283 let llm_client =
284 AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
285 llm_client
286 .sync_batches()
287 .await
288 .context("Failed to sync batches")?;
289 }
290 _ => (),
291 };
292 Ok(())
293}