1use anyhow::Result;
2use async_trait::async_trait;
3use markdown::PathWithRange;
4
5use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
6
7pub struct CodeBlockCitations;
8
9const FENCE: &str = "```";
10
11#[async_trait(?Send)]
12impl Example for CodeBlockCitations {
13 fn meta(&self) -> ExampleMetadata {
14 ExampleMetadata {
15 name: "code_block_citations".to_string(),
16 url: "https://github.com/zed-industries/zed.git".to_string(),
17 revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
18 language_server: Some(LanguageServer {
19 file_extension: "rs".to_string(),
20 allow_preexisting_diagnostics: false,
21 }),
22 max_assertions: None,
23 }
24 }
25
26 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
27 const FILENAME: &str = "assistant_tool.rs";
28 cx.push_user_message(format!(
29 r#"
30 Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
31
32 Please show each method in a separate code snippet.
33 "#
34 ));
35
36 // Verify that the messages all have the correct formatting.
37 let texts: Vec<String> = cx.run_to_end().await?.texts().collect();
38 let closing_fence = format!("\n{FENCE}");
39
40 for text in texts.iter() {
41 let mut text = text.as_str();
42
43 while let Some(index) = text.find(FENCE) {
44 // Advance text past the opening backticks.
45 text = &text[index + FENCE.len()..];
46
47 // Find the closing backticks.
48 let content_len = text.find(&closing_fence);
49
50 // Verify the citation format - e.g. ```path/to/foo.txt#L123-456
51 if let Some(citation_len) = text.find('\n') {
52 let citation = &text[..citation_len];
53
54 if let Ok(()) =
55 cx.assert(citation.contains("/"), format!("Slash in {citation:?}",))
56 {
57 let path_range = PathWithRange::new(citation);
58 let path = cx
59 .agent_thread()
60 .update(cx, |thread, cx| {
61 thread
62 .project()
63 .read(cx)
64 .find_project_path(path_range.path, cx)
65 })
66 .ok()
67 .flatten();
68
69 if let Ok(path) = cx.assert_some(path, format!("Valid path: {citation:?}"))
70 {
71 let buffer_text = {
72 let buffer = match cx.agent_thread().update(cx, |thread, cx| {
73 thread
74 .project()
75 .update(cx, |project, cx| project.open_buffer(path, cx))
76 }) {
77 Ok(buffer_task) => buffer_task.await.ok(),
78 Err(err) => {
79 cx.assert(
80 false,
81 format!("Expected Ok(buffer), not {err:?}"),
82 )
83 .ok();
84 break;
85 }
86 };
87
88 let Ok(buffer_text) = cx.assert_some(
89 buffer.and_then(|buffer| {
90 buffer.read_with(cx, |buffer, _| buffer.text()).ok()
91 }),
92 "Reading buffer text succeeded",
93 ) else {
94 continue;
95 };
96 buffer_text
97 };
98
99 if let Some(content_len) = content_len {
100 // + 1 because there's a newline character after the citation.
101 let content =
102 &text[(citation.len() + 1)..content_len - (citation.len() + 1)];
103
104 // deindent (trim the start of each line) because sometimes the model
105 // chooses to deindent its code snippets for the sake of readability,
106 // which in markdown is not only reasonable but usually desirable.
107 cx.assert(
108 deindent(&buffer_text)
109 .trim()
110 .contains(deindent(&content).trim()),
111 "Code block content was found in file",
112 )
113 .ok();
114
115 if let Some(range) = path_range.range {
116 let start_line_index = range.start.line.saturating_sub(1);
117 let line_count =
118 range.end.line.saturating_sub(start_line_index);
119 let mut snippet = buffer_text
120 .lines()
121 .skip(start_line_index as usize)
122 .take(line_count as usize)
123 .collect::<Vec<&str>>()
124 .join("\n");
125
126 if let Some(start_col) = range.start.col {
127 snippet = snippet[start_col as usize..].to_string();
128 }
129
130 if let Some(end_col) = range.end.col {
131 let last_line = snippet.lines().last().unwrap();
132 snippet = snippet
133 [..snippet.len() - last_line.len() + end_col as usize]
134 .to_string();
135 }
136
137 // deindent (trim the start of each line) because sometimes the model
138 // chooses to deindent its code snippets for the sake of readability,
139 // which in markdown is not only reasonable but usually desirable.
140 cx.assert_eq(
141 deindent(snippet.as_str()).trim(),
142 deindent(content).trim(),
143 format!(
144 "Code block was at {:?}-{:?}",
145 range.start, range.end
146 ),
147 )
148 .ok();
149 }
150 }
151 }
152 }
153 } else {
154 cx.assert(
155 false,
156 format!("Opening {FENCE} did not have a newline anywhere after it."),
157 )
158 .ok();
159 }
160
161 if let Some(content_len) = content_len {
162 // Advance past the closing backticks
163 text = &text[content_len + FENCE.len()..];
164 } else {
165 // There were no closing backticks associated with these opening backticks.
166 cx.assert(
167 false,
168 "Code block opening had matching closing backticks.".to_string(),
169 )
170 .ok();
171
172 // There are no more code blocks to parse, so we're done.
173 break;
174 }
175 }
176 }
177
178 Ok(())
179 }
180
181 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
182 vec![
183 JudgeAssertion {
184 id: "trait method bodies are shown".to_string(),
185 description:
186 "All method bodies of the Tool trait are shown."
187 .to_string(),
188 },
189 JudgeAssertion {
190 id: "code blocks used".to_string(),
191 description:
192 "All code snippets are rendered inside markdown code blocks (as opposed to any other formatting besides code blocks)."
193 .to_string(),
194 },
195 JudgeAssertion {
196 id: "code blocks use backticks".to_string(),
197 description:
198 format!("All markdown code blocks use backtick fences ({FENCE}) rather than indentation.")
199 }
200 ]
201 }
202}
203
204fn deindent(as_str: impl AsRef<str>) -> String {
205 as_str
206 .as_ref()
207 .lines()
208 .map(|line| line.trim_start())
209 .collect::<Vec<&str>>()
210 .join("\n")
211}