example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::{ContextLoadResult, Thread, ThreadEvent};
 14use agent_settings::AgentProfileId;
 15use anyhow::{Result, anyhow};
 16use async_trait::async_trait;
 17use buffer_diff::DiffHunkStatus;
 18use cloud_llm_client::CompletionIntent;
 19use collections::HashMap;
 20use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 21use gpui::{App, AppContext, AsyncApp, Entity};
 22use language_model::{LanguageModel, Role, StopReason};
 23
 24pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 25
 26#[async_trait(?Send)]
 27pub trait Example {
 28    fn meta(&self) -> ExampleMetadata;
 29    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 30    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 31        Vec::new()
 32    }
 33    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 34        Vec::new()
 35    }
 36}
 37
 38#[derive(Clone, Debug)]
 39pub struct JudgeAssertion {
 40    pub id: String,
 41    pub description: String,
 42}
 43
 44#[derive(Clone, Debug)]
 45pub struct ExampleMetadata {
 46    pub name: String,
 47    pub url: String,
 48    pub revision: String,
 49    pub language_server: Option<LanguageServer>,
 50    pub max_assertions: Option<usize>,
 51    pub profile_id: AgentProfileId,
 52    pub existing_thread_json: Option<String>,
 53    pub max_turns: Option<u32>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct LanguageServer {
 58    pub file_extension: String,
 59    pub allow_preexisting_diagnostics: bool,
 60}
 61
 62impl ExampleMetadata {
 63    pub fn repo_name(&self) -> String {
 64        self.url
 65            .split('/')
 66            .next_back()
 67            .unwrap_or("")
 68            .trim_end_matches(".git")
 69            .into()
 70    }
 71}
 72
 73pub struct FailedAssertion(pub String);
 74
 75impl fmt::Debug for FailedAssertion {
 76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 77        write!(f, "Assertion failure: {}", self.0)
 78    }
 79}
 80
 81impl fmt::Display for FailedAssertion {
 82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 83        write!(f, "{}", self.0)
 84    }
 85}
 86
 87impl Error for FailedAssertion {}
 88
 89pub struct ExampleContext {
 90    meta: ExampleMetadata,
 91    log_prefix: String,
 92    agent_thread: Entity<agent::Thread>,
 93    app: AsyncApp,
 94    model: Arc<dyn LanguageModel>,
 95    pub assertions: AssertionsReport,
 96    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 97}
 98
 99impl ExampleContext {
100    pub fn new(
101        meta: ExampleMetadata,
102        log_prefix: String,
103        agent_thread: Entity<Thread>,
104        model: Arc<dyn LanguageModel>,
105        app: AsyncApp,
106    ) -> Self {
107        let assertions = AssertionsReport::new(meta.max_assertions);
108
109        Self {
110            meta,
111            log_prefix,
112            agent_thread,
113            assertions,
114            model,
115            app,
116            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
117        }
118    }
119
120    pub fn push_user_message(&mut self, text: impl ToString) {
121        self.app
122            .update_entity(&self.agent_thread, |thread, cx| {
123                thread.insert_user_message(
124                    text.to_string(),
125                    ContextLoadResult::default(),
126                    None,
127                    Vec::new(),
128                    cx,
129                );
130            })
131            .unwrap();
132    }
133
134    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
135        let message = message.to_string();
136        self.log_assertion(
137            if expected {
138                Ok(())
139            } else {
140                Err(anyhow::Error::from(FailedAssertion(message.clone())))
141            },
142            message,
143        )
144    }
145
146    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
147        let message = message.to_string();
148        self.log_assertion(
149            match option {
150                Some(value) => Ok(value),
151                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
152            },
153            message,
154        )
155    }
156
157    #[allow(dead_code)]
158    pub fn assert_eq<T: PartialEq + Debug>(
159        &mut self,
160        left: T,
161        right: T,
162        message: impl ToString,
163    ) -> Result<()> {
164        let message = message.to_string();
165        self.log_assertion(
166            if left == right {
167                Ok(())
168            } else {
169                println!(
170                    "{}{}",
171                    self.log_prefix,
172                    pretty_assertions::Comparison::new(&left, &right)
173                );
174                Err(anyhow::Error::from(FailedAssertion(message.clone())))
175            },
176            message,
177        )
178    }
179
180    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
181        if let Some(max) = self.meta.max_assertions {
182            anyhow::ensure!(
183                self.assertions.run_count() <= max,
184                "More assertions were run than the stated max_assertions of {max}"
185            );
186        }
187
188        self.assertions.ran.push(RanAssertion {
189            id: message.clone(),
190            result: Ok(RanAssertionResult {
191                analysis: None,
192                passed: result.is_ok(),
193            }),
194        });
195
196        if result.is_ok() {
197            println!("{}{}", self.log_prefix, message);
198        } else {
199            println!("{}{}", self.log_prefix, message);
200        }
201
202        result
203    }
204
205    pub async fn run_to_end(&mut self) -> Result<Response> {
206        self.run_turns(u32::MAX).await
207    }
208
209    pub async fn run_turn(&mut self) -> Result<Response> {
210        self.run_turns(1).await
211    }
212
213    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
214        let (mut tx, mut rx) = mpsc::channel(1);
215
216        let tool_metrics = self.tool_metrics.clone();
217        let log_prefix = self.log_prefix.clone();
218        let _subscription = self.app.subscribe(
219            &self.agent_thread,
220            move |thread, event: &ThreadEvent, cx| match event {
221                ThreadEvent::ShowError(thread_error) => {
222                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
223                }
224                ThreadEvent::Stopped(reason) => match reason {
225                    Ok(StopReason::EndTurn) => {
226                        tx.close_channel();
227                    }
228                    Ok(StopReason::ToolUse) => {
229                        if thread.read(cx).remaining_turns() == 0 {
230                            tx.close_channel();
231                        }
232                    }
233                    Ok(StopReason::MaxTokens) => {
234                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
235                    }
236                    Ok(StopReason::Refusal) => {
237                        tx.try_send(Err(anyhow!("Model refused to generate content")))
238                            .ok();
239                    }
240                    Err(err) => {
241                        tx.try_send(Err(anyhow!(err.clone()))).ok();
242                    }
243                },
244                ThreadEvent::NewRequest
245                | ThreadEvent::StreamedAssistantText(_, _)
246                | ThreadEvent::StreamedAssistantThinking(_, _)
247                | ThreadEvent::UsePendingTools { .. }
248                | ThreadEvent::CompletionCanceled => {}
249                ThreadEvent::ToolUseLimitReached => {}
250                ThreadEvent::ToolFinished {
251                    tool_use_id,
252                    pending_tool_use,
253                    ..
254                } => {
255                    thread.update(cx, |thread, _cx| {
256                        if let Some(tool_use) = pending_tool_use {
257                            let mut tool_metrics = tool_metrics.lock().unwrap();
258                            if let Some(tool_result) = thread.tool_result(tool_use_id) {
259                                let message = if tool_result.is_error {
260                                    format!("✖︎ {}", tool_use.name)
261                                } else {
262                                    format!("✔︎ {}", tool_use.name)
263                                };
264                                println!("{log_prefix}{message}");
265                                tool_metrics
266                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
267                            } else {
268                                let message =
269                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
270                                println!("{log_prefix}{message}");
271                                tool_metrics.insert(tool_use.name.clone(), true);
272                            }
273                        }
274                    });
275                }
276                ThreadEvent::InvalidToolInput { .. } => {
277                    println!("{log_prefix} invalid tool input");
278                }
279                ThreadEvent::MissingToolUse {
280                    tool_use_id: _,
281                    ui_text,
282                } => {
283                    println!("{log_prefix} {ui_text}");
284                }
285                ThreadEvent::ToolConfirmationNeeded => {
286                    panic!(
287                        "{}Bug: Tool confirmation should not be required in eval",
288                        log_prefix
289                    );
290                }
291                ThreadEvent::StreamedCompletion
292                | ThreadEvent::MessageAdded(_)
293                | ThreadEvent::MessageEdited(_)
294                | ThreadEvent::MessageDeleted(_)
295                | ThreadEvent::SummaryChanged
296                | ThreadEvent::SummaryGenerated
297                | ThreadEvent::ProfileChanged
298                | ThreadEvent::ReceivedTextChunk
299                | ThreadEvent::StreamedToolUse { .. }
300                | ThreadEvent::CheckpointChanged
301                | ThreadEvent::CancelEditing => {
302                    tx.try_send(Ok(())).ok();
303                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
304                        println!("{}Event: {:#?}", log_prefix, event);
305                    }
306                }
307            },
308        );
309
310        let model = self.model.clone();
311
312        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
313            thread.set_remaining_turns(iterations);
314            thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
315            thread.messages().len()
316        })?;
317
318        loop {
319            select_biased! {
320                result = rx.next() => {
321                    if let Some(result) = result {
322                        result?;
323                    } else {
324                        break;
325                    }
326                }
327                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
328                    anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
329                }
330            }
331        }
332
333        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
334            let mut messages = Vec::new();
335            for message in thread.messages().skip(message_count_before) {
336                messages.push(Message {
337                    _role: message.role,
338                    text: message.to_message_content(),
339                    tool_use: thread
340                        .tool_uses_for_message(message.id, cx)
341                        .into_iter()
342                        .map(|tool_use| ToolUse {
343                            name: tool_use.name.to_string(),
344                            value: tool_use.input,
345                        })
346                        .collect(),
347                });
348            }
349            messages
350        })?;
351
352        let response = Response::new(messages);
353
354        Ok(response)
355    }
356
357    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
358        self.agent_thread
359            .read_with(&self.app, |thread, cx| {
360                let action_log = thread.action_log().read(cx);
361                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
362                    |(buffer, diff)| {
363                        let snapshot = buffer.read(cx).snapshot();
364
365                        let file = snapshot.file().unwrap();
366                        let diff = diff.read(cx);
367                        let base_text = diff.base_text().text();
368
369                        let hunks = diff
370                            .hunks(&snapshot, cx)
371                            .map(|hunk| FileEditHunk {
372                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
373                                text: snapshot
374                                    .text_for_range(hunk.range.clone())
375                                    .collect::<String>(),
376                                status: hunk.status(),
377                            })
378                            .collect();
379
380                        (file.path().clone(), FileEdits { hunks })
381                    },
382                ))
383            })
384            .unwrap()
385    }
386
387    pub fn agent_thread(&self) -> Entity<Thread> {
388        self.agent_thread.clone()
389    }
390}
391
392impl AppContext for ExampleContext {
393    type Result<T> = anyhow::Result<T>;
394
395    fn new<T: 'static>(
396        &mut self,
397        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
398    ) -> Self::Result<Entity<T>> {
399        self.app.new(build_entity)
400    }
401
402    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
403        self.app.reserve_entity()
404    }
405
406    fn insert_entity<T: 'static>(
407        &mut self,
408        reservation: gpui::Reservation<T>,
409        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
410    ) -> Self::Result<Entity<T>> {
411        self.app.insert_entity(reservation, build_entity)
412    }
413
414    fn update_entity<T, R>(
415        &mut self,
416        handle: &Entity<T>,
417        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
418    ) -> Self::Result<R>
419    where
420        T: 'static,
421    {
422        self.app.update_entity(handle, update)
423    }
424
425    fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
426    where
427        T: 'static,
428    {
429        self.app.as_mut(handle)
430    }
431
432    fn read_entity<T, R>(
433        &self,
434        handle: &Entity<T>,
435        read: impl FnOnce(&T, &App) -> R,
436    ) -> Self::Result<R>
437    where
438        T: 'static,
439    {
440        self.app.read_entity(handle, read)
441    }
442
443    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
444    where
445        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
446    {
447        self.app.update_window(window, f)
448    }
449
450    fn read_window<T, R>(
451        &self,
452        window: &gpui::WindowHandle<T>,
453        read: impl FnOnce(Entity<T>, &App) -> R,
454    ) -> Result<R>
455    where
456        T: 'static,
457    {
458        self.app.read_window(window, read)
459    }
460
461    fn background_spawn<R>(
462        &self,
463        future: impl std::future::Future<Output = R> + Send + 'static,
464    ) -> gpui::Task<R>
465    where
466        R: Send + 'static,
467    {
468        self.app.background_spawn(future)
469    }
470
471    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
472    where
473        G: gpui::Global,
474    {
475        self.app.read_global(callback)
476    }
477}
478
479#[derive(Debug)]
480pub struct Response {
481    messages: Vec<Message>,
482}
483
484impl Response {
485    pub fn new(messages: Vec<Message>) -> Self {
486        Self { messages }
487    }
488
489    pub fn expect_tool(
490        &self,
491        tool_name: &'static str,
492        cx: &mut ExampleContext,
493    ) -> Result<&ToolUse> {
494        let result = self.find_tool_call(tool_name);
495        cx.assert_some(result, format!("called `{}`", tool_name))
496    }
497
498    pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
499        self.messages.iter().rev().find_map(|msg| {
500            msg.tool_use
501                .iter()
502                .find(|tool_use| tool_use.name == tool_name)
503        })
504    }
505
506    #[allow(dead_code)]
507    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
508        self.messages.iter().flat_map(|msg| &msg.tool_use)
509    }
510
511    pub fn texts(&self) -> impl Iterator<Item = String> {
512        self.messages.iter().map(|message| message.text.clone())
513    }
514}
515
516#[derive(Debug)]
517pub struct Message {
518    _role: Role,
519    text: String,
520    tool_use: Vec<ToolUse>,
521}
522
523#[derive(Debug)]
524pub struct ToolUse {
525    pub name: String,
526    value: serde_json::Value,
527}
528
529impl ToolUse {
530    pub fn parse_input<Input>(&self) -> Result<Input>
531    where
532        Input: for<'de> serde::Deserialize<'de>,
533    {
534        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
535    }
536}
537
538#[derive(Debug, Eq, PartialEq)]
539pub struct FileEdits {
540    pub hunks: Vec<FileEditHunk>,
541}
542
543#[derive(Debug, Eq, PartialEq)]
544pub struct FileEditHunk {
545    pub base_text: String,
546    pub text: String,
547    pub status: DiffHunkStatus,
548}
549
550impl FileEdits {
551    pub fn has_added_line(&self, line: &str) -> bool {
552        self.hunks.iter().any(|hunk| {
553            hunk.status == DiffHunkStatus::added_none()
554                && hunk.base_text.is_empty()
555                && hunk.text.contains(line)
556        })
557    }
558}