1use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
2use anyhow::anyhow;
3use assistant_tool::ToolWorkingSet;
4use client::{Client, UserStore};
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::StreamExt;
8use gpui::{App, AsyncApp, Entity, SemanticVersion, Subscription, Task, prelude::*};
9use language::LanguageRegistry;
10use language_model::{
11 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
12 LanguageModelRequest,
13};
14use node_runtime::NodeRuntime;
15use project::{Project, RealFs};
16use prompt_store::PromptBuilder;
17use settings::SettingsStore;
18use smol::channel;
19use std::sync::Arc;
20
21/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
22pub struct HeadlessAppState {
23 pub languages: Arc<LanguageRegistry>,
24 pub client: Arc<Client>,
25 pub user_store: Entity<UserStore>,
26 pub fs: Arc<dyn fs::Fs>,
27 pub node_runtime: NodeRuntime,
28
29 // Additional fields not present in `workspace::AppState`.
30 pub prompt_builder: Arc<PromptBuilder>,
31}
32
33pub struct HeadlessAssistant {
34 pub thread: Entity<Thread>,
35 pub project: Entity<Project>,
36 #[allow(dead_code)]
37 pub thread_store: Entity<ThreadStore>,
38 pub tool_use_counts: HashMap<Arc<str>, u32>,
39 pub done_tx: channel::Sender<anyhow::Result<()>>,
40 _subscription: Subscription,
41}
42
43impl HeadlessAssistant {
44 pub fn new(
45 app_state: Arc<HeadlessAppState>,
46 cx: &mut App,
47 ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
48 let env = None;
49 let project = Project::local(
50 app_state.client.clone(),
51 app_state.node_runtime.clone(),
52 app_state.user_store.clone(),
53 app_state.languages.clone(),
54 Arc::new(DapRegistry::default()),
55 app_state.fs.clone(),
56 env,
57 cx,
58 );
59
60 let tools = Arc::new(ToolWorkingSet::default());
61 let thread_store =
62 ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
63
64 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
65
66 let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
67
68 let headless_thread = cx.new(move |cx| Self {
69 _subscription: cx.subscribe(&thread, Self::handle_thread_event),
70 thread,
71 project,
72 thread_store,
73 tool_use_counts: HashMap::default(),
74 done_tx,
75 });
76
77 Ok((headless_thread, done_rx))
78 }
79
80 fn handle_thread_event(
81 &mut self,
82 thread: Entity<Thread>,
83 event: &ThreadEvent,
84 cx: &mut Context<Self>,
85 ) {
86 match event {
87 ThreadEvent::ShowError(err) => self
88 .done_tx
89 .send_blocking(Err(anyhow!("{:?}", err)))
90 .unwrap(),
91 ThreadEvent::DoneStreaming => {
92 let thread = thread.read(cx);
93 if let Some(message) = thread.messages().last() {
94 println!("Message: {}", message.to_string());
95 }
96 if thread.all_tools_finished() {
97 self.done_tx.send_blocking(Ok(())).unwrap()
98 }
99 }
100 ThreadEvent::UsePendingTools => {
101 thread.update(cx, |thread, cx| {
102 thread.use_pending_tools(cx);
103 });
104 }
105 ThreadEvent::ToolConfirmationNeeded => {
106 // Automatically approve all tools that need confirmation in headless mode
107 println!("Tool confirmation needed - automatically approving in headless mode");
108
109 // Get the tools needing confirmation
110 let tools_needing_confirmation: Vec<_> = thread
111 .read(cx)
112 .tools_needing_confirmation()
113 .cloned()
114 .collect();
115
116 // Run each tool that needs confirmation
117 for tool_use in tools_needing_confirmation {
118 if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
119 thread.update(cx, |thread, cx| {
120 println!("Auto-approving tool: {}", tool_use.name);
121
122 // Create a request to send to the tool
123 let request = thread.to_completion_request(RequestKind::Chat, cx);
124 let messages = Arc::new(request.messages);
125
126 // Run the tool
127 thread.run_tool(
128 tool_use.id.clone(),
129 tool_use.ui_text.clone(),
130 tool_use.input.clone(),
131 &messages,
132 tool,
133 cx,
134 );
135 });
136 }
137 }
138 }
139 ThreadEvent::ToolFinished {
140 tool_use_id,
141 pending_tool_use,
142 ..
143 } => {
144 if let Some(pending_tool_use) = pending_tool_use {
145 println!(
146 "Used tool {} with input: {}",
147 pending_tool_use.name, pending_tool_use.input
148 );
149 *self
150 .tool_use_counts
151 .entry(pending_tool_use.name.clone())
152 .or_insert(0) += 1;
153 }
154 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
155 println!("Tool result: {:?}", tool_result);
156 }
157 if thread.read(cx).all_tools_finished() {
158 let model_registry = LanguageModelRegistry::read_global(cx);
159 if let Some(model) = model_registry.default_model() {
160 thread.update(cx, |thread, cx| {
161 thread.attach_tool_results(cx);
162 thread.send_to_model(model.model, RequestKind::Chat, cx);
163 });
164 } else {
165 println!(
166 "Warning: No active language model available to continue conversation"
167 );
168 }
169 }
170 }
171 _ => {}
172 }
173 }
174}
175
176pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
177 release_channel::init(SemanticVersion::default(), cx);
178 gpui_tokio::init(cx);
179
180 let mut settings_store = SettingsStore::new(cx);
181 settings_store
182 .set_default_settings(settings::default_settings().as_ref(), cx)
183 .unwrap();
184 cx.set_global(settings_store);
185 client::init_settings(cx);
186 Project::init_settings(cx);
187
188 let client = Client::production(cx);
189 cx.set_http_client(client.http_client().clone());
190
191 let git_binary_path = None;
192 let fs = Arc::new(RealFs::new(
193 git_binary_path,
194 cx.background_executor().clone(),
195 ));
196
197 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
198
199 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
200
201 language::init(cx);
202 language_model::init(client.clone(), cx);
203 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
204 assistant_tools::init(client.http_client().clone(), cx);
205 context_server::init(cx);
206 let stdout_is_a_pty = false;
207 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
208 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
209
210 Arc::new(HeadlessAppState {
211 languages,
212 client,
213 user_store,
214 fs,
215 node_runtime: NodeRuntime::unavailable(),
216 prompt_builder,
217 })
218}
219
220pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
221 let model_registry = LanguageModelRegistry::read_global(cx);
222 let model = model_registry
223 .available_models(cx)
224 .find(|model| model.id().0 == model_name);
225
226 let Some(model) = model else {
227 return Err(anyhow!(
228 "No language model named {} was available. Available models: {}",
229 model_name,
230 model_registry
231 .available_models(cx)
232 .map(|model| model.id().0.clone())
233 .collect::<Vec<_>>()
234 .join(", ")
235 ));
236 };
237
238 Ok(model)
239}
240
241pub fn authenticate_model_provider(
242 provider_id: LanguageModelProviderId,
243 cx: &mut App,
244) -> Task<std::result::Result<(), AuthenticateError>> {
245 let model_registry = LanguageModelRegistry::read_global(cx);
246 let model_provider = model_registry.provider(&provider_id).unwrap();
247 model_provider.authenticate(cx)
248}
249
250pub async fn send_language_model_request(
251 model: Arc<dyn LanguageModel>,
252 request: LanguageModelRequest,
253 cx: &mut AsyncApp,
254) -> anyhow::Result<String> {
255 match model.stream_completion_text(request, &cx).await {
256 Ok(mut stream) => {
257 let mut full_response = String::new();
258
259 // Process the response stream
260 while let Some(chunk_result) = stream.stream.next().await {
261 match chunk_result {
262 Ok(chunk_str) => {
263 full_response.push_str(&chunk_str);
264 }
265 Err(err) => {
266 return Err(anyhow!(
267 "Error receiving response from language model: {err}"
268 ));
269 }
270 }
271 }
272
273 Ok(full_response)
274 }
275 Err(err) => Err(anyhow!(
276 "Failed to get response from language model. Error was: {err}"
277 )),
278 }
279}