1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::{ContextLoadResult, Thread, ThreadEvent};
14use agent_settings::AgentProfileId;
15use anyhow::{Result, anyhow};
16use async_trait::async_trait;
17use buffer_diff::DiffHunkStatus;
18use collections::HashMap;
19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
20use gpui::{App, AppContext, AsyncApp, Entity};
21use language_model::{LanguageModel, Role, StopReason};
22
23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
24
25#[async_trait(?Send)]
26pub trait Example {
27 fn meta(&self) -> ExampleMetadata;
28 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
29 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
30 Vec::new()
31 }
32 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
33 Vec::new()
34 }
35}
36
37#[derive(Clone, Debug)]
38pub struct JudgeAssertion {
39 pub id: String,
40 pub description: String,
41}
42
43#[derive(Clone, Debug)]
44pub struct ExampleMetadata {
45 pub name: String,
46 pub url: String,
47 pub revision: String,
48 pub language_server: Option<LanguageServer>,
49 pub max_assertions: Option<usize>,
50 pub profile_id: AgentProfileId,
51 pub existing_thread_json: Option<String>,
52 pub max_turns: Option<u32>,
53}
54
55#[derive(Clone, Debug)]
56pub struct LanguageServer {
57 pub file_extension: String,
58 pub allow_preexisting_diagnostics: bool,
59}
60
61impl ExampleMetadata {
62 pub fn repo_name(&self) -> String {
63 self.url
64 .split('/')
65 .next_back()
66 .unwrap_or(&"")
67 .trim_end_matches(".git")
68 .into()
69 }
70}
71
72pub struct FailedAssertion(pub String);
73
74impl fmt::Debug for FailedAssertion {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(f, "Assertion failure: {}", self.0)
77 }
78}
79
80impl fmt::Display for FailedAssertion {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86impl Error for FailedAssertion {}
87
88pub struct ExampleContext {
89 meta: ExampleMetadata,
90 log_prefix: String,
91 agent_thread: Entity<agent::Thread>,
92 app: AsyncApp,
93 model: Arc<dyn LanguageModel>,
94 pub assertions: AssertionsReport,
95 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
96}
97
98impl ExampleContext {
99 pub fn new(
100 meta: ExampleMetadata,
101 log_prefix: String,
102 agent_thread: Entity<agent::Thread>,
103 model: Arc<dyn LanguageModel>,
104 app: AsyncApp,
105 ) -> Self {
106 let assertions = AssertionsReport::new(meta.max_assertions);
107
108 Self {
109 meta,
110 log_prefix,
111 agent_thread,
112 assertions,
113 model,
114 app,
115 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
116 }
117 }
118
119 pub fn push_user_message(&mut self, text: impl ToString) {
120 self.app
121 .update_entity(&self.agent_thread, |thread, cx| {
122 thread.insert_user_message(
123 text.to_string(),
124 ContextLoadResult::default(),
125 None,
126 Vec::new(),
127 cx,
128 );
129 })
130 .unwrap();
131 }
132
133 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
134 let message = message.to_string();
135 self.log_assertion(
136 if expected {
137 Ok(())
138 } else {
139 Err(anyhow::Error::from(FailedAssertion(message.clone())))
140 },
141 message,
142 )
143 }
144
145 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
146 let message = message.to_string();
147 self.log_assertion(
148 match option {
149 Some(value) => Ok(value),
150 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
151 },
152 message,
153 )
154 }
155
156 #[allow(dead_code)]
157 pub fn assert_eq<T: PartialEq + Debug>(
158 &mut self,
159 left: T,
160 right: T,
161 message: impl ToString,
162 ) -> Result<()> {
163 let message = message.to_string();
164 self.log_assertion(
165 if left == right {
166 Ok(())
167 } else {
168 println!(
169 "{}{}",
170 self.log_prefix,
171 pretty_assertions::Comparison::new(&left, &right)
172 );
173 Err(anyhow::Error::from(FailedAssertion(message.clone())))
174 },
175 message,
176 )
177 }
178
179 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
180 if let Some(max) = self.meta.max_assertions {
181 anyhow::ensure!(
182 self.assertions.run_count() <= max,
183 "More assertions were run than the stated max_assertions of {max}"
184 );
185 }
186
187 self.assertions.ran.push(RanAssertion {
188 id: message.clone(),
189 result: Ok(RanAssertionResult {
190 analysis: None,
191 passed: result.is_ok(),
192 }),
193 });
194
195 if result.is_ok() {
196 println!("{}✅ {}", self.log_prefix, message);
197 } else {
198 println!("{}❌ {}", self.log_prefix, message);
199 }
200
201 result
202 }
203
204 pub async fn run_to_end(&mut self) -> Result<Response> {
205 self.run_turns(u32::MAX).await
206 }
207
208 pub async fn run_turn(&mut self) -> Result<Response> {
209 self.run_turns(1).await
210 }
211
212 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
213 let (mut tx, mut rx) = mpsc::channel(1);
214
215 let tool_metrics = self.tool_metrics.clone();
216 let log_prefix = self.log_prefix.clone();
217 let _subscription = self.app.subscribe(
218 &self.agent_thread,
219 move |thread, event: &ThreadEvent, cx| match event {
220 ThreadEvent::ShowError(thread_error) => {
221 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
222 }
223 ThreadEvent::Stopped(reason) => match reason {
224 Ok(StopReason::EndTurn) => {
225 tx.close_channel();
226 }
227 Ok(StopReason::ToolUse) => {
228 if thread.read(cx).remaining_turns() == 0 {
229 tx.close_channel();
230 }
231 }
232 Ok(StopReason::MaxTokens) => {
233 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
234 }
235 Ok(StopReason::Refusal) => {
236 tx.try_send(Err(anyhow!("Model refused to generate content")))
237 .ok();
238 }
239 Err(err) => {
240 tx.try_send(Err(anyhow!(err.clone()))).ok();
241 }
242 },
243 ThreadEvent::NewRequest
244 | ThreadEvent::StreamedAssistantText(_, _)
245 | ThreadEvent::StreamedAssistantThinking(_, _)
246 | ThreadEvent::UsePendingTools { .. }
247 | ThreadEvent::CompletionCanceled => {}
248 ThreadEvent::ToolFinished {
249 tool_use_id,
250 pending_tool_use,
251 ..
252 } => {
253 thread.update(cx, |thread, _cx| {
254 if let Some(tool_use) = pending_tool_use {
255 let mut tool_metrics = tool_metrics.lock().unwrap();
256 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
257 let message = if tool_result.is_error {
258 format!("✖︎ {}", tool_use.name)
259 } else {
260 format!("✔︎ {}", tool_use.name)
261 };
262 println!("{log_prefix}{message}");
263 tool_metrics
264 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
265 } else {
266 let message =
267 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
268 println!("{log_prefix}{message}");
269 tool_metrics.insert(tool_use.name.clone(), true);
270 }
271 }
272 });
273 }
274 ThreadEvent::InvalidToolInput { .. } => {
275 println!("{log_prefix} invalid tool input");
276 }
277 ThreadEvent::MissingToolUse {
278 tool_use_id: _,
279 ui_text,
280 } => {
281 println!("{log_prefix} {ui_text}");
282 }
283 ThreadEvent::ToolConfirmationNeeded => {
284 panic!(
285 "{}Bug: Tool confirmation should not be required in eval",
286 log_prefix
287 );
288 }
289 ThreadEvent::StreamedCompletion
290 | ThreadEvent::MessageAdded(_)
291 | ThreadEvent::MessageEdited(_)
292 | ThreadEvent::MessageDeleted(_)
293 | ThreadEvent::SummaryChanged
294 | ThreadEvent::SummaryGenerated
295 | ThreadEvent::ReceivedTextChunk
296 | ThreadEvent::StreamedToolUse { .. }
297 | ThreadEvent::CheckpointChanged
298 | ThreadEvent::CancelEditing => {
299 tx.try_send(Ok(())).ok();
300 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
301 println!("{}Event: {:#?}", log_prefix, event);
302 }
303 }
304 },
305 );
306
307 let model = self.model.clone();
308
309 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
310 thread.set_remaining_turns(iterations);
311 thread.send_to_model(model, None, cx);
312 thread.messages().len()
313 })?;
314
315 loop {
316 select_biased! {
317 result = rx.next() => {
318 if let Some(result) = result {
319 result?;
320 } else {
321 break;
322 }
323 }
324 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
325 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
326 }
327 }
328 }
329
330 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
331 let mut messages = Vec::new();
332 for message in thread.messages().skip(message_count_before) {
333 messages.push(Message {
334 _role: message.role,
335 text: message.to_string(),
336 tool_use: thread
337 .tool_uses_for_message(message.id, cx)
338 .into_iter()
339 .map(|tool_use| ToolUse {
340 name: tool_use.name.to_string(),
341 value: tool_use.input,
342 })
343 .collect(),
344 });
345 }
346 messages
347 })?;
348
349 let response = Response::new(messages);
350
351 Ok(response)
352 }
353
354 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
355 self.agent_thread
356 .read_with(&self.app, |thread, cx| {
357 let action_log = thread.action_log().read(cx);
358 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
359 |(buffer, diff)| {
360 let snapshot = buffer.read(cx).snapshot();
361
362 let file = snapshot.file().unwrap();
363 let diff = diff.read(cx);
364 let base_text = diff.base_text().text();
365
366 let hunks = diff
367 .hunks(&snapshot, cx)
368 .map(|hunk| FileEditHunk {
369 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
370 text: snapshot
371 .text_for_range(hunk.range.clone())
372 .collect::<String>(),
373 status: hunk.status(),
374 })
375 .collect();
376
377 (file.path().clone(), FileEdits { hunks })
378 },
379 ))
380 })
381 .unwrap()
382 }
383
384 pub fn agent_thread(&self) -> Entity<Thread> {
385 self.agent_thread.clone()
386 }
387}
388
389impl AppContext for ExampleContext {
390 type Result<T> = anyhow::Result<T>;
391
392 fn new<T: 'static>(
393 &mut self,
394 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
395 ) -> Self::Result<Entity<T>> {
396 self.app.new(build_entity)
397 }
398
399 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
400 self.app.reserve_entity()
401 }
402
403 fn insert_entity<T: 'static>(
404 &mut self,
405 reservation: gpui::Reservation<T>,
406 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
407 ) -> Self::Result<Entity<T>> {
408 self.app.insert_entity(reservation, build_entity)
409 }
410
411 fn update_entity<T, R>(
412 &mut self,
413 handle: &Entity<T>,
414 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
415 ) -> Self::Result<R>
416 where
417 T: 'static,
418 {
419 self.app.update_entity(handle, update)
420 }
421
422 fn read_entity<T, R>(
423 &self,
424 handle: &Entity<T>,
425 read: impl FnOnce(&T, &App) -> R,
426 ) -> Self::Result<R>
427 where
428 T: 'static,
429 {
430 self.app.read_entity(handle, read)
431 }
432
433 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
434 where
435 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
436 {
437 self.app.update_window(window, f)
438 }
439
440 fn read_window<T, R>(
441 &self,
442 window: &gpui::WindowHandle<T>,
443 read: impl FnOnce(Entity<T>, &App) -> R,
444 ) -> Result<R>
445 where
446 T: 'static,
447 {
448 self.app.read_window(window, read)
449 }
450
451 fn background_spawn<R>(
452 &self,
453 future: impl std::future::Future<Output = R> + Send + 'static,
454 ) -> gpui::Task<R>
455 where
456 R: Send + 'static,
457 {
458 self.app.background_spawn(future)
459 }
460
461 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
462 where
463 G: gpui::Global,
464 {
465 self.app.read_global(callback)
466 }
467}
468
469#[derive(Debug)]
470pub struct Response {
471 messages: Vec<Message>,
472}
473
474impl Response {
475 pub fn new(messages: Vec<Message>) -> Self {
476 Self { messages }
477 }
478
479 pub fn expect_tool(
480 &self,
481 tool_name: &'static str,
482 cx: &mut ExampleContext,
483 ) -> Result<&ToolUse> {
484 let result = self.find_tool_call(tool_name);
485 cx.assert_some(result, format!("called `{}`", tool_name))
486 }
487
488 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
489 self.messages.iter().rev().find_map(|msg| {
490 msg.tool_use
491 .iter()
492 .find(|tool_use| tool_use.name == tool_name)
493 })
494 }
495
496 #[allow(dead_code)]
497 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
498 self.messages.iter().flat_map(|msg| &msg.tool_use)
499 }
500
501 pub fn texts(&self) -> impl Iterator<Item = String> {
502 self.messages.iter().map(|message| message.text.clone())
503 }
504}
505
506#[derive(Debug)]
507pub struct Message {
508 _role: Role,
509 text: String,
510 tool_use: Vec<ToolUse>,
511}
512
513#[derive(Debug)]
514pub struct ToolUse {
515 pub name: String,
516 value: serde_json::Value,
517}
518
519impl ToolUse {
520 pub fn parse_input<Input>(&self) -> Result<Input>
521 where
522 Input: for<'de> serde::Deserialize<'de>,
523 {
524 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
525 }
526}
527
528#[derive(Debug, Eq, PartialEq)]
529pub struct FileEdits {
530 pub hunks: Vec<FileEditHunk>,
531}
532
533#[derive(Debug, Eq, PartialEq)]
534pub struct FileEditHunk {
535 pub base_text: String,
536 pub text: String,
537 pub status: DiffHunkStatus,
538}
539
540impl FileEdits {
541 pub fn has_added_line(&self, line: &str) -> bool {
542 self.hunks.iter().any(|hunk| {
543 hunk.status == DiffHunkStatus::added_none()
544 && hunk.base_text.is_empty()
545 && hunk.text.contains(line)
546 })
547 }
548}