thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::ContextServerId;
 13use futures::channel::{mpsc, oneshot};
 14use futures::future::{self, BoxFuture, Shared};
 15use futures::{FutureExt as _, StreamExt as _};
 16use gpui::{
 17    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 18    Subscription, Task, prelude::*,
 19};
 20use heed::Database;
 21use heed::types::SerdeBincode;
 22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 23use project::context_server_store::{ContextServerStatus, ContextServerStore};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use util::ResultExt as _;
 32
 33use crate::context_server_tool::ContextServerTool;
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    prompt_store: Option<Entity<PromptStore>>,
 66    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 67    threads: Vec<SerializedThreadMetadata>,
 68    project_context: SharedProjectContext,
 69    reload_system_prompt_tx: mpsc::Sender<()>,
 70    _reload_system_prompt_task: Task<()>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74pub struct RulesLoadingError {
 75    pub message: SharedString,
 76}
 77
 78impl EventEmitter<RulesLoadingError> for ThreadStore {}
 79
 80impl ThreadStore {
 81    pub fn load(
 82        project: Entity<Project>,
 83        tools: Entity<ToolWorkingSet>,
 84        prompt_store: Option<Entity<PromptStore>>,
 85        prompt_builder: Arc<PromptBuilder>,
 86        cx: &mut App,
 87    ) -> Task<Result<Entity<Self>>> {
 88        cx.spawn(async move |cx| {
 89            let (thread_store, ready_rx) = cx.update(|cx| {
 90                let mut option_ready_rx = None;
 91                let thread_store = cx.new(|cx| {
 92                    let (thread_store, ready_rx) =
 93                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 94                    option_ready_rx = Some(ready_rx);
 95                    thread_store
 96                });
 97                (thread_store, option_ready_rx.take().unwrap())
 98            })?;
 99            ready_rx.await?;
100            Ok(thread_store)
101        })
102    }
103
104    fn new(
105        project: Entity<Project>,
106        tools: Entity<ToolWorkingSet>,
107        prompt_builder: Arc<PromptBuilder>,
108        prompt_store: Option<Entity<PromptStore>>,
109        cx: &mut Context<Self>,
110    ) -> (Self, oneshot::Receiver<()>) {
111        let mut subscriptions = vec![
112            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
113                this.load_default_profile(cx);
114            }),
115            cx.subscribe(&project, Self::handle_project_event),
116        ];
117
118        if let Some(prompt_store) = prompt_store.as_ref() {
119            subscriptions.push(cx.subscribe(
120                prompt_store,
121                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
122                    this.enqueue_system_prompt_reload();
123                },
124            ))
125        }
126
127        // This channel and task prevent concurrent and redundant loading of the system prompt.
128        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
129        let (ready_tx, ready_rx) = oneshot::channel();
130        let mut ready_tx = Some(ready_tx);
131        let reload_system_prompt_task = cx.spawn({
132            let prompt_store = prompt_store.clone();
133            async move |thread_store, cx| {
134                loop {
135                    let Some(reload_task) = thread_store
136                        .update(cx, |thread_store, cx| {
137                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
138                        })
139                        .ok()
140                    else {
141                        return;
142                    };
143                    reload_task.await;
144                    if let Some(ready_tx) = ready_tx.take() {
145                        ready_tx.send(()).ok();
146                    }
147                    reload_system_prompt_rx.next().await;
148                }
149            }
150        });
151
152        let this = Self {
153            project,
154            tools,
155            prompt_builder,
156            prompt_store,
157            context_server_tool_ids: HashMap::default(),
158            threads: Vec::new(),
159            project_context: SharedProjectContext::default(),
160            reload_system_prompt_tx,
161            _reload_system_prompt_task: reload_system_prompt_task,
162            _subscriptions: subscriptions,
163        };
164        this.load_default_profile(cx);
165        this.register_context_server_handlers(cx);
166        this.reload(cx).detach_and_log_err(cx);
167        (this, ready_rx)
168    }
169
170    fn handle_project_event(
171        &mut self,
172        _project: Entity<Project>,
173        event: &project::Event,
174        _cx: &mut Context<Self>,
175    ) {
176        match event {
177            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
178                self.enqueue_system_prompt_reload();
179            }
180            project::Event::WorktreeUpdatedEntries(_, items) => {
181                if items.iter().any(|(path, _, _)| {
182                    RULES_FILE_NAMES
183                        .iter()
184                        .any(|name| path.as_ref() == Path::new(name))
185                }) {
186                    self.enqueue_system_prompt_reload();
187                }
188            }
189            _ => {}
190        }
191    }
192
193    fn enqueue_system_prompt_reload(&mut self) {
194        self.reload_system_prompt_tx.try_send(()).ok();
195    }
196
197    // Note that this should only be called from `reload_system_prompt_task`.
198    fn reload_system_prompt(
199        &self,
200        prompt_store: Option<Entity<PromptStore>>,
201        cx: &mut Context<Self>,
202    ) -> Task<()> {
203        let worktrees = self
204            .project
205            .read(cx)
206            .visible_worktrees(cx)
207            .collect::<Vec<_>>();
208        let worktree_tasks = worktrees
209            .into_iter()
210            .map(|worktree| {
211                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
212            })
213            .collect::<Vec<_>>();
214        let default_user_rules_task = match prompt_store {
215            None => Task::ready(vec![]),
216            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
217                let prompts = prompt_store.default_prompt_metadata();
218                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
219                    let contents = prompt_store.load(prompt_metadata.id, cx);
220                    async move { (contents.await, prompt_metadata) }
221                });
222                cx.background_spawn(future::join_all(load_tasks))
223            }),
224        };
225
226        cx.spawn(async move |this, cx| {
227            let (worktrees, default_user_rules) =
228                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
229
230            let worktrees = worktrees
231                .into_iter()
232                .map(|(worktree, rules_error)| {
233                    if let Some(rules_error) = rules_error {
234                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
235                    }
236                    worktree
237                })
238                .collect::<Vec<_>>();
239
240            let default_user_rules = default_user_rules
241                .into_iter()
242                .flat_map(|(contents, prompt_metadata)| match contents {
243                    Ok(contents) => Some(UserRulesContext {
244                        uuid: match prompt_metadata.id {
245                            PromptId::User { uuid } => uuid,
246                            PromptId::EditWorkflow => return None,
247                        },
248                        title: prompt_metadata.title.map(|title| title.to_string()),
249                        contents,
250                    }),
251                    Err(err) => {
252                        this.update(cx, |_, cx| {
253                            cx.emit(RulesLoadingError {
254                                message: format!("{err:?}").into(),
255                            });
256                        })
257                        .ok();
258                        None
259                    }
260                })
261                .collect::<Vec<_>>();
262
263            this.update(cx, |this, _cx| {
264                *this.project_context.0.borrow_mut() =
265                    Some(ProjectContext::new(worktrees, default_user_rules));
266            })
267            .ok();
268        })
269    }
270
271    fn load_worktree_info_for_system_prompt(
272        worktree: Entity<Worktree>,
273        project: Entity<Project>,
274        cx: &mut App,
275    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
276        let root_name = worktree.read(cx).root_name().into();
277
278        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
279        let Some(rules_task) = rules_task else {
280            return Task::ready((
281                WorktreeContext {
282                    root_name,
283                    rules_file: None,
284                },
285                None,
286            ));
287        };
288
289        cx.spawn(async move |_| {
290            let (rules_file, rules_file_error) = match rules_task.await {
291                Ok(rules_file) => (Some(rules_file), None),
292                Err(err) => (
293                    None,
294                    Some(RulesLoadingError {
295                        message: format!("{err}").into(),
296                    }),
297                ),
298            };
299            let worktree_info = WorktreeContext {
300                root_name,
301                rules_file,
302            };
303            (worktree_info, rules_file_error)
304        })
305    }
306
307    fn load_worktree_rules_file(
308        worktree: Entity<Worktree>,
309        project: Entity<Project>,
310        cx: &mut App,
311    ) -> Option<Task<Result<RulesFileContext>>> {
312        let worktree_ref = worktree.read(cx);
313        let worktree_id = worktree_ref.id();
314        let selected_rules_file = RULES_FILE_NAMES
315            .into_iter()
316            .filter_map(|name| {
317                worktree_ref
318                    .entry_for_path(name)
319                    .filter(|entry| entry.is_file())
320                    .map(|entry| entry.path.clone())
321            })
322            .next();
323
324        // Note that Cline supports `.clinerules` being a directory, but that is not currently
325        // supported. This doesn't seem to occur often in GitHub repositories.
326        selected_rules_file.map(|path_in_worktree| {
327            let project_path = ProjectPath {
328                worktree_id,
329                path: path_in_worktree.clone(),
330            };
331            let buffer_task =
332                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
333            let rope_task = cx.spawn(async move |cx| {
334                buffer_task.await?.read_with(cx, |buffer, cx| {
335                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
336                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
337                })?
338            });
339            // Build a string from the rope on a background thread.
340            cx.background_spawn(async move {
341                let (project_entry_id, rope) = rope_task.await?;
342                anyhow::Ok(RulesFileContext {
343                    path_in_worktree,
344                    text: rope.to_string().trim().to_string(),
345                    project_entry_id: project_entry_id.to_usize(),
346                })
347            })
348        })
349    }
350
351    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
352        &self.prompt_store
353    }
354
355    pub fn tools(&self) -> Entity<ToolWorkingSet> {
356        self.tools.clone()
357    }
358
359    /// Returns the number of threads.
360    pub fn thread_count(&self) -> usize {
361        self.threads.len()
362    }
363
364    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
365        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
366        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
367        threads
368    }
369
370    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
371        cx.new(|cx| {
372            Thread::new(
373                self.project.clone(),
374                self.tools.clone(),
375                self.prompt_builder.clone(),
376                self.project_context.clone(),
377                cx,
378            )
379        })
380    }
381
382    pub fn open_thread(
383        &self,
384        id: &ThreadId,
385        cx: &mut Context<Self>,
386    ) -> Task<Result<Entity<Thread>>> {
387        let id = id.clone();
388        let database_future = ThreadsDatabase::global_future(cx);
389        cx.spawn(async move |this, cx| {
390            let database = database_future.await.map_err(|err| anyhow!(err))?;
391            let thread = database
392                .try_find_thread(id.clone())
393                .await?
394                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
395
396            let thread = this.update(cx, |this, cx| {
397                cx.new(|cx| {
398                    Thread::deserialize(
399                        id.clone(),
400                        thread,
401                        this.project.clone(),
402                        this.tools.clone(),
403                        this.prompt_builder.clone(),
404                        this.project_context.clone(),
405                        cx,
406                    )
407                })
408            })?;
409
410            Ok(thread)
411        })
412    }
413
414    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
415        let (metadata, serialized_thread) =
416            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
417
418        let database_future = ThreadsDatabase::global_future(cx);
419        cx.spawn(async move |this, cx| {
420            let serialized_thread = serialized_thread.await?;
421            let database = database_future.await.map_err(|err| anyhow!(err))?;
422            database.save_thread(metadata, serialized_thread).await?;
423
424            this.update(cx, |this, cx| this.reload(cx))?.await
425        })
426    }
427
428    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
429        let id = id.clone();
430        let database_future = ThreadsDatabase::global_future(cx);
431        cx.spawn(async move |this, cx| {
432            let database = database_future.await.map_err(|err| anyhow!(err))?;
433            database.delete_thread(id.clone()).await?;
434
435            this.update(cx, |this, cx| {
436                this.threads.retain(|thread| thread.id != id);
437                cx.notify();
438            })
439        })
440    }
441
442    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
443        let database_future = ThreadsDatabase::global_future(cx);
444        cx.spawn(async move |this, cx| {
445            let threads = database_future
446                .await
447                .map_err(|err| anyhow!(err))?
448                .list_threads()
449                .await?;
450
451            this.update(cx, |this, cx| {
452                this.threads = threads;
453                cx.notify();
454            })
455        })
456    }
457
458    fn load_default_profile(&self, cx: &mut Context<Self>) {
459        let assistant_settings = AssistantSettings::get_global(cx);
460
461        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
462    }
463
464    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
465        let assistant_settings = AssistantSettings::get_global(cx);
466
467        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
468            self.load_profile(profile.clone(), cx);
469        }
470    }
471
472    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
473        self.tools.update(cx, |tools, cx| {
474            tools.disable_all_tools(cx);
475            tools.enable(
476                ToolSource::Native,
477                &profile
478                    .tools
479                    .iter()
480                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
481                    .collect::<Vec<_>>(),
482                cx,
483            );
484        });
485
486        if profile.enable_all_context_servers {
487            for context_server_id in self
488                .project
489                .read(cx)
490                .context_server_store()
491                .read(cx)
492                .all_server_ids()
493            {
494                self.tools.update(cx, |tools, cx| {
495                    tools.enable_source(
496                        ToolSource::ContextServer {
497                            id: context_server_id.0.into(),
498                        },
499                        cx,
500                    );
501                });
502            }
503            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
504            for (context_server_id, preset) in &profile.context_servers {
505                self.tools.update(cx, |tools, cx| {
506                    tools.disable(
507                        ToolSource::ContextServer {
508                            id: context_server_id.clone().into(),
509                        },
510                        &preset
511                            .tools
512                            .iter()
513                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
514                            .collect::<Vec<_>>(),
515                        cx,
516                    )
517                })
518            }
519        } else {
520            for (context_server_id, preset) in &profile.context_servers {
521                self.tools.update(cx, |tools, cx| {
522                    tools.enable(
523                        ToolSource::ContextServer {
524                            id: context_server_id.clone().into(),
525                        },
526                        &preset
527                            .tools
528                            .iter()
529                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
530                            .collect::<Vec<_>>(),
531                        cx,
532                    )
533                })
534            }
535        }
536    }
537
538    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
539        cx.subscribe(
540            &self.project.read(cx).context_server_store(),
541            Self::handle_context_server_event,
542        )
543        .detach();
544    }
545
546    fn handle_context_server_event(
547        &mut self,
548        context_server_store: Entity<ContextServerStore>,
549        event: &project::context_server_store::Event,
550        cx: &mut Context<Self>,
551    ) {
552        let tool_working_set = self.tools.clone();
553        match event {
554            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
555                match status {
556                    ContextServerStatus::Running => {
557                        if let Some(server) =
558                            context_server_store.read(cx).get_running_server(server_id)
559                        {
560                            let context_server_manager = context_server_store.clone();
561                            cx.spawn({
562                                let server = server.clone();
563                                let server_id = server_id.clone();
564                                async move |this, cx| {
565                                    let Some(protocol) = server.client() else {
566                                        return;
567                                    };
568
569                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
570                                        if let Some(tools) = protocol.list_tools().await.log_err() {
571                                            let tool_ids = tool_working_set
572                                                .update(cx, |tool_working_set, _| {
573                                                    tools
574                                                        .tools
575                                                        .into_iter()
576                                                        .map(|tool| {
577                                                            log::info!(
578                                                                "registering context server tool: {:?}",
579                                                                tool.name
580                                                            );
581                                                            tool_working_set.insert(Arc::new(
582                                                                ContextServerTool::new(
583                                                                    context_server_manager.clone(),
584                                                                    server.id(),
585                                                                    tool,
586                                                                ),
587                                                            ))
588                                                        })
589                                                        .collect::<Vec<_>>()
590                                                })
591                                                .log_err();
592
593                                            if let Some(tool_ids) = tool_ids {
594                                                this.update(cx, |this, cx| {
595                                                    this.context_server_tool_ids
596                                                        .insert(server_id, tool_ids);
597                                                    this.load_default_profile(cx);
598                                                })
599                                                .log_err();
600                                            }
601                                        }
602                                    }
603                                }
604                            })
605                            .detach();
606                        }
607                    }
608                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
609                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
610                            tool_working_set.update(cx, |tool_working_set, _| {
611                                tool_working_set.remove(&tool_ids);
612                            });
613                            self.load_default_profile(cx);
614                        }
615                    }
616                    _ => {}
617                }
618            }
619        }
620    }
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct SerializedThreadMetadata {
625    pub id: ThreadId,
626    pub summary: SharedString,
627    pub updated_at: DateTime<Utc>,
628}
629
630#[derive(Serialize, Deserialize, Debug)]
631pub struct SerializedThread {
632    pub version: String,
633    pub summary: SharedString,
634    pub updated_at: DateTime<Utc>,
635    pub messages: Vec<SerializedMessage>,
636    #[serde(default)]
637    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
638    #[serde(default)]
639    pub cumulative_token_usage: TokenUsage,
640    #[serde(default)]
641    pub request_token_usage: Vec<TokenUsage>,
642    #[serde(default)]
643    pub detailed_summary_state: DetailedSummaryState,
644    #[serde(default)]
645    pub exceeded_window_error: Option<ExceededWindowError>,
646    #[serde(default)]
647    pub model: Option<SerializedLanguageModel>,
648}
649
650#[derive(Serialize, Deserialize, Debug)]
651pub struct SerializedLanguageModel {
652    pub provider: String,
653    pub model: String,
654}
655
656impl SerializedThread {
657    pub const VERSION: &'static str = "0.2.0";
658
659    pub fn from_json(json: &[u8]) -> Result<Self> {
660        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
661        match saved_thread_json.get("version") {
662            Some(serde_json::Value::String(version)) => match version.as_str() {
663                SerializedThreadV0_1_0::VERSION => {
664                    let saved_thread =
665                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
666                    Ok(saved_thread.upgrade())
667                }
668                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
669                    saved_thread_json,
670                )?),
671                _ => Err(anyhow!(
672                    "unrecognized serialized thread version: {}",
673                    version
674                )),
675            },
676            None => {
677                let saved_thread =
678                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
679                Ok(saved_thread.upgrade())
680            }
681            version => Err(anyhow!(
682                "unrecognized serialized thread version: {:?}",
683                version
684            )),
685        }
686    }
687}
688
689#[derive(Serialize, Deserialize, Debug)]
690pub struct SerializedThreadV0_1_0(
691    // The structure did not change, so we are reusing the latest SerializedThread.
692    // When making the next version, make sure this points to SerializedThreadV0_2_0
693    SerializedThread,
694);
695
696impl SerializedThreadV0_1_0 {
697    pub const VERSION: &'static str = "0.1.0";
698
699    pub fn upgrade(self) -> SerializedThread {
700        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
701
702        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
703
704        for message in self.0.messages {
705            if message.role == Role::User && !message.tool_results.is_empty() {
706                if let Some(last_message) = messages.last_mut() {
707                    debug_assert!(last_message.role == Role::Assistant);
708
709                    last_message.tool_results = message.tool_results;
710                    continue;
711                }
712            }
713
714            messages.push(message);
715        }
716
717        SerializedThread { messages, ..self.0 }
718    }
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722pub struct SerializedMessage {
723    pub id: MessageId,
724    pub role: Role,
725    #[serde(default)]
726    pub segments: Vec<SerializedMessageSegment>,
727    #[serde(default)]
728    pub tool_uses: Vec<SerializedToolUse>,
729    #[serde(default)]
730    pub tool_results: Vec<SerializedToolResult>,
731    #[serde(default)]
732    pub context: String,
733    #[serde(default)]
734    pub creases: Vec<SerializedCrease>,
735}
736
737#[derive(Debug, Serialize, Deserialize)]
738#[serde(tag = "type")]
739pub enum SerializedMessageSegment {
740    #[serde(rename = "text")]
741    Text {
742        text: String,
743    },
744    #[serde(rename = "thinking")]
745    Thinking {
746        text: String,
747        #[serde(skip_serializing_if = "Option::is_none")]
748        signature: Option<String>,
749    },
750    RedactedThinking {
751        data: Vec<u8>,
752    },
753}
754
755#[derive(Debug, Serialize, Deserialize)]
756pub struct SerializedToolUse {
757    pub id: LanguageModelToolUseId,
758    pub name: SharedString,
759    pub input: serde_json::Value,
760}
761
762#[derive(Debug, Serialize, Deserialize)]
763pub struct SerializedToolResult {
764    pub tool_use_id: LanguageModelToolUseId,
765    pub is_error: bool,
766    pub content: Arc<str>,
767}
768
769#[derive(Serialize, Deserialize)]
770struct LegacySerializedThread {
771    pub summary: SharedString,
772    pub updated_at: DateTime<Utc>,
773    pub messages: Vec<LegacySerializedMessage>,
774    #[serde(default)]
775    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
776}
777
778impl LegacySerializedThread {
779    pub fn upgrade(self) -> SerializedThread {
780        SerializedThread {
781            version: SerializedThread::VERSION.to_string(),
782            summary: self.summary,
783            updated_at: self.updated_at,
784            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
785            initial_project_snapshot: self.initial_project_snapshot,
786            cumulative_token_usage: TokenUsage::default(),
787            request_token_usage: Vec::new(),
788            detailed_summary_state: DetailedSummaryState::default(),
789            exceeded_window_error: None,
790            model: None,
791        }
792    }
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796struct LegacySerializedMessage {
797    pub id: MessageId,
798    pub role: Role,
799    pub text: String,
800    #[serde(default)]
801    pub tool_uses: Vec<SerializedToolUse>,
802    #[serde(default)]
803    pub tool_results: Vec<SerializedToolResult>,
804}
805
806impl LegacySerializedMessage {
807    fn upgrade(self) -> SerializedMessage {
808        SerializedMessage {
809            id: self.id,
810            role: self.role,
811            segments: vec![SerializedMessageSegment::Text { text: self.text }],
812            tool_uses: self.tool_uses,
813            tool_results: self.tool_results,
814            context: String::new(),
815            creases: Vec::new(),
816        }
817    }
818}
819
820#[derive(Debug, Serialize, Deserialize)]
821pub struct SerializedCrease {
822    pub start: usize,
823    pub end: usize,
824    pub icon_path: SharedString,
825    pub label: SharedString,
826}
827
828struct GlobalThreadsDatabase(
829    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
830);
831
832impl Global for GlobalThreadsDatabase {}
833
834pub(crate) struct ThreadsDatabase {
835    executor: BackgroundExecutor,
836    env: heed::Env,
837    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
838}
839
840impl heed::BytesEncode<'_> for SerializedThread {
841    type EItem = SerializedThread;
842
843    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
844        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
845    }
846}
847
848impl<'a> heed::BytesDecode<'a> for SerializedThread {
849    type DItem = SerializedThread;
850
851    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
852        // We implement this type manually because we want to call `SerializedThread::from_json`,
853        // instead of the Deserialize trait implementation for `SerializedThread`.
854        SerializedThread::from_json(bytes).map_err(Into::into)
855    }
856}
857
858impl ThreadsDatabase {
859    fn global_future(
860        cx: &mut App,
861    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
862        GlobalThreadsDatabase::global(cx).0.clone()
863    }
864
865    fn init(cx: &mut App) {
866        let executor = cx.background_executor().clone();
867        let database_future = executor
868            .spawn({
869                let executor = executor.clone();
870                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
871                async move { ThreadsDatabase::new(database_path, executor) }
872            })
873            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
874            .boxed()
875            .shared();
876
877        cx.set_global(GlobalThreadsDatabase(database_future));
878    }
879
880    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
881        std::fs::create_dir_all(&path)?;
882
883        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
884        let env = unsafe {
885            heed::EnvOpenOptions::new()
886                .map_size(ONE_GB_IN_BYTES)
887                .max_dbs(1)
888                .open(path)?
889        };
890
891        let mut txn = env.write_txn()?;
892        let threads = env.create_database(&mut txn, Some("threads"))?;
893        txn.commit()?;
894
895        Ok(Self {
896            executor,
897            env,
898            threads,
899        })
900    }
901
902    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
903        let env = self.env.clone();
904        let threads = self.threads;
905
906        self.executor.spawn(async move {
907            let txn = env.read_txn()?;
908            let mut iter = threads.iter(&txn)?;
909            let mut threads = Vec::new();
910            while let Some((key, value)) = iter.next().transpose()? {
911                threads.push(SerializedThreadMetadata {
912                    id: key,
913                    summary: value.summary,
914                    updated_at: value.updated_at,
915                });
916            }
917
918            Ok(threads)
919        })
920    }
921
922    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
923        let env = self.env.clone();
924        let threads = self.threads;
925
926        self.executor.spawn(async move {
927            let txn = env.read_txn()?;
928            let thread = threads.get(&txn, &id)?;
929            Ok(thread)
930        })
931    }
932
933    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
934        let env = self.env.clone();
935        let threads = self.threads;
936
937        self.executor.spawn(async move {
938            let mut txn = env.write_txn()?;
939            threads.put(&mut txn, &id, &thread)?;
940            txn.commit()?;
941            Ok(())
942        })
943    }
944
945    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
946        let env = self.env.clone();
947        let threads = self.threads;
948
949        self.executor.spawn(async move {
950            let mut txn = env.write_txn()?;
951            threads.delete(&mut txn, &id)?;
952            txn.commit()?;
953            Ok(())
954        })
955    }
956}