1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7
8use crate::{
9 ToolMetrics,
10 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
11};
12use agent::{ContextLoadResult, Thread, ThreadEvent};
13use agent_settings::AgentProfileId;
14use anyhow::{Result, anyhow};
15use async_trait::async_trait;
16use buffer_diff::DiffHunkStatus;
17use cloud_llm_client::CompletionIntent;
18use collections::HashMap;
19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
20use gpui::{App, AppContext, AsyncApp, Entity};
21use language_model::{LanguageModel, Role, StopReason};
22use util::rel_path::RelPath;
23
24pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
25
26#[async_trait(?Send)]
27pub trait Example {
28 fn meta(&self) -> ExampleMetadata;
29 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
30 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
31 Vec::new()
32 }
33 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
34 Vec::new()
35 }
36}
37
38#[derive(Clone, Debug)]
39pub struct JudgeAssertion {
40 pub id: String,
41 pub description: String,
42}
43
44#[derive(Clone, Debug)]
45pub struct ExampleMetadata {
46 pub name: String,
47 pub url: String,
48 pub revision: String,
49 pub language_server: Option<LanguageServer>,
50 pub max_assertions: Option<usize>,
51 pub profile_id: AgentProfileId,
52 pub existing_thread_json: Option<String>,
53 pub max_turns: Option<u32>,
54}
55
56#[derive(Clone, Debug)]
57pub struct LanguageServer {
58 pub file_extension: String,
59 pub allow_preexisting_diagnostics: bool,
60}
61
62impl ExampleMetadata {
63 pub fn repo_name(&self) -> String {
64 self.url
65 .split('/')
66 .next_back()
67 .unwrap_or("")
68 .trim_end_matches(".git")
69 .into()
70 }
71}
72
73pub struct FailedAssertion(pub String);
74
75impl fmt::Debug for FailedAssertion {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(f, "Assertion failure: {}", self.0)
78 }
79}
80
81impl fmt::Display for FailedAssertion {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(f, "{}", self.0)
84 }
85}
86
87impl Error for FailedAssertion {}
88
89pub struct ExampleContext {
90 meta: ExampleMetadata,
91 log_prefix: String,
92 agent_thread: Entity<agent::Thread>,
93 app: AsyncApp,
94 model: Arc<dyn LanguageModel>,
95 pub assertions: AssertionsReport,
96 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
97}
98
99impl ExampleContext {
100 pub fn new(
101 meta: ExampleMetadata,
102 log_prefix: String,
103 agent_thread: Entity<Thread>,
104 model: Arc<dyn LanguageModel>,
105 app: AsyncApp,
106 ) -> Self {
107 let assertions = AssertionsReport::new(meta.max_assertions);
108
109 Self {
110 meta,
111 log_prefix,
112 agent_thread,
113 assertions,
114 model,
115 app,
116 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
117 }
118 }
119
120 pub fn push_user_message(&mut self, text: impl ToString) {
121 self.app
122 .update_entity(&self.agent_thread, |thread, cx| {
123 thread.insert_user_message(
124 text.to_string(),
125 ContextLoadResult::default(),
126 None,
127 Vec::new(),
128 cx,
129 );
130 })
131 .unwrap();
132 }
133
134 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
135 let message = message.to_string();
136 self.log_assertion(
137 if expected {
138 Ok(())
139 } else {
140 Err(anyhow::Error::from(FailedAssertion(message.clone())))
141 },
142 message,
143 )
144 }
145
146 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
147 let message = message.to_string();
148 self.log_assertion(
149 match option {
150 Some(value) => Ok(value),
151 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
152 },
153 message,
154 )
155 }
156
157 #[allow(dead_code)]
158 pub fn assert_eq<T: PartialEq + Debug>(
159 &mut self,
160 left: T,
161 right: T,
162 message: impl ToString,
163 ) -> Result<()> {
164 let message = message.to_string();
165 self.log_assertion(
166 if left == right {
167 Ok(())
168 } else {
169 println!(
170 "{}{}",
171 self.log_prefix,
172 pretty_assertions::Comparison::new(&left, &right)
173 );
174 Err(anyhow::Error::from(FailedAssertion(message.clone())))
175 },
176 message,
177 )
178 }
179
180 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
181 if let Some(max) = self.meta.max_assertions {
182 anyhow::ensure!(
183 self.assertions.run_count() <= max,
184 "More assertions were run than the stated max_assertions of {max}"
185 );
186 }
187
188 self.assertions.ran.push(RanAssertion {
189 id: message.clone(),
190 result: Ok(RanAssertionResult {
191 analysis: None,
192 passed: result.is_ok(),
193 }),
194 });
195
196 if result.is_ok() {
197 println!("{}✅ {}", self.log_prefix, message);
198 } else {
199 println!("{}❌ {}", self.log_prefix, message);
200 }
201
202 result
203 }
204
205 pub async fn run_to_end(&mut self) -> Result<Response> {
206 self.run_turns(u32::MAX).await
207 }
208
209 pub async fn run_turn(&mut self) -> Result<Response> {
210 self.run_turns(1).await
211 }
212
213 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
214 let (mut tx, mut rx) = mpsc::channel(1);
215
216 let tool_metrics = self.tool_metrics.clone();
217 let log_prefix = self.log_prefix.clone();
218 let _subscription = self.app.subscribe(
219 &self.agent_thread,
220 move |thread, event: &ThreadEvent, cx| match event {
221 ThreadEvent::ShowError(thread_error) => {
222 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
223 }
224 ThreadEvent::Stopped(reason) => match reason {
225 Ok(StopReason::EndTurn) => {
226 tx.close_channel();
227 }
228 Ok(StopReason::ToolUse) => {
229 if thread.read(cx).remaining_turns() == 0 {
230 tx.close_channel();
231 }
232 }
233 Ok(StopReason::MaxTokens) => {
234 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
235 }
236 Ok(StopReason::Refusal) => {
237 tx.try_send(Err(anyhow!("Model refused to generate content")))
238 .ok();
239 }
240 Err(err) => {
241 tx.try_send(Err(anyhow!(err.clone()))).ok();
242 }
243 },
244 ThreadEvent::NewRequest
245 | ThreadEvent::StreamedAssistantText(_, _)
246 | ThreadEvent::StreamedAssistantThinking(_, _)
247 | ThreadEvent::UsePendingTools { .. }
248 | ThreadEvent::CompletionCanceled => {}
249 ThreadEvent::ToolUseLimitReached => {}
250 ThreadEvent::ToolFinished {
251 tool_use_id,
252 pending_tool_use,
253 ..
254 } => {
255 thread.update(cx, |thread, _cx| {
256 if let Some(tool_use) = pending_tool_use {
257 let mut tool_metrics = tool_metrics.lock().unwrap();
258 if let Some(tool_result) = thread.tool_result(tool_use_id) {
259 let message = if tool_result.is_error {
260 format!("✖︎ {}", tool_use.name)
261 } else {
262 format!("✔︎ {}", tool_use.name)
263 };
264 println!("{log_prefix}{message}");
265 tool_metrics
266 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
267 } else {
268 let message =
269 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
270 println!("{log_prefix}{message}");
271 tool_metrics.insert(tool_use.name.clone(), true);
272 }
273 }
274 });
275 }
276 ThreadEvent::InvalidToolInput { .. } => {
277 println!("{log_prefix} invalid tool input");
278 }
279 ThreadEvent::MissingToolUse {
280 tool_use_id: _,
281 ui_text,
282 } => {
283 println!("{log_prefix} {ui_text}");
284 }
285 ThreadEvent::ToolConfirmationNeeded => {
286 panic!(
287 "{}Bug: Tool confirmation should not be required in eval",
288 log_prefix
289 );
290 }
291 ThreadEvent::StreamedCompletion
292 | ThreadEvent::MessageAdded(_)
293 | ThreadEvent::MessageEdited(_)
294 | ThreadEvent::MessageDeleted(_)
295 | ThreadEvent::SummaryChanged
296 | ThreadEvent::SummaryGenerated
297 | ThreadEvent::ProfileChanged
298 | ThreadEvent::ReceivedTextChunk
299 | ThreadEvent::StreamedToolUse { .. }
300 | ThreadEvent::CheckpointChanged
301 | ThreadEvent::CancelEditing => {
302 tx.try_send(Ok(())).ok();
303 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
304 println!("{}Event: {:#?}", log_prefix, event);
305 }
306 }
307 },
308 );
309
310 let model = self.model.clone();
311
312 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
313 thread.set_remaining_turns(iterations);
314 thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
315 thread.messages().len()
316 })?;
317
318 loop {
319 select_biased! {
320 result = rx.next() => {
321 if let Some(result) = result {
322 result?;
323 } else {
324 break;
325 }
326 }
327 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
328 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
329 }
330 }
331 }
332
333 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
334 let mut messages = Vec::new();
335 for message in thread.messages().skip(message_count_before) {
336 messages.push(Message {
337 _role: message.role,
338 text: message.to_message_content(),
339 tool_use: thread
340 .tool_uses_for_message(message.id, cx)
341 .into_iter()
342 .map(|tool_use| ToolUse {
343 name: tool_use.name.to_string(),
344 value: tool_use.input,
345 })
346 .collect(),
347 });
348 }
349 messages
350 })?;
351
352 let response = Response::new(messages);
353
354 Ok(response)
355 }
356
357 pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
358 self.agent_thread
359 .read_with(&self.app, |thread, cx| {
360 let action_log = thread.action_log().read(cx);
361 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
362 |(buffer, diff)| {
363 let snapshot = buffer.read(cx).snapshot();
364
365 let file = snapshot.file().unwrap();
366 let diff = diff.read(cx);
367 let base_text = diff.base_text().text();
368
369 let hunks = diff
370 .hunks(&snapshot, cx)
371 .map(|hunk| FileEditHunk {
372 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
373 text: snapshot
374 .text_for_range(hunk.range.clone())
375 .collect::<String>(),
376 status: hunk.status(),
377 })
378 .collect();
379
380 (file.path().clone(), FileEdits { hunks })
381 },
382 ))
383 })
384 .unwrap()
385 }
386
387 pub fn agent_thread(&self) -> Entity<Thread> {
388 self.agent_thread.clone()
389 }
390}
391
392impl AppContext for ExampleContext {
393 type Result<T> = anyhow::Result<T>;
394
395 fn new<T: 'static>(
396 &mut self,
397 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
398 ) -> Self::Result<Entity<T>> {
399 self.app.new(build_entity)
400 }
401
402 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
403 self.app.reserve_entity()
404 }
405
406 fn insert_entity<T: 'static>(
407 &mut self,
408 reservation: gpui::Reservation<T>,
409 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
410 ) -> Self::Result<Entity<T>> {
411 self.app.insert_entity(reservation, build_entity)
412 }
413
414 fn update_entity<T, R>(
415 &mut self,
416 handle: &Entity<T>,
417 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
418 ) -> Self::Result<R>
419 where
420 T: 'static,
421 {
422 self.app.update_entity(handle, update)
423 }
424
425 fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
426 where
427 T: 'static,
428 {
429 self.app.as_mut(handle)
430 }
431
432 fn read_entity<T, R>(
433 &self,
434 handle: &Entity<T>,
435 read: impl FnOnce(&T, &App) -> R,
436 ) -> Self::Result<R>
437 where
438 T: 'static,
439 {
440 self.app.read_entity(handle, read)
441 }
442
443 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
444 where
445 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
446 {
447 self.app.update_window(window, f)
448 }
449
450 fn read_window<T, R>(
451 &self,
452 window: &gpui::WindowHandle<T>,
453 read: impl FnOnce(Entity<T>, &App) -> R,
454 ) -> Result<R>
455 where
456 T: 'static,
457 {
458 self.app.read_window(window, read)
459 }
460
461 fn background_spawn<R>(
462 &self,
463 future: impl std::future::Future<Output = R> + Send + 'static,
464 ) -> gpui::Task<R>
465 where
466 R: Send + 'static,
467 {
468 self.app.background_spawn(future)
469 }
470
471 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
472 where
473 G: gpui::Global,
474 {
475 self.app.read_global(callback)
476 }
477}
478
479#[derive(Debug)]
480pub struct Response {
481 messages: Vec<Message>,
482}
483
484impl Response {
485 pub fn new(messages: Vec<Message>) -> Self {
486 Self { messages }
487 }
488
489 pub fn expect_tool(
490 &self,
491 tool_name: &'static str,
492 cx: &mut ExampleContext,
493 ) -> Result<&ToolUse> {
494 let result = self.find_tool_call(tool_name);
495 cx.assert_some(result, format!("called `{}`", tool_name))
496 }
497
498 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
499 self.messages.iter().rev().find_map(|msg| {
500 msg.tool_use
501 .iter()
502 .find(|tool_use| tool_use.name == tool_name)
503 })
504 }
505
506 #[allow(dead_code)]
507 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
508 self.messages.iter().flat_map(|msg| &msg.tool_use)
509 }
510
511 pub fn texts(&self) -> impl Iterator<Item = String> {
512 self.messages.iter().map(|message| message.text.clone())
513 }
514}
515
516#[derive(Debug)]
517pub struct Message {
518 _role: Role,
519 text: String,
520 tool_use: Vec<ToolUse>,
521}
522
523#[derive(Debug)]
524pub struct ToolUse {
525 pub name: String,
526 value: serde_json::Value,
527}
528
529impl ToolUse {
530 pub fn parse_input<Input>(&self) -> Result<Input>
531 where
532 Input: for<'de> serde::Deserialize<'de>,
533 {
534 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
535 }
536}
537
538#[derive(Debug, Eq, PartialEq)]
539pub struct FileEdits {
540 pub hunks: Vec<FileEditHunk>,
541}
542
543#[derive(Debug, Eq, PartialEq)]
544pub struct FileEditHunk {
545 pub base_text: String,
546 pub text: String,
547 pub status: DiffHunkStatus,
548}
549
550impl FileEdits {
551 pub fn has_added_line(&self, line: &str) -> bool {
552 self.hunks.iter().any(|hunk| {
553 hunk.status == DiffHunkStatus::added_none()
554 && hunk.base_text.is_empty()
555 && hunk.text.contains(line)
556 })
557 }
558}