headless_assistant.rs

  1use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
  2use anyhow::anyhow;
  3use assistant_tool::ToolWorkingSet;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
  8use language::LanguageRegistry;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use node_runtime::NodeRuntime;
 13use project::{Project, RealFs};
 14use prompt_store::PromptBuilder;
 15use settings::SettingsStore;
 16use smol::channel;
 17use std::sync::Arc;
 18
 19/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 20pub struct HeadlessAppState {
 21    pub languages: Arc<LanguageRegistry>,
 22    pub client: Arc<Client>,
 23    pub user_store: Entity<UserStore>,
 24    pub fs: Arc<dyn fs::Fs>,
 25    pub node_runtime: NodeRuntime,
 26
 27    // Additional fields not present in `workspace::AppState`.
 28    pub prompt_builder: Arc<PromptBuilder>,
 29}
 30
 31pub struct HeadlessAssistant {
 32    pub thread: Entity<Thread>,
 33    pub project: Entity<Project>,
 34    #[allow(dead_code)]
 35    pub thread_store: Entity<ThreadStore>,
 36    pub tool_use_counts: HashMap<Arc<str>, u32>,
 37    pub done_tx: channel::Sender<anyhow::Result<()>>,
 38    _subscription: Subscription,
 39}
 40
 41impl HeadlessAssistant {
 42    pub fn new(
 43        app_state: Arc<HeadlessAppState>,
 44        cx: &mut App,
 45    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 46        let env = None;
 47        let project = Project::local(
 48            app_state.client.clone(),
 49            app_state.node_runtime.clone(),
 50            app_state.user_store.clone(),
 51            app_state.languages.clone(),
 52            Arc::new(DapRegistry::default()),
 53            app_state.fs.clone(),
 54            env,
 55            cx,
 56        );
 57
 58        let tools = Arc::new(ToolWorkingSet::default());
 59        let thread_store =
 60            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 61
 62        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 63
 64        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 65
 66        let headless_thread = cx.new(move |cx| Self {
 67            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 68            thread,
 69            project,
 70            thread_store,
 71            tool_use_counts: HashMap::default(),
 72            done_tx,
 73        });
 74
 75        Ok((headless_thread, done_rx))
 76    }
 77
 78    fn handle_thread_event(
 79        &mut self,
 80        thread: Entity<Thread>,
 81        event: &ThreadEvent,
 82        cx: &mut Context<Self>,
 83    ) {
 84        match event {
 85            ThreadEvent::ShowError(err) => self
 86                .done_tx
 87                .send_blocking(Err(anyhow!("{:?}", err)))
 88                .unwrap(),
 89            ThreadEvent::DoneStreaming => {
 90                let thread = thread.read(cx);
 91                if let Some(message) = thread.messages().last() {
 92                    println!("Message: {}", message.to_string());
 93                }
 94                if thread.all_tools_finished() {
 95                    self.done_tx.send_blocking(Ok(())).unwrap()
 96                }
 97            }
 98            ThreadEvent::UsePendingTools => {
 99                thread.update(cx, |thread, cx| {
100                    thread.use_pending_tools(cx);
101                });
102            }
103            ThreadEvent::ToolConfirmationNeeded => {
104                // Automatically approve all tools that need confirmation in headless mode
105                println!("Tool confirmation needed - automatically approving in headless mode");
106
107                // Get the tools needing confirmation
108                let tools_needing_confirmation: Vec<_> = thread
109                    .read(cx)
110                    .tools_needing_confirmation()
111                    .cloned()
112                    .collect();
113
114                // Run each tool that needs confirmation
115                for tool_use in tools_needing_confirmation {
116                    if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
117                        thread.update(cx, |thread, cx| {
118                            println!("Auto-approving tool: {}", tool_use.name);
119
120                            // Create a request to send to the tool
121                            let request = thread.to_completion_request(RequestKind::Chat, cx);
122                            let messages = Arc::new(request.messages);
123
124                            // Run the tool
125                            thread.run_tool(
126                                tool_use.id.clone(),
127                                tool_use.ui_text.clone(),
128                                tool_use.input.clone(),
129                                &messages,
130                                tool,
131                                cx,
132                            );
133                        });
134                    }
135                }
136            }
137            ThreadEvent::ToolFinished {
138                tool_use_id,
139                pending_tool_use,
140                ..
141            } => {
142                if let Some(pending_tool_use) = pending_tool_use {
143                    println!(
144                        "Used tool {} with input: {}",
145                        pending_tool_use.name, pending_tool_use.input
146                    );
147                    *self
148                        .tool_use_counts
149                        .entry(pending_tool_use.name.clone())
150                        .or_insert(0) += 1;
151                }
152                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
153                    println!("Tool result: {:?}", tool_result);
154                }
155                if thread.read(cx).all_tools_finished() {
156                    let model_registry = LanguageModelRegistry::read_global(cx);
157                    if let Some(model) = model_registry.default_model() {
158                        thread.update(cx, |thread, cx| {
159                            thread.attach_tool_results(cx);
160                            thread.send_to_model(model.model, RequestKind::Chat, cx);
161                        });
162                    } else {
163                        println!(
164                            "Warning: No active language model available to continue conversation"
165                        );
166                    }
167                }
168            }
169            _ => {}
170        }
171    }
172}
173
174pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
175    release_channel::init(SemanticVersion::default(), cx);
176    gpui_tokio::init(cx);
177
178    let mut settings_store = SettingsStore::new(cx);
179    settings_store
180        .set_default_settings(settings::default_settings().as_ref(), cx)
181        .unwrap();
182    cx.set_global(settings_store);
183    client::init_settings(cx);
184    Project::init_settings(cx);
185
186    let client = Client::production(cx);
187    cx.set_http_client(client.http_client().clone());
188
189    let git_binary_path = None;
190    let fs = Arc::new(RealFs::new(
191        git_binary_path,
192        cx.background_executor().clone(),
193    ));
194
195    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
196
197    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
198
199    language::init(cx);
200    language_model::init(client.clone(), cx);
201    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
202    assistant_tools::init(client.http_client().clone(), cx);
203    context_server::init(cx);
204    let stdout_is_a_pty = false;
205    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
206    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
207
208    Arc::new(HeadlessAppState {
209        languages,
210        client,
211        user_store,
212        fs,
213        node_runtime: NodeRuntime::unavailable(),
214        prompt_builder,
215    })
216}
217
218pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
219    let model_registry = LanguageModelRegistry::read_global(cx);
220    let model = model_registry
221        .available_models(cx)
222        .find(|model| model.id().0 == model_name);
223
224    let Some(model) = model else {
225        return Err(anyhow!(
226            "No language model named {} was available. Available models: {}",
227            model_name,
228            model_registry
229                .available_models(cx)
230                .map(|model| model.id().0.clone())
231                .collect::<Vec<_>>()
232                .join(", ")
233        ));
234    };
235
236    Ok(model)
237}
238
239pub fn authenticate_model_provider(
240    provider_id: LanguageModelProviderId,
241    cx: &mut App,
242) -> Task<std::result::Result<(), AuthenticateError>> {
243    let model_registry = LanguageModelRegistry::read_global(cx);
244    let model_provider = model_registry.provider(&provider_id).unwrap();
245    model_provider.authenticate(cx)
246}