1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::{ContextLoadResult, Thread, ThreadEvent};
14use anyhow::{Result, anyhow};
15use async_trait::async_trait;
16use buffer_diff::DiffHunkStatus;
17use collections::HashMap;
18use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
19use gpui::{App, AppContext, AsyncApp, Entity};
20use language_model::{LanguageModel, Role, StopReason};
21
22pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
23
24#[async_trait(?Send)]
25pub trait Example {
26 fn meta(&self) -> ExampleMetadata;
27 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
28 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
29 Vec::new()
30 }
31 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
32 Vec::new()
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct JudgeAssertion {
38 pub id: String,
39 pub description: String,
40}
41
42#[derive(Clone, Debug)]
43pub struct ExampleMetadata {
44 pub name: String,
45 pub url: String,
46 pub revision: String,
47 pub language_server: Option<LanguageServer>,
48 pub max_assertions: Option<usize>,
49}
50
51#[derive(Clone, Debug)]
52pub struct LanguageServer {
53 pub file_extension: String,
54 pub allow_preexisting_diagnostics: bool,
55}
56
57impl ExampleMetadata {
58 pub fn repo_name(&self) -> String {
59 self.url
60 .split('/')
61 .next_back()
62 .unwrap_or(&"")
63 .trim_end_matches(".git")
64 .into()
65 }
66}
67
68pub struct FailedAssertion(pub String);
69
70impl fmt::Debug for FailedAssertion {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "Assertion failure: {}", self.0)
73 }
74}
75
76impl fmt::Display for FailedAssertion {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "{}", self.0)
79 }
80}
81
82impl Error for FailedAssertion {}
83
84pub struct ExampleContext {
85 meta: ExampleMetadata,
86 log_prefix: String,
87 agent_thread: Entity<agent::Thread>,
88 app: AsyncApp,
89 model: Arc<dyn LanguageModel>,
90 pub assertions: AssertionsReport,
91 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
92}
93
94impl ExampleContext {
95 pub fn new(
96 meta: ExampleMetadata,
97 log_prefix: String,
98 agent_thread: Entity<agent::Thread>,
99 model: Arc<dyn LanguageModel>,
100 app: AsyncApp,
101 ) -> Self {
102 let assertions = AssertionsReport::new(meta.max_assertions);
103
104 Self {
105 meta,
106 log_prefix,
107 agent_thread,
108 assertions,
109 model,
110 app,
111 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
112 }
113 }
114
115 pub fn push_user_message(&mut self, text: impl ToString) {
116 self.app
117 .update_entity(&self.agent_thread, |thread, cx| {
118 thread.insert_user_message(
119 text.to_string(),
120 ContextLoadResult::default(),
121 None,
122 Vec::new(),
123 cx,
124 );
125 })
126 .unwrap();
127 }
128
129 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
130 let message = message.to_string();
131 self.log_assertion(
132 if expected {
133 Ok(())
134 } else {
135 Err(anyhow::Error::from(FailedAssertion(message.clone())))
136 },
137 message,
138 )
139 }
140
141 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
142 let message = message.to_string();
143 self.log_assertion(
144 match option {
145 Some(value) => Ok(value),
146 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
147 },
148 message,
149 )
150 }
151
152 #[allow(dead_code)]
153 pub fn assert_eq<T: PartialEq + Debug>(
154 &mut self,
155 left: T,
156 right: T,
157 message: impl ToString,
158 ) -> Result<()> {
159 let message = message.to_string();
160 self.log_assertion(
161 if left == right {
162 Ok(())
163 } else {
164 println!(
165 "{}{}",
166 self.log_prefix,
167 pretty_assertions::Comparison::new(&left, &right)
168 );
169 Err(anyhow::Error::from(FailedAssertion(message.clone())))
170 },
171 message,
172 )
173 }
174
175 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
176 if let Some(max) = self.meta.max_assertions {
177 if self.assertions.run_count() > max {
178 return Err(anyhow!(
179 "More assertions were run than the stated max_assertions of {}",
180 max
181 ));
182 }
183 }
184
185 self.assertions.ran.push(RanAssertion {
186 id: message.clone(),
187 result: Ok(RanAssertionResult {
188 analysis: None,
189 passed: result.is_ok(),
190 }),
191 });
192
193 if result.is_ok() {
194 println!("{}✅ {}", self.log_prefix, message);
195 } else {
196 println!("{}❌ {}", self.log_prefix, message);
197 }
198
199 result
200 }
201
202 pub async fn run_to_end(&mut self) -> Result<Response> {
203 self.run_turns(u32::MAX).await
204 }
205
206 pub async fn run_turn(&mut self) -> Result<Response> {
207 self.run_turns(1).await
208 }
209
210 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
211 let (mut tx, mut rx) = mpsc::channel(1);
212
213 let tool_metrics = self.tool_metrics.clone();
214 let log_prefix = self.log_prefix.clone();
215 let _subscription = self.app.subscribe(
216 &self.agent_thread,
217 move |thread, event: &ThreadEvent, cx| match event {
218 ThreadEvent::ShowError(thread_error) => {
219 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
220 }
221 ThreadEvent::Stopped(reason) => match reason {
222 Ok(StopReason::EndTurn) => {
223 tx.close_channel();
224 }
225 Ok(StopReason::ToolUse) => {
226 if thread.read(cx).remaining_turns() == 0 {
227 tx.close_channel();
228 }
229 }
230 Ok(StopReason::MaxTokens) => {
231 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
232 }
233 Err(err) => {
234 tx.try_send(Err(anyhow!(err.clone()))).ok();
235 }
236 },
237 ThreadEvent::NewRequest
238 | ThreadEvent::StreamedAssistantText(_, _)
239 | ThreadEvent::StreamedAssistantThinking(_, _)
240 | ThreadEvent::UsePendingTools { .. }
241 | ThreadEvent::CompletionCanceled => {}
242 ThreadEvent::ToolFinished {
243 tool_use_id,
244 pending_tool_use,
245 ..
246 } => {
247 thread.update(cx, |thread, _cx| {
248 if let Some(tool_use) = pending_tool_use {
249 let mut tool_metrics = tool_metrics.lock().unwrap();
250 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
251 let message = if tool_result.is_error {
252 format!("✖︎ {}", tool_use.name)
253 } else {
254 format!("✔︎ {}", tool_use.name)
255 };
256 println!("{log_prefix}{message}");
257 tool_metrics
258 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
259 } else {
260 let message =
261 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
262 println!("{log_prefix}{message}");
263 tool_metrics.insert(tool_use.name.clone(), true);
264 }
265 }
266 });
267 }
268 ThreadEvent::InvalidToolInput { .. } => {
269 println!("{log_prefix} invalid tool input");
270 }
271 ThreadEvent::ToolConfirmationNeeded => {
272 panic!(
273 "{}Bug: Tool confirmation should not be required in eval",
274 log_prefix
275 );
276 }
277 ThreadEvent::StreamedCompletion
278 | ThreadEvent::MessageAdded(_)
279 | ThreadEvent::MessageEdited(_)
280 | ThreadEvent::MessageDeleted(_)
281 | ThreadEvent::SummaryChanged
282 | ThreadEvent::SummaryGenerated
283 | ThreadEvent::ReceivedTextChunk
284 | ThreadEvent::StreamedToolUse { .. }
285 | ThreadEvent::CheckpointChanged
286 | ThreadEvent::UsageUpdated(_)
287 | ThreadEvent::CancelEditing => {
288 tx.try_send(Ok(())).ok();
289 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
290 println!("{}Event: {:#?}", log_prefix, event);
291 }
292 }
293 },
294 );
295
296 let model = self.model.clone();
297
298 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
299 thread.set_remaining_turns(iterations);
300 thread.send_to_model(model, None, cx);
301 thread.messages().len()
302 })?;
303
304 loop {
305 select_biased! {
306 result = rx.next() => {
307 if let Some(result) = result {
308 result?;
309 } else {
310 break;
311 }
312 }
313 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
314 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
315 }
316 }
317 }
318
319 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
320 let mut messages = Vec::new();
321 for message in thread.messages().skip(message_count_before) {
322 messages.push(Message {
323 _role: message.role,
324 text: message.to_string(),
325 tool_use: thread
326 .tool_uses_for_message(message.id, cx)
327 .into_iter()
328 .map(|tool_use| ToolUse {
329 name: tool_use.name.to_string(),
330 value: tool_use.input,
331 })
332 .collect(),
333 });
334 }
335 messages
336 })?;
337
338 let response = Response::new(messages);
339
340 Ok(response)
341 }
342
343 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
344 self.agent_thread
345 .read_with(&self.app, |thread, cx| {
346 let action_log = thread.action_log().read(cx);
347 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
348 |(buffer, diff)| {
349 let snapshot = buffer.read(cx).snapshot();
350
351 let file = snapshot.file().unwrap();
352 let diff = diff.read(cx);
353 let base_text = diff.base_text().text();
354
355 let hunks = diff
356 .hunks(&snapshot, cx)
357 .map(|hunk| FileEditHunk {
358 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
359 text: snapshot
360 .text_for_range(hunk.range.clone())
361 .collect::<String>(),
362 status: hunk.status(),
363 })
364 .collect();
365
366 (file.path().clone(), FileEdits { hunks })
367 },
368 ))
369 })
370 .unwrap()
371 }
372
373 pub fn agent_thread(&self) -> Entity<Thread> {
374 self.agent_thread.clone()
375 }
376}
377
378impl AppContext for ExampleContext {
379 type Result<T> = anyhow::Result<T>;
380
381 fn new<T: 'static>(
382 &mut self,
383 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
384 ) -> Self::Result<Entity<T>> {
385 self.app.new(build_entity)
386 }
387
388 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
389 self.app.reserve_entity()
390 }
391
392 fn insert_entity<T: 'static>(
393 &mut self,
394 reservation: gpui::Reservation<T>,
395 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
396 ) -> Self::Result<Entity<T>> {
397 self.app.insert_entity(reservation, build_entity)
398 }
399
400 fn update_entity<T, R>(
401 &mut self,
402 handle: &Entity<T>,
403 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
404 ) -> Self::Result<R>
405 where
406 T: 'static,
407 {
408 self.app.update_entity(handle, update)
409 }
410
411 fn read_entity<T, R>(
412 &self,
413 handle: &Entity<T>,
414 read: impl FnOnce(&T, &App) -> R,
415 ) -> Self::Result<R>
416 where
417 T: 'static,
418 {
419 self.app.read_entity(handle, read)
420 }
421
422 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
423 where
424 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
425 {
426 self.app.update_window(window, f)
427 }
428
429 fn read_window<T, R>(
430 &self,
431 window: &gpui::WindowHandle<T>,
432 read: impl FnOnce(Entity<T>, &App) -> R,
433 ) -> Result<R>
434 where
435 T: 'static,
436 {
437 self.app.read_window(window, read)
438 }
439
440 fn background_spawn<R>(
441 &self,
442 future: impl std::future::Future<Output = R> + Send + 'static,
443 ) -> gpui::Task<R>
444 where
445 R: Send + 'static,
446 {
447 self.app.background_spawn(future)
448 }
449
450 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
451 where
452 G: gpui::Global,
453 {
454 self.app.read_global(callback)
455 }
456}
457
458#[derive(Debug)]
459pub struct Response {
460 messages: Vec<Message>,
461}
462
463impl Response {
464 pub fn new(messages: Vec<Message>) -> Self {
465 Self { messages }
466 }
467
468 pub fn expect_tool(
469 &self,
470 tool_name: &'static str,
471 cx: &mut ExampleContext,
472 ) -> Result<&ToolUse> {
473 let result = self.messages.iter().find_map(|msg| {
474 msg.tool_use
475 .iter()
476 .find(|tool_use| tool_use.name == tool_name)
477 });
478 cx.assert_some(result, format!("called `{}`", tool_name))
479 }
480
481 #[allow(dead_code)]
482 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
483 self.messages.iter().flat_map(|msg| &msg.tool_use)
484 }
485
486 pub fn texts(&self) -> impl Iterator<Item = String> {
487 self.messages.iter().map(|message| message.text.clone())
488 }
489}
490
491#[derive(Debug)]
492pub struct Message {
493 _role: Role,
494 text: String,
495 tool_use: Vec<ToolUse>,
496}
497
498#[derive(Debug)]
499pub struct ToolUse {
500 pub name: String,
501 value: serde_json::Value,
502}
503
504impl ToolUse {
505 pub fn parse_input<Input>(&self) -> Result<Input>
506 where
507 Input: for<'de> serde::Deserialize<'de>,
508 {
509 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
510 }
511}
512
513#[derive(Debug, Eq, PartialEq)]
514pub struct FileEdits {
515 pub hunks: Vec<FileEditHunk>,
516}
517
518#[derive(Debug, Eq, PartialEq)]
519pub struct FileEditHunk {
520 pub base_text: String,
521 pub text: String,
522 pub status: DiffHunkStatus,
523}
524
525impl FileEdits {
526 pub fn has_added_line(&self, line: &str) -> bool {
527 self.hunks.iter().any(|hunk| {
528 hunk.status == DiffHunkStatus::added_none()
529 && hunk.base_text.is_empty()
530 && hunk.text.contains(line)
531 })
532 }
533}