1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use dap::DapRegistry;
5use futures::channel::oneshot;
6use gpui::{App, Task};
7use language_model::{LanguageModel, StopReason};
8use project::Project;
9use serde::Deserialize;
10use std::process::Command;
11use std::sync::Arc;
12use std::{
13 fs,
14 path::{Path, PathBuf},
15};
16
17use crate::AgentAppState;
18
19#[derive(Debug, Deserialize)]
20pub struct ExampleBase {
21 pub path: PathBuf,
22 pub revision: String,
23}
24
25#[derive(Debug)]
26pub struct Example {
27 pub base: ExampleBase,
28
29 /// Content of the prompt.md file
30 pub prompt: String,
31
32 /// Content of the rubric.md file
33 pub _rubric: String,
34}
35
36impl Example {
37 /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
38 pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
39 let base_path = dir_path.as_ref().join("base.toml");
40 let prompt_path = dir_path.as_ref().join("prompt.md");
41 let rubric_path = dir_path.as_ref().join("rubric.md");
42
43 let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
44 base.path = base.path.canonicalize()?;
45
46 Ok(Example {
47 base,
48 prompt: fs::read_to_string(prompt_path)?,
49 _rubric: fs::read_to_string(rubric_path)?,
50 })
51 }
52
53 /// Set up the example by checking out the specified Git revision
54 pub fn setup(&self) -> Result<()> {
55 // Check if the directory exists
56 let path = Path::new(&self.base.path);
57 anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
58
59 // Change to the project directory and checkout the specified revision
60 let output = Command::new("git")
61 .current_dir(&self.base.path)
62 .arg("checkout")
63 .arg(&self.base.revision)
64 .output()?;
65 anyhow::ensure!(
66 output.status.success(),
67 "Failed to checkout revision {}: {}",
68 self.base.revision,
69 String::from_utf8_lossy(&output.stderr),
70 );
71
72 Ok(())
73 }
74
75 pub fn run(
76 self,
77 model: Arc<dyn LanguageModel>,
78 app_state: Arc<AgentAppState>,
79 cx: &mut App,
80 ) -> Task<Result<()>> {
81 let project = Project::local(
82 app_state.client.clone(),
83 app_state.node_runtime.clone(),
84 app_state.user_store.clone(),
85 app_state.languages.clone(),
86 Arc::new(DapRegistry::default()),
87 app_state.fs.clone(),
88 None,
89 cx,
90 );
91
92 let worktree = project.update(cx, |project, cx| {
93 project.create_worktree(self.base.path, true, cx)
94 });
95
96 let tools = Arc::new(ToolWorkingSet::default());
97 let thread_store =
98 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
99
100 println!("USER:");
101 println!("{}", self.prompt);
102 println!("ASSISTANT:");
103 cx.spawn(async move |cx| {
104 worktree.await?;
105 let thread_store = thread_store.await;
106 let thread =
107 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
108
109 let (tx, rx) = oneshot::channel();
110 let mut tx = Some(tx);
111
112 let _subscription =
113 cx.subscribe(
114 &thread,
115 move |thread, event: &ThreadEvent, cx| match event {
116 ThreadEvent::Stopped(reason) => match reason {
117 Ok(StopReason::EndTurn) => {
118 if let Some(tx) = tx.take() {
119 tx.send(Ok(())).ok();
120 }
121 }
122 Ok(StopReason::MaxTokens) => {
123 if let Some(tx) = tx.take() {
124 tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
125 }
126 }
127 Ok(StopReason::ToolUse) => {}
128 Err(error) => {
129 if let Some(tx) = tx.take() {
130 tx.send(Err(anyhow!(error.clone()))).ok();
131 }
132 }
133 },
134 ThreadEvent::ShowError(thread_error) => {
135 if let Some(tx) = tx.take() {
136 tx.send(Err(anyhow!(thread_error.clone()))).ok();
137 }
138 }
139 ThreadEvent::StreamedAssistantText(_, chunk) => {
140 print!("{}", chunk);
141 }
142 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
143 print!("{}", chunk);
144 }
145 ThreadEvent::UsePendingTools { tool_uses } => {
146 println!("\n\nUSING TOOLS:");
147 for tool_use in tool_uses {
148 println!("{}: {}", tool_use.name, tool_use.input);
149 }
150 }
151 ThreadEvent::ToolFinished {
152 tool_use_id,
153 pending_tool_use,
154 ..
155 } => {
156 if let Some(tool_use) = pending_tool_use {
157 println!("\nTOOL FINISHED: {}", tool_use.name);
158 }
159 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
160 println!("\n{}\n", tool_result.content);
161 }
162 }
163 _ => {}
164 },
165 )?;
166
167 thread.update(cx, |thread, cx| {
168 let context = vec![];
169 thread.insert_user_message(self.prompt.clone(), context, None, cx);
170 thread.send_to_model(model, RequestKind::Chat, cx);
171 })?;
172
173 rx.await??;
174
175 Ok(())
176 })
177 }
178}