1use anyhow::anyhow;
2use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
3use assistant_tool::ToolWorkingSet;
4use client::{Client, UserStore};
5use collections::HashMap;
6use futures::StreamExt;
7use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
8use language::LanguageRegistry;
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
11 LanguageModelRequest,
12};
13use node_runtime::NodeRuntime;
14use project::{Project, RealFs};
15use prompt_store::PromptBuilder;
16use settings::SettingsStore;
17use smol::channel;
18use std::sync::Arc;
19
20/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
21pub struct HeadlessAppState {
22 pub languages: Arc<LanguageRegistry>,
23 pub client: Arc<Client>,
24 pub user_store: Entity<UserStore>,
25 pub fs: Arc<dyn fs::Fs>,
26 pub node_runtime: NodeRuntime,
27
28 // Additional fields not present in `workspace::AppState`.
29 pub prompt_builder: Arc<PromptBuilder>,
30}
31
32pub struct HeadlessAssistant {
33 pub thread: Entity<Thread>,
34 pub project: Entity<Project>,
35 #[allow(dead_code)]
36 pub thread_store: Entity<ThreadStore>,
37 pub tool_use_counts: HashMap<Arc<str>, u32>,
38 pub done_tx: channel::Sender<anyhow::Result<()>>,
39 _subscription: Subscription,
40}
41
42impl HeadlessAssistant {
43 pub fn new(
44 app_state: Arc<HeadlessAppState>,
45 cx: &mut App,
46 ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
47 let env = None;
48 let project = Project::local(
49 app_state.client.clone(),
50 app_state.node_runtime.clone(),
51 app_state.user_store.clone(),
52 app_state.languages.clone(),
53 app_state.fs.clone(),
54 env,
55 cx,
56 );
57
58 let tools = Arc::new(ToolWorkingSet::default());
59 let thread_store =
60 ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
61
62 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
63
64 let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
65
66 let headless_thread = cx.new(move |cx| Self {
67 _subscription: cx.subscribe(&thread, Self::handle_thread_event),
68 thread,
69 project,
70 thread_store,
71 tool_use_counts: HashMap::default(),
72 done_tx,
73 });
74
75 Ok((headless_thread, done_rx))
76 }
77
78 fn handle_thread_event(
79 &mut self,
80 thread: Entity<Thread>,
81 event: &ThreadEvent,
82 cx: &mut Context<Self>,
83 ) {
84 match event {
85 ThreadEvent::ShowError(err) => self
86 .done_tx
87 .send_blocking(Err(anyhow!("{:?}", err)))
88 .unwrap(),
89 ThreadEvent::DoneStreaming => {
90 let thread = thread.read(cx);
91 if let Some(message) = thread.messages().last() {
92 println!("Message: {}", message.to_string());
93 }
94 if thread.all_tools_finished() {
95 self.done_tx.send_blocking(Ok(())).unwrap()
96 }
97 }
98 ThreadEvent::UsePendingTools => {
99 thread.update(cx, |thread, cx| {
100 thread.use_pending_tools(cx);
101 });
102 }
103 ThreadEvent::ToolFinished {
104 tool_use_id,
105 pending_tool_use,
106 ..
107 } => {
108 if let Some(pending_tool_use) = pending_tool_use {
109 println!(
110 "Used tool {} with input: {}",
111 pending_tool_use.name, pending_tool_use.input
112 );
113 *self
114 .tool_use_counts
115 .entry(pending_tool_use.name.clone())
116 .or_insert(0) += 1;
117 }
118 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
119 println!("Tool result: {:?}", tool_result);
120 }
121 if thread.read(cx).all_tools_finished() {
122 let model_registry = LanguageModelRegistry::read_global(cx);
123 if let Some(model) = model_registry.active_model() {
124 thread.update(cx, |thread, cx| {
125 thread.attach_tool_results(vec![], cx);
126 thread.send_to_model(model, RequestKind::Chat, cx);
127 });
128 }
129 }
130 }
131 _ => {}
132 }
133 }
134}
135
136pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
137 release_channel::init(SemanticVersion::default(), cx);
138 gpui_tokio::init(cx);
139
140 let mut settings_store = SettingsStore::new(cx);
141 settings_store
142 .set_default_settings(settings::default_settings().as_ref(), cx)
143 .unwrap();
144 cx.set_global(settings_store);
145 client::init_settings(cx);
146 Project::init_settings(cx);
147
148 let client = Client::production(cx);
149 cx.set_http_client(client.http_client().clone());
150
151 let git_binary_path = None;
152 let fs = Arc::new(RealFs::new(git_binary_path));
153
154 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
155
156 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
157
158 language::init(cx);
159 language_model::init(client.clone(), cx);
160 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
161 assistant_tools::init(client.http_client().clone(), cx);
162 context_server::init(cx);
163 let stdout_is_a_pty = false;
164 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
165 assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
166
167 Arc::new(HeadlessAppState {
168 languages,
169 client,
170 user_store,
171 fs,
172 node_runtime: NodeRuntime::unavailable(),
173 prompt_builder,
174 })
175}
176
177pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
178 let model_registry = LanguageModelRegistry::read_global(cx);
179 let model = model_registry
180 .available_models(cx)
181 .find(|model| model.id().0 == model_name);
182
183 let Some(model) = model else {
184 return Err(anyhow!(
185 "No language model named {} was available. Available models: {}",
186 model_name,
187 model_registry
188 .available_models(cx)
189 .map(|model| model.id().0.clone())
190 .collect::<Vec<_>>()
191 .join(", ")
192 ));
193 };
194
195 Ok(model)
196}
197
198pub fn authenticate_model_provider(
199 provider_id: LanguageModelProviderId,
200 cx: &mut App,
201) -> Task<std::result::Result<(), AuthenticateError>> {
202 let model_registry = LanguageModelRegistry::read_global(cx);
203 let model_provider = model_registry.provider(&provider_id).unwrap();
204 model_provider.authenticate(cx)
205}
206
207pub async fn send_language_model_request(
208 model: Arc<dyn LanguageModel>,
209 request: LanguageModelRequest,
210 cx: &mut AsyncApp,
211) -> anyhow::Result<String> {
212 match model.stream_completion_text(request, &cx).await {
213 Ok(mut stream) => {
214 let mut full_response = String::new();
215
216 // Process the response stream
217 while let Some(chunk_result) = stream.stream.next().await {
218 match chunk_result {
219 Ok(chunk_str) => {
220 full_response.push_str(&chunk_str);
221 }
222 Err(err) => {
223 return Err(anyhow!(
224 "Error receiving response from language model: {err}"
225 ));
226 }
227 }
228 }
229
230 Ok(full_response)
231 }
232 Err(err) => Err(anyhow!(
233 "Failed to get response from language model. Error was: {err}"
234 )),
235 }
236}