1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use edit_prediction::{DebugEvent, EditPredictionStore};
13use futures::{FutureExt as _, StreamExt as _, future::Shared};
14use gpui::{AppContext as _, AsyncApp, Task};
15use std::{
16 fs,
17 sync::{
18 Arc, Mutex, OnceLock,
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 },
21};
22
23pub async fn run_prediction(
24 example: &mut Example,
25 provider: Option<PredictionProvider>,
26 repetition_count: usize,
27 app_state: Arc<EpAppState>,
28 progress: Arc<Progress>,
29 mut cx: AsyncApp,
30) {
31 if !example.predictions.is_empty() {
32 return;
33 }
34
35 let provider = provider.unwrap();
36
37 run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
38
39 if matches!(
40 provider,
41 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
42 ) {
43 let _step_progress = progress.start(Step::Predict, &example.name);
44
45 if example.prompt.is_none() {
46 run_format_prompt(
47 example,
48 PromptFormat::Teacher,
49 app_state.clone(),
50 progress,
51 cx,
52 )
53 .await;
54 }
55
56 let batched = matches!(provider, PredictionProvider::Teacher);
57 return predict_anthropic(example, repetition_count, batched).await;
58 }
59
60 run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
61
62 let _step_progress = progress.start(Step::Predict, &example.name);
63
64 if matches!(
65 provider,
66 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
67 ) {
68 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
69 AUTHENTICATED
70 .get_or_init(|| {
71 let client = app_state.client.clone();
72 cx.spawn(async move |cx| {
73 client
74 .sign_in_with_optional_connect(true, cx)
75 .await
76 .unwrap();
77 })
78 .shared()
79 })
80 .clone()
81 .await;
82 }
83
84 let ep_store = cx
85 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
86 .unwrap();
87
88 ep_store
89 .update(&mut cx, |store, _cx| {
90 let model = match provider {
91 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
92 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
93 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
94 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
95 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
96 unreachable!()
97 }
98 };
99 store.set_edit_prediction_model(model);
100 })
101 .unwrap();
102 let state = example.state.as_ref().unwrap();
103 let run_dir = RUN_DIR.join(&example.name);
104
105 let updated_example = Arc::new(Mutex::new(example.clone()));
106 let current_run_ix = Arc::new(AtomicUsize::new(0));
107
108 let mut debug_rx = ep_store
109 .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
110 .unwrap();
111 let debug_task = cx.background_spawn({
112 let updated_example = updated_example.clone();
113 let current_run_ix = current_run_ix.clone();
114 let run_dir = run_dir.clone();
115 async move {
116 while let Some(event) = debug_rx.next().await {
117 let run_ix = current_run_ix.load(SeqCst);
118 let mut updated_example = updated_example.lock().unwrap();
119
120 let run_dir = if repetition_count > 1 {
121 run_dir.join(format!("{:03}", run_ix))
122 } else {
123 run_dir.clone()
124 };
125
126 match event {
127 DebugEvent::EditPredictionStarted(request) => {
128 assert_eq!(updated_example.predictions.len(), run_ix + 1);
129
130 if let Some(prompt) = request.prompt {
131 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
132 }
133 }
134 DebugEvent::EditPredictionFinished(request) => {
135 assert_eq!(updated_example.predictions.len(), run_ix + 1);
136
137 if let Some(output) = request.model_output {
138 fs::write(run_dir.join("prediction_response.md"), &output)?;
139 updated_example
140 .predictions
141 .last_mut()
142 .unwrap()
143 .actual_output = output;
144 }
145 if run_ix >= repetition_count {
146 break;
147 }
148 }
149 _ => {}
150 }
151 }
152 anyhow::Ok(())
153 }
154 });
155
156 for ix in 0..repetition_count {
157 current_run_ix.store(ix, SeqCst);
158 let run_dir = if repetition_count > 1 {
159 run_dir.join(format!("{:03}", ix))
160 } else {
161 run_dir.clone()
162 };
163
164 fs::create_dir_all(&run_dir).unwrap();
165 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
166 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
167 }
168 #[cfg(unix)]
169 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
170 #[cfg(windows)]
171 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
172
173 updated_example
174 .lock()
175 .unwrap()
176 .predictions
177 .push(ExamplePrediction {
178 actual_patch: String::new(),
179 actual_output: String::new(),
180 provider,
181 });
182
183 let prediction = ep_store
184 .update(&mut cx, |store, cx| {
185 store.request_prediction(
186 &state.project,
187 &state.buffer,
188 state.cursor_position,
189 cloud_llm_client::PredictEditsRequestTrigger::Cli,
190 cx,
191 )
192 })
193 .unwrap()
194 .await
195 .unwrap();
196
197 let actual_patch = prediction
198 .and_then(|prediction| {
199 let prediction = prediction.prediction.ok()?;
200 prediction.edit_preview.as_unified_diff(&prediction.edits)
201 })
202 .unwrap_or_default();
203
204 let has_prediction = !actual_patch.is_empty();
205
206 updated_example
207 .lock()
208 .unwrap()
209 .predictions
210 .last_mut()
211 .unwrap()
212 .actual_patch = actual_patch;
213
214 if ix == repetition_count - 1 {
215 let (info, style) = if has_prediction {
216 ("predicted", InfoStyle::Normal)
217 } else {
218 ("no prediction", InfoStyle::Warning)
219 };
220 _step_progress.set_info(info, style);
221 }
222 }
223
224 ep_store
225 .update(&mut cx, |store, _| {
226 store.remove_project(&state.project);
227 })
228 .unwrap();
229 debug_task.await.unwrap();
230
231 *example = Arc::into_inner(updated_example)
232 .unwrap()
233 .into_inner()
234 .unwrap();
235}
236
237async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
238 let llm_model_name = "claude-sonnet-4-5";
239 let max_tokens = 16384;
240 let llm_client = if batched {
241 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
242 } else {
243 AnthropicClient::plain()
244 };
245 let llm_client = llm_client.expect("Failed to create LLM client");
246
247 let prompt = example
248 .prompt
249 .as_ref()
250 .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
251
252 let messages = vec![anthropic::Message {
253 role: anthropic::Role::User,
254 content: vec![anthropic::RequestContent::Text {
255 text: prompt.input.clone(),
256 cache_control: None,
257 }],
258 }];
259
260 let Some(response) = llm_client
261 .generate(llm_model_name, max_tokens, messages)
262 .await
263 .unwrap()
264 else {
265 // Request stashed for batched processing
266 return;
267 };
268
269 let actual_output = response
270 .content
271 .into_iter()
272 .filter_map(|content| match content {
273 anthropic::ResponseContent::Text { text } => Some(text),
274 _ => None,
275 })
276 .collect::<Vec<String>>()
277 .join("\n");
278
279 let actual_patch = TeacherPrompt::parse(example, &actual_output);
280
281 let prediction = ExamplePrediction {
282 actual_patch,
283 actual_output,
284 provider: PredictionProvider::Teacher,
285 };
286
287 example.predictions.push(prediction);
288}
289
290pub async fn sync_batches(provider: &PredictionProvider) {
291 match provider {
292 PredictionProvider::Teacher => {
293 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
294 let llm_client =
295 AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
296 llm_client
297 .sync_batches()
298 .await
299 .expect("Failed to sync batches");
300 }
301 _ => (),
302 }
303}