1use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
2use anyhow::anyhow;
3use assistant_tool::ToolWorkingSet;
4use client::{Client, UserStore};
5use collections::HashMap;
6use dap::DapRegistry;
7use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
8use language::LanguageRegistry;
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
11};
12use node_runtime::NodeRuntime;
13use project::{Project, RealFs};
14use prompt_store::PromptBuilder;
15use settings::SettingsStore;
16use smol::channel;
17use std::sync::Arc;
18
19/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
20pub struct HeadlessAppState {
21 pub languages: Arc<LanguageRegistry>,
22 pub client: Arc<Client>,
23 pub user_store: Entity<UserStore>,
24 pub fs: Arc<dyn fs::Fs>,
25 pub node_runtime: NodeRuntime,
26
27 // Additional fields not present in `workspace::AppState`.
28 pub prompt_builder: Arc<PromptBuilder>,
29}
30
31pub struct HeadlessAssistant {
32 pub thread: Entity<Thread>,
33 pub project: Entity<Project>,
34 #[allow(dead_code)]
35 pub thread_store: Entity<ThreadStore>,
36 pub tool_use_counts: HashMap<Arc<str>, u32>,
37 pub done_tx: channel::Sender<anyhow::Result<()>>,
38 _subscription: Subscription,
39}
40
41impl HeadlessAssistant {
42 pub fn new(
43 app_state: Arc<HeadlessAppState>,
44 cx: &mut App,
45 ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
46 let env = None;
47 let project = Project::local(
48 app_state.client.clone(),
49 app_state.node_runtime.clone(),
50 app_state.user_store.clone(),
51 app_state.languages.clone(),
52 Arc::new(DapRegistry::default()),
53 app_state.fs.clone(),
54 env,
55 cx,
56 );
57
58 let tools = Arc::new(ToolWorkingSet::default());
59 let thread_store =
60 ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
61
62 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
63
64 let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
65
66 let headless_thread = cx.new(move |cx| Self {
67 _subscription: cx.subscribe(&thread, Self::handle_thread_event),
68 thread,
69 project,
70 thread_store,
71 tool_use_counts: HashMap::default(),
72 done_tx,
73 });
74
75 Ok((headless_thread, done_rx))
76 }
77
78 fn handle_thread_event(
79 &mut self,
80 thread: Entity<Thread>,
81 event: &ThreadEvent,
82 cx: &mut Context<Self>,
83 ) {
84 match event {
85 ThreadEvent::ShowError(err) => self
86 .done_tx
87 .send_blocking(Err(anyhow!("{:?}", err)))
88 .unwrap(),
89 ThreadEvent::DoneStreaming => {
90 let thread = thread.read(cx);
91 if let Some(message) = thread.messages().last() {
92 println!("Message: {}", message.to_string());
93 }
94 if thread.all_tools_finished() {
95 self.done_tx.send_blocking(Ok(())).unwrap()
96 }
97 }
98 ThreadEvent::UsePendingTools { .. } => {}
99 ThreadEvent::ToolConfirmationNeeded => {
100 // Automatically approve all tools that need confirmation in headless mode
101 println!("Tool confirmation needed - automatically approving in headless mode");
102
103 // Get the tools needing confirmation
104 let tools_needing_confirmation: Vec<_> = thread
105 .read(cx)
106 .tools_needing_confirmation()
107 .cloned()
108 .collect();
109
110 // Run each tool that needs confirmation
111 for tool_use in tools_needing_confirmation {
112 if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
113 thread.update(cx, |thread, cx| {
114 println!("Auto-approving tool: {}", tool_use.name);
115
116 // Create a request to send to the tool
117 let request = thread.to_completion_request(RequestKind::Chat, cx);
118 let messages = Arc::new(request.messages);
119
120 // Run the tool
121 thread.run_tool(
122 tool_use.id.clone(),
123 tool_use.ui_text.clone(),
124 tool_use.input.clone(),
125 &messages,
126 tool,
127 cx,
128 );
129 });
130 }
131 }
132 }
133 ThreadEvent::ToolFinished {
134 tool_use_id,
135 pending_tool_use,
136 ..
137 } => {
138 if let Some(pending_tool_use) = pending_tool_use {
139 println!(
140 "Used tool {} with input: {}",
141 pending_tool_use.name, pending_tool_use.input
142 );
143 *self
144 .tool_use_counts
145 .entry(pending_tool_use.name.clone())
146 .or_insert(0) += 1;
147 }
148 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
149 println!("Tool result: {:?}", tool_result);
150 }
151 }
152 _ => {}
153 }
154 }
155}
156
157pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
158 release_channel::init(SemanticVersion::default(), cx);
159 gpui_tokio::init(cx);
160
161 let mut settings_store = SettingsStore::new(cx);
162 settings_store
163 .set_default_settings(settings::default_settings().as_ref(), cx)
164 .unwrap();
165 cx.set_global(settings_store);
166 client::init_settings(cx);
167 Project::init_settings(cx);
168
169 let client = Client::production(cx);
170 cx.set_http_client(client.http_client().clone());
171
172 let git_binary_path = None;
173 let fs = Arc::new(RealFs::new(
174 git_binary_path,
175 cx.background_executor().clone(),
176 ));
177
178 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
179
180 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
181
182 language::init(cx);
183 language_model::init(client.clone(), cx);
184 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
185 assistant_tools::init(client.http_client().clone(), cx);
186 context_server::init(cx);
187 let stdout_is_a_pty = false;
188 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
189 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
190
191 Arc::new(HeadlessAppState {
192 languages,
193 client,
194 user_store,
195 fs,
196 node_runtime: NodeRuntime::unavailable(),
197 prompt_builder,
198 })
199}
200
201pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
202 let model_registry = LanguageModelRegistry::read_global(cx);
203 let model = model_registry
204 .available_models(cx)
205 .find(|model| model.id().0 == model_name);
206
207 let Some(model) = model else {
208 return Err(anyhow!(
209 "No language model named {} was available. Available models: {}",
210 model_name,
211 model_registry
212 .available_models(cx)
213 .map(|model| model.id().0.clone())
214 .collect::<Vec<_>>()
215 .join(", ")
216 ));
217 };
218
219 Ok(model)
220}
221
222pub fn authenticate_model_provider(
223 provider_id: LanguageModelProviderId,
224 cx: &mut App,
225) -> Task<std::result::Result<(), AuthenticateError>> {
226 let model_registry = LanguageModelRegistry::read_global(cx);
227 let model_provider = model_registry.provider(&provider_id).unwrap();
228 model_provider.authenticate(cx)
229}