1use anyhow::anyhow;
2use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
3use assistant_tool::ToolWorkingSet;
4use client::{Client, UserStore};
5use collections::HashMap;
6use futures::StreamExt;
7use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
8use language::LanguageRegistry;
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
11 LanguageModelRequest,
12};
13use node_runtime::NodeRuntime;
14use project::{Project, RealFs};
15use prompt_store::PromptBuilder;
16use settings::SettingsStore;
17use smol::channel;
18use std::sync::Arc;
19
20/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
21pub struct HeadlessAppState {
22 pub languages: Arc<LanguageRegistry>,
23 pub client: Arc<Client>,
24 pub user_store: Entity<UserStore>,
25 pub fs: Arc<dyn fs::Fs>,
26 pub node_runtime: NodeRuntime,
27
28 // Additional fields not present in `workspace::AppState`.
29 pub prompt_builder: Arc<PromptBuilder>,
30}
31
32pub struct HeadlessAssistant {
33 pub thread: Entity<Thread>,
34 pub project: Entity<Project>,
35 #[allow(dead_code)]
36 pub thread_store: Entity<ThreadStore>,
37 pub tool_use_counts: HashMap<Arc<str>, u32>,
38 pub done_tx: channel::Sender<anyhow::Result<()>>,
39 _subscription: Subscription,
40}
41
42impl HeadlessAssistant {
43 pub fn new(
44 app_state: Arc<HeadlessAppState>,
45 cx: &mut App,
46 ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
47 let env = None;
48 let project = Project::local(
49 app_state.client.clone(),
50 app_state.node_runtime.clone(),
51 app_state.user_store.clone(),
52 app_state.languages.clone(),
53 app_state.fs.clone(),
54 env,
55 cx,
56 );
57
58 let tools = Arc::new(ToolWorkingSet::default());
59 let thread_store =
60 ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
61
62 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
63
64 let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
65
66 let headless_thread = cx.new(move |cx| Self {
67 _subscription: cx.subscribe(&thread, Self::handle_thread_event),
68 thread,
69 project,
70 thread_store,
71 tool_use_counts: HashMap::default(),
72 done_tx,
73 });
74
75 Ok((headless_thread, done_rx))
76 }
77
78 fn handle_thread_event(
79 &mut self,
80 thread: Entity<Thread>,
81 event: &ThreadEvent,
82 cx: &mut Context<Self>,
83 ) {
84 match event {
85 ThreadEvent::ShowError(err) => self
86 .done_tx
87 .send_blocking(Err(anyhow!("{:?}", err)))
88 .unwrap(),
89 ThreadEvent::DoneStreaming => {
90 let thread = thread.read(cx);
91 if let Some(message) = thread.messages().last() {
92 println!("Message: {}", message.to_string());
93 }
94 if thread.all_tools_finished() {
95 self.done_tx.send_blocking(Ok(())).unwrap()
96 }
97 }
98 ThreadEvent::UsePendingTools => {
99 thread.update(cx, |thread, cx| {
100 thread.use_pending_tools(cx);
101 });
102 }
103 ThreadEvent::ToolFinished {
104 tool_use_id,
105 pending_tool_use,
106 ..
107 } => {
108 if let Some(pending_tool_use) = pending_tool_use {
109 println!(
110 "Used tool {} with input: {}",
111 pending_tool_use.name, pending_tool_use.input
112 );
113 *self
114 .tool_use_counts
115 .entry(pending_tool_use.name.clone())
116 .or_insert(0) += 1;
117 }
118 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
119 println!("Tool result: {:?}", tool_result);
120 }
121 if thread.read(cx).all_tools_finished() {
122 let model_registry = LanguageModelRegistry::read_global(cx);
123 if let Some(model) = model_registry.active_model() {
124 thread.update(cx, |thread, cx| {
125 thread.attach_tool_results(vec![], cx);
126 thread.send_to_model(model, RequestKind::Chat, cx);
127 });
128 }
129 }
130 }
131 _ => {}
132 }
133 }
134}
135
136pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
137 release_channel::init(SemanticVersion::default(), cx);
138 gpui_tokio::init(cx);
139
140 let mut settings_store = SettingsStore::new(cx);
141 settings_store
142 .set_default_settings(settings::default_settings().as_ref(), cx)
143 .unwrap();
144 cx.set_global(settings_store);
145 client::init_settings(cx);
146 Project::init_settings(cx);
147
148 let client = Client::production(cx);
149 cx.set_http_client(client.http_client().clone());
150
151 let git_binary_path = None;
152 let fs = Arc::new(RealFs::new(
153 git_binary_path,
154 cx.background_executor().clone(),
155 ));
156
157 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
158
159 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
160
161 language::init(cx);
162 language_model::init(client.clone(), cx);
163 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
164 assistant_tools::init(client.http_client().clone(), cx);
165 context_server::init(cx);
166 let stdout_is_a_pty = false;
167 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
168 assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
169
170 Arc::new(HeadlessAppState {
171 languages,
172 client,
173 user_store,
174 fs,
175 node_runtime: NodeRuntime::unavailable(),
176 prompt_builder,
177 })
178}
179
180pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
181 let model_registry = LanguageModelRegistry::read_global(cx);
182 let model = model_registry
183 .available_models(cx)
184 .find(|model| model.id().0 == model_name);
185
186 let Some(model) = model else {
187 return Err(anyhow!(
188 "No language model named {} was available. Available models: {}",
189 model_name,
190 model_registry
191 .available_models(cx)
192 .map(|model| model.id().0.clone())
193 .collect::<Vec<_>>()
194 .join(", ")
195 ));
196 };
197
198 Ok(model)
199}
200
201pub fn authenticate_model_provider(
202 provider_id: LanguageModelProviderId,
203 cx: &mut App,
204) -> Task<std::result::Result<(), AuthenticateError>> {
205 let model_registry = LanguageModelRegistry::read_global(cx);
206 let model_provider = model_registry.provider(&provider_id).unwrap();
207 model_provider.authenticate(cx)
208}
209
210pub async fn send_language_model_request(
211 model: Arc<dyn LanguageModel>,
212 request: LanguageModelRequest,
213 cx: &mut AsyncApp,
214) -> anyhow::Result<String> {
215 match model.stream_completion_text(request, &cx).await {
216 Ok(mut stream) => {
217 let mut full_response = String::new();
218
219 // Process the response stream
220 while let Some(chunk_result) = stream.stream.next().await {
221 match chunk_result {
222 Ok(chunk_str) => {
223 full_response.push_str(&chunk_str);
224 }
225 Err(err) => {
226 return Err(anyhow!(
227 "Error receiving response from language model: {err}"
228 ));
229 }
230 }
231 }
232
233 Ok(full_response)
234 }
235 Err(err) => Err(anyhow!(
236 "Failed to get response from language model. Error was: {err}"
237 )),
238 }
239}