1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 retrieve_context::run_context_retrieval,
10};
11use edit_prediction::{DebugEvent, EditPredictionStore};
12use futures::{FutureExt as _, StreamExt as _, future::Shared};
13use gpui::{AppContext as _, AsyncApp, Task};
14use std::{
15 fs,
16 sync::{
17 Arc, Mutex, OnceLock,
18 atomic::{AtomicUsize, Ordering::SeqCst},
19 },
20};
21
22pub async fn run_prediction(
23 example: &mut Example,
24 provider: Option<PredictionProvider>,
25 repetition_count: usize,
26 app_state: Arc<EpAppState>,
27 mut cx: AsyncApp,
28) {
29 if !example.predictions.is_empty() {
30 return;
31 }
32
33 run_context_retrieval(example, app_state.clone(), cx.clone()).await;
34
35 let provider = provider.unwrap();
36
37 if matches!(
38 provider,
39 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
40 ) {
41 if example.prompt.is_none() {
42 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
43 }
44
45 let batched = matches!(provider, PredictionProvider::Teacher);
46 return predict_anthropic(example, repetition_count, batched).await;
47 }
48
49 run_load_project(example, app_state.clone(), cx.clone()).await;
50
51 if matches!(
52 provider,
53 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
54 ) {
55 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
56 AUTHENTICATED
57 .get_or_init(|| {
58 let client = app_state.client.clone();
59 cx.spawn(async move |cx| {
60 client
61 .sign_in_with_optional_connect(true, cx)
62 .await
63 .unwrap();
64 })
65 .shared()
66 })
67 .clone()
68 .await;
69 }
70
71 let ep_store = cx
72 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
73 .unwrap();
74
75 ep_store
76 .update(&mut cx, |store, _cx| {
77 let model = match provider {
78 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
79 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
80 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
81 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
82 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
83 unreachable!()
84 }
85 };
86 store.set_edit_prediction_model(model);
87 })
88 .unwrap();
89 let state = example.state.as_ref().unwrap();
90 let run_dir = RUN_DIR.join(&example.name);
91
92 let updated_example = Arc::new(Mutex::new(example.clone()));
93 let current_run_ix = Arc::new(AtomicUsize::new(0));
94
95 let mut debug_rx = ep_store
96 .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
97 .unwrap();
98 let debug_task = cx.background_spawn({
99 let updated_example = updated_example.clone();
100 let current_run_ix = current_run_ix.clone();
101 let run_dir = run_dir.clone();
102 async move {
103 while let Some(event) = debug_rx.next().await {
104 let run_ix = current_run_ix.load(SeqCst);
105 let mut updated_example = updated_example.lock().unwrap();
106
107 let run_dir = if repetition_count > 1 {
108 run_dir.join(format!("{:03}", run_ix))
109 } else {
110 run_dir.clone()
111 };
112
113 match event {
114 DebugEvent::EditPredictionStarted(request) => {
115 assert_eq!(updated_example.predictions.len(), run_ix + 1);
116
117 if let Some(prompt) = request.prompt {
118 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
119 }
120 }
121 DebugEvent::EditPredictionFinished(request) => {
122 assert_eq!(updated_example.predictions.len(), run_ix + 1);
123
124 if let Some(output) = request.model_output {
125 fs::write(run_dir.join("prediction_response.md"), &output)?;
126 updated_example
127 .predictions
128 .last_mut()
129 .unwrap()
130 .actual_output = output;
131 }
132 if run_ix >= repetition_count {
133 break;
134 }
135 }
136 _ => {}
137 }
138 }
139 anyhow::Ok(())
140 }
141 });
142
143 for ix in 0..repetition_count {
144 current_run_ix.store(ix, SeqCst);
145 let run_dir = if repetition_count > 1 {
146 run_dir.join(format!("{:03}", ix))
147 } else {
148 run_dir.clone()
149 };
150
151 fs::create_dir_all(&run_dir).unwrap();
152 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
153 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
154 }
155 #[cfg(unix)]
156 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
157 #[cfg(windows)]
158 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
159
160 updated_example
161 .lock()
162 .unwrap()
163 .predictions
164 .push(ExamplePrediction {
165 actual_patch: String::new(),
166 actual_output: String::new(),
167 provider,
168 });
169
170 let prediction = ep_store
171 .update(&mut cx, |store, cx| {
172 store.request_prediction(
173 &state.project,
174 &state.buffer,
175 state.cursor_position,
176 cloud_llm_client::PredictEditsRequestTrigger::Cli,
177 cx,
178 )
179 })
180 .unwrap()
181 .await
182 .unwrap();
183
184 updated_example
185 .lock()
186 .unwrap()
187 .predictions
188 .last_mut()
189 .unwrap()
190 .actual_patch = prediction
191 .and_then(|prediction| {
192 let prediction = prediction.prediction.ok()?;
193 prediction.edit_preview.as_unified_diff(&prediction.edits)
194 })
195 .unwrap_or_default();
196 }
197
198 ep_store
199 .update(&mut cx, |store, _| {
200 store.remove_project(&state.project);
201 })
202 .unwrap();
203 debug_task.await.unwrap();
204
205 *example = Arc::into_inner(updated_example)
206 .unwrap()
207 .into_inner()
208 .unwrap();
209}
210
211async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
212 let llm_model_name = "claude-sonnet-4-5";
213 let max_tokens = 16384;
214 let llm_client = if batched {
215 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
216 } else {
217 AnthropicClient::plain()
218 };
219 let llm_client = llm_client.expect("Failed to create LLM client");
220
221 let prompt = example
222 .prompt
223 .as_ref()
224 .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
225
226 let messages = vec![anthropic::Message {
227 role: anthropic::Role::User,
228 content: vec![anthropic::RequestContent::Text {
229 text: prompt.input.clone(),
230 cache_control: None,
231 }],
232 }];
233
234 let Some(response) = llm_client
235 .generate(llm_model_name, max_tokens, messages)
236 .await
237 .unwrap()
238 else {
239 // Request stashed for batched processing
240 return;
241 };
242
243 let actual_output = response
244 .content
245 .into_iter()
246 .filter_map(|content| match content {
247 anthropic::ResponseContent::Text { text } => Some(text),
248 _ => None,
249 })
250 .collect::<Vec<String>>()
251 .join("\n");
252
253 let actual_patch = TeacherPrompt::parse(example, &actual_output);
254
255 let prediction = ExamplePrediction {
256 actual_patch,
257 actual_output,
258 provider: PredictionProvider::Teacher,
259 };
260
261 example.predictions.push(prediction);
262}
263
264pub async fn sync_batches(provider: &PredictionProvider) {
265 match provider {
266 PredictionProvider::Teacher => {
267 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
268 let llm_client =
269 AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
270 llm_client
271 .sync_batches()
272 .await
273 .expect("Failed to sync batches");
274 }
275 _ => (),
276 }
277}