agent.rs

  1use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
  2use anyhow::anyhow;
  3use assistant_tool::ToolWorkingSet;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
  8use language::LanguageRegistry;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use node_runtime::NodeRuntime;
 13use project::{Project, RealFs};
 14use prompt_store::PromptBuilder;
 15use settings::SettingsStore;
 16use smol::channel;
 17use std::sync::Arc;
 18
 19/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 20pub struct AgentAppState {
 21    pub languages: Arc<LanguageRegistry>,
 22    pub client: Arc<Client>,
 23    pub user_store: Entity<UserStore>,
 24    pub fs: Arc<dyn fs::Fs>,
 25    pub node_runtime: NodeRuntime,
 26
 27    // Additional fields not present in `workspace::AppState`.
 28    pub prompt_builder: Arc<PromptBuilder>,
 29}
 30
 31pub struct Agent {
 32    // pub thread: Entity<Thread>,
 33    // pub project: Entity<Project>,
 34    #[allow(dead_code)]
 35    pub thread_store: Entity<ThreadStore>,
 36    pub tool_use_counts: HashMap<Arc<str>, u32>,
 37    pub done_tx: channel::Sender<anyhow::Result<()>>,
 38    _subscription: Subscription,
 39}
 40
 41impl Agent {
 42    pub fn new(
 43        app_state: Arc<AgentAppState>,
 44        cx: &mut App,
 45    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 46        let env = None;
 47        let project = Project::local(
 48            app_state.client.clone(),
 49            app_state.node_runtime.clone(),
 50            app_state.user_store.clone(),
 51            app_state.languages.clone(),
 52            Arc::new(DapRegistry::default()),
 53            app_state.fs.clone(),
 54            env,
 55            cx,
 56        );
 57
 58        let tools = Arc::new(ToolWorkingSet::default());
 59        let thread_store =
 60            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 61
 62        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 63
 64        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 65
 66        let headless_thread = cx.new(move |cx| Self {
 67            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 68            // thread,
 69            // project,
 70            thread_store,
 71            tool_use_counts: HashMap::default(),
 72            done_tx,
 73        });
 74
 75        Ok((headless_thread, done_rx))
 76    }
 77
 78    fn handle_thread_event(
 79        &mut self,
 80        thread: Entity<Thread>,
 81        event: &ThreadEvent,
 82        cx: &mut Context<Self>,
 83    ) {
 84        match event {
 85            ThreadEvent::ShowError(err) => self
 86                .done_tx
 87                .send_blocking(Err(anyhow!("{:?}", err)))
 88                .unwrap(),
 89            ThreadEvent::DoneStreaming => {
 90                let thread = thread.read(cx);
 91                if let Some(message) = thread.messages().last() {
 92                    println!("Message: {}", message.to_string());
 93                }
 94                if thread.all_tools_finished() {
 95                    self.done_tx.send_blocking(Ok(())).unwrap()
 96                }
 97            }
 98            ThreadEvent::UsePendingTools { .. } => {}
 99            ThreadEvent::ToolConfirmationNeeded => {
100                // Automatically approve all tools that need confirmation in headless mode
101                println!("Tool confirmation needed - automatically approving in headless mode");
102
103                // Get the tools needing confirmation
104                let tools_needing_confirmation: Vec<_> = thread
105                    .read(cx)
106                    .tools_needing_confirmation()
107                    .cloned()
108                    .collect();
109
110                // Run each tool that needs confirmation
111                for tool_use in tools_needing_confirmation {
112                    if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
113                        thread.update(cx, |thread, cx| {
114                            println!("Auto-approving tool: {}", tool_use.name);
115
116                            // Create a request to send to the tool
117                            let request = thread.to_completion_request(RequestKind::Chat, cx);
118                            let messages = Arc::new(request.messages);
119
120                            // Run the tool
121                            thread.run_tool(
122                                tool_use.id.clone(),
123                                tool_use.ui_text.clone(),
124                                tool_use.input.clone(),
125                                &messages,
126                                tool,
127                                cx,
128                            );
129                        });
130                    }
131                }
132            }
133            ThreadEvent::ToolFinished {
134                tool_use_id,
135                pending_tool_use,
136                ..
137            } => {
138                if let Some(pending_tool_use) = pending_tool_use {
139                    println!(
140                        "Used tool {} with input: {}",
141                        pending_tool_use.name, pending_tool_use.input
142                    );
143                    *self
144                        .tool_use_counts
145                        .entry(pending_tool_use.name.clone())
146                        .or_insert(0) += 1;
147                }
148                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
149                    println!("Tool result: {:?}", tool_result);
150                }
151            }
152            _ => {}
153        }
154    }
155}
156
157pub fn init(cx: &mut App) -> Arc<AgentAppState> {
158    release_channel::init(SemanticVersion::default(), cx);
159    gpui_tokio::init(cx);
160
161    let mut settings_store = SettingsStore::new(cx);
162    settings_store
163        .set_default_settings(settings::default_settings().as_ref(), cx)
164        .unwrap();
165    cx.set_global(settings_store);
166    client::init_settings(cx);
167    Project::init_settings(cx);
168
169    let client = Client::production(cx);
170    cx.set_http_client(client.http_client().clone());
171
172    let git_binary_path = None;
173    let fs = Arc::new(RealFs::new(
174        git_binary_path,
175        cx.background_executor().clone(),
176    ));
177
178    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
179
180    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
181
182    language::init(cx);
183    language_model::init(client.clone(), cx);
184    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
185    assistant_tools::init(client.http_client().clone(), cx);
186    context_server::init(cx);
187    let stdout_is_a_pty = false;
188    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
189    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
190
191    Arc::new(AgentAppState {
192        languages,
193        client,
194        user_store,
195        fs,
196        node_runtime: NodeRuntime::unavailable(),
197        prompt_builder,
198    })
199}
200
201pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
202    let model_registry = LanguageModelRegistry::read_global(cx);
203    let model = model_registry
204        .available_models(cx)
205        .find(|model| model.id().0 == model_name);
206
207    let Some(model) = model else {
208        return Err(anyhow!(
209            "No language model named {} was available. Available models: {}",
210            model_name,
211            model_registry
212                .available_models(cx)
213                .map(|model| model.id().0.clone())
214                .collect::<Vec<_>>()
215                .join(", ")
216        ));
217    };
218
219    Ok(model)
220}
221
222pub fn authenticate_model_provider(
223    provider_id: LanguageModelProviderId,
224    cx: &mut App,
225) -> Task<std::result::Result<(), AuthenticateError>> {
226    let model_registry = LanguageModelRegistry::read_global(cx);
227    let model_provider = model_registry.provider(&provider_id).unwrap();
228    model_provider.authenticate(cx)
229}