code_block_citations.rs

  1use anyhow::Result;
  2use assistant_settings::AgentProfileId;
  3use async_trait::async_trait;
  4use markdown::PathWithRange;
  5
  6use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
  7
  8pub struct CodeBlockCitations;
  9
 10const FENCE: &str = "```";
 11
 12#[async_trait(?Send)]
 13impl Example for CodeBlockCitations {
 14    fn meta(&self) -> ExampleMetadata {
 15        ExampleMetadata {
 16            name: "code_block_citations".to_string(),
 17            url: "https://github.com/zed-industries/zed.git".to_string(),
 18            revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
 19            language_server: Some(LanguageServer {
 20                file_extension: "rs".to_string(),
 21                allow_preexisting_diagnostics: false,
 22            }),
 23            max_assertions: None,
 24            profile_id: AgentProfileId::default(),
 25        }
 26    }
 27
 28    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
 29        const FILENAME: &str = "assistant_tool.rs";
 30        cx.push_user_message(format!(
 31            r#"
 32            Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
 33
 34            Please show each method in a separate code snippet.
 35            "#
 36        ));
 37
 38        // Verify that the messages all have the correct formatting.
 39        let texts: Vec<String> = cx.run_to_end().await?.texts().collect();
 40        let closing_fence = format!("\n{FENCE}");
 41
 42        for text in texts.iter() {
 43            let mut text = text.as_str();
 44
 45            while let Some(index) = text.find(FENCE) {
 46                // Advance text past the opening backticks.
 47                text = &text[index + FENCE.len()..];
 48
 49                // Find the closing backticks.
 50                let content_len = text.find(&closing_fence);
 51
 52                // Verify the citation format - e.g. ```path/to/foo.txt#L123-456
 53                if let Some(citation_len) = text.find('\n') {
 54                    let citation = &text[..citation_len];
 55
 56                    if let Ok(()) =
 57                        cx.assert(citation.contains("/"), format!("Slash in {citation:?}",))
 58                    {
 59                        let path_range = PathWithRange::new(citation);
 60                        let path = cx
 61                            .agent_thread()
 62                            .update(cx, |thread, cx| {
 63                                thread
 64                                    .project()
 65                                    .read(cx)
 66                                    .find_project_path(path_range.path, cx)
 67                            })
 68                            .ok()
 69                            .flatten();
 70
 71                        if let Ok(path) = cx.assert_some(path, format!("Valid path: {citation:?}"))
 72                        {
 73                            let buffer_text = {
 74                                let buffer = match cx.agent_thread().update(cx, |thread, cx| {
 75                                    thread
 76                                        .project()
 77                                        .update(cx, |project, cx| project.open_buffer(path, cx))
 78                                }) {
 79                                    Ok(buffer_task) => buffer_task.await.ok(),
 80                                    Err(err) => {
 81                                        cx.assert(
 82                                            false,
 83                                            format!("Expected Ok(buffer), not {err:?}"),
 84                                        )
 85                                        .ok();
 86                                        break;
 87                                    }
 88                                };
 89
 90                                let Ok(buffer_text) = cx.assert_some(
 91                                    buffer.and_then(|buffer| {
 92                                        buffer.read_with(cx, |buffer, _| buffer.text()).ok()
 93                                    }),
 94                                    "Reading buffer text succeeded",
 95                                ) else {
 96                                    continue;
 97                                };
 98                                buffer_text
 99                            };
100
101                            if let Some(content_len) = content_len {
102                                // + 1 because there's a newline character after the citation.
103                                let content =
104                                    &text[(citation.len() + 1)..content_len - (citation.len() + 1)];
105
106                                // deindent (trim the start of each line) because sometimes the model
107                                // chooses to deindent its code snippets for the sake of readability,
108                                // which in markdown is not only reasonable but usually desirable.
109                                cx.assert(
110                                    deindent(&buffer_text)
111                                        .trim()
112                                        .contains(deindent(&content).trim()),
113                                    "Code block content was found in file",
114                                )
115                                .ok();
116
117                                if let Some(range) = path_range.range {
118                                    let start_line_index = range.start.line.saturating_sub(1);
119                                    let line_count =
120                                        range.end.line.saturating_sub(start_line_index);
121                                    let mut snippet = buffer_text
122                                        .lines()
123                                        .skip(start_line_index as usize)
124                                        .take(line_count as usize)
125                                        .collect::<Vec<&str>>()
126                                        .join("\n");
127
128                                    if let Some(start_col) = range.start.col {
129                                        snippet = snippet[start_col as usize..].to_string();
130                                    }
131
132                                    if let Some(end_col) = range.end.col {
133                                        let last_line = snippet.lines().last().unwrap();
134                                        snippet = snippet
135                                            [..snippet.len() - last_line.len() + end_col as usize]
136                                            .to_string();
137                                    }
138
139                                    // deindent (trim the start of each line) because sometimes the model
140                                    // chooses to deindent its code snippets for the sake of readability,
141                                    // which in markdown is not only reasonable but usually desirable.
142                                    cx.assert_eq(
143                                        deindent(snippet.as_str()).trim(),
144                                        deindent(content).trim(),
145                                        format!(
146                                            "Code block was at {:?}-{:?}",
147                                            range.start, range.end
148                                        ),
149                                    )
150                                    .ok();
151                                }
152                            }
153                        }
154                    }
155                } else {
156                    cx.assert(
157                        false,
158                        format!("Opening {FENCE} did not have a newline anywhere after it."),
159                    )
160                    .ok();
161                }
162
163                if let Some(content_len) = content_len {
164                    // Advance past the closing backticks
165                    text = &text[content_len + FENCE.len()..];
166                } else {
167                    // There were no closing backticks associated with these opening backticks.
168                    cx.assert(
169                        false,
170                        "Code block opening had matching closing backticks.".to_string(),
171                    )
172                    .ok();
173
174                    // There are no more code blocks to parse, so we're done.
175                    break;
176                }
177            }
178        }
179
180        Ok(())
181    }
182
183    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
184        vec![
185            JudgeAssertion {
186                id: "trait method bodies are shown".to_string(),
187                description:
188                    "All method bodies of the Tool trait are shown."
189                        .to_string(),
190            },
191            JudgeAssertion {
192                id: "code blocks used".to_string(),
193                description:
194                   "All code snippets are rendered inside markdown code blocks (as opposed to any other formatting besides code blocks)."
195                        .to_string(),
196            },
197            JudgeAssertion {
198              id: "code blocks use backticks".to_string(),
199              description:
200                  format!("All markdown code blocks use backtick fences ({FENCE}) rather than indentation.")
201            }
202        ]
203    }
204}
205
206fn deindent(as_str: impl AsRef<str>) -> String {
207    as_str
208        .as_ref()
209        .lines()
210        .map(|line| line.trim_start())
211        .collect::<Vec<&str>>()
212        .join("\n")
213}