1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 sync::{Arc, Mutex},
5 time::Duration,
6 u32,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use acp_thread::UserMessageId;
14use agent::{Thread, ThreadEvent, UserMessageContent};
15use agent_client_protocol as acp;
16use agent_settings::AgentProfileId;
17use anyhow::{Result, anyhow};
18use async_trait::async_trait;
19use buffer_diff::DiffHunkStatus;
20use collections::HashMap;
21use futures::{FutureExt as _, StreamExt, select_biased};
22use gpui::{App, AppContext, AsyncApp, Entity};
23use language_model::Role;
24use util::rel_path::RelPath;
25
26pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
27
28#[async_trait(?Send)]
29pub trait Example {
30 fn meta(&self) -> ExampleMetadata;
31 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
32 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
33 Vec::new()
34 }
35 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
36 Vec::new()
37 }
38}
39
40#[derive(Clone, Debug)]
41pub struct JudgeAssertion {
42 pub id: String,
43 pub description: String,
44}
45
46#[derive(Clone, Debug)]
47pub struct ExampleMetadata {
48 pub name: String,
49 pub url: String,
50 pub revision: String,
51 pub language_server: Option<LanguageServer>,
52 pub max_assertions: Option<usize>,
53 pub profile_id: AgentProfileId,
54 pub existing_thread_json: Option<String>,
55 pub max_turns: Option<u32>,
56}
57
58#[derive(Clone, Debug)]
59pub struct LanguageServer {
60 pub file_extension: String,
61 pub allow_preexisting_diagnostics: bool,
62}
63
64impl ExampleMetadata {
65 pub fn repo_name(&self) -> String {
66 self.url
67 .split('/')
68 .next_back()
69 .unwrap_or("")
70 .trim_end_matches(".git")
71 .into()
72 }
73}
74
75pub struct FailedAssertion(pub String);
76
77impl fmt::Debug for FailedAssertion {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(f, "Assertion failure: {}", self.0)
80 }
81}
82
83impl fmt::Display for FailedAssertion {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89impl Error for FailedAssertion {}
90
91pub struct ExampleContext {
92 meta: ExampleMetadata,
93 log_prefix: String,
94 agent_thread: Entity<agent::Thread>,
95 app: AsyncApp,
96 pub assertions: AssertionsReport,
97 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
98}
99
100impl ExampleContext {
101 pub fn new(
102 meta: ExampleMetadata,
103 log_prefix: String,
104 agent_thread: Entity<Thread>,
105 app: AsyncApp,
106 ) -> Self {
107 let assertions = AssertionsReport::new(meta.max_assertions);
108
109 Self {
110 meta,
111 log_prefix,
112 agent_thread,
113 assertions,
114 app,
115 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
116 }
117 }
118
119 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
120 let message = message.to_string();
121 self.log_assertion(
122 if expected {
123 Ok(())
124 } else {
125 Err(anyhow::Error::from(FailedAssertion(message.clone())))
126 },
127 message,
128 )
129 }
130
131 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
132 let message = message.to_string();
133 self.log_assertion(
134 match option {
135 Some(value) => Ok(value),
136 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
137 },
138 message,
139 )
140 }
141
142 #[allow(dead_code)]
143 pub fn assert_eq<T: PartialEq + Debug>(
144 &mut self,
145 left: T,
146 right: T,
147 message: impl ToString,
148 ) -> Result<()> {
149 let message = message.to_string();
150 self.log_assertion(
151 if left == right {
152 Ok(())
153 } else {
154 println!(
155 "{}{}",
156 self.log_prefix,
157 pretty_assertions::Comparison::new(&left, &right)
158 );
159 Err(anyhow::Error::from(FailedAssertion(message.clone())))
160 },
161 message,
162 )
163 }
164
165 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
166 if let Some(max) = self.meta.max_assertions {
167 anyhow::ensure!(
168 self.assertions.run_count() <= max,
169 "More assertions were run than the stated max_assertions of {max}"
170 );
171 }
172
173 self.assertions.ran.push(RanAssertion {
174 id: message.clone(),
175 result: Ok(RanAssertionResult {
176 analysis: None,
177 passed: result.is_ok(),
178 }),
179 });
180
181 if result.is_ok() {
182 println!("{}✅ {}", self.log_prefix, message);
183 } else {
184 println!("{}❌ {}", self.log_prefix, message);
185 }
186
187 result
188 }
189
190 pub async fn prompt(&mut self, prompt: impl Into<String>) -> Result<Response> {
191 self.prompt_with_max_turns(prompt, u32::MAX).await
192 }
193
194 pub async fn prompt_with_max_turns(
195 &mut self,
196 prompt: impl Into<String>,
197 max_turns: u32,
198 ) -> Result<Response> {
199 let content = vec![UserMessageContent::Text(prompt.into())];
200 self.run_turns(Some(content), max_turns).await
201 }
202
203 pub async fn proceed_with_max_turns(&mut self, max_turns: u32) -> Result<Response> {
204 self.run_turns(None, max_turns).await
205 }
206
207 async fn run_turns(
208 &mut self,
209 prompt: Option<Vec<UserMessageContent>>,
210 max_turns: u32,
211 ) -> Result<Response> {
212 let tool_metrics = self.tool_metrics.clone();
213 let log_prefix = self.log_prefix.clone();
214
215 let mut remaining_turns = max_turns;
216
217 let mut event_stream = self.agent_thread.update(&mut self.app, |thread, cx| {
218 if let Some(prompt) = prompt {
219 let id = UserMessageId::new();
220 thread.send(id, prompt, cx)
221 } else {
222 thread.proceed(cx)
223 }
224 })??;
225
226 let task = self.app.background_spawn(async move {
227 let mut messages = Vec::new();
228 let mut tool_uses_by_id = HashMap::default();
229 while let Some(event) = event_stream.next().await {
230 match event? {
231 ThreadEvent::UserMessage(user_message) => {
232 messages.push(Message {
233 role: Role::User,
234 text: user_message.to_markdown(),
235 tool_use: Vec::new(),
236 });
237 }
238 ThreadEvent::AgentThinking(text) | ThreadEvent::AgentText(text) => {
239 if matches!(
240 messages.last(),
241 Some(Message {
242 role: Role::Assistant,
243 ..
244 })
245 ) {
246 messages.last_mut().unwrap().text.push_str(&text);
247 } else {
248 messages.push(Message {
249 role: Role::Assistant,
250 text,
251 tool_use: Vec::new(),
252 });
253 }
254 }
255 ThreadEvent::ToolCall(tool_call) => {
256 let meta = tool_call.meta.expect("Missing meta field in tool_call");
257 let tool_name = meta
258 .get("tool_name")
259 .expect("Missing tool_name field in meta")
260 .as_str()
261 .expect("Unknown tool_name content in meta");
262
263 tool_uses_by_id.insert(
264 tool_call.id,
265 ToolUse {
266 name: tool_name.to_string(),
267 value: tool_call.raw_input.unwrap_or_default(),
268 },
269 );
270 if matches!(
271 tool_call.status,
272 acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed
273 ) {
274 panic!("Tool call completed without update");
275 }
276 }
277 ThreadEvent::ToolCallUpdate(tool_call_update) => {
278 if let acp_thread::ToolCallUpdate::UpdateFields(update) = tool_call_update {
279 if let Some(raw_input) = update.fields.raw_input {
280 if let Some(tool_use) = tool_uses_by_id.get_mut(&update.id) {
281 tool_use.value = raw_input;
282 }
283 }
284
285 if matches!(
286 update.fields.status,
287 Some(acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed)
288 ) {
289 let succeeded =
290 update.fields.status == Some(acp::ToolCallStatus::Completed);
291
292 let tool_use = tool_uses_by_id
293 .remove(&update.id)
294 .expect("Unrecognized tool call completed");
295
296 let log_message = if succeeded {
297 format!("✔︎ {}", tool_use.name)
298 } else {
299 format!("✖︎ {}", tool_use.name)
300 };
301 println!("{log_prefix}{log_message}");
302
303 tool_metrics
304 .lock()
305 .unwrap()
306 .insert(tool_use.name.clone().into(), succeeded);
307
308 if let Some(message) = messages.last_mut() {
309 message.tool_use.push(tool_use);
310 } else {
311 messages.push(Message {
312 role: Role::Assistant,
313 text: "".to_string(),
314 tool_use: vec![tool_use],
315 });
316 }
317
318 remaining_turns -= 1;
319 if remaining_turns == 0 {
320 return Ok(messages);
321 }
322 }
323 }
324 }
325 ThreadEvent::ToolCallAuthorization(_) => panic!(
326 "{}Bug: Tool confirmation should not be required in eval",
327 log_prefix
328 ),
329 ThreadEvent::Retry(status) => {
330 println!("{log_prefix} Got retry: {status:?}");
331 }
332 ThreadEvent::Stop(stop_reason) => match stop_reason {
333 acp::StopReason::EndTurn => {}
334 acp::StopReason::MaxTokens => {
335 return Err(anyhow!("Exceeded maximum tokens"));
336 }
337 acp::StopReason::MaxTurnRequests => {
338 return Err(anyhow!("Exceeded maximum turn requests"));
339 }
340 acp::StopReason::Refusal => {
341 return Err(anyhow!("Refusal"));
342 }
343 acp::StopReason::Cancelled => return Err(anyhow!("Cancelled")),
344 },
345 }
346 }
347 Ok(messages)
348 });
349
350 select_biased! {
351 result = task.fuse() => {
352 Ok(Response::new(result?))
353 }
354 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
355 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
356 }
357 }
358 }
359
360 pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
361 self.agent_thread
362 .read_with(&self.app, |thread, cx| {
363 let action_log = thread.action_log().read(cx);
364 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
365 |(buffer, diff)| {
366 let snapshot = buffer.read(cx).snapshot();
367
368 let file = snapshot.file().unwrap();
369 let diff = diff.read(cx);
370 let base_text = diff.base_text().text();
371
372 let hunks = diff
373 .hunks(&snapshot, cx)
374 .map(|hunk| FileEditHunk {
375 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
376 text: snapshot
377 .text_for_range(hunk.range.clone())
378 .collect::<String>(),
379 status: hunk.status(),
380 })
381 .collect();
382
383 (file.path().clone(), FileEdits { hunks })
384 },
385 ))
386 })
387 .unwrap()
388 }
389
390 pub fn agent_thread(&self) -> Entity<Thread> {
391 self.agent_thread.clone()
392 }
393}
394
395impl AppContext for ExampleContext {
396 type Result<T> = anyhow::Result<T>;
397
398 fn new<T: 'static>(
399 &mut self,
400 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
401 ) -> Self::Result<Entity<T>> {
402 self.app.new(build_entity)
403 }
404
405 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
406 self.app.reserve_entity()
407 }
408
409 fn insert_entity<T: 'static>(
410 &mut self,
411 reservation: gpui::Reservation<T>,
412 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
413 ) -> Self::Result<Entity<T>> {
414 self.app.insert_entity(reservation, build_entity)
415 }
416
417 fn update_entity<T, R>(
418 &mut self,
419 handle: &Entity<T>,
420 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
421 ) -> Self::Result<R>
422 where
423 T: 'static,
424 {
425 self.app.update_entity(handle, update)
426 }
427
428 fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
429 where
430 T: 'static,
431 {
432 self.app.as_mut(handle)
433 }
434
435 fn read_entity<T, R>(
436 &self,
437 handle: &Entity<T>,
438 read: impl FnOnce(&T, &App) -> R,
439 ) -> Self::Result<R>
440 where
441 T: 'static,
442 {
443 self.app.read_entity(handle, read)
444 }
445
446 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
447 where
448 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
449 {
450 self.app.update_window(window, f)
451 }
452
453 fn read_window<T, R>(
454 &self,
455 window: &gpui::WindowHandle<T>,
456 read: impl FnOnce(Entity<T>, &App) -> R,
457 ) -> Result<R>
458 where
459 T: 'static,
460 {
461 self.app.read_window(window, read)
462 }
463
464 fn background_spawn<R>(
465 &self,
466 future: impl std::future::Future<Output = R> + Send + 'static,
467 ) -> gpui::Task<R>
468 where
469 R: Send + 'static,
470 {
471 self.app.background_spawn(future)
472 }
473
474 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
475 where
476 G: gpui::Global,
477 {
478 self.app.read_global(callback)
479 }
480}
481
482#[derive(Debug)]
483pub struct Response {
484 messages: Vec<Message>,
485}
486
487impl Response {
488 pub fn new(messages: Vec<Message>) -> Self {
489 Self { messages }
490 }
491
492 pub fn expect_tool_call(
493 &self,
494 tool_name: &'static str,
495 cx: &mut ExampleContext,
496 ) -> Result<&ToolUse> {
497 let result = self.find_tool_call(tool_name);
498 cx.assert_some(result, format!("called `{}`", tool_name))
499 }
500
501 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
502 self.messages.iter().rev().find_map(|msg| {
503 msg.tool_use
504 .iter()
505 .find(|tool_use| tool_use.name == tool_name)
506 })
507 }
508
509 pub fn tool_calls(&self) -> impl Iterator<Item = &ToolUse> {
510 self.messages.iter().flat_map(|msg| &msg.tool_use)
511 }
512
513 pub fn texts(&self) -> impl Iterator<Item = String> {
514 self.messages.iter().map(|message| message.text.clone())
515 }
516}
517
518#[derive(Debug)]
519pub struct Message {
520 role: Role,
521 text: String,
522 tool_use: Vec<ToolUse>,
523}
524
525#[derive(Debug)]
526pub struct ToolUse {
527 pub name: String,
528 value: serde_json::Value,
529}
530
531impl ToolUse {
532 pub fn parse_input<Input>(&self) -> Result<Input>
533 where
534 Input: for<'de> serde::Deserialize<'de>,
535 {
536 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
537 }
538}
539
540#[derive(Debug, Eq, PartialEq)]
541pub struct FileEdits {
542 pub hunks: Vec<FileEditHunk>,
543}
544
545#[derive(Debug, Eq, PartialEq)]
546pub struct FileEditHunk {
547 pub base_text: String,
548 pub text: String,
549 pub status: DiffHunkStatus,
550}
551
552impl FileEdits {
553 pub fn has_added_line(&self, line: &str) -> bool {
554 self.hunks.iter().any(|hunk| {
555 hunk.status == DiffHunkStatus::added_none()
556 && hunk.base_text.is_empty()
557 && hunk.text.contains(line)
558 })
559 }
560}