headless_assistant.rs

  1use anyhow::anyhow;
  2use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
  3use assistant_tool::ToolWorkingSet;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use futures::StreamExt;
  7use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
  8use language::LanguageRegistry;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11    LanguageModelRequest,
 12};
 13use node_runtime::NodeRuntime;
 14use project::{Project, RealFs};
 15use prompt_store::PromptBuilder;
 16use settings::SettingsStore;
 17use smol::channel;
 18use std::sync::Arc;
 19
 20/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 21pub struct HeadlessAppState {
 22    pub languages: Arc<LanguageRegistry>,
 23    pub client: Arc<Client>,
 24    pub user_store: Entity<UserStore>,
 25    pub fs: Arc<dyn fs::Fs>,
 26    pub node_runtime: NodeRuntime,
 27
 28    // Additional fields not present in `workspace::AppState`.
 29    pub prompt_builder: Arc<PromptBuilder>,
 30}
 31
 32pub struct HeadlessAssistant {
 33    pub thread: Entity<Thread>,
 34    pub project: Entity<Project>,
 35    #[allow(dead_code)]
 36    pub thread_store: Entity<ThreadStore>,
 37    pub tool_use_counts: HashMap<Arc<str>, u32>,
 38    pub done_tx: channel::Sender<anyhow::Result<()>>,
 39    _subscription: Subscription,
 40}
 41
 42impl HeadlessAssistant {
 43    pub fn new(
 44        app_state: Arc<HeadlessAppState>,
 45        cx: &mut App,
 46    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
 47        let env = None;
 48        let project = Project::local(
 49            app_state.client.clone(),
 50            app_state.node_runtime.clone(),
 51            app_state.user_store.clone(),
 52            app_state.languages.clone(),
 53            app_state.fs.clone(),
 54            env,
 55            cx,
 56        );
 57
 58        let tools = Arc::new(ToolWorkingSet::default());
 59        let thread_store =
 60            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
 61
 62        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
 63
 64        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
 65
 66        let headless_thread = cx.new(move |cx| Self {
 67            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
 68            thread,
 69            project,
 70            thread_store,
 71            tool_use_counts: HashMap::default(),
 72            done_tx,
 73        });
 74
 75        Ok((headless_thread, done_rx))
 76    }
 77
 78    fn handle_thread_event(
 79        &mut self,
 80        thread: Entity<Thread>,
 81        event: &ThreadEvent,
 82        cx: &mut Context<Self>,
 83    ) {
 84        match event {
 85            ThreadEvent::ShowError(err) => self
 86                .done_tx
 87                .send_blocking(Err(anyhow!("{:?}", err)))
 88                .unwrap(),
 89            ThreadEvent::DoneStreaming => {
 90                let thread = thread.read(cx);
 91                if let Some(message) = thread.messages().last() {
 92                    println!("Message: {}", message.to_string());
 93                }
 94                if thread.all_tools_finished() {
 95                    self.done_tx.send_blocking(Ok(())).unwrap()
 96                }
 97            }
 98            ThreadEvent::UsePendingTools => {
 99                thread.update(cx, |thread, cx| {
100                    thread.use_pending_tools(cx);
101                });
102            }
103            ThreadEvent::ToolFinished {
104                tool_use_id,
105                pending_tool_use,
106                ..
107            } => {
108                if let Some(pending_tool_use) = pending_tool_use {
109                    println!(
110                        "Used tool {} with input: {}",
111                        pending_tool_use.name, pending_tool_use.input
112                    );
113                    *self
114                        .tool_use_counts
115                        .entry(pending_tool_use.name.clone())
116                        .or_insert(0) += 1;
117                }
118                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
119                    println!("Tool result: {:?}", tool_result);
120                }
121                if thread.read(cx).all_tools_finished() {
122                    let model_registry = LanguageModelRegistry::read_global(cx);
123                    if let Some(model) = model_registry.active_model() {
124                        thread.update(cx, |thread, cx| {
125                            thread.attach_tool_results(vec![], cx);
126                            thread.send_to_model(model, RequestKind::Chat, cx);
127                        });
128                    }
129                }
130            }
131            _ => {}
132        }
133    }
134}
135
136pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
137    release_channel::init(SemanticVersion::default(), cx);
138    gpui_tokio::init(cx);
139
140    let mut settings_store = SettingsStore::new(cx);
141    settings_store
142        .set_default_settings(settings::default_settings().as_ref(), cx)
143        .unwrap();
144    cx.set_global(settings_store);
145    client::init_settings(cx);
146    Project::init_settings(cx);
147
148    let client = Client::production(cx);
149    cx.set_http_client(client.http_client().clone());
150
151    let git_binary_path = None;
152    let fs = Arc::new(RealFs::new(git_binary_path));
153
154    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
155
156    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
157
158    language::init(cx);
159    language_model::init(client.clone(), cx);
160    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
161    assistant_tools::init(client.http_client().clone(), cx);
162    context_server::init(cx);
163    let stdout_is_a_pty = false;
164    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
165    assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
166
167    Arc::new(HeadlessAppState {
168        languages,
169        client,
170        user_store,
171        fs,
172        node_runtime: NodeRuntime::unavailable(),
173        prompt_builder,
174    })
175}
176
177pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
178    let model_registry = LanguageModelRegistry::read_global(cx);
179    let model = model_registry
180        .available_models(cx)
181        .find(|model| model.id().0 == model_name);
182
183    let Some(model) = model else {
184        return Err(anyhow!(
185            "No language model named {} was available. Available models: {}",
186            model_name,
187            model_registry
188                .available_models(cx)
189                .map(|model| model.id().0.clone())
190                .collect::<Vec<_>>()
191                .join(", ")
192        ));
193    };
194
195    Ok(model)
196}
197
198pub fn authenticate_model_provider(
199    provider_id: LanguageModelProviderId,
200    cx: &mut App,
201) -> Task<std::result::Result<(), AuthenticateError>> {
202    let model_registry = LanguageModelRegistry::read_global(cx);
203    let model_provider = model_registry.provider(&provider_id).unwrap();
204    model_provider.authenticate(cx)
205}
206
207pub async fn send_language_model_request(
208    model: Arc<dyn LanguageModel>,
209    request: LanguageModelRequest,
210    cx: &mut AsyncApp,
211) -> anyhow::Result<String> {
212    match model.stream_completion_text(request, &cx).await {
213        Ok(mut stream) => {
214            let mut full_response = String::new();
215
216            // Process the response stream
217            while let Some(chunk_result) = stream.stream.next().await {
218                match chunk_result {
219                    Ok(chunk_str) => {
220                        full_response.push_str(&chunk_str);
221                    }
222                    Err(err) => {
223                        return Err(anyhow!(
224                            "Error receiving response from language model: {err}"
225                        ));
226                    }
227                }
228            }
229
230            Ok(full_response)
231        }
232        Err(err) => Err(anyhow!(
233            "Failed to get response from language model. Error was: {err}"
234        )),
235    }
236}