1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7
8use crate::{
9 ToolMetrics,
10 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
11};
12use acp_thread::UserMessageId;
13use agent::{Thread, ThreadEvent, UserMessageContent};
14use agent_client_protocol as acp;
15use agent_settings::AgentProfileId;
16use anyhow::{Result, anyhow};
17use async_trait::async_trait;
18use buffer_diff::DiffHunkStatus;
19use cloud_llm_client::CompletionIntent;
20use collections::HashMap;
21use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
22use gpui::{App, AppContext, AsyncApp, Entity};
23use language_model::{LanguageModel, Role, StopReason};
24use util::rel_path::RelPath;
25
26pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
27
28#[async_trait(?Send)]
29pub trait Example {
30 fn meta(&self) -> ExampleMetadata;
31 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
32 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
33 Vec::new()
34 }
35 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
36 Vec::new()
37 }
38}
39
40#[derive(Clone, Debug)]
41pub struct JudgeAssertion {
42 pub id: String,
43 pub description: String,
44}
45
46#[derive(Clone, Debug)]
47pub struct ExampleMetadata {
48 pub name: String,
49 pub url: String,
50 pub revision: String,
51 pub language_server: Option<LanguageServer>,
52 pub max_assertions: Option<usize>,
53 pub profile_id: AgentProfileId,
54 pub existing_thread_json: Option<String>,
55 pub max_turns: Option<u32>,
56}
57
58#[derive(Clone, Debug)]
59pub struct LanguageServer {
60 pub file_extension: String,
61 pub allow_preexisting_diagnostics: bool,
62}
63
64impl ExampleMetadata {
65 pub fn repo_name(&self) -> String {
66 self.url
67 .split('/')
68 .next_back()
69 .unwrap_or("")
70 .trim_end_matches(".git")
71 .into()
72 }
73}
74
75pub struct FailedAssertion(pub String);
76
77impl fmt::Debug for FailedAssertion {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(f, "Assertion failure: {}", self.0)
80 }
81}
82
83impl fmt::Display for FailedAssertion {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89impl Error for FailedAssertion {}
90
91pub struct ExampleContext {
92 meta: ExampleMetadata,
93 log_prefix: String,
94 agent_thread: Entity<agent::Thread>,
95 app: AsyncApp,
96 model: Arc<dyn LanguageModel>,
97 pub assertions: AssertionsReport,
98 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
99}
100
101impl ExampleContext {
102 pub fn new(
103 meta: ExampleMetadata,
104 log_prefix: String,
105 agent_thread: Entity<Thread>,
106 model: Arc<dyn LanguageModel>,
107 app: AsyncApp,
108 ) -> Self {
109 let assertions = AssertionsReport::new(meta.max_assertions);
110
111 Self {
112 meta,
113 log_prefix,
114 agent_thread,
115 assertions,
116 model,
117 app,
118 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
119 }
120 }
121
122 pub fn push_user_message(&mut self, text: impl ToString) {
123 self.app
124 .update_entity(&self.agent_thread, |thread, cx| {
125 thread.insert_user_message(
126 text.to_string(),
127 ContextLoadResult::default(),
128 None,
129 Vec::new(),
130 cx,
131 );
132 })
133 .unwrap();
134 }
135
136 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
137 let message = message.to_string();
138 self.log_assertion(
139 if expected {
140 Ok(())
141 } else {
142 Err(anyhow::Error::from(FailedAssertion(message.clone())))
143 },
144 message,
145 )
146 }
147
148 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
149 let message = message.to_string();
150 self.log_assertion(
151 match option {
152 Some(value) => Ok(value),
153 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
154 },
155 message,
156 )
157 }
158
159 #[allow(dead_code)]
160 pub fn assert_eq<T: PartialEq + Debug>(
161 &mut self,
162 left: T,
163 right: T,
164 message: impl ToString,
165 ) -> Result<()> {
166 let message = message.to_string();
167 self.log_assertion(
168 if left == right {
169 Ok(())
170 } else {
171 println!(
172 "{}{}",
173 self.log_prefix,
174 pretty_assertions::Comparison::new(&left, &right)
175 );
176 Err(anyhow::Error::from(FailedAssertion(message.clone())))
177 },
178 message,
179 )
180 }
181
182 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
183 if let Some(max) = self.meta.max_assertions {
184 anyhow::ensure!(
185 self.assertions.run_count() <= max,
186 "More assertions were run than the stated max_assertions of {max}"
187 );
188 }
189
190 self.assertions.ran.push(RanAssertion {
191 id: message.clone(),
192 result: Ok(RanAssertionResult {
193 analysis: None,
194 passed: result.is_ok(),
195 }),
196 });
197
198 if result.is_ok() {
199 println!("{}✅ {}", self.log_prefix, message);
200 } else {
201 println!("{}❌ {}", self.log_prefix, message);
202 }
203
204 result
205 }
206
207 pub async fn run_to_end(&mut self) -> Result<Response> {
208 self.run_turns(u32::MAX).await
209 }
210
211 pub async fn run_turn(&mut self) -> Result<Response> {
212 self.run_turns(1).await
213 }
214
215 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
216 let (mut tx, mut rx) = mpsc::channel(1);
217
218 let tool_metrics = self.tool_metrics.clone();
219 let log_prefix = self.log_prefix.clone();
220 let _subscription = self.app.subscribe(
221 &self.agent_thread,
222 move |thread, event: &ThreadEvent, cx| match event {
223 ThreadEvent::ShowError(thread_error) => {
224 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
225 }
226 ThreadEvent::Stopped(reason) => match reason {
227 Ok(StopReason::EndTurn) => {
228 tx.close_channel();
229 }
230 Ok(StopReason::ToolUse) => {
231 if thread.read(cx).remaining_turns() == 0 {
232 tx.close_channel();
233 }
234 }
235 Ok(StopReason::MaxTokens) => {
236 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
237 }
238 Ok(StopReason::Refusal) => {
239 tx.try_send(Err(anyhow!("Model refused to generate content")))
240 .ok();
241 }
242 Err(err) => {
243 tx.try_send(Err(anyhow!(err.clone()))).ok();
244 }
245 },
246 ThreadEvent::NewRequest
247 | ThreadEvent::StreamedAssistantText(_, _)
248 | ThreadEvent::StreamedAssistantThinking(_, _)
249 | ThreadEvent::UsePendingTools { .. }
250 | ThreadEvent::CompletionCanceled => {}
251 ThreadEvent::ToolUseLimitReached => {}
252 ThreadEvent::ToolFinished {
253 tool_use_id,
254 pending_tool_use,
255 ..
256 } => {
257 thread.update(cx, |thread, _cx| {
258 if let Some(tool_use) = pending_tool_use {
259 let mut tool_metrics = tool_metrics.lock().unwrap();
260 if let Some(tool_result) = thread.tool_result(tool_use_id) {
261 let message = if tool_result.is_error {
262 format!("✖︎ {}", tool_use.name)
263 } else {
264 format!("✔︎ {}", tool_use.name)
265 };
266 println!("{log_prefix}{message}");
267 tool_metrics
268 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
269 } else {
270 let message =
271 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
272 println!("{log_prefix}{message}");
273 tool_metrics.insert(tool_use.name.clone(), true);
274 }
275 }
276 });
277 }
278 ThreadEvent::InvalidToolInput { .. } => {
279 println!("{log_prefix} invalid tool input");
280 }
281 ThreadEvent::MissingToolUse {
282 tool_use_id: _,
283 ui_text,
284 } => {
285 println!("{log_prefix} {ui_text}");
286 }
287 ThreadEvent::ToolConfirmationNeeded => {
288 panic!(
289 "{}Bug: Tool confirmation should not be required in eval",
290 log_prefix
291 );
292 }
293 ThreadEvent::StreamedCompletion
294 | ThreadEvent::MessageAdded(_)
295 | ThreadEvent::MessageEdited(_)
296 | ThreadEvent::MessageDeleted(_)
297 | ThreadEvent::SummaryChanged
298 | ThreadEvent::SummaryGenerated
299 | ThreadEvent::ProfileChanged
300 | ThreadEvent::ReceivedTextChunk
301 | ThreadEvent::StreamedToolUse { .. }
302 | ThreadEvent::CheckpointChanged
303 | ThreadEvent::CancelEditing => {
304 tx.try_send(Ok(())).ok();
305 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
306 println!("{}Event: {:#?}", log_prefix, event);
307 }
308 }
309 },
310 );
311
312 let model = self.model.clone();
313
314 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
315 thread.set_remaining_turns(iterations);
316 thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
317 thread.messages().len()
318 })?;
319
320 loop {
321 select_biased! {
322 result = rx.next() => {
323 if let Some(result) = result {
324 result?;
325 } else {
326 break;
327 }
328 }
329 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
330 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
331 }
332 }
333 }
334
335 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
336 let mut messages = Vec::new();
337 for message in thread.messages().skip(message_count_before) {
338 messages.push(Message {
339 _role: message.role,
340 text: message.to_message_content(),
341 tool_use: thread
342 .tool_uses_for_message(message.id, cx)
343 .into_iter()
344 .map(|tool_use| ToolUse {
345 name: tool_use.name.to_string(),
346 value: tool_use.input,
347 })
348 .collect(),
349 });
350 }
351 messages
352 })?;
353
354 let response = Response::new(messages);
355
356 Ok(response)
357 }
358
359 pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
360 self.agent_thread
361 .read_with(&self.app, |thread, cx| {
362 let action_log = thread.action_log().read(cx);
363 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
364 |(buffer, diff)| {
365 let snapshot = buffer.read(cx).snapshot();
366
367 let file = snapshot.file().unwrap();
368 let diff = diff.read(cx);
369 let base_text = diff.base_text().text();
370
371 let hunks = diff
372 .hunks(&snapshot, cx)
373 .map(|hunk| FileEditHunk {
374 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
375 text: snapshot
376 .text_for_range(hunk.range.clone())
377 .collect::<String>(),
378 status: hunk.status(),
379 })
380 .collect();
381
382 (file.path().clone(), FileEdits { hunks })
383 },
384 ))
385 })
386 .unwrap()
387 }
388
389 pub fn agent_thread(&self) -> Entity<Thread> {
390 self.agent_thread.clone()
391 }
392}
393
394impl AppContext for ExampleContext {
395 type Result<T> = anyhow::Result<T>;
396
397 fn new<T: 'static>(
398 &mut self,
399 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
400 ) -> Self::Result<Entity<T>> {
401 self.app.new(build_entity)
402 }
403
404 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
405 self.app.reserve_entity()
406 }
407
408 fn insert_entity<T: 'static>(
409 &mut self,
410 reservation: gpui::Reservation<T>,
411 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
412 ) -> Self::Result<Entity<T>> {
413 self.app.insert_entity(reservation, build_entity)
414 }
415
416 fn update_entity<T, R>(
417 &mut self,
418 handle: &Entity<T>,
419 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
420 ) -> Self::Result<R>
421 where
422 T: 'static,
423 {
424 self.app.update_entity(handle, update)
425 }
426
427 fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
428 where
429 T: 'static,
430 {
431 self.app.as_mut(handle)
432 }
433
434 fn read_entity<T, R>(
435 &self,
436 handle: &Entity<T>,
437 read: impl FnOnce(&T, &App) -> R,
438 ) -> Self::Result<R>
439 where
440 T: 'static,
441 {
442 self.app.read_entity(handle, read)
443 }
444
445 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
446 where
447 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
448 {
449 self.app.update_window(window, f)
450 }
451
452 fn read_window<T, R>(
453 &self,
454 window: &gpui::WindowHandle<T>,
455 read: impl FnOnce(Entity<T>, &App) -> R,
456 ) -> Result<R>
457 where
458 T: 'static,
459 {
460 self.app.read_window(window, read)
461 }
462
463 fn background_spawn<R>(
464 &self,
465 future: impl std::future::Future<Output = R> + Send + 'static,
466 ) -> gpui::Task<R>
467 where
468 R: Send + 'static,
469 {
470 self.app.background_spawn(future)
471 }
472
473 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
474 where
475 G: gpui::Global,
476 {
477 self.app.read_global(callback)
478 }
479}
480
481#[derive(Debug)]
482pub struct Response {
483 messages: Vec<Message>,
484}
485
486impl Response {
487 pub fn new(messages: Vec<Message>) -> Self {
488 Self { messages }
489 }
490
491 pub fn expect_tool(
492 &self,
493 tool_name: &'static str,
494 cx: &mut ExampleContext,
495 ) -> Result<&ToolUse> {
496 let result = self.find_tool_call(tool_name);
497 cx.assert_some(result, format!("called `{}`", tool_name))
498 }
499
500 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
501 self.messages.iter().rev().find_map(|msg| {
502 msg.tool_use
503 .iter()
504 .find(|tool_use| tool_use.name == tool_name)
505 })
506 }
507
508 #[allow(dead_code)]
509 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
510 self.messages.iter().flat_map(|msg| &msg.tool_use)
511 }
512
513 pub fn texts(&self) -> impl Iterator<Item = String> {
514 self.messages.iter().map(|message| message.text.clone())
515 }
516}
517
518#[derive(Debug)]
519pub struct Message {
520 _role: Role,
521 text: String,
522 tool_use: Vec<ToolUse>,
523}
524
525#[derive(Debug)]
526pub struct ToolUse {
527 pub name: String,
528 value: serde_json::Value,
529}
530
531impl ToolUse {
532 pub fn parse_input<Input>(&self) -> Result<Input>
533 where
534 Input: for<'de> serde::Deserialize<'de>,
535 {
536 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
537 }
538}
539
540#[derive(Debug, Eq, PartialEq)]
541pub struct FileEdits {
542 pub hunks: Vec<FileEditHunk>,
543}
544
545#[derive(Debug, Eq, PartialEq)]
546pub struct FileEditHunk {
547 pub base_text: String,
548 pub text: String,
549 pub status: DiffHunkStatus,
550}
551
552impl FileEdits {
553 pub fn has_added_line(&self, line: &str) -> bool {
554 self.hunks.iter().any(|hunk| {
555 hunk.status == DiffHunkStatus::added_none()
556 && hunk.base_text.is_empty()
557 && hunk.text.contains(line)
558 })
559 }
560}