example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use dap::DapRegistry;
  5use futures::channel::oneshot;
  6use gpui::{App, Task};
  7use language_model::{LanguageModel, StopReason};
  8use project::Project;
  9use serde::Deserialize;
 10use std::process::Command;
 11use std::sync::Arc;
 12use std::{
 13    fs,
 14    path::{Path, PathBuf},
 15};
 16
 17use crate::AgentAppState;
 18
 19#[derive(Debug, Deserialize)]
 20pub struct ExampleBase {
 21    pub path: PathBuf,
 22    pub revision: String,
 23}
 24
 25#[derive(Debug)]
 26pub struct Example {
 27    pub base: ExampleBase,
 28
 29    /// Content of the prompt.md file
 30    pub prompt: String,
 31
 32    /// Content of the rubric.md file
 33    pub _rubric: String,
 34}
 35
 36impl Example {
 37    /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
 38    pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
 39        let base_path = dir_path.as_ref().join("base.toml");
 40        let prompt_path = dir_path.as_ref().join("prompt.md");
 41        let rubric_path = dir_path.as_ref().join("rubric.md");
 42
 43        let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
 44        base.path = base.path.canonicalize()?;
 45
 46        Ok(Example {
 47            base,
 48            prompt: fs::read_to_string(prompt_path)?,
 49            _rubric: fs::read_to_string(rubric_path)?,
 50        })
 51    }
 52
 53    /// Set up the example by checking out the specified Git revision
 54    pub fn setup(&self) -> Result<()> {
 55        // Check if the directory exists
 56        let path = Path::new(&self.base.path);
 57        anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
 58
 59        // Change to the project directory and checkout the specified revision
 60        let output = Command::new("git")
 61            .current_dir(&self.base.path)
 62            .arg("checkout")
 63            .arg(&self.base.revision)
 64            .output()?;
 65        anyhow::ensure!(
 66            output.status.success(),
 67            "Failed to checkout revision {}: {}",
 68            self.base.revision,
 69            String::from_utf8_lossy(&output.stderr),
 70        );
 71
 72        Ok(())
 73    }
 74
 75    pub fn run(
 76        self,
 77        model: Arc<dyn LanguageModel>,
 78        app_state: Arc<AgentAppState>,
 79        cx: &mut App,
 80    ) -> Task<Result<()>> {
 81        let project = Project::local(
 82            app_state.client.clone(),
 83            app_state.node_runtime.clone(),
 84            app_state.user_store.clone(),
 85            app_state.languages.clone(),
 86            Arc::new(DapRegistry::default()),
 87            app_state.fs.clone(),
 88            None,
 89            cx,
 90        );
 91
 92        let worktree = project.update(cx, |project, cx| {
 93            project.create_worktree(self.base.path, true, cx)
 94        });
 95
 96        let tools = Arc::new(ToolWorkingSet::default());
 97        let thread_store =
 98            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
 99
100        println!("USER:");
101        println!("{}", self.prompt);
102        println!("ASSISTANT:");
103        cx.spawn(async move |cx| {
104            worktree.await?;
105            let thread_store = thread_store.await;
106            let thread =
107                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
108
109            let (tx, rx) = oneshot::channel();
110            let mut tx = Some(tx);
111
112            let _subscription =
113                cx.subscribe(
114                    &thread,
115                    move |thread, event: &ThreadEvent, cx| match event {
116                        ThreadEvent::Stopped(reason) => match reason {
117                            Ok(StopReason::EndTurn) => {
118                                if let Some(tx) = tx.take() {
119                                    tx.send(Ok(())).ok();
120                                }
121                            }
122                            Ok(StopReason::MaxTokens) => {
123                                if let Some(tx) = tx.take() {
124                                    tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
125                                }
126                            }
127                            Ok(StopReason::ToolUse) => {}
128                            Err(error) => {
129                                if let Some(tx) = tx.take() {
130                                    tx.send(Err(anyhow!(error.clone()))).ok();
131                                }
132                            }
133                        },
134                        ThreadEvent::ShowError(thread_error) => {
135                            if let Some(tx) = tx.take() {
136                                tx.send(Err(anyhow!(thread_error.clone()))).ok();
137                            }
138                        }
139                        ThreadEvent::StreamedAssistantText(_, chunk) => {
140                            print!("{}", chunk);
141                        }
142                        ThreadEvent::StreamedAssistantThinking(_, chunk) => {
143                            print!("{}", chunk);
144                        }
145                        ThreadEvent::UsePendingTools { tool_uses } => {
146                            println!("\n\nUSING TOOLS:");
147                            for tool_use in tool_uses {
148                                println!("{}: {}", tool_use.name, tool_use.input);
149                            }
150                        }
151                        ThreadEvent::ToolFinished {
152                            tool_use_id,
153                            pending_tool_use,
154                            ..
155                        } => {
156                            if let Some(tool_use) = pending_tool_use {
157                                println!("\nTOOL FINISHED: {}", tool_use.name);
158                            }
159                            if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
160                                println!("\n{}\n", tool_result.content);
161                            }
162                        }
163                        _ => {}
164                    },
165                )?;
166
167            thread.update(cx, |thread, cx| {
168                let context = vec![];
169                thread.insert_user_message(self.prompt.clone(), context, None, cx);
170                thread.send_to_model(model, RequestKind::Chat, cx);
171            })?;
172
173            rx.await??;
174
175            Ok(())
176        })
177    }
178}