thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::channel::{mpsc, oneshot};
 16use futures::future::{self, BoxFuture, Shared};
 17use futures::{FutureExt as _, StreamExt as _};
 18use gpui::{
 19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 20    Subscription, Task, prelude::*,
 21};
 22use heed::Database;
 23use heed::types::SerdeBincode;
 24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 25use project::{Project, Worktree};
 26use prompt_store::{
 27    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 28    UserRulesContext, WorktreeContext,
 29};
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings as _, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    prompt_store: Option<Entity<PromptStore>>,
 66    context_server_manager: Entity<ContextServerManager>,
 67    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 68    threads: Vec<SerializedThreadMetadata>,
 69    project_context: SharedProjectContext,
 70    reload_system_prompt_tx: mpsc::Sender<()>,
 71    _reload_system_prompt_task: Task<()>,
 72    _subscriptions: Vec<Subscription>,
 73}
 74
 75pub struct RulesLoadingError {
 76    pub message: SharedString,
 77}
 78
 79impl EventEmitter<RulesLoadingError> for ThreadStore {}
 80
 81impl ThreadStore {
 82    pub fn load(
 83        project: Entity<Project>,
 84        tools: Entity<ToolWorkingSet>,
 85        prompt_store: Option<Entity<PromptStore>>,
 86        prompt_builder: Arc<PromptBuilder>,
 87        cx: &mut App,
 88    ) -> Task<Result<Entity<Self>>> {
 89        cx.spawn(async move |cx| {
 90            let (thread_store, ready_rx) = cx.update(|cx| {
 91                let mut option_ready_rx = None;
 92                let thread_store = cx.new(|cx| {
 93                    let (thread_store, ready_rx) =
 94                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 95                    option_ready_rx = Some(ready_rx);
 96                    thread_store
 97                });
 98                (thread_store, option_ready_rx.take().unwrap())
 99            })?;
100            ready_rx.await?;
101            Ok(thread_store)
102        })
103    }
104
105    fn new(
106        project: Entity<Project>,
107        tools: Entity<ToolWorkingSet>,
108        prompt_builder: Arc<PromptBuilder>,
109        prompt_store: Option<Entity<PromptStore>>,
110        cx: &mut Context<Self>,
111    ) -> (Self, oneshot::Receiver<()>) {
112        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113        let context_server_manager = cx.new(|cx| {
114            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115        });
116
117        let mut subscriptions = vec![
118            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119                this.load_default_profile(cx);
120            }),
121            cx.subscribe(&project, Self::handle_project_event),
122        ];
123
124        if let Some(prompt_store) = prompt_store.as_ref() {
125            subscriptions.push(cx.subscribe(
126                prompt_store,
127                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128                    this.enqueue_system_prompt_reload();
129                },
130            ))
131        }
132
133        // This channel and task prevent concurrent and redundant loading of the system prompt.
134        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135        let (ready_tx, ready_rx) = oneshot::channel();
136        let mut ready_tx = Some(ready_tx);
137        let reload_system_prompt_task = cx.spawn({
138            let prompt_store = prompt_store.clone();
139            async move |thread_store, cx| {
140                loop {
141                    let Some(reload_task) = thread_store
142                        .update(cx, |thread_store, cx| {
143                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
144                        })
145                        .ok()
146                    else {
147                        return;
148                    };
149                    reload_task.await;
150                    if let Some(ready_tx) = ready_tx.take() {
151                        ready_tx.send(()).ok();
152                    }
153                    reload_system_prompt_rx.next().await;
154                }
155            }
156        });
157
158        let this = Self {
159            project,
160            tools,
161            prompt_builder,
162            prompt_store,
163            context_server_manager,
164            context_server_tool_ids: HashMap::default(),
165            threads: Vec::new(),
166            project_context: SharedProjectContext::default(),
167            reload_system_prompt_tx,
168            _reload_system_prompt_task: reload_system_prompt_task,
169            _subscriptions: subscriptions,
170        };
171        this.load_default_profile(cx);
172        this.register_context_server_handlers(cx);
173        this.reload(cx).detach_and_log_err(cx);
174        (this, ready_rx)
175    }
176
177    fn handle_project_event(
178        &mut self,
179        _project: Entity<Project>,
180        event: &project::Event,
181        _cx: &mut Context<Self>,
182    ) {
183        match event {
184            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
185                self.enqueue_system_prompt_reload();
186            }
187            project::Event::WorktreeUpdatedEntries(_, items) => {
188                if items.iter().any(|(path, _, _)| {
189                    RULES_FILE_NAMES
190                        .iter()
191                        .any(|name| path.as_ref() == Path::new(name))
192                }) {
193                    self.enqueue_system_prompt_reload();
194                }
195            }
196            _ => {}
197        }
198    }
199
200    fn enqueue_system_prompt_reload(&mut self) {
201        self.reload_system_prompt_tx.try_send(()).ok();
202    }
203
204    // Note that this should only be called from `reload_system_prompt_task`.
205    fn reload_system_prompt(
206        &self,
207        prompt_store: Option<Entity<PromptStore>>,
208        cx: &mut Context<Self>,
209    ) -> Task<()> {
210        let project = self.project.read(cx);
211        let worktree_tasks = project
212            .visible_worktrees(cx)
213            .map(|worktree| {
214                Self::load_worktree_info_for_system_prompt(
215                    project.fs().clone(),
216                    worktree.read(cx),
217                    cx,
218                )
219            })
220            .collect::<Vec<_>>();
221        let default_user_rules_task = match prompt_store {
222            None => Task::ready(vec![]),
223            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
224                let prompts = prompt_store.default_prompt_metadata();
225                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
226                    let contents = prompt_store.load(prompt_metadata.id, cx);
227                    async move { (contents.await, prompt_metadata) }
228                });
229                cx.background_spawn(future::join_all(load_tasks))
230            }),
231        };
232
233        cx.spawn(async move |this, cx| {
234            let (worktrees, default_user_rules) =
235                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
236
237            let worktrees = worktrees
238                .into_iter()
239                .map(|(worktree, rules_error)| {
240                    if let Some(rules_error) = rules_error {
241                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
242                    }
243                    worktree
244                })
245                .collect::<Vec<_>>();
246
247            let default_user_rules = default_user_rules
248                .into_iter()
249                .flat_map(|(contents, prompt_metadata)| match contents {
250                    Ok(contents) => Some(UserRulesContext {
251                        uuid: match prompt_metadata.id {
252                            PromptId::User { uuid } => uuid,
253                            PromptId::EditWorkflow => return None,
254                        },
255                        title: prompt_metadata.title.map(|title| title.to_string()),
256                        contents,
257                    }),
258                    Err(err) => {
259                        this.update(cx, |_, cx| {
260                            cx.emit(RulesLoadingError {
261                                message: format!("{err:?}").into(),
262                            });
263                        })
264                        .ok();
265                        None
266                    }
267                })
268                .collect::<Vec<_>>();
269
270            this.update(cx, |this, _cx| {
271                *this.project_context.0.borrow_mut() =
272                    Some(ProjectContext::new(worktrees, default_user_rules));
273            })
274            .ok();
275        })
276    }
277
278    fn load_worktree_info_for_system_prompt(
279        fs: Arc<dyn Fs>,
280        worktree: &Worktree,
281        cx: &App,
282    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
283        let root_name = worktree.root_name().into();
284
285        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
286        let Some(rules_task) = rules_task else {
287            return Task::ready((
288                WorktreeContext {
289                    root_name,
290                    rules_file: None,
291                },
292                None,
293            ));
294        };
295
296        cx.spawn(async move |_| {
297            let (rules_file, rules_file_error) = match rules_task.await {
298                Ok(rules_file) => (Some(rules_file), None),
299                Err(err) => (
300                    None,
301                    Some(RulesLoadingError {
302                        message: format!("{err}").into(),
303                    }),
304                ),
305            };
306            let worktree_info = WorktreeContext {
307                root_name,
308                rules_file,
309            };
310            (worktree_info, rules_file_error)
311        })
312    }
313
314    fn load_worktree_rules_file(
315        fs: Arc<dyn Fs>,
316        worktree: &Worktree,
317        cx: &App,
318    ) -> Option<Task<Result<RulesFileContext>>> {
319        let selected_rules_file = RULES_FILE_NAMES
320            .into_iter()
321            .filter_map(|name| {
322                worktree
323                    .entry_for_path(name)
324                    .filter(|entry| entry.is_file())
325                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
326            })
327            .next();
328
329        // Note that Cline supports `.clinerules` being a directory, but that is not currently
330        // supported. This doesn't seem to occur often in GitHub repositories.
331        selected_rules_file.map(|(path_in_worktree, abs_path)| {
332            let fs = fs.clone();
333            cx.background_spawn(async move {
334                let abs_path = abs_path?;
335                let text = fs.load(&abs_path).await.with_context(|| {
336                    format!("Failed to load assistant rules file {:?}", abs_path)
337                })?;
338                anyhow::Ok(RulesFileContext {
339                    path_in_worktree,
340                    abs_path: abs_path.into(),
341                    text: text.trim().to_string(),
342                })
343            })
344        })
345    }
346
347    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
348        self.context_server_manager.clone()
349    }
350
351    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
352        &self.prompt_store
353    }
354
355    pub fn tools(&self) -> Entity<ToolWorkingSet> {
356        self.tools.clone()
357    }
358
359    /// Returns the number of threads.
360    pub fn thread_count(&self) -> usize {
361        self.threads.len()
362    }
363
364    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
365        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
366        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
367        threads
368    }
369
370    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
371        cx.new(|cx| {
372            Thread::new(
373                self.project.clone(),
374                self.tools.clone(),
375                self.prompt_builder.clone(),
376                self.project_context.clone(),
377                cx,
378            )
379        })
380    }
381
382    pub fn open_thread(
383        &self,
384        id: &ThreadId,
385        cx: &mut Context<Self>,
386    ) -> Task<Result<Entity<Thread>>> {
387        let id = id.clone();
388        let database_future = ThreadsDatabase::global_future(cx);
389        cx.spawn(async move |this, cx| {
390            let database = database_future.await.map_err(|err| anyhow!(err))?;
391            let thread = database
392                .try_find_thread(id.clone())
393                .await?
394                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
395
396            let thread = this.update(cx, |this, cx| {
397                cx.new(|cx| {
398                    Thread::deserialize(
399                        id.clone(),
400                        thread,
401                        this.project.clone(),
402                        this.tools.clone(),
403                        this.prompt_builder.clone(),
404                        this.project_context.clone(),
405                        cx,
406                    )
407                })
408            })?;
409
410            Ok(thread)
411        })
412    }
413
414    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
415        let (metadata, serialized_thread) =
416            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
417
418        let database_future = ThreadsDatabase::global_future(cx);
419        cx.spawn(async move |this, cx| {
420            let serialized_thread = serialized_thread.await?;
421            let database = database_future.await.map_err(|err| anyhow!(err))?;
422            database.save_thread(metadata, serialized_thread).await?;
423
424            this.update(cx, |this, cx| this.reload(cx))?.await
425        })
426    }
427
428    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
429        let id = id.clone();
430        let database_future = ThreadsDatabase::global_future(cx);
431        cx.spawn(async move |this, cx| {
432            let database = database_future.await.map_err(|err| anyhow!(err))?;
433            database.delete_thread(id.clone()).await?;
434
435            this.update(cx, |this, cx| {
436                this.threads.retain(|thread| thread.id != id);
437                cx.notify();
438            })
439        })
440    }
441
442    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
443        let database_future = ThreadsDatabase::global_future(cx);
444        cx.spawn(async move |this, cx| {
445            let threads = database_future
446                .await
447                .map_err(|err| anyhow!(err))?
448                .list_threads()
449                .await?;
450
451            this.update(cx, |this, cx| {
452                this.threads = threads;
453                cx.notify();
454            })
455        })
456    }
457
458    fn load_default_profile(&self, cx: &mut Context<Self>) {
459        let assistant_settings = AssistantSettings::get_global(cx);
460
461        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
462    }
463
464    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
465        let assistant_settings = AssistantSettings::get_global(cx);
466
467        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
468            self.load_profile(profile.clone(), cx);
469        }
470    }
471
472    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
473        self.tools.update(cx, |tools, cx| {
474            tools.disable_all_tools(cx);
475            tools.enable(
476                ToolSource::Native,
477                &profile
478                    .tools
479                    .iter()
480                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
481                    .collect::<Vec<_>>(),
482                cx,
483            );
484        });
485
486        if profile.enable_all_context_servers {
487            for context_server in self.context_server_manager.read(cx).all_servers() {
488                self.tools.update(cx, |tools, cx| {
489                    tools.enable_source(
490                        ToolSource::ContextServer {
491                            id: context_server.id().into(),
492                        },
493                        cx,
494                    );
495                });
496            }
497            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
498            for (context_server_id, preset) in &profile.context_servers {
499                self.tools.update(cx, |tools, cx| {
500                    tools.disable(
501                        ToolSource::ContextServer {
502                            id: context_server_id.clone().into(),
503                        },
504                        &preset
505                            .tools
506                            .iter()
507                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
508                            .collect::<Vec<_>>(),
509                        cx,
510                    )
511                })
512            }
513        } else {
514            for (context_server_id, preset) in &profile.context_servers {
515                self.tools.update(cx, |tools, cx| {
516                    tools.enable(
517                        ToolSource::ContextServer {
518                            id: context_server_id.clone().into(),
519                        },
520                        &preset
521                            .tools
522                            .iter()
523                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
524                            .collect::<Vec<_>>(),
525                        cx,
526                    )
527                })
528            }
529        }
530    }
531
532    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
533        cx.subscribe(
534            &self.context_server_manager.clone(),
535            Self::handle_context_server_event,
536        )
537        .detach();
538    }
539
540    fn handle_context_server_event(
541        &mut self,
542        context_server_manager: Entity<ContextServerManager>,
543        event: &context_server::manager::Event,
544        cx: &mut Context<Self>,
545    ) {
546        let tool_working_set = self.tools.clone();
547        match event {
548            context_server::manager::Event::ServerStarted { server_id } => {
549                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
550                    let context_server_manager = context_server_manager.clone();
551                    cx.spawn({
552                        let server = server.clone();
553                        let server_id = server_id.clone();
554                        async move |this, cx| {
555                            let Some(protocol) = server.client() else {
556                                return;
557                            };
558
559                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
560                                if let Some(tools) = protocol.list_tools().await.log_err() {
561                                    let tool_ids = tool_working_set
562                                        .update(cx, |tool_working_set, _| {
563                                            tools
564                                                .tools
565                                                .into_iter()
566                                                .map(|tool| {
567                                                    log::info!(
568                                                        "registering context server tool: {:?}",
569                                                        tool.name
570                                                    );
571                                                    tool_working_set.insert(Arc::new(
572                                                        ContextServerTool::new(
573                                                            context_server_manager.clone(),
574                                                            server.id(),
575                                                            tool,
576                                                        ),
577                                                    ))
578                                                })
579                                                .collect::<Vec<_>>()
580                                        })
581                                        .log_err();
582
583                                    if let Some(tool_ids) = tool_ids {
584                                        this.update(cx, |this, cx| {
585                                            this.context_server_tool_ids
586                                                .insert(server_id, tool_ids);
587                                            this.load_default_profile(cx);
588                                        })
589                                        .log_err();
590                                    }
591                                }
592                            }
593                        }
594                    })
595                    .detach();
596                }
597            }
598            context_server::manager::Event::ServerStopped { server_id } => {
599                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
600                    tool_working_set.update(cx, |tool_working_set, _| {
601                        tool_working_set.remove(&tool_ids);
602                    });
603                    self.load_default_profile(cx);
604                }
605            }
606        }
607    }
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct SerializedThreadMetadata {
612    pub id: ThreadId,
613    pub summary: SharedString,
614    pub updated_at: DateTime<Utc>,
615}
616
617#[derive(Serialize, Deserialize, Debug)]
618pub struct SerializedThread {
619    pub version: String,
620    pub summary: SharedString,
621    pub updated_at: DateTime<Utc>,
622    pub messages: Vec<SerializedMessage>,
623    #[serde(default)]
624    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
625    #[serde(default)]
626    pub cumulative_token_usage: TokenUsage,
627    #[serde(default)]
628    pub request_token_usage: Vec<TokenUsage>,
629    #[serde(default)]
630    pub detailed_summary_state: DetailedSummaryState,
631    #[serde(default)]
632    pub exceeded_window_error: Option<ExceededWindowError>,
633}
634
635impl SerializedThread {
636    pub const VERSION: &'static str = "0.2.0";
637
638    pub fn from_json(json: &[u8]) -> Result<Self> {
639        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
640        match saved_thread_json.get("version") {
641            Some(serde_json::Value::String(version)) => match version.as_str() {
642                SerializedThreadV0_1_0::VERSION => {
643                    let saved_thread =
644                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
645                    Ok(saved_thread.upgrade())
646                }
647                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
648                    saved_thread_json,
649                )?),
650                _ => Err(anyhow!(
651                    "unrecognized serialized thread version: {}",
652                    version
653                )),
654            },
655            None => {
656                let saved_thread =
657                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
658                Ok(saved_thread.upgrade())
659            }
660            version => Err(anyhow!(
661                "unrecognized serialized thread version: {:?}",
662                version
663            )),
664        }
665    }
666}
667
668#[derive(Serialize, Deserialize, Debug)]
669pub struct SerializedThreadV0_1_0(
670    // The structure did not change, so we are reusing the latest SerializedThread.
671    // When making the next version, make sure this points to SerializedThreadV0_2_0
672    SerializedThread,
673);
674
675impl SerializedThreadV0_1_0 {
676    pub const VERSION: &'static str = "0.1.0";
677
678    pub fn upgrade(self) -> SerializedThread {
679        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
680
681        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
682
683        for message in self.0.messages {
684            if message.role == Role::User && !message.tool_results.is_empty() {
685                if let Some(last_message) = messages.last_mut() {
686                    debug_assert!(last_message.role == Role::Assistant);
687
688                    last_message.tool_results = message.tool_results;
689                    continue;
690                }
691            }
692
693            messages.push(message);
694        }
695
696        SerializedThread { messages, ..self.0 }
697    }
698}
699
700#[derive(Debug, Serialize, Deserialize)]
701pub struct SerializedMessage {
702    pub id: MessageId,
703    pub role: Role,
704    #[serde(default)]
705    pub segments: Vec<SerializedMessageSegment>,
706    #[serde(default)]
707    pub tool_uses: Vec<SerializedToolUse>,
708    #[serde(default)]
709    pub tool_results: Vec<SerializedToolResult>,
710    #[serde(default)]
711    pub context: String,
712}
713
714#[derive(Debug, Serialize, Deserialize)]
715#[serde(tag = "type")]
716pub enum SerializedMessageSegment {
717    #[serde(rename = "text")]
718    Text {
719        text: String,
720    },
721    #[serde(rename = "thinking")]
722    Thinking {
723        text: String,
724        #[serde(skip_serializing_if = "Option::is_none")]
725        signature: Option<String>,
726    },
727    RedactedThinking {
728        data: Vec<u8>,
729    },
730}
731
732#[derive(Debug, Serialize, Deserialize)]
733pub struct SerializedToolUse {
734    pub id: LanguageModelToolUseId,
735    pub name: SharedString,
736    pub input: serde_json::Value,
737}
738
739#[derive(Debug, Serialize, Deserialize)]
740pub struct SerializedToolResult {
741    pub tool_use_id: LanguageModelToolUseId,
742    pub is_error: bool,
743    pub content: Arc<str>,
744}
745
746#[derive(Serialize, Deserialize)]
747struct LegacySerializedThread {
748    pub summary: SharedString,
749    pub updated_at: DateTime<Utc>,
750    pub messages: Vec<LegacySerializedMessage>,
751    #[serde(default)]
752    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
753}
754
755impl LegacySerializedThread {
756    pub fn upgrade(self) -> SerializedThread {
757        SerializedThread {
758            version: SerializedThread::VERSION.to_string(),
759            summary: self.summary,
760            updated_at: self.updated_at,
761            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
762            initial_project_snapshot: self.initial_project_snapshot,
763            cumulative_token_usage: TokenUsage::default(),
764            request_token_usage: Vec::new(),
765            detailed_summary_state: DetailedSummaryState::default(),
766            exceeded_window_error: None,
767        }
768    }
769}
770
771#[derive(Debug, Serialize, Deserialize)]
772struct LegacySerializedMessage {
773    pub id: MessageId,
774    pub role: Role,
775    pub text: String,
776    #[serde(default)]
777    pub tool_uses: Vec<SerializedToolUse>,
778    #[serde(default)]
779    pub tool_results: Vec<SerializedToolResult>,
780}
781
782impl LegacySerializedMessage {
783    fn upgrade(self) -> SerializedMessage {
784        SerializedMessage {
785            id: self.id,
786            role: self.role,
787            segments: vec![SerializedMessageSegment::Text { text: self.text }],
788            tool_uses: self.tool_uses,
789            tool_results: self.tool_results,
790            context: String::new(),
791        }
792    }
793}
794
795struct GlobalThreadsDatabase(
796    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
797);
798
799impl Global for GlobalThreadsDatabase {}
800
801pub(crate) struct ThreadsDatabase {
802    executor: BackgroundExecutor,
803    env: heed::Env,
804    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
805}
806
807impl heed::BytesEncode<'_> for SerializedThread {
808    type EItem = SerializedThread;
809
810    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
811        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
812    }
813}
814
815impl<'a> heed::BytesDecode<'a> for SerializedThread {
816    type DItem = SerializedThread;
817
818    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
819        // We implement this type manually because we want to call `SerializedThread::from_json`,
820        // instead of the Deserialize trait implementation for `SerializedThread`.
821        SerializedThread::from_json(bytes).map_err(Into::into)
822    }
823}
824
825impl ThreadsDatabase {
826    fn global_future(
827        cx: &mut App,
828    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
829        GlobalThreadsDatabase::global(cx).0.clone()
830    }
831
832    fn init(cx: &mut App) {
833        let executor = cx.background_executor().clone();
834        let database_future = executor
835            .spawn({
836                let executor = executor.clone();
837                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
838                async move { ThreadsDatabase::new(database_path, executor) }
839            })
840            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
841            .boxed()
842            .shared();
843
844        cx.set_global(GlobalThreadsDatabase(database_future));
845    }
846
847    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
848        std::fs::create_dir_all(&path)?;
849
850        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
851        let env = unsafe {
852            heed::EnvOpenOptions::new()
853                .map_size(ONE_GB_IN_BYTES)
854                .max_dbs(1)
855                .open(path)?
856        };
857
858        let mut txn = env.write_txn()?;
859        let threads = env.create_database(&mut txn, Some("threads"))?;
860        txn.commit()?;
861
862        Ok(Self {
863            executor,
864            env,
865            threads,
866        })
867    }
868
869    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
870        let env = self.env.clone();
871        let threads = self.threads;
872
873        self.executor.spawn(async move {
874            let txn = env.read_txn()?;
875            let mut iter = threads.iter(&txn)?;
876            let mut threads = Vec::new();
877            while let Some((key, value)) = iter.next().transpose()? {
878                threads.push(SerializedThreadMetadata {
879                    id: key,
880                    summary: value.summary,
881                    updated_at: value.updated_at,
882                });
883            }
884
885            Ok(threads)
886        })
887    }
888
889    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
890        let env = self.env.clone();
891        let threads = self.threads;
892
893        self.executor.spawn(async move {
894            let txn = env.read_txn()?;
895            let thread = threads.get(&txn, &id)?;
896            Ok(thread)
897        })
898    }
899
900    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
901        let env = self.env.clone();
902        let threads = self.threads;
903
904        self.executor.spawn(async move {
905            let mut txn = env.write_txn()?;
906            threads.put(&mut txn, &id, &thread)?;
907            txn.commit()?;
908            Ok(())
909        })
910    }
911
912    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
913        let env = self.env.clone();
914        let threads = self.threads;
915
916        self.executor.spawn(async move {
917            let mut txn = env.write_txn()?;
918            threads.delete(&mut txn, &id)?;
919            txn.commit()?;
920            Ok(())
921        })
922    }
923}