thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::FutureExt as _;
 13use futures::future::{self, BoxFuture, Shared};
 14use gpui::{
 15    App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
 16    prelude::*,
 17};
 18use heed::Database;
 19use heed::types::SerdeBincode;
 20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 21use project::Project;
 22use prompt_store::PromptBuilder;
 23use serde::{Deserialize, Serialize};
 24use settings::{Settings as _, SettingsStore};
 25use util::ResultExt as _;
 26
 27use crate::thread::{
 28    DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
 29};
 30
 31pub fn init(cx: &mut App) {
 32    ThreadsDatabase::init(cx);
 33}
 34
 35pub struct ThreadStore {
 36    project: Entity<Project>,
 37    tools: Arc<ToolWorkingSet>,
 38    prompt_builder: Arc<PromptBuilder>,
 39    context_server_manager: Entity<ContextServerManager>,
 40    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 41    threads: Vec<SerializedThreadMetadata>,
 42    _subscriptions: Vec<Subscription>,
 43}
 44
 45impl ThreadStore {
 46    pub fn new(
 47        project: Entity<Project>,
 48        tools: Arc<ToolWorkingSet>,
 49        prompt_builder: Arc<PromptBuilder>,
 50        cx: &mut App,
 51    ) -> Result<Entity<Self>> {
 52        let this = cx.new(|cx| {
 53            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 54            let context_server_manager = cx.new(|cx| {
 55                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 56            });
 57            let settings_subscription =
 58                cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 59                    this.load_default_profile(cx);
 60                });
 61
 62            let this = Self {
 63                project,
 64                tools,
 65                prompt_builder,
 66                context_server_manager,
 67                context_server_tool_ids: HashMap::default(),
 68                threads: Vec::new(),
 69                _subscriptions: vec![settings_subscription],
 70            };
 71            this.load_default_profile(cx);
 72            this.register_context_server_handlers(cx);
 73            this.reload(cx).detach_and_log_err(cx);
 74
 75            this
 76        });
 77
 78        Ok(this)
 79    }
 80
 81    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 82        self.context_server_manager.clone()
 83    }
 84
 85    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 86        self.tools.clone()
 87    }
 88
 89    /// Returns the number of threads.
 90    pub fn thread_count(&self) -> usize {
 91        self.threads.len()
 92    }
 93
 94    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 95        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 96        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 97        threads
 98    }
 99
100    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
101        self.threads().into_iter().take(limit).collect()
102    }
103
104    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
105        cx.new(|cx| {
106            Thread::new(
107                self.project.clone(),
108                self.tools.clone(),
109                self.prompt_builder.clone(),
110                cx,
111            )
112        })
113    }
114
115    pub fn open_thread(
116        &self,
117        id: &ThreadId,
118        cx: &mut Context<Self>,
119    ) -> Task<Result<Entity<Thread>>> {
120        let id = id.clone();
121        let database_future = ThreadsDatabase::global_future(cx);
122        cx.spawn(async move |this, cx| {
123            let database = database_future.await.map_err(|err| anyhow!(err))?;
124            let thread = database
125                .try_find_thread(id.clone())
126                .await?
127                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
128
129            let thread = this.update(cx, |this, cx| {
130                cx.new(|cx| {
131                    Thread::deserialize(
132                        id.clone(),
133                        thread,
134                        this.project.clone(),
135                        this.tools.clone(),
136                        this.prompt_builder.clone(),
137                        cx,
138                    )
139                })
140            })?;
141
142            let (system_prompt_context, load_error) = thread
143                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
144                .await;
145            thread.update(cx, |thread, cx| {
146                thread.set_system_prompt_context(system_prompt_context);
147                if let Some(load_error) = load_error {
148                    cx.emit(ThreadEvent::ShowError(load_error));
149                }
150            })?;
151
152            Ok(thread)
153        })
154    }
155
156    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
157        let (metadata, serialized_thread) =
158            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
159
160        let database_future = ThreadsDatabase::global_future(cx);
161        cx.spawn(async move |this, cx| {
162            let serialized_thread = serialized_thread.await?;
163            let database = database_future.await.map_err(|err| anyhow!(err))?;
164            database.save_thread(metadata, serialized_thread).await?;
165
166            this.update(cx, |this, cx| this.reload(cx))?.await
167        })
168    }
169
170    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
171        let id = id.clone();
172        let database_future = ThreadsDatabase::global_future(cx);
173        cx.spawn(async move |this, cx| {
174            let database = database_future.await.map_err(|err| anyhow!(err))?;
175            database.delete_thread(id.clone()).await?;
176
177            this.update(cx, |this, cx| {
178                this.threads.retain(|thread| thread.id != id);
179                cx.notify();
180            })
181        })
182    }
183
184    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
185        let database_future = ThreadsDatabase::global_future(cx);
186        cx.spawn(async move |this, cx| {
187            let threads = database_future
188                .await
189                .map_err(|err| anyhow!(err))?
190                .list_threads()
191                .await?;
192
193            this.update(cx, |this, cx| {
194                this.threads = threads;
195                cx.notify();
196            })
197        })
198    }
199
200    fn load_default_profile(&self, cx: &Context<Self>) {
201        let assistant_settings = AssistantSettings::get_global(cx);
202
203        self.load_profile_by_id(&assistant_settings.default_profile, cx);
204    }
205
206    pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
207        let assistant_settings = AssistantSettings::get_global(cx);
208
209        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
210            self.load_profile(profile, cx);
211        }
212    }
213
214    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
215        self.tools.disable_all_tools();
216        self.tools.enable(
217            ToolSource::Native,
218            &profile
219                .tools
220                .iter()
221                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
222                .collect::<Vec<_>>(),
223        );
224
225        if profile.enable_all_context_servers {
226            for context_server in self.context_server_manager.read(cx).all_servers() {
227                self.tools.enable_source(
228                    ToolSource::ContextServer {
229                        id: context_server.id().into(),
230                    },
231                    cx,
232                );
233            }
234        } else {
235            for (context_server_id, preset) in &profile.context_servers {
236                self.tools.enable(
237                    ToolSource::ContextServer {
238                        id: context_server_id.clone().into(),
239                    },
240                    &preset
241                        .tools
242                        .iter()
243                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
244                        .collect::<Vec<_>>(),
245                )
246            }
247        }
248    }
249
250    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
251        cx.subscribe(
252            &self.context_server_manager.clone(),
253            Self::handle_context_server_event,
254        )
255        .detach();
256    }
257
258    fn handle_context_server_event(
259        &mut self,
260        context_server_manager: Entity<ContextServerManager>,
261        event: &context_server::manager::Event,
262        cx: &mut Context<Self>,
263    ) {
264        let tool_working_set = self.tools.clone();
265        match event {
266            context_server::manager::Event::ServerStarted { server_id } => {
267                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
268                    let context_server_manager = context_server_manager.clone();
269                    cx.spawn({
270                        let server = server.clone();
271                        let server_id = server_id.clone();
272                        async move |this, cx| {
273                            let Some(protocol) = server.client() else {
274                                return;
275                            };
276
277                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
278                                if let Some(tools) = protocol.list_tools().await.log_err() {
279                                    let tool_ids = tools
280                                        .tools
281                                        .into_iter()
282                                        .map(|tool| {
283                                            log::info!(
284                                                "registering context server tool: {:?}",
285                                                tool.name
286                                            );
287                                            tool_working_set.insert(Arc::new(
288                                                ContextServerTool::new(
289                                                    context_server_manager.clone(),
290                                                    server.id(),
291                                                    tool,
292                                                ),
293                                            ))
294                                        })
295                                        .collect::<Vec<_>>();
296
297                                    this.update(cx, |this, cx| {
298                                        this.context_server_tool_ids.insert(server_id, tool_ids);
299                                        this.load_default_profile(cx);
300                                    })
301                                    .log_err();
302                                }
303                            }
304                        }
305                    })
306                    .detach();
307                }
308            }
309            context_server::manager::Event::ServerStopped { server_id } => {
310                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
311                    tool_working_set.remove(&tool_ids);
312                    self.load_default_profile(cx);
313                }
314            }
315        }
316    }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct SerializedThreadMetadata {
321    pub id: ThreadId,
322    pub summary: SharedString,
323    pub updated_at: DateTime<Utc>,
324}
325
326#[derive(Serialize, Deserialize, Debug)]
327pub struct SerializedThread {
328    pub version: String,
329    pub summary: SharedString,
330    pub updated_at: DateTime<Utc>,
331    pub messages: Vec<SerializedMessage>,
332    #[serde(default)]
333    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
334    #[serde(default)]
335    pub cumulative_token_usage: TokenUsage,
336    #[serde(default)]
337    pub detailed_summary_state: DetailedSummaryState,
338}
339
340impl SerializedThread {
341    pub const VERSION: &'static str = "0.1.0";
342
343    pub fn from_json(json: &[u8]) -> Result<Self> {
344        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
345        match saved_thread_json.get("version") {
346            Some(serde_json::Value::String(version)) => match version.as_str() {
347                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
348                    saved_thread_json,
349                )?),
350                _ => Err(anyhow!(
351                    "unrecognized serialized thread version: {}",
352                    version
353                )),
354            },
355            None => {
356                let saved_thread =
357                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
358                Ok(saved_thread.upgrade())
359            }
360            version => Err(anyhow!(
361                "unrecognized serialized thread version: {:?}",
362                version
363            )),
364        }
365    }
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct SerializedMessage {
370    pub id: MessageId,
371    pub role: Role,
372    #[serde(default)]
373    pub segments: Vec<SerializedMessageSegment>,
374    #[serde(default)]
375    pub tool_uses: Vec<SerializedToolUse>,
376    #[serde(default)]
377    pub tool_results: Vec<SerializedToolResult>,
378    #[serde(default)]
379    pub context: String,
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383#[serde(tag = "type")]
384pub enum SerializedMessageSegment {
385    #[serde(rename = "text")]
386    Text { text: String },
387    #[serde(rename = "thinking")]
388    Thinking { text: String },
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct SerializedToolUse {
393    pub id: LanguageModelToolUseId,
394    pub name: SharedString,
395    pub input: serde_json::Value,
396}
397
398#[derive(Debug, Serialize, Deserialize)]
399pub struct SerializedToolResult {
400    pub tool_use_id: LanguageModelToolUseId,
401    pub is_error: bool,
402    pub content: Arc<str>,
403}
404
405#[derive(Serialize, Deserialize)]
406struct LegacySerializedThread {
407    pub summary: SharedString,
408    pub updated_at: DateTime<Utc>,
409    pub messages: Vec<LegacySerializedMessage>,
410    #[serde(default)]
411    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
412}
413
414impl LegacySerializedThread {
415    pub fn upgrade(self) -> SerializedThread {
416        SerializedThread {
417            version: SerializedThread::VERSION.to_string(),
418            summary: self.summary,
419            updated_at: self.updated_at,
420            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
421            initial_project_snapshot: self.initial_project_snapshot,
422            cumulative_token_usage: TokenUsage::default(),
423            detailed_summary_state: DetailedSummaryState::default(),
424        }
425    }
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429struct LegacySerializedMessage {
430    pub id: MessageId,
431    pub role: Role,
432    pub text: String,
433    #[serde(default)]
434    pub tool_uses: Vec<SerializedToolUse>,
435    #[serde(default)]
436    pub tool_results: Vec<SerializedToolResult>,
437}
438
439impl LegacySerializedMessage {
440    fn upgrade(self) -> SerializedMessage {
441        SerializedMessage {
442            id: self.id,
443            role: self.role,
444            segments: vec![SerializedMessageSegment::Text { text: self.text }],
445            tool_uses: self.tool_uses,
446            tool_results: self.tool_results,
447            context: String::new(),
448        }
449    }
450}
451
452struct GlobalThreadsDatabase(
453    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
454);
455
456impl Global for GlobalThreadsDatabase {}
457
458pub(crate) struct ThreadsDatabase {
459    executor: BackgroundExecutor,
460    env: heed::Env,
461    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
462}
463
464impl heed::BytesEncode<'_> for SerializedThread {
465    type EItem = SerializedThread;
466
467    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
468        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
469    }
470}
471
472impl<'a> heed::BytesDecode<'a> for SerializedThread {
473    type DItem = SerializedThread;
474
475    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
476        // We implement this type manually because we want to call `SerializedThread::from_json`,
477        // instead of the Deserialize trait implementation for `SerializedThread`.
478        SerializedThread::from_json(bytes).map_err(Into::into)
479    }
480}
481
482impl ThreadsDatabase {
483    fn global_future(
484        cx: &mut App,
485    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
486        GlobalThreadsDatabase::global(cx).0.clone()
487    }
488
489    fn init(cx: &mut App) {
490        let executor = cx.background_executor().clone();
491        let database_future = executor
492            .spawn({
493                let executor = executor.clone();
494                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
495                async move { ThreadsDatabase::new(database_path, executor) }
496            })
497            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
498            .boxed()
499            .shared();
500
501        cx.set_global(GlobalThreadsDatabase(database_future));
502    }
503
504    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
505        std::fs::create_dir_all(&path)?;
506
507        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
508        let env = unsafe {
509            heed::EnvOpenOptions::new()
510                .map_size(ONE_GB_IN_BYTES)
511                .max_dbs(1)
512                .open(path)?
513        };
514
515        let mut txn = env.write_txn()?;
516        let threads = env.create_database(&mut txn, Some("threads"))?;
517        txn.commit()?;
518
519        Ok(Self {
520            executor,
521            env,
522            threads,
523        })
524    }
525
526    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
527        let env = self.env.clone();
528        let threads = self.threads;
529
530        self.executor.spawn(async move {
531            let txn = env.read_txn()?;
532            let mut iter = threads.iter(&txn)?;
533            let mut threads = Vec::new();
534            while let Some((key, value)) = iter.next().transpose()? {
535                threads.push(SerializedThreadMetadata {
536                    id: key,
537                    summary: value.summary,
538                    updated_at: value.updated_at,
539                });
540            }
541
542            Ok(threads)
543        })
544    }
545
546    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
547        let env = self.env.clone();
548        let threads = self.threads;
549
550        self.executor.spawn(async move {
551            let txn = env.read_txn()?;
552            let thread = threads.get(&txn, &id)?;
553            Ok(thread)
554        })
555    }
556
557    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
558        let env = self.env.clone();
559        let threads = self.threads;
560
561        self.executor.spawn(async move {
562            let mut txn = env.write_txn()?;
563            threads.put(&mut txn, &id, &thread)?;
564            txn.commit()?;
565            Ok(())
566        })
567    }
568
569    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
570        let env = self.env.clone();
571        let threads = self.threads;
572
573        self.executor.spawn(async move {
574            let mut txn = env.write_txn()?;
575            threads.delete(&mut txn, &id)?;
576            txn.commit()?;
577            Ok(())
578        })
579    }
580}