1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::ThreadEvent;
14use anyhow::{Result, anyhow};
15use async_trait::async_trait;
16use buffer_diff::DiffHunkStatus;
17use collections::HashMap;
18use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
19use gpui::{AppContext, AsyncApp, Entity};
20use language_model::{LanguageModel, Role, StopReason};
21
22pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
23
24#[async_trait(?Send)]
25pub trait Example {
26 fn meta(&self) -> ExampleMetadata;
27 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
28 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
29 Vec::new()
30 }
31 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
32 Vec::new()
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct JudgeAssertion {
38 pub id: String,
39 pub description: String,
40}
41
42#[derive(Clone, Debug)]
43pub struct ExampleMetadata {
44 pub name: String,
45 pub url: String,
46 pub revision: String,
47 pub language_server: Option<LanguageServer>,
48 pub max_assertions: Option<usize>,
49}
50
51#[derive(Clone, Debug)]
52pub struct LanguageServer {
53 pub file_extension: String,
54 pub allow_preexisting_diagnostics: bool,
55}
56
57impl ExampleMetadata {
58 pub fn repo_name(&self) -> String {
59 self.url
60 .split('/')
61 .next_back()
62 .unwrap_or(&"")
63 .trim_end_matches(".git")
64 .into()
65 }
66}
67
68pub struct FailedAssertion(pub String);
69
70impl fmt::Debug for FailedAssertion {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "Assertion failure: {}", self.0)
73 }
74}
75
76impl fmt::Display for FailedAssertion {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "{}", self.0)
79 }
80}
81
82impl Error for FailedAssertion {}
83
84pub struct ExampleContext {
85 meta: ExampleMetadata,
86 log_prefix: String,
87 agent_thread: Entity<agent::Thread>,
88 app: AsyncApp,
89 model: Arc<dyn LanguageModel>,
90 pub assertions: AssertionsReport,
91 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
92}
93
94impl ExampleContext {
95 pub fn new(
96 meta: ExampleMetadata,
97 log_prefix: String,
98 agent_thread: Entity<agent::Thread>,
99 model: Arc<dyn LanguageModel>,
100 app: AsyncApp,
101 ) -> Self {
102 let assertions = AssertionsReport::new(meta.max_assertions);
103
104 Self {
105 meta,
106 log_prefix,
107 agent_thread,
108 assertions,
109 model,
110 app,
111 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
112 }
113 }
114
115 pub fn push_user_message(&mut self, text: impl ToString) {
116 self.app
117 .update_entity(&self.agent_thread, |thread, cx| {
118 thread.insert_user_message(text.to_string(), vec![], None, cx);
119 })
120 .unwrap();
121 }
122
123 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
124 let message = message.to_string();
125 self.log_assertion(
126 if expected {
127 Ok(())
128 } else {
129 Err(anyhow::Error::from(FailedAssertion(message.clone())))
130 },
131 message,
132 )
133 }
134
135 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
136 let message = message.to_string();
137 self.log_assertion(
138 match option {
139 Some(value) => Ok(value),
140 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
141 },
142 message,
143 )
144 }
145
146 #[allow(dead_code)]
147 pub fn assert_eq<T: PartialEq + Debug>(
148 &mut self,
149 left: T,
150 right: T,
151 message: impl ToString,
152 ) -> Result<()> {
153 let message = message.to_string();
154 self.log_assertion(
155 if left == right {
156 Ok(())
157 } else {
158 println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
159 Err(anyhow::Error::from(FailedAssertion(message.clone())))
160 },
161 message,
162 )
163 }
164
165 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
166 if let Some(max) = self.meta.max_assertions {
167 if self.assertions.run_count() > max {
168 return Err(anyhow!(
169 "More assertions were run than the stated max_assertions of {}",
170 max
171 ));
172 }
173 }
174
175 self.assertions.ran.push(RanAssertion {
176 id: message.clone(),
177 result: Ok(RanAssertionResult {
178 analysis: None,
179 passed: result.is_ok(),
180 }),
181 });
182
183 if result.is_ok() {
184 println!("{}✅ {}", self.log_prefix, message);
185 } else {
186 println!("{}❌ {}", self.log_prefix, message);
187 }
188
189 result
190 }
191
192 pub async fn run_to_end(&mut self) -> Result<Response> {
193 self.run_turns(u32::MAX).await
194 }
195
196 pub async fn run_turn(&mut self) -> Result<Response> {
197 self.run_turns(1).await
198 }
199
200 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
201 let (mut tx, mut rx) = mpsc::channel(1);
202
203 let tool_metrics = self.tool_metrics.clone();
204 let log_prefix = self.log_prefix.clone();
205 let _subscription = self.app.subscribe(
206 &self.agent_thread,
207 move |thread, event: &ThreadEvent, cx| match event {
208 ThreadEvent::ShowError(thread_error) => {
209 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
210 }
211 ThreadEvent::Stopped(reason) => match reason {
212 Ok(StopReason::EndTurn) => {
213 tx.close_channel();
214 }
215 Ok(StopReason::ToolUse) => {
216 if thread.read(cx).remaining_turns() == 0 {
217 tx.close_channel();
218 }
219 }
220 Ok(StopReason::MaxTokens) => {
221 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
222 }
223 Err(err) => {
224 tx.try_send(Err(anyhow!(err.clone()))).ok();
225 }
226 },
227 ThreadEvent::StreamedAssistantText(_, _)
228 | ThreadEvent::StreamedAssistantThinking(_, _)
229 | ThreadEvent::UsePendingTools { .. } => {}
230 ThreadEvent::ToolFinished {
231 tool_use_id,
232 pending_tool_use,
233 ..
234 } => {
235 thread.update(cx, |thread, _cx| {
236 if let Some(tool_use) = pending_tool_use {
237 let mut tool_metrics = tool_metrics.lock().unwrap();
238 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
239 let message = if tool_result.is_error {
240 format!("✖︎ {}", tool_use.name)
241 } else {
242 format!("✔︎ {}", tool_use.name)
243 };
244 println!("{log_prefix}{message}");
245 tool_metrics
246 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
247 } else {
248 let message =
249 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
250 println!("{log_prefix}{message}");
251 tool_metrics.insert(tool_use.name.clone(), true);
252 }
253 }
254 });
255 }
256 ThreadEvent::InvalidToolInput { .. } => {
257 println!("{log_prefix} invalid tool input");
258 }
259 ThreadEvent::ToolConfirmationNeeded => {
260 panic!(
261 "{}Bug: Tool confirmation should not be required in eval",
262 log_prefix
263 );
264 }
265 ThreadEvent::StreamedCompletion
266 | ThreadEvent::MessageAdded(_)
267 | ThreadEvent::MessageEdited(_)
268 | ThreadEvent::MessageDeleted(_)
269 | ThreadEvent::SummaryChanged
270 | ThreadEvent::SummaryGenerated
271 | ThreadEvent::ReceivedTextChunk
272 | ThreadEvent::StreamedToolUse { .. }
273 | ThreadEvent::CheckpointChanged
274 | ThreadEvent::UsageUpdated(_) => {
275 tx.try_send(Ok(())).ok();
276 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
277 println!("{}Event: {:#?}", log_prefix, event);
278 }
279 }
280 },
281 );
282
283 let model = self.model.clone();
284
285 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
286 thread.set_remaining_turns(iterations);
287 thread.send_to_model(model, None, cx);
288 thread.messages().len()
289 })?;
290
291 loop {
292 select_biased! {
293 result = rx.next() => {
294 if let Some(result) = result {
295 result?;
296 } else {
297 break;
298 }
299 }
300 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
301 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
302 }
303 }
304 }
305
306 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
307 let mut messages = Vec::new();
308 for message in thread.messages().skip(message_count_before) {
309 messages.push(Message {
310 _role: message.role,
311 _text: message.to_string(),
312 tool_use: thread
313 .tool_uses_for_message(message.id, cx)
314 .into_iter()
315 .map(|tool_use| ToolUse {
316 name: tool_use.name.to_string(),
317 value: tool_use.input,
318 })
319 .collect(),
320 });
321 }
322 messages
323 })?;
324
325 let response = Response::new(messages);
326
327 Ok(response)
328 }
329
330 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
331 self.app
332 .read_entity(&self.agent_thread, |thread, cx| {
333 let action_log = thread.action_log().read(cx);
334 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
335 |(buffer, diff)| {
336 let snapshot = buffer.read(cx).snapshot();
337
338 let file = snapshot.file().unwrap();
339 let diff = diff.read(cx);
340 let base_text = diff.base_text().text();
341
342 let hunks = diff
343 .hunks(&snapshot, cx)
344 .map(|hunk| FileEditHunk {
345 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
346 text: snapshot
347 .text_for_range(hunk.range.clone())
348 .collect::<String>(),
349 status: hunk.status(),
350 })
351 .collect();
352
353 (file.path().clone(), FileEdits { hunks })
354 },
355 ))
356 })
357 .unwrap()
358 }
359}
360
361#[derive(Debug)]
362pub struct Response {
363 messages: Vec<Message>,
364}
365
366impl Response {
367 pub fn new(messages: Vec<Message>) -> Self {
368 Self { messages }
369 }
370
371 pub fn expect_tool(
372 &self,
373 tool_name: &'static str,
374 cx: &mut ExampleContext,
375 ) -> Result<&ToolUse> {
376 let result = self.messages.iter().find_map(|msg| {
377 msg.tool_use
378 .iter()
379 .find(|tool_use| tool_use.name == tool_name)
380 });
381 cx.assert_some(result, format!("called `{}`", tool_name))
382 }
383
384 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
385 self.messages.iter().flat_map(|msg| &msg.tool_use)
386 }
387}
388
389#[derive(Debug)]
390pub struct Message {
391 _role: Role,
392 _text: String,
393 tool_use: Vec<ToolUse>,
394}
395
396#[derive(Debug)]
397pub struct ToolUse {
398 pub name: String,
399 value: serde_json::Value,
400}
401
402impl ToolUse {
403 pub fn parse_input<Input>(&self) -> Result<Input>
404 where
405 Input: for<'de> serde::Deserialize<'de>,
406 {
407 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
408 }
409}
410
411#[derive(Debug)]
412pub struct FileEdits {
413 hunks: Vec<FileEditHunk>,
414}
415
416#[derive(Debug)]
417struct FileEditHunk {
418 base_text: String,
419 text: String,
420 status: DiffHunkStatus,
421}
422
423impl FileEdits {
424 pub fn has_added_line(&self, line: &str) -> bool {
425 self.hunks.iter().any(|hunk| {
426 hunk.status == DiffHunkStatus::added_none()
427 && hunk.base_text.is_empty()
428 && hunk.text.contains(line)
429 })
430 }
431}