thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::FutureExt as _;
 13use futures::future::{self, BoxFuture, Shared};
 14use gpui::{
 15    App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
 16    prelude::*,
 17};
 18use heed::Database;
 19use heed::types::SerdeBincode;
 20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 21use project::Project;
 22use prompt_store::PromptBuilder;
 23use serde::{Deserialize, Serialize};
 24use settings::{Settings as _, SettingsStore};
 25use util::ResultExt as _;
 26
 27use crate::thread::{
 28    DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
 29};
 30
 31pub fn init(cx: &mut App) {
 32    ThreadsDatabase::init(cx);
 33}
 34
 35pub struct ThreadStore {
 36    project: Entity<Project>,
 37    tools: Arc<ToolWorkingSet>,
 38    prompt_builder: Arc<PromptBuilder>,
 39    context_server_manager: Entity<ContextServerManager>,
 40    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 41    threads: Vec<SerializedThreadMetadata>,
 42    _subscriptions: Vec<Subscription>,
 43}
 44
 45impl ThreadStore {
 46    pub fn new(
 47        project: Entity<Project>,
 48        tools: Arc<ToolWorkingSet>,
 49        prompt_builder: Arc<PromptBuilder>,
 50        cx: &mut App,
 51    ) -> Result<Entity<Self>> {
 52        let this = cx.new(|cx| {
 53            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 54            let context_server_manager = cx.new(|cx| {
 55                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 56            });
 57            let settings_subscription =
 58                cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 59                    this.load_default_profile(cx);
 60                });
 61
 62            let this = Self {
 63                project,
 64                tools,
 65                prompt_builder,
 66                context_server_manager,
 67                context_server_tool_ids: HashMap::default(),
 68                threads: Vec::new(),
 69                _subscriptions: vec![settings_subscription],
 70            };
 71            this.load_default_profile(cx);
 72            this.register_context_server_handlers(cx);
 73            this.reload(cx).detach_and_log_err(cx);
 74
 75            this
 76        });
 77
 78        Ok(this)
 79    }
 80
 81    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 82        self.context_server_manager.clone()
 83    }
 84
 85    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 86        self.tools.clone()
 87    }
 88
 89    /// Returns the number of threads.
 90    pub fn thread_count(&self) -> usize {
 91        self.threads.len()
 92    }
 93
 94    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 95        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 96        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 97        threads
 98    }
 99
100    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
101        self.threads().into_iter().take(limit).collect()
102    }
103
104    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
105        cx.new(|cx| {
106            Thread::new(
107                self.project.clone(),
108                self.tools.clone(),
109                self.prompt_builder.clone(),
110                cx,
111            )
112        })
113    }
114
115    pub fn open_thread(
116        &self,
117        id: &ThreadId,
118        cx: &mut Context<Self>,
119    ) -> Task<Result<Entity<Thread>>> {
120        let id = id.clone();
121        let database_future = ThreadsDatabase::global_future(cx);
122        cx.spawn(async move |this, cx| {
123            let database = database_future.await.map_err(|err| anyhow!(err))?;
124            let thread = database
125                .try_find_thread(id.clone())
126                .await?
127                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
128
129            let thread = this.update(cx, |this, cx| {
130                cx.new(|cx| {
131                    Thread::deserialize(
132                        id.clone(),
133                        thread,
134                        this.project.clone(),
135                        this.tools.clone(),
136                        this.prompt_builder.clone(),
137                        cx,
138                    )
139                })
140            })?;
141
142            let (system_prompt_context, load_error) = thread
143                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
144                .await;
145            thread.update(cx, |thread, cx| {
146                thread.set_system_prompt_context(system_prompt_context);
147                if let Some(load_error) = load_error {
148                    cx.emit(ThreadEvent::ShowError(load_error));
149                }
150            })?;
151
152            Ok(thread)
153        })
154    }
155
156    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
157        let (metadata, serialized_thread) =
158            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
159
160        let database_future = ThreadsDatabase::global_future(cx);
161        cx.spawn(async move |this, cx| {
162            let serialized_thread = serialized_thread.await?;
163            let database = database_future.await.map_err(|err| anyhow!(err))?;
164            database.save_thread(metadata, serialized_thread).await?;
165
166            this.update(cx, |this, cx| this.reload(cx))?.await
167        })
168    }
169
170    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
171        let id = id.clone();
172        let database_future = ThreadsDatabase::global_future(cx);
173        cx.spawn(async move |this, cx| {
174            let database = database_future.await.map_err(|err| anyhow!(err))?;
175            database.delete_thread(id.clone()).await?;
176
177            this.update(cx, |this, _cx| {
178                this.threads.retain(|thread| thread.id != id)
179            })
180        })
181    }
182
183    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
184        let database_future = ThreadsDatabase::global_future(cx);
185        cx.spawn(async move |this, cx| {
186            let threads = database_future
187                .await
188                .map_err(|err| anyhow!(err))?
189                .list_threads()
190                .await?;
191
192            this.update(cx, |this, cx| {
193                this.threads = threads;
194                cx.notify();
195            })
196        })
197    }
198
199    fn load_default_profile(&self, cx: &Context<Self>) {
200        let assistant_settings = AssistantSettings::get_global(cx);
201
202        self.load_profile_by_id(&assistant_settings.default_profile, cx);
203    }
204
205    pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
206        let assistant_settings = AssistantSettings::get_global(cx);
207
208        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
209            self.load_profile(profile, cx);
210        }
211    }
212
213    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
214        self.tools.disable_all_tools();
215        self.tools.enable(
216            ToolSource::Native,
217            &profile
218                .tools
219                .iter()
220                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
221                .collect::<Vec<_>>(),
222        );
223
224        if profile.enable_all_context_servers {
225            for context_server in self.context_server_manager.read(cx).all_servers() {
226                self.tools.enable_source(
227                    ToolSource::ContextServer {
228                        id: context_server.id().into(),
229                    },
230                    cx,
231                );
232            }
233        } else {
234            for (context_server_id, preset) in &profile.context_servers {
235                self.tools.enable(
236                    ToolSource::ContextServer {
237                        id: context_server_id.clone().into(),
238                    },
239                    &preset
240                        .tools
241                        .iter()
242                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
243                        .collect::<Vec<_>>(),
244                )
245            }
246        }
247    }
248
249    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
250        cx.subscribe(
251            &self.context_server_manager.clone(),
252            Self::handle_context_server_event,
253        )
254        .detach();
255    }
256
257    fn handle_context_server_event(
258        &mut self,
259        context_server_manager: Entity<ContextServerManager>,
260        event: &context_server::manager::Event,
261        cx: &mut Context<Self>,
262    ) {
263        let tool_working_set = self.tools.clone();
264        match event {
265            context_server::manager::Event::ServerStarted { server_id } => {
266                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
267                    let context_server_manager = context_server_manager.clone();
268                    cx.spawn({
269                        let server = server.clone();
270                        let server_id = server_id.clone();
271                        async move |this, cx| {
272                            let Some(protocol) = server.client() else {
273                                return;
274                            };
275
276                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
277                                if let Some(tools) = protocol.list_tools().await.log_err() {
278                                    let tool_ids = tools
279                                        .tools
280                                        .into_iter()
281                                        .map(|tool| {
282                                            log::info!(
283                                                "registering context server tool: {:?}",
284                                                tool.name
285                                            );
286                                            tool_working_set.insert(Arc::new(
287                                                ContextServerTool::new(
288                                                    context_server_manager.clone(),
289                                                    server.id(),
290                                                    tool,
291                                                ),
292                                            ))
293                                        })
294                                        .collect::<Vec<_>>();
295
296                                    this.update(cx, |this, cx| {
297                                        this.context_server_tool_ids.insert(server_id, tool_ids);
298                                        this.load_default_profile(cx);
299                                    })
300                                    .log_err();
301                                }
302                            }
303                        }
304                    })
305                    .detach();
306                }
307            }
308            context_server::manager::Event::ServerStopped { server_id } => {
309                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
310                    tool_working_set.remove(&tool_ids);
311                    self.load_default_profile(cx);
312                }
313            }
314        }
315    }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct SerializedThreadMetadata {
320    pub id: ThreadId,
321    pub summary: SharedString,
322    pub updated_at: DateTime<Utc>,
323}
324
325#[derive(Serialize, Deserialize, Debug)]
326pub struct SerializedThread {
327    pub version: String,
328    pub summary: SharedString,
329    pub updated_at: DateTime<Utc>,
330    pub messages: Vec<SerializedMessage>,
331    #[serde(default)]
332    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
333    #[serde(default)]
334    pub cumulative_token_usage: TokenUsage,
335    #[serde(default)]
336    pub detailed_summary_state: DetailedSummaryState,
337}
338
339impl SerializedThread {
340    pub const VERSION: &'static str = "0.1.0";
341
342    pub fn from_json(json: &[u8]) -> Result<Self> {
343        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
344        match saved_thread_json.get("version") {
345            Some(serde_json::Value::String(version)) => match version.as_str() {
346                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
347                    saved_thread_json,
348                )?),
349                _ => Err(anyhow!(
350                    "unrecognized serialized thread version: {}",
351                    version
352                )),
353            },
354            None => {
355                let saved_thread =
356                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
357                Ok(saved_thread.upgrade())
358            }
359            version => Err(anyhow!(
360                "unrecognized serialized thread version: {:?}",
361                version
362            )),
363        }
364    }
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct SerializedMessage {
369    pub id: MessageId,
370    pub role: Role,
371    #[serde(default)]
372    pub segments: Vec<SerializedMessageSegment>,
373    #[serde(default)]
374    pub tool_uses: Vec<SerializedToolUse>,
375    #[serde(default)]
376    pub tool_results: Vec<SerializedToolResult>,
377}
378
379#[derive(Debug, Serialize, Deserialize)]
380#[serde(tag = "type")]
381pub enum SerializedMessageSegment {
382    #[serde(rename = "text")]
383    Text { text: String },
384    #[serde(rename = "thinking")]
385    Thinking { text: String },
386}
387
388#[derive(Debug, Serialize, Deserialize)]
389pub struct SerializedToolUse {
390    pub id: LanguageModelToolUseId,
391    pub name: SharedString,
392    pub input: serde_json::Value,
393}
394
395#[derive(Debug, Serialize, Deserialize)]
396pub struct SerializedToolResult {
397    pub tool_use_id: LanguageModelToolUseId,
398    pub is_error: bool,
399    pub content: Arc<str>,
400}
401
402#[derive(Serialize, Deserialize)]
403struct LegacySerializedThread {
404    pub summary: SharedString,
405    pub updated_at: DateTime<Utc>,
406    pub messages: Vec<LegacySerializedMessage>,
407    #[serde(default)]
408    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
409}
410
411impl LegacySerializedThread {
412    pub fn upgrade(self) -> SerializedThread {
413        SerializedThread {
414            version: SerializedThread::VERSION.to_string(),
415            summary: self.summary,
416            updated_at: self.updated_at,
417            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
418            initial_project_snapshot: self.initial_project_snapshot,
419            cumulative_token_usage: TokenUsage::default(),
420            detailed_summary_state: DetailedSummaryState::default(),
421        }
422    }
423}
424
425#[derive(Debug, Serialize, Deserialize)]
426struct LegacySerializedMessage {
427    pub id: MessageId,
428    pub role: Role,
429    pub text: String,
430    #[serde(default)]
431    pub tool_uses: Vec<SerializedToolUse>,
432    #[serde(default)]
433    pub tool_results: Vec<SerializedToolResult>,
434}
435
436impl LegacySerializedMessage {
437    fn upgrade(self) -> SerializedMessage {
438        SerializedMessage {
439            id: self.id,
440            role: self.role,
441            segments: vec![SerializedMessageSegment::Text { text: self.text }],
442            tool_uses: self.tool_uses,
443            tool_results: self.tool_results,
444        }
445    }
446}
447
448struct GlobalThreadsDatabase(
449    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
450);
451
452impl Global for GlobalThreadsDatabase {}
453
454pub(crate) struct ThreadsDatabase {
455    executor: BackgroundExecutor,
456    env: heed::Env,
457    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
458}
459
460impl heed::BytesEncode<'_> for SerializedThread {
461    type EItem = SerializedThread;
462
463    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
464        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
465    }
466}
467
468impl<'a> heed::BytesDecode<'a> for SerializedThread {
469    type DItem = SerializedThread;
470
471    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
472        // We implement this type manually because we want to call `SerializedThread::from_json`,
473        // instead of the Deserialize trait implementation for `SerializedThread`.
474        SerializedThread::from_json(bytes).map_err(Into::into)
475    }
476}
477
478impl ThreadsDatabase {
479    fn global_future(
480        cx: &mut App,
481    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
482        GlobalThreadsDatabase::global(cx).0.clone()
483    }
484
485    fn init(cx: &mut App) {
486        let executor = cx.background_executor().clone();
487        let database_future = executor
488            .spawn({
489                let executor = executor.clone();
490                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
491                async move { ThreadsDatabase::new(database_path, executor) }
492            })
493            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
494            .boxed()
495            .shared();
496
497        cx.set_global(GlobalThreadsDatabase(database_future));
498    }
499
500    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
501        std::fs::create_dir_all(&path)?;
502
503        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
504        let env = unsafe {
505            heed::EnvOpenOptions::new()
506                .map_size(ONE_GB_IN_BYTES)
507                .max_dbs(1)
508                .open(path)?
509        };
510
511        let mut txn = env.write_txn()?;
512        let threads = env.create_database(&mut txn, Some("threads"))?;
513        txn.commit()?;
514
515        Ok(Self {
516            executor,
517            env,
518            threads,
519        })
520    }
521
522    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
523        let env = self.env.clone();
524        let threads = self.threads;
525
526        self.executor.spawn(async move {
527            let txn = env.read_txn()?;
528            let mut iter = threads.iter(&txn)?;
529            let mut threads = Vec::new();
530            while let Some((key, value)) = iter.next().transpose()? {
531                threads.push(SerializedThreadMetadata {
532                    id: key,
533                    summary: value.summary,
534                    updated_at: value.updated_at,
535                });
536            }
537
538            Ok(threads)
539        })
540    }
541
542    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
543        let env = self.env.clone();
544        let threads = self.threads;
545
546        self.executor.spawn(async move {
547            let txn = env.read_txn()?;
548            let thread = threads.get(&txn, &id)?;
549            Ok(thread)
550        })
551    }
552
553    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
554        let env = self.env.clone();
555        let threads = self.threads;
556
557        self.executor.spawn(async move {
558            let mut txn = env.write_txn()?;
559            threads.put(&mut txn, &id, &thread)?;
560            txn.commit()?;
561            Ok(())
562        })
563    }
564
565    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
566        let env = self.env.clone();
567        let threads = self.threads;
568
569        self.executor.spawn(async move {
570            let mut txn = env.write_txn()?;
571            threads.delete(&mut txn, &id)?;
572            txn.commit()?;
573            Ok(())
574        })
575    }
576}