1use std::cell::{Ref, RefCell};
2use std::path::{Path, PathBuf};
3use std::rc::Rc;
4use std::sync::{Arc, Mutex};
5
6use agent_settings::{AgentProfileId, CompletionMode};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ToolId, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::HashMap;
11use context_server::ContextServerId;
12use futures::channel::{mpsc, oneshot};
13use futures::future::{self, BoxFuture, Shared};
14use futures::{FutureExt as _, StreamExt as _};
15use gpui::{
16 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
17 Subscription, Task, prelude::*,
18};
19
20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
21use project::context_server_store::{ContextServerStatus, ContextServerStore};
22use project::{Project, ProjectItem, ProjectPath, Worktree};
23use prompt_store::{
24 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
25 UserRulesContext, WorktreeContext,
26};
27use serde::{Deserialize, Serialize};
28use ui::Window;
29use util::ResultExt as _;
30
31use crate::context_server_tool::ContextServerTool;
32use crate::thread::{
33 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
34};
35use indoc::indoc;
36use sqlez::{
37 bindable::{Bind, Column},
38 connection::Connection,
39 statement::Statement,
40};
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub enum DataType {
44 #[serde(rename = "json")]
45 Json,
46 #[serde(rename = "zstd")]
47 Zstd,
48}
49
50impl Bind for DataType {
51 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
52 let value = match self {
53 DataType::Json => "json",
54 DataType::Zstd => "zstd",
55 };
56 value.bind(statement, start_index)
57 }
58}
59
60impl Column for DataType {
61 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
62 let (value, next_index) = String::column(statement, start_index)?;
63 let data_type = match value.as_str() {
64 "json" => DataType::Json,
65 "zstd" => DataType::Zstd,
66 _ => anyhow::bail!("Unknown data type: {}", value),
67 };
68 Ok((data_type, next_index))
69 }
70}
71
72const RULES_FILE_NAMES: [&'static str; 8] = [
73 ".rules",
74 ".cursorrules",
75 ".windsurfrules",
76 ".clinerules",
77 ".github/copilot-instructions.md",
78 "CLAUDE.md",
79 "AGENT.md",
80 "AGENTS.md",
81];
82
83pub fn init(cx: &mut App) {
84 ThreadsDatabase::init(cx);
85}
86
87/// A system prompt shared by all threads created by this ThreadStore
88#[derive(Clone, Default)]
89pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
90
91impl SharedProjectContext {
92 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
93 self.0.borrow()
94 }
95}
96
97pub type TextThreadStore = assistant_context_editor::ContextStore;
98
99pub struct ThreadStore {
100 project: Entity<Project>,
101 tools: Entity<ToolWorkingSet>,
102 prompt_builder: Arc<PromptBuilder>,
103 prompt_store: Option<Entity<PromptStore>>,
104 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
105 threads: Vec<SerializedThreadMetadata>,
106 project_context: SharedProjectContext,
107 reload_system_prompt_tx: mpsc::Sender<()>,
108 _reload_system_prompt_task: Task<()>,
109 _subscriptions: Vec<Subscription>,
110}
111
112pub struct RulesLoadingError {
113 pub message: SharedString,
114}
115
116impl EventEmitter<RulesLoadingError> for ThreadStore {}
117
118impl ThreadStore {
119 pub fn load(
120 project: Entity<Project>,
121 tools: Entity<ToolWorkingSet>,
122 prompt_store: Option<Entity<PromptStore>>,
123 prompt_builder: Arc<PromptBuilder>,
124 cx: &mut App,
125 ) -> Task<Result<Entity<Self>>> {
126 cx.spawn(async move |cx| {
127 let (thread_store, ready_rx) = cx.update(|cx| {
128 let mut option_ready_rx = None;
129 let thread_store = cx.new(|cx| {
130 let (thread_store, ready_rx) =
131 Self::new(project, tools, prompt_builder, prompt_store, cx);
132 option_ready_rx = Some(ready_rx);
133 thread_store
134 });
135 (thread_store, option_ready_rx.take().unwrap())
136 })?;
137 ready_rx.await?;
138 Ok(thread_store)
139 })
140 }
141
142 fn new(
143 project: Entity<Project>,
144 tools: Entity<ToolWorkingSet>,
145 prompt_builder: Arc<PromptBuilder>,
146 prompt_store: Option<Entity<PromptStore>>,
147 cx: &mut Context<Self>,
148 ) -> (Self, oneshot::Receiver<()>) {
149 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
150
151 if let Some(prompt_store) = prompt_store.as_ref() {
152 subscriptions.push(cx.subscribe(
153 prompt_store,
154 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
155 this.enqueue_system_prompt_reload();
156 },
157 ))
158 }
159
160 // This channel and task prevent concurrent and redundant loading of the system prompt.
161 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
162 let (ready_tx, ready_rx) = oneshot::channel();
163 let mut ready_tx = Some(ready_tx);
164 let reload_system_prompt_task = cx.spawn({
165 let prompt_store = prompt_store.clone();
166 async move |thread_store, cx| {
167 loop {
168 let Some(reload_task) = thread_store
169 .update(cx, |thread_store, cx| {
170 thread_store.reload_system_prompt(prompt_store.clone(), cx)
171 })
172 .ok()
173 else {
174 return;
175 };
176 reload_task.await;
177 if let Some(ready_tx) = ready_tx.take() {
178 ready_tx.send(()).ok();
179 }
180 reload_system_prompt_rx.next().await;
181 }
182 }
183 });
184
185 let this = Self {
186 project,
187 tools,
188 prompt_builder,
189 prompt_store,
190 context_server_tool_ids: HashMap::default(),
191 threads: Vec::new(),
192 project_context: SharedProjectContext::default(),
193 reload_system_prompt_tx,
194 _reload_system_prompt_task: reload_system_prompt_task,
195 _subscriptions: subscriptions,
196 };
197 this.register_context_server_handlers(cx);
198 this.reload(cx).detach_and_log_err(cx);
199 (this, ready_rx)
200 }
201
202 fn handle_project_event(
203 &mut self,
204 _project: Entity<Project>,
205 event: &project::Event,
206 _cx: &mut Context<Self>,
207 ) {
208 match event {
209 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
210 self.enqueue_system_prompt_reload();
211 }
212 project::Event::WorktreeUpdatedEntries(_, items) => {
213 if items.iter().any(|(path, _, _)| {
214 RULES_FILE_NAMES
215 .iter()
216 .any(|name| path.as_ref() == Path::new(name))
217 }) {
218 self.enqueue_system_prompt_reload();
219 }
220 }
221 _ => {}
222 }
223 }
224
225 fn enqueue_system_prompt_reload(&mut self) {
226 self.reload_system_prompt_tx.try_send(()).ok();
227 }
228
229 // Note that this should only be called from `reload_system_prompt_task`.
230 fn reload_system_prompt(
231 &self,
232 prompt_store: Option<Entity<PromptStore>>,
233 cx: &mut Context<Self>,
234 ) -> Task<()> {
235 let worktrees = self
236 .project
237 .read(cx)
238 .visible_worktrees(cx)
239 .collect::<Vec<_>>();
240 let worktree_tasks = worktrees
241 .into_iter()
242 .map(|worktree| {
243 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
244 })
245 .collect::<Vec<_>>();
246 let default_user_rules_task = match prompt_store {
247 None => Task::ready(vec![]),
248 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
249 let prompts = prompt_store.default_prompt_metadata();
250 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
251 let contents = prompt_store.load(prompt_metadata.id, cx);
252 async move { (contents.await, prompt_metadata) }
253 });
254 cx.background_spawn(future::join_all(load_tasks))
255 }),
256 };
257
258 cx.spawn(async move |this, cx| {
259 let (worktrees, default_user_rules) =
260 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
261
262 let worktrees = worktrees
263 .into_iter()
264 .map(|(worktree, rules_error)| {
265 if let Some(rules_error) = rules_error {
266 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
267 }
268 worktree
269 })
270 .collect::<Vec<_>>();
271
272 let default_user_rules = default_user_rules
273 .into_iter()
274 .flat_map(|(contents, prompt_metadata)| match contents {
275 Ok(contents) => Some(UserRulesContext {
276 uuid: match prompt_metadata.id {
277 PromptId::User { uuid } => uuid,
278 PromptId::EditWorkflow => return None,
279 },
280 title: prompt_metadata.title.map(|title| title.to_string()),
281 contents,
282 }),
283 Err(err) => {
284 this.update(cx, |_, cx| {
285 cx.emit(RulesLoadingError {
286 message: format!("{err:?}").into(),
287 });
288 })
289 .ok();
290 None
291 }
292 })
293 .collect::<Vec<_>>();
294
295 this.update(cx, |this, _cx| {
296 *this.project_context.0.borrow_mut() =
297 Some(ProjectContext::new(worktrees, default_user_rules));
298 })
299 .ok();
300 })
301 }
302
303 fn load_worktree_info_for_system_prompt(
304 worktree: Entity<Worktree>,
305 project: Entity<Project>,
306 cx: &mut App,
307 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
308 let tree = worktree.read(cx);
309 let root_name = tree.root_name().into();
310 let abs_path = tree.abs_path();
311
312 let mut context = WorktreeContext {
313 root_name,
314 abs_path,
315 rules_file: None,
316 };
317
318 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
319 let Some(rules_task) = rules_task else {
320 return Task::ready((context, None));
321 };
322
323 cx.spawn(async move |_| {
324 let (rules_file, rules_file_error) = match rules_task.await {
325 Ok(rules_file) => (Some(rules_file), None),
326 Err(err) => (
327 None,
328 Some(RulesLoadingError {
329 message: format!("{err}").into(),
330 }),
331 ),
332 };
333 context.rules_file = rules_file;
334 (context, rules_file_error)
335 })
336 }
337
338 fn load_worktree_rules_file(
339 worktree: Entity<Worktree>,
340 project: Entity<Project>,
341 cx: &mut App,
342 ) -> Option<Task<Result<RulesFileContext>>> {
343 let worktree = worktree.read(cx);
344 let worktree_id = worktree.id();
345 let selected_rules_file = RULES_FILE_NAMES
346 .into_iter()
347 .filter_map(|name| {
348 worktree
349 .entry_for_path(name)
350 .filter(|entry| entry.is_file())
351 .map(|entry| entry.path.clone())
352 })
353 .next();
354
355 // Note that Cline supports `.clinerules` being a directory, but that is not currently
356 // supported. This doesn't seem to occur often in GitHub repositories.
357 selected_rules_file.map(|path_in_worktree| {
358 let project_path = ProjectPath {
359 worktree_id,
360 path: path_in_worktree.clone(),
361 };
362 let buffer_task =
363 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
364 let rope_task = cx.spawn(async move |cx| {
365 buffer_task.await?.read_with(cx, |buffer, cx| {
366 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
367 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
368 })?
369 });
370 // Build a string from the rope on a background thread.
371 cx.background_spawn(async move {
372 let (project_entry_id, rope) = rope_task.await?;
373 anyhow::Ok(RulesFileContext {
374 path_in_worktree,
375 text: rope.to_string().trim().to_string(),
376 project_entry_id: project_entry_id.to_usize(),
377 })
378 })
379 })
380 }
381
382 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
383 &self.prompt_store
384 }
385
386 pub fn tools(&self) -> Entity<ToolWorkingSet> {
387 self.tools.clone()
388 }
389
390 /// Returns the number of threads.
391 pub fn thread_count(&self) -> usize {
392 self.threads.len()
393 }
394
395 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
396 // ordering is from "ORDER BY" in `list_threads`
397 self.threads.iter()
398 }
399
400 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
401 cx.new(|cx| {
402 Thread::new(
403 self.project.clone(),
404 self.tools.clone(),
405 self.prompt_builder.clone(),
406 self.project_context.clone(),
407 cx,
408 )
409 })
410 }
411
412 pub fn create_thread_from_serialized(
413 &mut self,
414 serialized: SerializedThread,
415 cx: &mut Context<Self>,
416 ) -> Entity<Thread> {
417 cx.new(|cx| {
418 Thread::deserialize(
419 ThreadId::new(),
420 serialized,
421 self.project.clone(),
422 self.tools.clone(),
423 self.prompt_builder.clone(),
424 self.project_context.clone(),
425 None,
426 cx,
427 )
428 })
429 }
430
431 pub fn open_thread(
432 &self,
433 id: &ThreadId,
434 window: &mut Window,
435 cx: &mut Context<Self>,
436 ) -> Task<Result<Entity<Thread>>> {
437 let id = id.clone();
438 let database_future = ThreadsDatabase::global_future(cx);
439 let this = cx.weak_entity();
440 window.spawn(cx, async move |cx| {
441 let database = database_future.await.map_err(|err| anyhow!(err))?;
442 let thread = database
443 .try_find_thread(id.clone())
444 .await?
445 .with_context(|| format!("no thread found with ID: {id:?}"))?;
446
447 let thread = this.update_in(cx, |this, window, cx| {
448 cx.new(|cx| {
449 Thread::deserialize(
450 id.clone(),
451 thread,
452 this.project.clone(),
453 this.tools.clone(),
454 this.prompt_builder.clone(),
455 this.project_context.clone(),
456 Some(window),
457 cx,
458 )
459 })
460 })?;
461
462 Ok(thread)
463 })
464 }
465
466 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
467 let (metadata, serialized_thread) =
468 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
469
470 let database_future = ThreadsDatabase::global_future(cx);
471 cx.spawn(async move |this, cx| {
472 let serialized_thread = serialized_thread.await?;
473 let database = database_future.await.map_err(|err| anyhow!(err))?;
474 database.save_thread(metadata, serialized_thread).await?;
475
476 this.update(cx, |this, cx| this.reload(cx))?.await
477 })
478 }
479
480 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
481 let id = id.clone();
482 let database_future = ThreadsDatabase::global_future(cx);
483 cx.spawn(async move |this, cx| {
484 let database = database_future.await.map_err(|err| anyhow!(err))?;
485 database.delete_thread(id.clone()).await?;
486
487 this.update(cx, |this, cx| {
488 this.threads.retain(|thread| thread.id != id);
489 cx.notify();
490 })
491 })
492 }
493
494 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
495 let database_future = ThreadsDatabase::global_future(cx);
496 cx.spawn(async move |this, cx| {
497 let threads = database_future
498 .await
499 .map_err(|err| anyhow!(err))?
500 .list_threads()
501 .await?;
502
503 this.update(cx, |this, cx| {
504 this.threads = threads;
505 cx.notify();
506 })
507 })
508 }
509
510 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
511 let context_server_store = self.project.read(cx).context_server_store();
512 cx.subscribe(&context_server_store, Self::handle_context_server_event)
513 .detach();
514
515 // Check for any servers that were already running before the handler was registered
516 for server in context_server_store.read(cx).running_servers() {
517 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
518 }
519 }
520
521 fn handle_context_server_event(
522 &mut self,
523 context_server_store: Entity<ContextServerStore>,
524 event: &project::context_server_store::Event,
525 cx: &mut Context<Self>,
526 ) {
527 let tool_working_set = self.tools.clone();
528 match event {
529 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
530 match status {
531 ContextServerStatus::Starting => {}
532 ContextServerStatus::Running => {
533 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
534 }
535 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
536 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
537 tool_working_set.update(cx, |tool_working_set, _| {
538 tool_working_set.remove(&tool_ids);
539 });
540 }
541 }
542 }
543 }
544 }
545 }
546
547 fn load_context_server_tools(
548 &self,
549 server_id: ContextServerId,
550 context_server_store: Entity<ContextServerStore>,
551 cx: &mut Context<Self>,
552 ) {
553 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
554 return;
555 };
556 let tool_working_set = self.tools.clone();
557 cx.spawn(async move |this, cx| {
558 let Some(protocol) = server.client() else {
559 return;
560 };
561
562 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
563 if let Some(response) = protocol
564 .request::<context_server::types::requests::ListTools>(())
565 .await
566 .log_err()
567 {
568 let tool_ids = tool_working_set
569 .update(cx, |tool_working_set, _| {
570 response
571 .tools
572 .into_iter()
573 .map(|tool| {
574 log::info!("registering context server tool: {:?}", tool.name);
575 tool_working_set.insert(Arc::new(ContextServerTool::new(
576 context_server_store.clone(),
577 server.id(),
578 tool,
579 )))
580 })
581 .collect::<Vec<_>>()
582 })
583 .log_err();
584
585 if let Some(tool_ids) = tool_ids {
586 this.update(cx, |this, _| {
587 this.context_server_tool_ids.insert(server_id, tool_ids);
588 })
589 .log_err();
590 }
591 }
592 }
593 })
594 .detach();
595 }
596}
597
598#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct SerializedThreadMetadata {
600 pub id: ThreadId,
601 pub summary: SharedString,
602 pub updated_at: DateTime<Utc>,
603}
604
605#[derive(Serialize, Deserialize, Debug, PartialEq)]
606pub struct SerializedThread {
607 pub version: String,
608 pub summary: SharedString,
609 pub updated_at: DateTime<Utc>,
610 pub messages: Vec<SerializedMessage>,
611 #[serde(default)]
612 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
613 #[serde(default)]
614 pub cumulative_token_usage: TokenUsage,
615 #[serde(default)]
616 pub request_token_usage: Vec<TokenUsage>,
617 #[serde(default)]
618 pub detailed_summary_state: DetailedSummaryState,
619 #[serde(default)]
620 pub exceeded_window_error: Option<ExceededWindowError>,
621 #[serde(default)]
622 pub model: Option<SerializedLanguageModel>,
623 #[serde(default)]
624 pub completion_mode: Option<CompletionMode>,
625 #[serde(default)]
626 pub tool_use_limit_reached: bool,
627 #[serde(default)]
628 pub profile: Option<AgentProfileId>,
629}
630
631#[derive(Serialize, Deserialize, Debug, PartialEq)]
632pub struct SerializedLanguageModel {
633 pub provider: String,
634 pub model: String,
635}
636
637impl SerializedThread {
638 pub const VERSION: &'static str = "0.2.0";
639
640 pub fn from_json(json: &[u8]) -> Result<Self> {
641 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
642 match saved_thread_json.get("version") {
643 Some(serde_json::Value::String(version)) => match version.as_str() {
644 SerializedThreadV0_1_0::VERSION => {
645 let saved_thread =
646 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
647 Ok(saved_thread.upgrade())
648 }
649 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
650 saved_thread_json,
651 )?),
652 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
653 },
654 None => {
655 let saved_thread =
656 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
657 Ok(saved_thread.upgrade())
658 }
659 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
660 }
661 }
662}
663
664#[derive(Serialize, Deserialize, Debug)]
665pub struct SerializedThreadV0_1_0(
666 // The structure did not change, so we are reusing the latest SerializedThread.
667 // When making the next version, make sure this points to SerializedThreadV0_2_0
668 SerializedThread,
669);
670
671impl SerializedThreadV0_1_0 {
672 pub const VERSION: &'static str = "0.1.0";
673
674 pub fn upgrade(self) -> SerializedThread {
675 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
676
677 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
678
679 for message in self.0.messages {
680 if message.role == Role::User && !message.tool_results.is_empty() {
681 if let Some(last_message) = messages.last_mut() {
682 debug_assert!(last_message.role == Role::Assistant);
683
684 last_message.tool_results = message.tool_results;
685 continue;
686 }
687 }
688
689 messages.push(message);
690 }
691
692 SerializedThread {
693 messages,
694 version: SerializedThread::VERSION.to_string(),
695 ..self.0
696 }
697 }
698}
699
700#[derive(Debug, Serialize, Deserialize, PartialEq)]
701pub struct SerializedMessage {
702 pub id: MessageId,
703 pub role: Role,
704 #[serde(default)]
705 pub segments: Vec<SerializedMessageSegment>,
706 #[serde(default)]
707 pub tool_uses: Vec<SerializedToolUse>,
708 #[serde(default)]
709 pub tool_results: Vec<SerializedToolResult>,
710 #[serde(default)]
711 pub context: String,
712 #[serde(default)]
713 pub creases: Vec<SerializedCrease>,
714 #[serde(default)]
715 pub is_hidden: bool,
716}
717
718#[derive(Debug, Serialize, Deserialize, PartialEq)]
719#[serde(tag = "type")]
720pub enum SerializedMessageSegment {
721 #[serde(rename = "text")]
722 Text {
723 text: String,
724 },
725 #[serde(rename = "thinking")]
726 Thinking {
727 text: String,
728 #[serde(skip_serializing_if = "Option::is_none")]
729 signature: Option<String>,
730 },
731 RedactedThinking {
732 data: Vec<u8>,
733 },
734}
735
736#[derive(Debug, Serialize, Deserialize, PartialEq)]
737pub struct SerializedToolUse {
738 pub id: LanguageModelToolUseId,
739 pub name: SharedString,
740 pub input: serde_json::Value,
741}
742
743#[derive(Debug, Serialize, Deserialize, PartialEq)]
744pub struct SerializedToolResult {
745 pub tool_use_id: LanguageModelToolUseId,
746 pub is_error: bool,
747 pub content: LanguageModelToolResultContent,
748 pub output: Option<serde_json::Value>,
749}
750
751#[derive(Serialize, Deserialize)]
752struct LegacySerializedThread {
753 pub summary: SharedString,
754 pub updated_at: DateTime<Utc>,
755 pub messages: Vec<LegacySerializedMessage>,
756 #[serde(default)]
757 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
758}
759
760impl LegacySerializedThread {
761 pub fn upgrade(self) -> SerializedThread {
762 SerializedThread {
763 version: SerializedThread::VERSION.to_string(),
764 summary: self.summary,
765 updated_at: self.updated_at,
766 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
767 initial_project_snapshot: self.initial_project_snapshot,
768 cumulative_token_usage: TokenUsage::default(),
769 request_token_usage: Vec::new(),
770 detailed_summary_state: DetailedSummaryState::default(),
771 exceeded_window_error: None,
772 model: None,
773 completion_mode: None,
774 tool_use_limit_reached: false,
775 profile: None,
776 }
777 }
778}
779
780#[derive(Debug, Serialize, Deserialize)]
781struct LegacySerializedMessage {
782 pub id: MessageId,
783 pub role: Role,
784 pub text: String,
785 #[serde(default)]
786 pub tool_uses: Vec<SerializedToolUse>,
787 #[serde(default)]
788 pub tool_results: Vec<SerializedToolResult>,
789}
790
791impl LegacySerializedMessage {
792 fn upgrade(self) -> SerializedMessage {
793 SerializedMessage {
794 id: self.id,
795 role: self.role,
796 segments: vec![SerializedMessageSegment::Text { text: self.text }],
797 tool_uses: self.tool_uses,
798 tool_results: self.tool_results,
799 context: String::new(),
800 creases: Vec::new(),
801 is_hidden: false,
802 }
803 }
804}
805
806#[derive(Debug, Serialize, Deserialize, PartialEq)]
807pub struct SerializedCrease {
808 pub start: usize,
809 pub end: usize,
810 pub icon_path: SharedString,
811 pub label: SharedString,
812}
813
814struct GlobalThreadsDatabase(
815 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
816);
817
818impl Global for GlobalThreadsDatabase {}
819
820pub(crate) struct ThreadsDatabase {
821 executor: BackgroundExecutor,
822 connection: Arc<Mutex<Connection>>,
823}
824
825impl ThreadsDatabase {
826 fn connection(&self) -> Arc<Mutex<Connection>> {
827 self.connection.clone()
828 }
829
830 const COMPRESSION_LEVEL: i32 = 3;
831}
832
833impl Bind for ThreadId {
834 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
835 self.to_string().bind(statement, start_index)
836 }
837}
838
839impl Column for ThreadId {
840 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
841 let (id_str, next_index) = String::column(statement, start_index)?;
842 Ok((ThreadId::from(id_str.as_str()), next_index))
843 }
844}
845
846impl ThreadsDatabase {
847 fn global_future(
848 cx: &mut App,
849 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
850 GlobalThreadsDatabase::global(cx).0.clone()
851 }
852
853 fn init(cx: &mut App) {
854 let executor = cx.background_executor().clone();
855 let database_future = executor
856 .spawn({
857 let executor = executor.clone();
858 let threads_dir = paths::data_dir().join("threads");
859 async move { ThreadsDatabase::new(threads_dir, executor) }
860 })
861 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
862 .boxed()
863 .shared();
864
865 cx.set_global(GlobalThreadsDatabase(database_future));
866 }
867
868 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
869 std::fs::create_dir_all(&threads_dir)?;
870
871 let sqlite_path = threads_dir.join("threads.db");
872 let mdb_path = threads_dir.join("threads-db.1.mdb");
873
874 let needs_migration_from_heed = mdb_path.exists();
875
876 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
877
878 connection.exec(indoc! {"
879 CREATE TABLE IF NOT EXISTS threads (
880 id TEXT PRIMARY KEY,
881 summary TEXT NOT NULL,
882 updated_at TEXT NOT NULL,
883 data_type TEXT NOT NULL,
884 data BLOB NOT NULL
885 )
886 "})?()
887 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
888
889 let db = Self {
890 executor: executor.clone(),
891 connection: Arc::new(Mutex::new(connection)),
892 };
893
894 if needs_migration_from_heed {
895 let db_connection = db.connection();
896 let executor_clone = executor.clone();
897 executor
898 .spawn(async move {
899 log::info!("Starting threads.db migration");
900 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
901 std::fs::remove_dir_all(mdb_path)?;
902 log::info!("threads.db migrated to sqlite");
903 Ok::<(), anyhow::Error>(())
904 })
905 .detach();
906 }
907
908 Ok(db)
909 }
910
911 // Remove this migration after 2025-09-01
912 fn migrate_from_heed(
913 mdb_path: &Path,
914 connection: Arc<Mutex<Connection>>,
915 _executor: BackgroundExecutor,
916 ) -> Result<()> {
917 use heed::types::SerdeBincode;
918 struct SerializedThreadHeed(SerializedThread);
919
920 impl heed::BytesEncode<'_> for SerializedThreadHeed {
921 type EItem = SerializedThreadHeed;
922
923 fn bytes_encode(
924 item: &Self::EItem,
925 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
926 serde_json::to_vec(&item.0)
927 .map(std::borrow::Cow::Owned)
928 .map_err(Into::into)
929 }
930 }
931
932 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
933 type DItem = SerializedThreadHeed;
934
935 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
936 SerializedThread::from_json(bytes)
937 .map(SerializedThreadHeed)
938 .map_err(Into::into)
939 }
940 }
941
942 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
943
944 let env = unsafe {
945 heed::EnvOpenOptions::new()
946 .map_size(ONE_GB_IN_BYTES)
947 .max_dbs(1)
948 .open(mdb_path)?
949 };
950
951 let txn = env.write_txn()?;
952 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
953 .open_database(&txn, Some("threads"))?
954 .ok_or_else(|| anyhow!("threads database not found"))?;
955
956 for result in threads.iter(&txn)? {
957 let (thread_id, thread_heed) = result?;
958 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
959 }
960
961 Ok(())
962 }
963
964 fn save_thread_sync(
965 connection: &Arc<Mutex<Connection>>,
966 id: ThreadId,
967 thread: SerializedThread,
968 ) -> Result<()> {
969 let json_data = serde_json::to_string(&thread)?;
970 let summary = thread.summary.to_string();
971 let updated_at = thread.updated_at.to_rfc3339();
972
973 let connection = connection.lock().unwrap();
974
975 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
976 let data_type = DataType::Zstd;
977 let data = compressed;
978
979 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
980 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
981 "})?;
982
983 insert((id, summary, updated_at, data_type, data))?;
984
985 Ok(())
986 }
987
988 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
989 let connection = self.connection.clone();
990
991 self.executor.spawn(async move {
992 let connection = connection.lock().unwrap();
993 let mut select =
994 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
995 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
996 "})?;
997
998 let rows = select(())?;
999 let mut threads = Vec::new();
1000
1001 for (id, summary, updated_at) in rows {
1002 threads.push(SerializedThreadMetadata {
1003 id,
1004 summary: summary.into(),
1005 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1006 });
1007 }
1008
1009 Ok(threads)
1010 })
1011 }
1012
1013 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1014 let connection = self.connection.clone();
1015
1016 self.executor.spawn(async move {
1017 let connection = connection.lock().unwrap();
1018 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1019 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1020 "})?;
1021
1022 let rows = select(id)?;
1023 if let Some((data_type, data)) = rows.into_iter().next() {
1024 let json_data = match data_type {
1025 DataType::Zstd => {
1026 let decompressed = zstd::decode_all(&data[..])?;
1027 String::from_utf8(decompressed)?
1028 }
1029 DataType::Json => String::from_utf8(data)?,
1030 };
1031
1032 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1033 Ok(Some(thread))
1034 } else {
1035 Ok(None)
1036 }
1037 })
1038 }
1039
1040 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1041 let connection = self.connection.clone();
1042
1043 self.executor
1044 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1045 }
1046
1047 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1048 let connection = self.connection.clone();
1049
1050 self.executor.spawn(async move {
1051 let connection = connection.lock().unwrap();
1052
1053 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1054 DELETE FROM threads WHERE id = ?
1055 "})?;
1056
1057 delete(id)?;
1058
1059 Ok(())
1060 })
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067 use crate::thread::{DetailedSummaryState, MessageId};
1068 use chrono::Utc;
1069 use language_model::{Role, TokenUsage};
1070 use pretty_assertions::assert_eq;
1071
1072 #[test]
1073 fn test_legacy_serialized_thread_upgrade() {
1074 let updated_at = Utc::now();
1075 let legacy_thread = LegacySerializedThread {
1076 summary: "Test conversation".into(),
1077 updated_at,
1078 messages: vec![LegacySerializedMessage {
1079 id: MessageId(1),
1080 role: Role::User,
1081 text: "Hello, world!".to_string(),
1082 tool_uses: vec![],
1083 tool_results: vec![],
1084 }],
1085 initial_project_snapshot: None,
1086 };
1087
1088 let upgraded = legacy_thread.upgrade();
1089
1090 assert_eq!(
1091 upgraded,
1092 SerializedThread {
1093 summary: "Test conversation".into(),
1094 updated_at,
1095 messages: vec![SerializedMessage {
1096 id: MessageId(1),
1097 role: Role::User,
1098 segments: vec![SerializedMessageSegment::Text {
1099 text: "Hello, world!".to_string()
1100 }],
1101 tool_uses: vec![],
1102 tool_results: vec![],
1103 context: "".to_string(),
1104 creases: vec![],
1105 is_hidden: false
1106 }],
1107 version: SerializedThread::VERSION.to_string(),
1108 initial_project_snapshot: None,
1109 cumulative_token_usage: TokenUsage::default(),
1110 request_token_usage: vec![],
1111 detailed_summary_state: DetailedSummaryState::default(),
1112 exceeded_window_error: None,
1113 model: None,
1114 completion_mode: None,
1115 tool_use_limit_reached: false,
1116 profile: None
1117 }
1118 )
1119 }
1120
1121 #[test]
1122 fn test_serialized_threadv0_1_0_upgrade() {
1123 let updated_at = Utc::now();
1124 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1125 summary: "Test conversation".into(),
1126 updated_at,
1127 messages: vec![
1128 SerializedMessage {
1129 id: MessageId(1),
1130 role: Role::User,
1131 segments: vec![SerializedMessageSegment::Text {
1132 text: "Use tool_1".to_string(),
1133 }],
1134 tool_uses: vec![],
1135 tool_results: vec![],
1136 context: "".to_string(),
1137 creases: vec![],
1138 is_hidden: false,
1139 },
1140 SerializedMessage {
1141 id: MessageId(2),
1142 role: Role::Assistant,
1143 segments: vec![SerializedMessageSegment::Text {
1144 text: "I want to use a tool".to_string(),
1145 }],
1146 tool_uses: vec![SerializedToolUse {
1147 id: "abc".into(),
1148 name: "tool_1".into(),
1149 input: serde_json::Value::Null,
1150 }],
1151 tool_results: vec![],
1152 context: "".to_string(),
1153 creases: vec![],
1154 is_hidden: false,
1155 },
1156 SerializedMessage {
1157 id: MessageId(1),
1158 role: Role::User,
1159 segments: vec![SerializedMessageSegment::Text {
1160 text: "Here is the tool result".to_string(),
1161 }],
1162 tool_uses: vec![],
1163 tool_results: vec![SerializedToolResult {
1164 tool_use_id: "abc".into(),
1165 is_error: false,
1166 content: LanguageModelToolResultContent::Text("abcdef".into()),
1167 output: Some(serde_json::Value::Null),
1168 }],
1169 context: "".to_string(),
1170 creases: vec![],
1171 is_hidden: false,
1172 },
1173 ],
1174 version: SerializedThreadV0_1_0::VERSION.to_string(),
1175 initial_project_snapshot: None,
1176 cumulative_token_usage: TokenUsage::default(),
1177 request_token_usage: vec![],
1178 detailed_summary_state: DetailedSummaryState::default(),
1179 exceeded_window_error: None,
1180 model: None,
1181 completion_mode: None,
1182 tool_use_limit_reached: false,
1183 profile: None,
1184 });
1185 let upgraded = thread_v0_1_0.upgrade();
1186
1187 assert_eq!(
1188 upgraded,
1189 SerializedThread {
1190 summary: "Test conversation".into(),
1191 updated_at,
1192 messages: vec![
1193 SerializedMessage {
1194 id: MessageId(1),
1195 role: Role::User,
1196 segments: vec![SerializedMessageSegment::Text {
1197 text: "Use tool_1".to_string()
1198 }],
1199 tool_uses: vec![],
1200 tool_results: vec![],
1201 context: "".to_string(),
1202 creases: vec![],
1203 is_hidden: false
1204 },
1205 SerializedMessage {
1206 id: MessageId(2),
1207 role: Role::Assistant,
1208 segments: vec![SerializedMessageSegment::Text {
1209 text: "I want to use a tool".to_string(),
1210 }],
1211 tool_uses: vec![SerializedToolUse {
1212 id: "abc".into(),
1213 name: "tool_1".into(),
1214 input: serde_json::Value::Null,
1215 }],
1216 tool_results: vec![SerializedToolResult {
1217 tool_use_id: "abc".into(),
1218 is_error: false,
1219 content: LanguageModelToolResultContent::Text("abcdef".into()),
1220 output: Some(serde_json::Value::Null),
1221 }],
1222 context: "".to_string(),
1223 creases: vec![],
1224 is_hidden: false,
1225 },
1226 ],
1227 version: SerializedThread::VERSION.to_string(),
1228 initial_project_snapshot: None,
1229 cumulative_token_usage: TokenUsage::default(),
1230 request_token_usage: vec![],
1231 detailed_summary_state: DetailedSummaryState::default(),
1232 exceeded_window_error: None,
1233 model: None,
1234 completion_mode: None,
1235 tool_use_limit_reached: false,
1236 profile: None
1237 }
1238 )
1239 }
1240}