headless_assistant.rs

  1use anyhow::anyhow;
  2use assistant_tool::ToolWorkingSet;
  3use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::StreamExt;
  8use gpui::{App, AsyncApp, Entity, SemanticVersion, Subscription, Task, prelude::*};
  9use language::LanguageRegistry;
 10use language_model::{
 11    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 12    LanguageModelRequest,
 13};
 14use node_runtime::NodeRuntime;
 15use project::{Project, RealFs};
 16use prompt_store::PromptBuilder;
 17use settings::SettingsStore;
 18use smol::channel;
 19use std::sync::Arc;
 20
 21/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 22pub struct HeadlessAppState {
 23    pub languages: Arc<LanguageRegistry>,
 24    pub client: Arc<Client>,
 25    pub user_store: Entity<UserStore>,
 26    pub fs: Arc<dyn fs::Fs>,
 27    pub node_runtime: NodeRuntime,
 28
 29    // Additional fields not present in `workspace::AppState`.
 30    pub prompt_builder: Arc<PromptBuilder>,
 31}
 32
 33pub struct HeadlessAssistant {
 34    pub thread: Entity<Thread>,
 35    pub project: Entity<Project>,
 36    #[allow(dead_code)]
 37    pub thread_store: Entity<ThreadStore>,
 38    pub tool_use_counts: HashMap<Arc<str>, u32>,
 39    pub done_tx: channel::Sender<anyhow::Result<()>>,
 40    _subscription: Subscription,
 41}
 42
 43impl HeadlessAssistant {
 44    pub fn new(
 45        app_state: Arc<HeadlessAppState>,
 46        cx: &mut App,
 47    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 48        let env = None;
 49        let project = Project::local(
 50            app_state.client.clone(),
 51            app_state.node_runtime.clone(),
 52            app_state.user_store.clone(),
 53            app_state.languages.clone(),
 54            Arc::new(DapRegistry::default()),
 55            app_state.fs.clone(),
 56            env,
 57            cx,
 58        );
 59
 60        let tools = Arc::new(ToolWorkingSet::default());
 61        let thread_store =
 62            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 63
 64        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 65
 66        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 67
 68        let headless_thread = cx.new(move |cx| Self {
 69            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 70            thread,
 71            project,
 72            thread_store,
 73            tool_use_counts: HashMap::default(),
 74            done_tx,
 75        });
 76
 77        Ok((headless_thread, done_rx))
 78    }
 79
 80    fn handle_thread_event(
 81        &mut self,
 82        thread: Entity<Thread>,
 83        event: &ThreadEvent,
 84        cx: &mut Context<Self>,
 85    ) {
 86        match event {
 87            ThreadEvent::ShowError(err) => self
 88                .done_tx
 89                .send_blocking(Err(anyhow!("{:?}", err)))
 90                .unwrap(),
 91            ThreadEvent::DoneStreaming => {
 92                let thread = thread.read(cx);
 93                if let Some(message) = thread.messages().last() {
 94                    println!("Message: {}", message.to_string());
 95                }
 96                if thread.all_tools_finished() {
 97                    self.done_tx.send_blocking(Ok(())).unwrap()
 98                }
 99            }
100            ThreadEvent::UsePendingTools => {
101                thread.update(cx, |thread, cx| {
102                    thread.use_pending_tools(cx);
103                });
104            }
105            ThreadEvent::ToolFinished {
106                tool_use_id,
107                pending_tool_use,
108                ..
109            } => {
110                if let Some(pending_tool_use) = pending_tool_use {
111                    println!(
112                        "Used tool {} with input: {}",
113                        pending_tool_use.name, pending_tool_use.input
114                    );
115                    *self
116                        .tool_use_counts
117                        .entry(pending_tool_use.name.clone())
118                        .or_insert(0) += 1;
119                }
120                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
121                    println!("Tool result: {:?}", tool_result);
122                }
123                if thread.read(cx).all_tools_finished() {
124                    let model_registry = LanguageModelRegistry::read_global(cx);
125                    if let Some(model) = model_registry.active_model() {
126                        thread.update(cx, |thread, cx| {
127                            thread.attach_tool_results(vec![], cx);
128                            thread.send_to_model(model, RequestKind::Chat, cx);
129                        });
130                    }
131                }
132            }
133            _ => {}
134        }
135    }
136}
137
138pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
139    release_channel::init(SemanticVersion::default(), cx);
140    gpui_tokio::init(cx);
141
142    let mut settings_store = SettingsStore::new(cx);
143    settings_store
144        .set_default_settings(settings::default_settings().as_ref(), cx)
145        .unwrap();
146    cx.set_global(settings_store);
147    client::init_settings(cx);
148    Project::init_settings(cx);
149
150    let client = Client::production(cx);
151    cx.set_http_client(client.http_client().clone());
152
153    let git_binary_path = None;
154    let fs = Arc::new(RealFs::new(
155        git_binary_path,
156        cx.background_executor().clone(),
157    ));
158
159    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
160
161    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
162
163    language::init(cx);
164    language_model::init(client.clone(), cx);
165    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
166    assistant_tools::init(client.http_client().clone(), cx);
167    context_server::init(cx);
168    let stdout_is_a_pty = false;
169    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
170    assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
171
172    Arc::new(HeadlessAppState {
173        languages,
174        client,
175        user_store,
176        fs,
177        node_runtime: NodeRuntime::unavailable(),
178        prompt_builder,
179    })
180}
181
182pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
183    let model_registry = LanguageModelRegistry::read_global(cx);
184    let model = model_registry
185        .available_models(cx)
186        .find(|model| model.id().0 == model_name);
187
188    let Some(model) = model else {
189        return Err(anyhow!(
190            "No language model named {} was available. Available models: {}",
191            model_name,
192            model_registry
193                .available_models(cx)
194                .map(|model| model.id().0.clone())
195                .collect::<Vec<_>>()
196                .join(", ")
197        ));
198    };
199
200    Ok(model)
201}
202
203pub fn authenticate_model_provider(
204    provider_id: LanguageModelProviderId,
205    cx: &mut App,
206) -> Task<std::result::Result<(), AuthenticateError>> {
207    let model_registry = LanguageModelRegistry::read_global(cx);
208    let model_provider = model_registry.provider(&provider_id).unwrap();
209    model_provider.authenticate(cx)
210}
211
212pub async fn send_language_model_request(
213    model: Arc<dyn LanguageModel>,
214    request: LanguageModelRequest,
215    cx: &mut AsyncApp,
216) -> anyhow::Result<String> {
217    match model.stream_completion_text(request, &cx).await {
218        Ok(mut stream) => {
219            let mut full_response = String::new();
220
221            // Process the response stream
222            while let Some(chunk_result) = stream.stream.next().await {
223                match chunk_result {
224                    Ok(chunk_str) => {
225                        full_response.push_str(&chunk_str);
226                    }
227                    Err(err) => {
228                        return Err(anyhow!(
229                            "Error receiving response from language model: {err}"
230                        ));
231                    }
232                }
233            }
234
235            Ok(full_response)
236        }
237        Err(err) => Err(anyhow!(
238            "Failed to get response from language model. Error was: {err}"
239        )),
240    }
241}