1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use edit_prediction::{DebugEvent, EditPredictionStore};
13use futures::{FutureExt as _, StreamExt as _, future::Shared};
14use gpui::{AppContext as _, AsyncApp, Task};
15use std::{
16 fs,
17 sync::{
18 Arc, Mutex, OnceLock,
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 },
21};
22
23pub async fn run_prediction(
24 example: &mut Example,
25 provider: Option<PredictionProvider>,
26 repetition_count: usize,
27 app_state: Arc<EpAppState>,
28 mut cx: AsyncApp,
29) {
30 if !example.predictions.is_empty() {
31 return;
32 }
33
34 let provider = provider.unwrap();
35
36 run_context_retrieval(example, app_state.clone(), cx.clone()).await;
37
38 if matches!(
39 provider,
40 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
41 ) {
42 let _step_progress = Progress::global().start(Step::Predict, &example.name);
43
44 if example.prompt.is_none() {
45 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
46 }
47
48 let batched = matches!(provider, PredictionProvider::Teacher);
49 return predict_anthropic(example, repetition_count, batched).await;
50 }
51
52 run_load_project(example, app_state.clone(), cx.clone()).await;
53
54 let _step_progress = Progress::global().start(Step::Predict, &example.name);
55
56 if matches!(
57 provider,
58 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
59 ) {
60 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
61 AUTHENTICATED
62 .get_or_init(|| {
63 let client = app_state.client.clone();
64 cx.spawn(async move |cx| {
65 client
66 .sign_in_with_optional_connect(true, cx)
67 .await
68 .unwrap();
69 })
70 .shared()
71 })
72 .clone()
73 .await;
74 }
75
76 let ep_store = cx
77 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
78 .unwrap();
79
80 ep_store
81 .update(&mut cx, |store, _cx| {
82 let model = match provider {
83 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
84 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
85 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
86 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
87 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
88 unreachable!()
89 }
90 };
91 store.set_edit_prediction_model(model);
92 })
93 .unwrap();
94 let state = example.state.as_ref().unwrap();
95 let run_dir = RUN_DIR.join(&example.name);
96
97 let updated_example = Arc::new(Mutex::new(example.clone()));
98 let current_run_ix = Arc::new(AtomicUsize::new(0));
99
100 let mut debug_rx = ep_store
101 .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
102 .unwrap();
103 let debug_task = cx.background_spawn({
104 let updated_example = updated_example.clone();
105 let current_run_ix = current_run_ix.clone();
106 let run_dir = run_dir.clone();
107 async move {
108 while let Some(event) = debug_rx.next().await {
109 let run_ix = current_run_ix.load(SeqCst);
110 let mut updated_example = updated_example.lock().unwrap();
111
112 let run_dir = if repetition_count > 1 {
113 run_dir.join(format!("{:03}", run_ix))
114 } else {
115 run_dir.clone()
116 };
117
118 match event {
119 DebugEvent::EditPredictionStarted(request) => {
120 assert_eq!(updated_example.predictions.len(), run_ix + 1);
121
122 if let Some(prompt) = request.prompt {
123 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
124 }
125 }
126 DebugEvent::EditPredictionFinished(request) => {
127 assert_eq!(updated_example.predictions.len(), run_ix + 1);
128
129 if let Some(output) = request.model_output {
130 fs::write(run_dir.join("prediction_response.md"), &output)?;
131 updated_example
132 .predictions
133 .last_mut()
134 .unwrap()
135 .actual_output = output;
136 }
137 if run_ix >= repetition_count {
138 break;
139 }
140 }
141 _ => {}
142 }
143 }
144 anyhow::Ok(())
145 }
146 });
147
148 for ix in 0..repetition_count {
149 current_run_ix.store(ix, SeqCst);
150 let run_dir = if repetition_count > 1 {
151 run_dir.join(format!("{:03}", ix))
152 } else {
153 run_dir.clone()
154 };
155
156 fs::create_dir_all(&run_dir).unwrap();
157 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
158 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
159 }
160 #[cfg(unix)]
161 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
162 #[cfg(windows)]
163 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
164
165 updated_example
166 .lock()
167 .unwrap()
168 .predictions
169 .push(ExamplePrediction {
170 actual_patch: String::new(),
171 actual_output: String::new(),
172 provider,
173 });
174
175 let prediction = ep_store
176 .update(&mut cx, |store, cx| {
177 store.request_prediction(
178 &state.project,
179 &state.buffer,
180 state.cursor_position,
181 cloud_llm_client::PredictEditsRequestTrigger::Cli,
182 cx,
183 )
184 })
185 .unwrap()
186 .await
187 .unwrap();
188
189 let actual_patch = prediction
190 .and_then(|prediction| {
191 let prediction = prediction.prediction.ok()?;
192 prediction.edit_preview.as_unified_diff(&prediction.edits)
193 })
194 .unwrap_or_default();
195
196 let has_prediction = !actual_patch.is_empty();
197
198 updated_example
199 .lock()
200 .unwrap()
201 .predictions
202 .last_mut()
203 .unwrap()
204 .actual_patch = actual_patch;
205
206 if ix == repetition_count - 1 {
207 let (info, style) = if has_prediction {
208 ("predicted", InfoStyle::Normal)
209 } else {
210 ("no prediction", InfoStyle::Warning)
211 };
212 _step_progress.set_info(info, style);
213 }
214 }
215
216 ep_store
217 .update(&mut cx, |store, _| {
218 store.remove_project(&state.project);
219 })
220 .unwrap();
221 debug_task.await.unwrap();
222
223 *example = Arc::into_inner(updated_example)
224 .unwrap()
225 .into_inner()
226 .unwrap();
227}
228
229async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
230 let llm_model_name = "claude-sonnet-4-5";
231 let max_tokens = 16384;
232 let llm_client = if batched {
233 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
234 } else {
235 AnthropicClient::plain()
236 };
237 let llm_client = llm_client.expect("Failed to create LLM client");
238
239 let prompt = example
240 .prompt
241 .as_ref()
242 .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
243
244 let messages = vec![anthropic::Message {
245 role: anthropic::Role::User,
246 content: vec![anthropic::RequestContent::Text {
247 text: prompt.input.clone(),
248 cache_control: None,
249 }],
250 }];
251
252 let Some(response) = llm_client
253 .generate(llm_model_name, max_tokens, messages)
254 .await
255 .unwrap()
256 else {
257 // Request stashed for batched processing
258 return;
259 };
260
261 let actual_output = response
262 .content
263 .into_iter()
264 .filter_map(|content| match content {
265 anthropic::ResponseContent::Text { text } => Some(text),
266 _ => None,
267 })
268 .collect::<Vec<String>>()
269 .join("\n");
270
271 let actual_patch = TeacherPrompt::parse(example, &actual_output);
272
273 let prediction = ExamplePrediction {
274 actual_patch,
275 actual_output,
276 provider: PredictionProvider::Teacher,
277 };
278
279 example.predictions.push(prediction);
280}
281
282pub async fn sync_batches(provider: &PredictionProvider) {
283 match provider {
284 PredictionProvider::Teacher => {
285 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
286 let llm_client =
287 AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
288 llm_client
289 .sync_batches()
290 .await
291 .expect("Failed to sync batches");
292 }
293 _ => (),
294 }
295}