headless_assistant.rs

  1use anyhow::anyhow;
  2use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
  3use assistant_tool::ToolWorkingSet;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use futures::StreamExt;
  7use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
  8use language::LanguageRegistry;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11    LanguageModelRequest,
 12};
 13use node_runtime::NodeRuntime;
 14use project::{Project, RealFs};
 15use prompt_store::PromptBuilder;
 16use settings::SettingsStore;
 17use smol::channel;
 18use std::sync::Arc;
 19
 20/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 21pub struct HeadlessAppState {
 22    pub languages: Arc<LanguageRegistry>,
 23    pub client: Arc<Client>,
 24    pub user_store: Entity<UserStore>,
 25    pub fs: Arc<dyn fs::Fs>,
 26    pub node_runtime: NodeRuntime,
 27
 28    // Additional fields not present in `workspace::AppState`.
 29    pub prompt_builder: Arc<PromptBuilder>,
 30}
 31
 32pub struct HeadlessAssistant {
 33    pub thread: Entity<Thread>,
 34    pub project: Entity<Project>,
 35    #[allow(dead_code)]
 36    pub thread_store: Entity<ThreadStore>,
 37    pub tool_use_counts: HashMap<Arc<str>, u32>,
 38    pub done_tx: channel::Sender<anyhow::Result<()>>,
 39    _subscription: Subscription,
 40}
 41
 42impl HeadlessAssistant {
 43    pub fn new(
 44        app_state: Arc<HeadlessAppState>,
 45        cx: &mut App,
 46    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 47        let env = None;
 48        let project = Project::local(
 49            app_state.client.clone(),
 50            app_state.node_runtime.clone(),
 51            app_state.user_store.clone(),
 52            app_state.languages.clone(),
 53            app_state.fs.clone(),
 54            env,
 55            cx,
 56        );
 57
 58        let tools = Arc::new(ToolWorkingSet::default());
 59        let thread_store =
 60            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 61
 62        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 63
 64        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 65
 66        let headless_thread = cx.new(move |cx| Self {
 67            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 68            thread,
 69            project,
 70            thread_store,
 71            tool_use_counts: HashMap::default(),
 72            done_tx,
 73        });
 74
 75        Ok((headless_thread, done_rx))
 76    }
 77
 78    fn handle_thread_event(
 79        &mut self,
 80        thread: Entity<Thread>,
 81        event: &ThreadEvent,
 82        cx: &mut Context<Self>,
 83    ) {
 84        match event {
 85            ThreadEvent::ShowError(err) => self
 86                .done_tx
 87                .send_blocking(Err(anyhow!("{:?}", err)))
 88                .unwrap(),
 89            ThreadEvent::DoneStreaming => {
 90                let thread = thread.read(cx);
 91                if let Some(message) = thread.messages().last() {
 92                    println!("Message: {}", message.text,);
 93                }
 94                if thread.all_tools_finished() {
 95                    self.done_tx.send_blocking(Ok(())).unwrap()
 96                }
 97            }
 98            ThreadEvent::UsePendingTools => {
 99                thread.update(cx, |thread, cx| {
100                    thread.use_pending_tools(cx);
101                });
102            }
103            ThreadEvent::ToolFinished {
104                tool_use_id,
105                pending_tool_use,
106                ..
107            } => {
108                if let Some(pending_tool_use) = pending_tool_use {
109                    println!(
110                        "Used tool {} with input: {}",
111                        pending_tool_use.name, pending_tool_use.input
112                    );
113                    *self
114                        .tool_use_counts
115                        .entry(pending_tool_use.name.clone())
116                        .or_insert(0) += 1;
117                }
118                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
119                    println!("Tool result: {:?}", tool_result);
120                }
121                if thread.read(cx).all_tools_finished() {
122                    let model_registry = LanguageModelRegistry::read_global(cx);
123                    if let Some(model) = model_registry.active_model() {
124                        thread.update(cx, |thread, cx| {
125                            thread.attach_tool_results(vec![], cx);
126                            thread.send_to_model(model, RequestKind::Chat, cx);
127                        });
128                    }
129                }
130            }
131            ThreadEvent::StreamedCompletion
132            | ThreadEvent::SummaryChanged
133            | ThreadEvent::StreamedAssistantText(_, _)
134            | ThreadEvent::MessageAdded(_)
135            | ThreadEvent::MessageEdited(_)
136            | ThreadEvent::MessageDeleted(_) => {}
137        }
138    }
139}
140
141pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
142    release_channel::init(SemanticVersion::default(), cx);
143    gpui_tokio::init(cx);
144
145    let mut settings_store = SettingsStore::new(cx);
146    settings_store
147        .set_default_settings(settings::default_settings().as_ref(), cx)
148        .unwrap();
149    cx.set_global(settings_store);
150    client::init_settings(cx);
151    Project::init_settings(cx);
152
153    let client = Client::production(cx);
154    cx.set_http_client(client.http_client().clone());
155
156    let git_binary_path = None;
157    let fs = Arc::new(RealFs::new(git_binary_path));
158
159    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
160
161    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
162
163    language::init(cx);
164    language_model::init(client.clone(), cx);
165    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
166    assistant_tools::init(client.http_client().clone(), cx);
167    context_server::init(cx);
168    let stdout_is_a_pty = false;
169    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
170    assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
171
172    Arc::new(HeadlessAppState {
173        languages,
174        client,
175        user_store,
176        fs,
177        node_runtime: NodeRuntime::unavailable(),
178        prompt_builder,
179    })
180}
181
182pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
183    let model_registry = LanguageModelRegistry::read_global(cx);
184    let model = model_registry
185        .available_models(cx)
186        .find(|model| model.id().0 == model_name);
187
188    let Some(model) = model else {
189        return Err(anyhow!(
190            "No language model named {} was available. Available models: {}",
191            model_name,
192            model_registry
193                .available_models(cx)
194                .map(|model| model.id().0.clone())
195                .collect::<Vec<_>>()
196                .join(", ")
197        ));
198    };
199
200    Ok(model)
201}
202
203pub fn authenticate_model_provider(
204    provider_id: LanguageModelProviderId,
205    cx: &mut App,
206) -> Task<std::result::Result<(), AuthenticateError>> {
207    let model_registry = LanguageModelRegistry::read_global(cx);
208    let model_provider = model_registry.provider(&provider_id).unwrap();
209    model_provider.authenticate(cx)
210}
211
212pub async fn send_language_model_request(
213    model: Arc<dyn LanguageModel>,
214    request: LanguageModelRequest,
215    cx: AsyncApp,
216) -> anyhow::Result<String> {
217    match model.stream_completion_text(request, &cx).await {
218        Ok(mut stream) => {
219            let mut full_response = String::new();
220
221            // Process the response stream
222            while let Some(chunk_result) = stream.stream.next().await {
223                match chunk_result {
224                    Ok(chunk_str) => {
225                        full_response.push_str(&chunk_str);
226                    }
227                    Err(err) => {
228                        return Err(anyhow!(
229                            "Error receiving response from language model: {err}"
230                        ));
231                    }
232                }
233            }
234
235            Ok(full_response)
236        }
237        Err(err) => Err(anyhow!(
238            "Failed to get response from language model. Error was: {err}"
239        )),
240    }
241}