1use anyhow::anyhow;
2use assistant_tool::ToolWorkingSet;
3use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
4use client::{Client, UserStore};
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::StreamExt;
8use gpui::{App, AsyncApp, Entity, SemanticVersion, Subscription, Task, prelude::*};
9use language::LanguageRegistry;
10use language_model::{
11 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
12 LanguageModelRequest,
13};
14use node_runtime::NodeRuntime;
15use project::{Project, RealFs};
16use prompt_store::PromptBuilder;
17use settings::SettingsStore;
18use smol::channel;
19use std::sync::Arc;
20
21/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
22pub struct HeadlessAppState {
23 pub languages: Arc<LanguageRegistry>,
24 pub client: Arc<Client>,
25 pub user_store: Entity<UserStore>,
26 pub fs: Arc<dyn fs::Fs>,
27 pub node_runtime: NodeRuntime,
28
29 // Additional fields not present in `workspace::AppState`.
30 pub prompt_builder: Arc<PromptBuilder>,
31}
32
33pub struct HeadlessAssistant {
34 pub thread: Entity<Thread>,
35 pub project: Entity<Project>,
36 #[allow(dead_code)]
37 pub thread_store: Entity<ThreadStore>,
38 pub tool_use_counts: HashMap<Arc<str>, u32>,
39 pub done_tx: channel::Sender<anyhow::Result<()>>,
40 _subscription: Subscription,
41}
42
43impl HeadlessAssistant {
44 pub fn new(
45 app_state: Arc<HeadlessAppState>,
46 cx: &mut App,
47 ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
48 let env = None;
49 let project = Project::local(
50 app_state.client.clone(),
51 app_state.node_runtime.clone(),
52 app_state.user_store.clone(),
53 app_state.languages.clone(),
54 Arc::new(DapRegistry::default()),
55 app_state.fs.clone(),
56 env,
57 cx,
58 );
59
60 let tools = Arc::new(ToolWorkingSet::default());
61 let thread_store =
62 ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
63
64 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
65
66 let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
67
68 let headless_thread = cx.new(move |cx| Self {
69 _subscription: cx.subscribe(&thread, Self::handle_thread_event),
70 thread,
71 project,
72 thread_store,
73 tool_use_counts: HashMap::default(),
74 done_tx,
75 });
76
77 Ok((headless_thread, done_rx))
78 }
79
80 fn handle_thread_event(
81 &mut self,
82 thread: Entity<Thread>,
83 event: &ThreadEvent,
84 cx: &mut Context<Self>,
85 ) {
86 match event {
87 ThreadEvent::ShowError(err) => self
88 .done_tx
89 .send_blocking(Err(anyhow!("{:?}", err)))
90 .unwrap(),
91 ThreadEvent::DoneStreaming => {
92 let thread = thread.read(cx);
93 if let Some(message) = thread.messages().last() {
94 println!("Message: {}", message.to_string());
95 }
96 if thread.all_tools_finished() {
97 self.done_tx.send_blocking(Ok(())).unwrap()
98 }
99 }
100 ThreadEvent::UsePendingTools => {
101 thread.update(cx, |thread, cx| {
102 thread.use_pending_tools(cx);
103 });
104 }
105 ThreadEvent::ToolFinished {
106 tool_use_id,
107 pending_tool_use,
108 ..
109 } => {
110 if let Some(pending_tool_use) = pending_tool_use {
111 println!(
112 "Used tool {} with input: {}",
113 pending_tool_use.name, pending_tool_use.input
114 );
115 *self
116 .tool_use_counts
117 .entry(pending_tool_use.name.clone())
118 .or_insert(0) += 1;
119 }
120 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
121 println!("Tool result: {:?}", tool_result);
122 }
123 if thread.read(cx).all_tools_finished() {
124 let model_registry = LanguageModelRegistry::read_global(cx);
125 if let Some(model) = model_registry.active_model() {
126 thread.update(cx, |thread, cx| {
127 thread.attach_tool_results(vec![], cx);
128 thread.send_to_model(model, RequestKind::Chat, cx);
129 });
130 }
131 }
132 }
133 _ => {}
134 }
135 }
136}
137
138pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
139 release_channel::init(SemanticVersion::default(), cx);
140 gpui_tokio::init(cx);
141
142 let mut settings_store = SettingsStore::new(cx);
143 settings_store
144 .set_default_settings(settings::default_settings().as_ref(), cx)
145 .unwrap();
146 cx.set_global(settings_store);
147 client::init_settings(cx);
148 Project::init_settings(cx);
149
150 let client = Client::production(cx);
151 cx.set_http_client(client.http_client().clone());
152
153 let git_binary_path = None;
154 let fs = Arc::new(RealFs::new(
155 git_binary_path,
156 cx.background_executor().clone(),
157 ));
158
159 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
160
161 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
162
163 language::init(cx);
164 language_model::init(client.clone(), cx);
165 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
166 assistant_tools::init(client.http_client().clone(), cx);
167 context_server::init(cx);
168 let stdout_is_a_pty = false;
169 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
170 assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
171
172 Arc::new(HeadlessAppState {
173 languages,
174 client,
175 user_store,
176 fs,
177 node_runtime: NodeRuntime::unavailable(),
178 prompt_builder,
179 })
180}
181
182pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
183 let model_registry = LanguageModelRegistry::read_global(cx);
184 let model = model_registry
185 .available_models(cx)
186 .find(|model| model.id().0 == model_name);
187
188 let Some(model) = model else {
189 return Err(anyhow!(
190 "No language model named {} was available. Available models: {}",
191 model_name,
192 model_registry
193 .available_models(cx)
194 .map(|model| model.id().0.clone())
195 .collect::<Vec<_>>()
196 .join(", ")
197 ));
198 };
199
200 Ok(model)
201}
202
203pub fn authenticate_model_provider(
204 provider_id: LanguageModelProviderId,
205 cx: &mut App,
206) -> Task<std::result::Result<(), AuthenticateError>> {
207 let model_registry = LanguageModelRegistry::read_global(cx);
208 let model_provider = model_registry.provider(&provider_id).unwrap();
209 model_provider.authenticate(cx)
210}
211
212pub async fn send_language_model_request(
213 model: Arc<dyn LanguageModel>,
214 request: LanguageModelRequest,
215 cx: &mut AsyncApp,
216) -> anyhow::Result<String> {
217 match model.stream_completion_text(request, &cx).await {
218 Ok(mut stream) => {
219 let mut full_response = String::new();
220
221 // Process the response stream
222 while let Some(chunk_result) = stream.stream.next().await {
223 match chunk_result {
224 Ok(chunk_str) => {
225 full_response.push_str(&chunk_str);
226 }
227 Err(err) => {
228 return Err(anyhow!(
229 "Error receiving response from language model: {err}"
230 ));
231 }
232 }
233 }
234
235 Ok(full_response)
236 }
237 Err(err) => Err(anyhow!(
238 "Failed to get response from language model. Error was: {err}"
239 )),
240 }
241}