1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
14 MessageContent, Role, StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::cell::RefCell;
19use std::fmt::Write as _;
20use std::fs::File;
21use std::io::Write as _;
22use std::rc::Rc;
23use std::sync::{Arc, Mutex};
24use std::time::Duration;
25use std::{
26 fs,
27 path::{Path, PathBuf},
28};
29use unindent::Unindent as _;
30use util::ResultExt as _;
31use util::command::new_smol_command;
32use util::serde::default_true;
33
34use crate::AgentAppState;
35
36pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
37pub const REPOS_DIR: &str = "./crates/eval/repos";
38pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
39
40const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
41
42#[derive(Clone, Debug, Deserialize)]
43pub struct ExampleBase {
44 pub url: String,
45 pub revision: String,
46 pub language_extension: Option<String>,
47 pub insert_id: Option<String>,
48 #[serde(default = "default_true")]
49 pub require_lsp: bool,
50 #[serde(default)]
51 pub allow_preexisting_diagnostics: bool,
52}
53
54impl ExampleBase {
55 pub fn repo_name(&self) -> String {
56 self.url
57 .split('/')
58 .next_back()
59 .unwrap_or(&"")
60 .trim_end_matches(".git")
61 .into()
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct Example {
67 pub name: String,
68 /// Content of `base.toml`
69 pub base: ExampleBase,
70 /// Content of `prompt.md`
71 pub prompt: String,
72 /// Content of `diff_criteria.md`
73 pub diff_criteria: String,
74 /// Content of `thread_criteria.md`, if that file exists (it's optional)
75 pub thread_criteria: Option<String>,
76 /// Path to the directory containing the requests and responses for the agentic loop
77 pub run_directory_path: PathBuf,
78 /// Prefix used for logging that identifies this example
79 pub log_prefix: String,
80}
81
82#[derive(Debug, Serialize, Deserialize, Clone)]
83pub struct RunOutput {
84 pub repository_diff: String,
85 pub ran_diagnostics_check: bool,
86 pub diagnostics_before: Option<String>,
87 pub diagnostics_after: Option<String>,
88 pub response_count: usize,
89 pub token_usage: TokenUsage,
90 pub tool_use_counts: HashMap<Arc<str>, u32>,
91 pub last_request: LanguageModelRequest,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct JudgeDiffInput {
96 pub repository_diff: String,
97 pub ran_diagnostics_check: bool,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub diagnostics_before: Option<String>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub diagnostics_after: Option<String>,
102 pub criteria: String,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct JudgeThreadInput {
107 pub messages: String,
108 pub criteria: String,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct JudgeResponse {
113 pub analysis: String,
114 pub score: u32,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct JudgeOutput {
119 pub thread: Option<JudgeResponse>,
120 pub diff: JudgeResponse,
121}
122
123impl Example {
124 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
125 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
126 let name = Self::name_from_path(dir_path);
127 let base_path = dir_path.join("base.toml");
128 let prompt_path = dir_path.join("prompt.md");
129 let diff_criteria_path = dir_path.join("diff_criteria.md");
130 let thread_criteria_path = dir_path.join("thread_criteria.md");
131 let thread_criteria = if thread_criteria_path.exists() {
132 Some(fs::read_to_string(thread_criteria_path.clone())?)
133 } else {
134 None
135 };
136
137 Ok(Example {
138 name: name.clone(),
139 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
140 prompt: fs::read_to_string(prompt_path.clone())?,
141 thread_criteria,
142 diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
143 run_directory_path: run_dir.to_path_buf(),
144 log_prefix: name,
145 })
146 }
147
148 pub fn set_repetition_number(&mut self, repetition_number: u32) {
149 if repetition_number > 0 {
150 self.name = format!("{}-{}", self.name, repetition_number);
151 }
152 }
153
154 pub fn example_output_directory(&self) -> PathBuf {
155 self.run_directory_path.join(&self.name)
156 }
157
158 pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
159 self.log_prefix = format!(
160 "{}{:<width$}\x1b[0m | ",
161 color,
162 self.name,
163 width = name_width
164 );
165 }
166
167 pub fn name_from_path(path: &Path) -> String {
168 path.file_name().unwrap().to_string_lossy().to_string()
169 }
170
171 pub fn worktree_path(&self) -> PathBuf {
172 Path::new(WORKTREES_DIR)
173 .canonicalize()
174 .context(format!("No such directory {WORKTREES_DIR}"))
175 .unwrap()
176 .join(&self.name)
177 .join(self.base.repo_name())
178 }
179
180 /// Set up the example by checking out the specified Git revision
181 pub async fn setup(&mut self) -> Result<()> {
182 let repo_path = repo_path_for_url(&self.base.url);
183
184 let revision_exists = run_git(
185 &repo_path,
186 &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
187 )
188 .await
189 .is_ok();
190
191 if !revision_exists {
192 println!(
193 "{}Fetching revision {}",
194 self.log_prefix, &self.base.revision
195 );
196 run_git(
197 &repo_path,
198 &["fetch", "--depth", "1", "origin", &self.base.revision],
199 )
200 .await?;
201 }
202
203 let worktree_path = self.worktree_path();
204
205 if worktree_path.is_dir() {
206 println!("{}Resetting existing worktree", self.log_prefix);
207
208 // TODO: consider including "-x" to remove ignored files. The downside of this is that
209 // it will also remove build artifacts, and so prevent incremental reuse there.
210 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
211 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
212 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
213 } else {
214 println!("{}Creating worktree", self.log_prefix);
215
216 let worktree_path_string = worktree_path.to_string_lossy().to_string();
217
218 run_git(
219 &repo_path,
220 &[
221 "worktree",
222 "add",
223 "-f",
224 &worktree_path_string,
225 &self.base.revision,
226 ],
227 )
228 .await?;
229 }
230
231 std::fs::create_dir_all(self.example_output_directory())?;
232
233 Ok(())
234 }
235
236 pub fn run(
237 &self,
238 model: Arc<dyn LanguageModel>,
239 app_state: Arc<AgentAppState>,
240 cx: &mut App,
241 ) -> Task<Result<RunOutput>> {
242 let project = Project::local(
243 app_state.client.clone(),
244 app_state.node_runtime.clone(),
245 app_state.user_store.clone(),
246 app_state.languages.clone(),
247 Arc::new(DapRegistry::default()),
248 app_state.fs.clone(),
249 None,
250 cx,
251 );
252
253 let worktree_path = self.worktree_path();
254 let worktree = project.update(cx, |project, cx| {
255 project.create_worktree(&worktree_path, true, cx)
256 });
257
258 let tools = cx.new(|_| ToolWorkingSet::default());
259 let thread_store =
260 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
261 let this = self.clone();
262
263 cx.spawn(async move |cx| {
264 let worktree = worktree.await?;
265
266 // Wait for worktree scan to finish before choosing a file to open.
267 worktree
268 .update(cx, |worktree, _cx| {
269 worktree.as_local().unwrap().scan_complete()
270 })?
271 .await;
272
273 let lsp_open_handle_and_store = if this.base.require_lsp {
274 let language_extension = this.base.language_extension.as_deref().context(
275 "language_extension field is required in base.toml when `require_lsp == true`",
276 )?;
277
278 // Open a file that matches the language to cause LSP to start.
279 let language_file = worktree.read_with(cx, |worktree, _cx| {
280 worktree
281 .files(false, 0)
282 .find_map(|e| {
283 if e.path.clone().extension().and_then(|ext| ext.to_str())
284 == Some(language_extension)
285 {
286 Some(ProjectPath {
287 worktree_id: worktree.id(),
288 path: e.path.clone(),
289 })
290 } else {
291 None
292 }
293 })
294 .context("Failed to find a file for example language")
295 })??;
296
297 let open_language_file_buffer_task = project.update(cx, |project, cx| {
298 project.open_buffer(language_file.clone(), cx)
299 })?;
300
301 let language_file_buffer = open_language_file_buffer_task.await?;
302
303 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
304 (
305 project.register_buffer_with_language_servers(&language_file_buffer, cx),
306 project.lsp_store().clone(),
307 )
308 })?;
309
310 // TODO: remove this once the diagnostics tool waits for new diagnostics
311 cx.background_executor().timer(Duration::new(5, 0)).await;
312 wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
313
314 lsp_store.update(cx, |lsp_store, cx| {
315 lsp_open_handle.update(cx, |buffer, cx| {
316 buffer.update(cx, |buffer, cx| {
317 let has_language_server = lsp_store
318 .language_servers_for_local_buffer(buffer, cx)
319 .next()
320 .is_some();
321 if has_language_server {
322 Ok(())
323 } else {
324 Err(anyhow!(
325 "`{:?}` was opened to cause the language server to start, \
326 but no language servers are registered for its buffer. \
327 Set `require_lsp = false` in `base.toml` to skip this.",
328 language_file
329 ))
330 }
331 })
332 })
333 })??;
334
335 Some((lsp_open_handle, lsp_store))
336 } else {
337 None
338 };
339
340 let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
341 if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
342 return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
343 }
344
345 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
346 return Err(anyhow!("Setup only mode"));
347 }
348
349 let thread_store = thread_store.await?;
350 let thread =
351 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
352 let last_request = Rc::new(RefCell::new(None));
353
354 thread.update(cx, |thread, _cx| {
355 let mut request_count = 0;
356 let example_dir_path = this.example_output_directory();
357
358 let last_request = Rc::clone(&last_request);
359 thread.set_request_callback(move |request, response_events| {
360 *last_request.borrow_mut() = Some(request.clone());
361
362 request_count += 1;
363 let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
364 let last_messages_file_path = example_dir_path.join("last.messages.md");
365 let request_markdown = RequestMarkdown::new(request);
366 let response_events_markdown = response_events_to_markdown(response_events);
367
368 let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
369 fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
370 fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
371
372 if request_count == 1 {
373 let tools_file_path = example_dir_path.join("tools.md");
374 fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
375 }
376 });
377 })?;
378
379 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
380 Mutex::new(HashMap::default()).into();
381
382 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
383
384 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
385 thread_event_tx.unbounded_send(event.clone()).log_err();
386 });
387
388 let event_handler_task = cx.spawn({
389 let log_prefix = this.log_prefix.clone();
390 let tool_use_counts = tool_use_counts.clone();
391 let thread = thread.downgrade();
392 async move |cx| {
393 loop {
394 let event = select_biased! {
395 event = thread_event_rx.next() => event,
396 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
397 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
398 }
399 };
400 let Some(event) = event else {
401 return Err(anyhow!("ThreadEvent channel ended early"));
402 };
403
404 match event {
405 ThreadEvent::Stopped(reason) => match reason {
406 Ok(StopReason::EndTurn) => {
407 return Ok(());
408 }
409 Ok(StopReason::MaxTokens) => {
410 return Err(anyhow!("Exceeded maximum tokens"));
411 }
412 Ok(StopReason::ToolUse) => {
413 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
414 println!("{}StopReason: Tool use", log_prefix);
415 }
416 }
417 Err(error) => {
418 return Err(anyhow!(error.clone()));
419 }
420 },
421 ThreadEvent::ShowError(thread_error) => {
422 break Err(anyhow!(thread_error.clone()));
423 }
424 ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
425 }
426 ThreadEvent::ToolFinished {
427 tool_use_id,
428 pending_tool_use,
429 ..
430 } => {
431 thread.update(cx, |thread, _cx| {
432 if let Some(tool_use) = pending_tool_use {
433 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
434 let message = if tool_result.is_error {
435 format!("TOOL FAILED: {}", tool_use.name)
436 } else {
437 format!("TOOL FINISHED: {}", tool_use.name)
438 };
439 println!("{log_prefix}{message}");
440 let mut tool_use_counts = tool_use_counts.lock().unwrap();
441 *tool_use_counts
442 .entry(tool_result.tool_name.clone())
443 .or_insert(0) += 1;
444 } else {
445 let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
446 println!("{log_prefix}{message}");
447 }
448 }
449 })?;
450 }
451 ThreadEvent::ToolConfirmationNeeded => {
452 panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
453 },
454 ThreadEvent::StreamedCompletion |
455 ThreadEvent::MessageAdded(_) |
456 ThreadEvent::MessageEdited(_) |
457 ThreadEvent::MessageDeleted(_) |
458 ThreadEvent::SummaryChanged |
459 ThreadEvent::SummaryGenerated |
460 ThreadEvent::CheckpointChanged |
461 ThreadEvent::UsageUpdated(_) => {
462 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
463 println!("{}Event: {:#?}", log_prefix, event);
464 }
465 }
466 }
467 }
468 }
469 });
470
471 thread.update(cx, |thread, cx| {
472 let context = vec![];
473 thread.insert_user_message(this.prompt.clone(), context, None, cx);
474 thread.send_to_model(model, RequestKind::Chat, cx);
475 })?;
476
477 event_handler_task.await?;
478
479 println!("{}Stopped", this.log_prefix);
480
481 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
482 wait_for_lang_server(lsp_store, this.log_prefix.clone(), cx).await?;
483 }
484
485 println!("{}Getting repository diff", this.log_prefix);
486 let repository_diff = this.repository_diff().await?;
487
488 let example_output_dir = this.example_output_directory();
489 let repository_diff_path = example_output_dir.join("patch.diff");
490 let mut repository_diff_output_file = File::create(&repository_diff_path)?;
491 writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
492
493 println!("{}Getting diagnostics", this.log_prefix);
494 let diagnostics_after = cx
495 .update(move |cx| {
496 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
497 })?
498 .await?;
499 println!("{}Got diagnostics", this.log_prefix);
500
501 let Some(last_request) = last_request.borrow_mut().take() else {
502 return Err(anyhow!("No requests ran."));
503 };
504
505 drop(subscription);
506 drop(lsp_open_handle_and_store);
507
508 if let Some(diagnostics_before) = &diagnostics_before {
509 fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
510 }
511
512 if let Some(diagnostics_after) = &diagnostics_after {
513 fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
514 }
515
516
517 thread.update(cx, |thread, _cx| {
518 let response_count = thread
519 .messages()
520 .filter(|message| message.role == language_model::Role::Assistant)
521 .count();
522 RunOutput {
523 repository_diff,
524 ran_diagnostics_check: this.base.require_lsp,
525 diagnostics_before,
526 diagnostics_after,
527 response_count,
528 token_usage: thread.cumulative_token_usage(),
529 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
530 last_request,
531 }
532 })
533 })
534 }
535
536 async fn judge_diff(
537 &self,
538 model: Arc<dyn LanguageModel>,
539 run_output: &RunOutput,
540 judge_number: u32,
541 cx: &AsyncApp,
542 ) -> Result<(String, JudgeResponse)> {
543 let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
544 let judge_diff_prompt_name = "judge_diff_prompt";
545 let mut hbs = Handlebars::new();
546 hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
547
548 let diff_prompt = hbs.render(
549 judge_diff_prompt_name,
550 &JudgeDiffInput {
551 repository_diff: run_output.repository_diff.clone(),
552 ran_diagnostics_check: run_output.ran_diagnostics_check,
553 diagnostics_before: run_output.diagnostics_before.clone(),
554 diagnostics_after: run_output.diagnostics_after.clone(),
555 criteria: self.diff_criteria.clone(),
556 },
557 )?;
558
559 let request = LanguageModelRequest {
560 thread_id: None,
561 prompt_id: None,
562 messages: vec![LanguageModelRequestMessage {
563 role: Role::User,
564 content: vec![MessageContent::Text(diff_prompt)],
565 cache: false,
566 }],
567 temperature: None,
568 tools: Vec::new(),
569 stop: Vec::new(),
570 };
571
572 let diff_response = send_language_model_request(model, request, cx).await?;
573 let diff_output = JudgeResponse::parse(&diff_response)?;
574
575 println!(
576 "{}Judge #{judge_number} - Diff score: {}",
577 self.log_prefix, diff_output.score
578 );
579
580 Ok((diff_response, diff_output))
581 }
582
583 async fn judge_thread(
584 &self,
585 model: Arc<dyn LanguageModel>,
586 run_output: &RunOutput,
587 judge_number: u32,
588 cx: &AsyncApp,
589 ) -> Result<(String, Option<JudgeResponse>)> {
590 if let Some(criteria) = self.thread_criteria.clone() {
591 let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
592 let judge_thread_prompt_name = "judge_thread_prompt";
593 let mut hbs = Handlebars::new();
594 hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
595
596 let request_markdown = RequestMarkdown::new(&run_output.last_request);
597 let thread_prompt = hbs.render(
598 judge_thread_prompt_name,
599 &JudgeThreadInput {
600 messages: request_markdown.messages,
601 criteria,
602 },
603 )?;
604
605 let request = LanguageModelRequest {
606 thread_id: None,
607 prompt_id: None,
608 messages: vec![LanguageModelRequestMessage {
609 role: Role::User,
610 content: vec![MessageContent::Text(thread_prompt)],
611 cache: false,
612 }],
613 temperature: None,
614 tools: Vec::new(),
615 stop: Vec::new(),
616 };
617
618 let thread_response = send_language_model_request(model, request, cx).await?;
619 let thread_output = JudgeResponse::parse(&thread_response)?;
620
621 println!(
622 "{}Judge #{judge_number} - Thread score: {}",
623 self.log_prefix, thread_output.score
624 );
625
626 Ok((thread_response, Some(thread_output)))
627 } else {
628 let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
629 Ok((msg, None))
630 }
631 }
632
633 pub async fn judge(
634 &self,
635 model: Arc<dyn LanguageModel>,
636 run_output: &RunOutput,
637 judge_number: u32,
638 cx: &AsyncApp,
639 ) -> Result<JudgeOutput> {
640 let mut output_file = File::create(
641 self.example_output_directory()
642 .join(format!("judge_{}.md", judge_number)),
643 )
644 .expect("failed to create judge.md");
645
646 println!("{}Running judge #{judge_number}", self.log_prefix);
647
648 let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
649 let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
650
651 let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
652
653 let (diff_response, diff_output) = diff_result?;
654 let (thread_response, thread_output) = thread_result?;
655
656 writeln!(
657 &mut output_file,
658 "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
659 )
660 .log_err();
661
662 Ok(JudgeOutput {
663 thread: thread_output,
664 diff: diff_output,
665 })
666 }
667
668 async fn repository_diff(&self) -> Result<String> {
669 let worktree_path = self.worktree_path();
670 run_git(&worktree_path, &["add", "."]).await?;
671 run_git(&worktree_path, &["diff", "--staged"]).await
672 }
673}
674
675fn wait_for_lang_server(
676 lsp_store: &Entity<LspStore>,
677 log_prefix: String,
678 cx: &mut AsyncApp,
679) -> Task<Result<()>> {
680 if cx
681 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
682 .unwrap()
683 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
684 {
685 return Task::ready(anyhow::Ok(()));
686 }
687
688 println!("{}⏵ Waiting for language server", log_prefix);
689
690 let (mut tx, mut rx) = mpsc::channel(1);
691
692 let subscription =
693 cx.subscribe(&lsp_store, {
694 let log_prefix = log_prefix.clone();
695 move |lsp_store, event, cx| {
696 match event {
697 project::LspStoreEvent::LanguageServerUpdate {
698 message:
699 client::proto::update_language_server::Variant::WorkProgress(
700 LspWorkProgress {
701 message: Some(message),
702 ..
703 },
704 ),
705 ..
706 } => println!("{}⟲ {message}", log_prefix),
707 _ => {}
708 }
709
710 if !has_pending_lang_server_work(&lsp_store, cx) {
711 tx.try_send(()).ok();
712 }
713 }
714 });
715
716 cx.spawn(async move |cx| {
717 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
718 let result = futures::select! {
719 _ = rx.next() => {
720 println!("{}⚑ Language server idle", log_prefix);
721 anyhow::Ok(())
722 },
723 _ = timeout.fuse() => {
724 Err(anyhow!("LSP wait timed out after 5 minutes"))
725 }
726 };
727 drop(subscription);
728 result
729 })
730}
731
732fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
733 lsp_store
734 .read(cx)
735 .language_server_statuses()
736 .any(|(_, status)| !status.pending_work.is_empty())
737}
738
739async fn query_lsp_diagnostics(
740 project: Entity<Project>,
741 cx: &mut AsyncApp,
742) -> Result<Option<String>> {
743 let paths_with_diagnostics = project.update(cx, |project, cx| {
744 project
745 .diagnostic_summaries(true, cx)
746 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
747 .map(|(project_path, _, _)| project_path)
748 .collect::<Vec<_>>()
749 })?;
750
751 if paths_with_diagnostics.is_empty() {
752 return Ok(None);
753 }
754
755 let mut output = String::new();
756 for project_path in paths_with_diagnostics {
757 let buffer = project
758 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
759 .await?;
760 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
761
762 for (_, group) in snapshot.diagnostic_groups(None) {
763 let entry = &group.entries[group.primary_ix];
764 let range = entry.range.to_point(&snapshot);
765 let severity = match entry.diagnostic.severity {
766 DiagnosticSeverity::ERROR => "error",
767 DiagnosticSeverity::WARNING => "warning",
768 _ => continue,
769 };
770
771 writeln!(
772 output,
773 "{} at line {}: {}",
774 severity,
775 range.start.row + 1,
776 entry.diagnostic.message
777 )?;
778 }
779 }
780 anyhow::Ok(Some(output))
781}
782
783impl JudgeResponse {
784 fn parse(response: &str) -> Result<Self> {
785 let analysis = get_tag("analysis", response)?.to_string();
786 let score = get_tag("score", response)?
787 .parse()
788 .context("error parsing score")?;
789
790 Ok(Self { analysis, score })
791 }
792}
793
794fn get_tag(name: &'static str, response: &str) -> Result<String> {
795 let start_tag = format!("<{}>", name);
796 let end_tag = format!("</{}>", name);
797
798 let start_ix = response
799 .find(&start_tag)
800 .context(format!("{} start tag not found", name))?;
801 let content_start_ix = start_ix + start_tag.len();
802
803 let end_ix = content_start_ix
804 + response[content_start_ix..]
805 .find(&end_tag)
806 .context(format!("{} end tag not found", name))?;
807
808 let content = response[content_start_ix..end_ix].trim().unindent();
809
810 anyhow::Ok(content)
811}
812
813pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
814 let repo_name = repo_url
815 .trim_start_matches("https://")
816 .replace(|c: char| !c.is_alphanumeric(), "-");
817 Path::new(REPOS_DIR)
818 .canonicalize()
819 .context(format!("No such directory {REPOS_DIR}"))
820 .unwrap()
821 .join(repo_name)
822}
823
824pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
825 let output = new_smol_command("git")
826 .current_dir(repo_path)
827 .args(args)
828 .output()
829 .await?;
830
831 if output.status.success() {
832 Ok(String::from_utf8(output.stdout)?.trim().to_string())
833 } else {
834 Err(anyhow!(
835 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
836 args.join(" "),
837 repo_path.display(),
838 output.status,
839 String::from_utf8_lossy(&output.stderr),
840 String::from_utf8_lossy(&output.stdout),
841 ))
842 }
843}
844
845pub async fn send_language_model_request(
846 model: Arc<dyn LanguageModel>,
847 request: LanguageModelRequest,
848 cx: &AsyncApp,
849) -> anyhow::Result<String> {
850 match model.stream_completion_text(request, &cx).await {
851 Ok(mut stream) => {
852 let mut full_response = String::new();
853 while let Some(chunk_result) = stream.stream.next().await {
854 match chunk_result {
855 Ok(chunk_str) => {
856 full_response.push_str(&chunk_str);
857 }
858 Err(err) => {
859 return Err(anyhow!(
860 "Error receiving response from language model: {err}"
861 ));
862 }
863 }
864 }
865 Ok(full_response)
866 }
867 Err(err) => Err(anyhow!(
868 "Failed to get response from language model. Error was: {err}"
869 )),
870 }
871}
872
873struct RequestMarkdown {
874 tools: String,
875 messages: String,
876}
877
878impl RequestMarkdown {
879 fn new(request: &LanguageModelRequest) -> Self {
880 let mut tools = String::new();
881 let mut messages = String::new();
882
883 // Print the tools
884 if !request.tools.is_empty() {
885 for tool in &request.tools {
886 write!(&mut tools, "# {}\n\n", tool.name).unwrap();
887 write!(&mut tools, "{}\n\n", tool.description).unwrap();
888 write!(
889 &mut tools,
890 "```json\n{}\n```\n\n",
891 serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default()
892 )
893 .unwrap();
894 }
895 }
896
897 // Print the messages
898 for message in &request.messages {
899 let role_str = match message.role {
900 Role::User => "👤 USER",
901 Role::Assistant => "🤖 ASSISTANT",
902 Role::System => "⚙️ SYSTEM",
903 };
904
905 messages.push_str(&format!("# {}\n\n", role_str));
906
907 for content in &message.content {
908 match content {
909 MessageContent::Text(text) => {
910 messages.push_str(text);
911 messages.push_str("\n\n");
912 }
913 MessageContent::Image(_) => {
914 messages.push_str("[IMAGE DATA]\n\n");
915 }
916 MessageContent::ToolUse(tool_use) => {
917 messages.push_str(&format!(
918 "**Tool Use**: {} (ID: {})\n",
919 tool_use.name, tool_use.id
920 ));
921 messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
922 }
923 MessageContent::ToolResult(tool_result) => {
924 messages.push_str(&format!(
925 "**Tool Result**: {} (ID: {})\n\n",
926 tool_result.tool_name, tool_result.tool_use_id
927 ));
928 if tool_result.is_error {
929 messages.push_str("**ERROR:**\n");
930 }
931 messages.push_str(&format!("{}\n", tool_result.content));
932 }
933 }
934 }
935 }
936
937 Self { tools, messages }
938 }
939}
940
941fn response_events_to_markdown(
942 response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
943) -> String {
944 let mut response = String::new();
945 // Print the response events if any
946 response.push_str("# Response\n\n");
947 let mut text_buffer = String::new();
948 let mut thinking_buffer = String::new();
949
950 let flush_buffers =
951 |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
952 if !text_buffer.is_empty() {
953 output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
954 text_buffer.clear();
955 }
956 if !thinking_buffer.is_empty() {
957 output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
958 thinking_buffer.clear();
959 }
960 };
961
962 for event in response_events {
963 match event {
964 Ok(LanguageModelCompletionEvent::Text(text)) => {
965 text_buffer.push_str(text);
966 }
967 Ok(LanguageModelCompletionEvent::Thinking(text)) => {
968 thinking_buffer.push_str(text);
969 }
970 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
971 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
972 response.push_str(&format!("**Stop**: {:?}\n\n", reason));
973 }
974 Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
975 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
976 response.push_str(&format!(
977 "**Tool Use**: {} (ID: {})\n",
978 tool_use.name, tool_use.id
979 ));
980 response.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
981 }
982 Ok(
983 LanguageModelCompletionEvent::UsageUpdate(_)
984 | LanguageModelCompletionEvent::StartMessage { .. },
985 ) => {}
986 Err(error) => {
987 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
988 response.push_str(&format!("**Error**: {}\n\n", error));
989 }
990 }
991 }
992
993 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
994
995 response
996}
997
998#[cfg(test)]
999mod test {
1000 use super::*;
1001 use handlebars::Handlebars;
1002
1003 #[test]
1004 fn test_parse_judge_output() {
1005 let response = r#"
1006 <analysis>The model did a good job but there were still compilations errors.</analysis>
1007 <score>3</score>
1008 "#
1009 .unindent();
1010
1011 let output = JudgeResponse::parse(&response).unwrap();
1012 assert_eq!(
1013 output.analysis,
1014 "The model did a good job but there were still compilations errors."
1015 );
1016 assert_eq!(output.score, 3);
1017
1018 let response = r#"
1019 Text around ignored
1020
1021 <analysis>
1022 Failed to compile:
1023 - Error 1
1024 - Error 2
1025 </analysis>
1026
1027 <score>1</score>
1028 "#
1029 .unindent();
1030
1031 let output = JudgeResponse::parse(&response).unwrap();
1032 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
1033 assert_eq!(output.score, 1);
1034 }
1035
1036 #[test]
1037 fn test_judge_prompt_with_diagnostics() {
1038 // Case 1: Both diagnostics before and after are present
1039 let input = JudgeDiffInput {
1040 repository_diff: "diff content goes here".to_string(),
1041 ran_diagnostics_check: true,
1042 diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1043 diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1044 criteria: "Fix all bugs".to_string(),
1045 };
1046
1047 let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1048
1049 let expected_diagnostics_section = r#"
1050 Take into account the diagnostics before and after applying the change:
1051
1052 <diagnostics_before>
1053 Error at line 10: variable not found
1054 </diagnostics_before>
1055
1056 <diagnostics_after>
1057 Error at line 15: missing semicolon
1058 </diagnostics_after>
1059 "#
1060 .unindent();
1061
1062 assert!(rendered.contains(&expected_diagnostics_section));
1063 }
1064
1065 #[test]
1066 fn test_judge_prompt_with_empty_diagnostics() {
1067 // Case 2: Diagnostics check run but no diagnostics found
1068 let input = JudgeDiffInput {
1069 repository_diff: "diff content goes here".to_string(),
1070 ran_diagnostics_check: true,
1071 diagnostics_before: None,
1072 diagnostics_after: None,
1073 criteria: "Fix all bugs".to_string(),
1074 };
1075
1076 let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1077
1078 let expected_diagnostics_section = r#"
1079 Take into account the diagnostics before and after applying the change:
1080
1081 <diagnostics_before>
1082 No diagnostics before applying the edits.
1083 </diagnostics_before>
1084
1085 <diagnostics_after>
1086 No diagnostics after applying the edits.
1087 </diagnostics_after>
1088 "#
1089 .unindent();
1090
1091 assert!(rendered.contains(&expected_diagnostics_section));
1092 }
1093
1094 #[test]
1095 fn test_judge_prompt_with_mixed_diagnostics() {
1096 let templates = templates();
1097
1098 // Case 3: Before diagnostics present, after diagnostics absent
1099 let input = JudgeDiffInput {
1100 repository_diff: "diff content goes here".to_string(),
1101 ran_diagnostics_check: true,
1102 diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1103 diagnostics_after: None,
1104 criteria: "Fix all bugs".to_string(),
1105 };
1106
1107 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1108
1109 let expected_diagnostics_section = r#"
1110 Take into account the diagnostics before and after applying the change:
1111
1112 <diagnostics_before>
1113 Error at line 10: variable not found
1114 </diagnostics_before>
1115
1116 <diagnostics_after>
1117 No diagnostics after applying the edits.
1118 </diagnostics_after>
1119 "#
1120 .unindent();
1121
1122 assert!(rendered.contains(&expected_diagnostics_section));
1123
1124 // Case 4: Before diagnostics absent, after diagnostics present
1125 let input = JudgeDiffInput {
1126 repository_diff: "diff content goes here".to_string(),
1127 ran_diagnostics_check: true,
1128 diagnostics_before: None,
1129 diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1130 criteria: "Fix all bugs".to_string(),
1131 };
1132
1133 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1134
1135 let expected_diagnostics_section = r#"
1136 Take into account the diagnostics before and after applying the change:
1137
1138 <diagnostics_before>
1139 No diagnostics before applying the edits.
1140 </diagnostics_before>
1141
1142 <diagnostics_after>
1143 Error at line 15: missing semicolon
1144 </diagnostics_after>
1145 "#
1146 .unindent();
1147
1148 assert!(rendered.contains(&expected_diagnostics_section));
1149 }
1150
1151 #[test]
1152 fn test_judge_prompt_without_diagnostics() {
1153 let templates = templates();
1154
1155 // Case 5: No diagnostics check run
1156 let input = JudgeDiffInput {
1157 repository_diff: "diff content goes here".to_string(),
1158 ran_diagnostics_check: false,
1159 diagnostics_before: None,
1160 diagnostics_after: None,
1161 criteria: "Fix all bugs".to_string(),
1162 };
1163
1164 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1165
1166 // Check for the message when no diagnostics were performed
1167 let diagnostics_message = "No diagnostic checks were performed.";
1168
1169 assert!(rendered.contains(diagnostics_message));
1170 assert!(!rendered.contains("<diagnostics_before>"));
1171 assert!(!rendered.contains("<diagnostics_after>"));
1172 }
1173
1174 const JUDGE_PROMPT_NAME: &str = "judge_prompt";
1175
1176 fn templates() -> Handlebars<'static> {
1177 let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
1178 language::LineEnding::normalize(&mut judge_prompt);
1179 let mut handlebars = Handlebars::new();
1180 handlebars
1181 .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
1182 .unwrap();
1183 handlebars
1184 }
1185}