headless_assistant.rs

  1use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
  2use anyhow::anyhow;
  3use assistant_tool::ToolWorkingSet;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::StreamExt;
  8use gpui::{App, AsyncApp, Entity, SemanticVersion, Subscription, Task, prelude::*};
  9use language::LanguageRegistry;
 10use language_model::{
 11    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 12    LanguageModelRequest,
 13};
 14use node_runtime::NodeRuntime;
 15use project::{Project, RealFs};
 16use prompt_store::PromptBuilder;
 17use settings::SettingsStore;
 18use smol::channel;
 19use std::sync::Arc;
 20
 21/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 22pub struct HeadlessAppState {
 23    pub languages: Arc<LanguageRegistry>,
 24    pub client: Arc<Client>,
 25    pub user_store: Entity<UserStore>,
 26    pub fs: Arc<dyn fs::Fs>,
 27    pub node_runtime: NodeRuntime,
 28
 29    // Additional fields not present in `workspace::AppState`.
 30    pub prompt_builder: Arc<PromptBuilder>,
 31}
 32
 33pub struct HeadlessAssistant {
 34    pub thread: Entity<Thread>,
 35    pub project: Entity<Project>,
 36    #[allow(dead_code)]
 37    pub thread_store: Entity<ThreadStore>,
 38    pub tool_use_counts: HashMap<Arc<str>, u32>,
 39    pub done_tx: channel::Sender<anyhow::Result<()>>,
 40    _subscription: Subscription,
 41}
 42
 43impl HeadlessAssistant {
 44    pub fn new(
 45        app_state: Arc<HeadlessAppState>,
 46        cx: &mut App,
 47    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 48        let env = None;
 49        let project = Project::local(
 50            app_state.client.clone(),
 51            app_state.node_runtime.clone(),
 52            app_state.user_store.clone(),
 53            app_state.languages.clone(),
 54            Arc::new(DapRegistry::default()),
 55            app_state.fs.clone(),
 56            env,
 57            cx,
 58        );
 59
 60        let tools = Arc::new(ToolWorkingSet::default());
 61        let thread_store =
 62            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 63
 64        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 65
 66        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 67
 68        let headless_thread = cx.new(move |cx| Self {
 69            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 70            thread,
 71            project,
 72            thread_store,
 73            tool_use_counts: HashMap::default(),
 74            done_tx,
 75        });
 76
 77        Ok((headless_thread, done_rx))
 78    }
 79
 80    fn handle_thread_event(
 81        &mut self,
 82        thread: Entity<Thread>,
 83        event: &ThreadEvent,
 84        cx: &mut Context<Self>,
 85    ) {
 86        match event {
 87            ThreadEvent::ShowError(err) => self
 88                .done_tx
 89                .send_blocking(Err(anyhow!("{:?}", err)))
 90                .unwrap(),
 91            ThreadEvent::DoneStreaming => {
 92                let thread = thread.read(cx);
 93                if let Some(message) = thread.messages().last() {
 94                    println!("Message: {}", message.to_string());
 95                }
 96                if thread.all_tools_finished() {
 97                    self.done_tx.send_blocking(Ok(())).unwrap()
 98                }
 99            }
100            ThreadEvent::UsePendingTools => {
101                thread.update(cx, |thread, cx| {
102                    thread.use_pending_tools(cx);
103                });
104            }
105            ThreadEvent::ToolConfirmationNeeded => {
106                // Automatically approve all tools that need confirmation in headless mode
107                println!("Tool confirmation needed - automatically approving in headless mode");
108
109                // Get the tools needing confirmation
110                let tools_needing_confirmation: Vec<_> = thread
111                    .read(cx)
112                    .tools_needing_confirmation()
113                    .cloned()
114                    .collect();
115
116                // Run each tool that needs confirmation
117                for tool_use in tools_needing_confirmation {
118                    if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
119                        thread.update(cx, |thread, cx| {
120                            println!("Auto-approving tool: {}", tool_use.name);
121
122                            // Create a request to send to the tool
123                            let request = thread.to_completion_request(RequestKind::Chat, cx);
124                            let messages = Arc::new(request.messages);
125
126                            // Run the tool
127                            thread.run_tool(
128                                tool_use.id.clone(),
129                                tool_use.ui_text.clone(),
130                                tool_use.input.clone(),
131                                &messages,
132                                tool,
133                                cx,
134                            );
135                        });
136                    }
137                }
138            }
139            ThreadEvent::ToolFinished {
140                tool_use_id,
141                pending_tool_use,
142                ..
143            } => {
144                if let Some(pending_tool_use) = pending_tool_use {
145                    println!(
146                        "Used tool {} with input: {}",
147                        pending_tool_use.name, pending_tool_use.input
148                    );
149                    *self
150                        .tool_use_counts
151                        .entry(pending_tool_use.name.clone())
152                        .or_insert(0) += 1;
153                }
154                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
155                    println!("Tool result: {:?}", tool_result);
156                }
157                if thread.read(cx).all_tools_finished() {
158                    let model_registry = LanguageModelRegistry::read_global(cx);
159                    if let Some(model) = model_registry.active_model() {
160                        thread.update(cx, |thread, cx| {
161                            thread.attach_tool_results(vec![], cx);
162                            thread.send_to_model(model, RequestKind::Chat, cx);
163                        });
164                    } else {
165                        println!(
166                            "Warning: No active language model available to continue conversation"
167                        );
168                    }
169                }
170            }
171            _ => {}
172        }
173    }
174}
175
176pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
177    release_channel::init(SemanticVersion::default(), cx);
178    gpui_tokio::init(cx);
179
180    let mut settings_store = SettingsStore::new(cx);
181    settings_store
182        .set_default_settings(settings::default_settings().as_ref(), cx)
183        .unwrap();
184    cx.set_global(settings_store);
185    client::init_settings(cx);
186    Project::init_settings(cx);
187
188    let client = Client::production(cx);
189    cx.set_http_client(client.http_client().clone());
190
191    let git_binary_path = None;
192    let fs = Arc::new(RealFs::new(
193        git_binary_path,
194        cx.background_executor().clone(),
195    ));
196
197    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
198
199    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
200
201    language::init(cx);
202    language_model::init(client.clone(), cx);
203    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
204    assistant_tools::init(client.http_client().clone(), cx);
205    context_server::init(cx);
206    let stdout_is_a_pty = false;
207    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
208    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
209
210    Arc::new(HeadlessAppState {
211        languages,
212        client,
213        user_store,
214        fs,
215        node_runtime: NodeRuntime::unavailable(),
216        prompt_builder,
217    })
218}
219
220pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
221    let model_registry = LanguageModelRegistry::read_global(cx);
222    let model = model_registry
223        .available_models(cx)
224        .find(|model| model.id().0 == model_name);
225
226    let Some(model) = model else {
227        return Err(anyhow!(
228            "No language model named {} was available. Available models: {}",
229            model_name,
230            model_registry
231                .available_models(cx)
232                .map(|model| model.id().0.clone())
233                .collect::<Vec<_>>()
234                .join(", ")
235        ));
236    };
237
238    Ok(model)
239}
240
241pub fn authenticate_model_provider(
242    provider_id: LanguageModelProviderId,
243    cx: &mut App,
244) -> Task<std::result::Result<(), AuthenticateError>> {
245    let model_registry = LanguageModelRegistry::read_global(cx);
246    let model_provider = model_registry.provider(&provider_id).unwrap();
247    model_provider.authenticate(cx)
248}
249
250pub async fn send_language_model_request(
251    model: Arc<dyn LanguageModel>,
252    request: LanguageModelRequest,
253    cx: &mut AsyncApp,
254) -> anyhow::Result<String> {
255    match model.stream_completion_text(request, &cx).await {
256        Ok(mut stream) => {
257            let mut full_response = String::new();
258
259            // Process the response stream
260            while let Some(chunk_result) = stream.stream.next().await {
261                match chunk_result {
262                    Ok(chunk_str) => {
263                        full_response.push_str(&chunk_str);
264                    }
265                    Err(err) => {
266                        return Err(anyhow!(
267                            "Error receiving response from language model: {err}"
268                        ));
269                    }
270                }
271            }
272
273            Ok(full_response)
274        }
275        Err(err) => Err(anyhow!(
276            "Failed to get response from language model. Error was: {err}"
277        )),
278    }
279}