1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::{
14 FutureExt as _, StreamExt as _,
15 channel::{mpsc, oneshot},
16 future::{self, BoxFuture, Shared},
17};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, Window, prelude::*,
21};
22use indoc::indoc;
23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
24use project::context_server_store::{ContextServerStatus, ContextServerStore};
25use project::{Project, ProjectItem, ProjectPath, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use sqlez::{
32 bindable::{Bind, Column},
33 connection::Connection,
34 statement::Statement,
35};
36use std::{
37 cell::{Ref, RefCell},
38 path::{Path, PathBuf},
39 rc::Rc,
40 sync::{Arc, Mutex},
41};
42use util::ResultExt as _;
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum DataType {
46 #[serde(rename = "json")]
47 Json,
48 #[serde(rename = "zstd")]
49 Zstd,
50}
51
52impl Bind for DataType {
53 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
54 let value = match self {
55 DataType::Json => "json",
56 DataType::Zstd => "zstd",
57 };
58 value.bind(statement, start_index)
59 }
60}
61
62impl Column for DataType {
63 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
64 let (value, next_index) = String::column(statement, start_index)?;
65 let data_type = match value.as_str() {
66 "json" => DataType::Json,
67 "zstd" => DataType::Zstd,
68 _ => anyhow::bail!("Unknown data type: {}", value),
69 };
70 Ok((data_type, next_index))
71 }
72}
73
74const RULES_FILE_NAMES: [&'static str; 8] = [
75 ".rules",
76 ".cursorrules",
77 ".windsurfrules",
78 ".clinerules",
79 ".github/copilot-instructions.md",
80 "CLAUDE.md",
81 "AGENT.md",
82 "AGENTS.md",
83];
84
85pub fn init(cx: &mut App) {
86 ThreadsDatabase::init(cx);
87}
88
89/// A system prompt shared by all threads created by this ThreadStore
90#[derive(Clone, Default)]
91pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
92
93impl SharedProjectContext {
94 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
95 self.0.borrow()
96 }
97}
98
99pub type TextThreadStore = assistant_context::ContextStore;
100
101pub struct ThreadStore {
102 project: Entity<Project>,
103 tools: Entity<ToolWorkingSet>,
104 prompt_builder: Arc<PromptBuilder>,
105 prompt_store: Option<Entity<PromptStore>>,
106 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
107 threads: Vec<SerializedThreadMetadata>,
108 project_context: SharedProjectContext,
109 reload_system_prompt_tx: mpsc::Sender<()>,
110 _reload_system_prompt_task: Task<()>,
111 _subscriptions: Vec<Subscription>,
112}
113
114pub struct RulesLoadingError {
115 pub message: SharedString,
116}
117
118impl EventEmitter<RulesLoadingError> for ThreadStore {}
119
120impl ThreadStore {
121 pub fn load(
122 project: Entity<Project>,
123 tools: Entity<ToolWorkingSet>,
124 prompt_store: Option<Entity<PromptStore>>,
125 prompt_builder: Arc<PromptBuilder>,
126 cx: &mut App,
127 ) -> Task<Result<Entity<Self>>> {
128 cx.spawn(async move |cx| {
129 let (thread_store, ready_rx) = cx.update(|cx| {
130 let mut option_ready_rx = None;
131 let thread_store = cx.new(|cx| {
132 let (thread_store, ready_rx) =
133 Self::new(project, tools, prompt_builder, prompt_store, cx);
134 option_ready_rx = Some(ready_rx);
135 thread_store
136 });
137 (thread_store, option_ready_rx.take().unwrap())
138 })?;
139 ready_rx.await?;
140 Ok(thread_store)
141 })
142 }
143
144 fn new(
145 project: Entity<Project>,
146 tools: Entity<ToolWorkingSet>,
147 prompt_builder: Arc<PromptBuilder>,
148 prompt_store: Option<Entity<PromptStore>>,
149 cx: &mut Context<Self>,
150 ) -> (Self, oneshot::Receiver<()>) {
151 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
152
153 if let Some(prompt_store) = prompt_store.as_ref() {
154 subscriptions.push(cx.subscribe(
155 prompt_store,
156 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
157 this.enqueue_system_prompt_reload();
158 },
159 ))
160 }
161
162 // This channel and task prevent concurrent and redundant loading of the system prompt.
163 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
164 let (ready_tx, ready_rx) = oneshot::channel();
165 let mut ready_tx = Some(ready_tx);
166 let reload_system_prompt_task = cx.spawn({
167 let prompt_store = prompt_store.clone();
168 async move |thread_store, cx| {
169 loop {
170 let Some(reload_task) = thread_store
171 .update(cx, |thread_store, cx| {
172 thread_store.reload_system_prompt(prompt_store.clone(), cx)
173 })
174 .ok()
175 else {
176 return;
177 };
178 reload_task.await;
179 if let Some(ready_tx) = ready_tx.take() {
180 ready_tx.send(()).ok();
181 }
182 reload_system_prompt_rx.next().await;
183 }
184 }
185 });
186
187 let this = Self {
188 project,
189 tools,
190 prompt_builder,
191 prompt_store,
192 context_server_tool_ids: HashMap::default(),
193 threads: Vec::new(),
194 project_context: SharedProjectContext::default(),
195 reload_system_prompt_tx,
196 _reload_system_prompt_task: reload_system_prompt_task,
197 _subscriptions: subscriptions,
198 };
199 this.register_context_server_handlers(cx);
200 this.reload(cx).detach_and_log_err(cx);
201 (this, ready_rx)
202 }
203
204 fn handle_project_event(
205 &mut self,
206 _project: Entity<Project>,
207 event: &project::Event,
208 _cx: &mut Context<Self>,
209 ) {
210 match event {
211 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
212 self.enqueue_system_prompt_reload();
213 }
214 project::Event::WorktreeUpdatedEntries(_, items) => {
215 if items.iter().any(|(path, _, _)| {
216 RULES_FILE_NAMES
217 .iter()
218 .any(|name| path.as_ref() == Path::new(name))
219 }) {
220 self.enqueue_system_prompt_reload();
221 }
222 }
223 _ => {}
224 }
225 }
226
227 fn enqueue_system_prompt_reload(&mut self) {
228 self.reload_system_prompt_tx.try_send(()).ok();
229 }
230
231 // Note that this should only be called from `reload_system_prompt_task`.
232 fn reload_system_prompt(
233 &self,
234 prompt_store: Option<Entity<PromptStore>>,
235 cx: &mut Context<Self>,
236 ) -> Task<()> {
237 let worktrees = self
238 .project
239 .read(cx)
240 .visible_worktrees(cx)
241 .collect::<Vec<_>>();
242 let worktree_tasks = worktrees
243 .into_iter()
244 .map(|worktree| {
245 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
246 })
247 .collect::<Vec<_>>();
248 let default_user_rules_task = match prompt_store {
249 None => Task::ready(vec![]),
250 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
251 let prompts = prompt_store.default_prompt_metadata();
252 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
253 let contents = prompt_store.load(prompt_metadata.id, cx);
254 async move { (contents.await, prompt_metadata) }
255 });
256 cx.background_spawn(future::join_all(load_tasks))
257 }),
258 };
259
260 cx.spawn(async move |this, cx| {
261 let (worktrees, default_user_rules) =
262 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
263
264 let worktrees = worktrees
265 .into_iter()
266 .map(|(worktree, rules_error)| {
267 if let Some(rules_error) = rules_error {
268 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
269 }
270 worktree
271 })
272 .collect::<Vec<_>>();
273
274 let default_user_rules = default_user_rules
275 .into_iter()
276 .flat_map(|(contents, prompt_metadata)| match contents {
277 Ok(contents) => Some(UserRulesContext {
278 uuid: match prompt_metadata.id {
279 PromptId::User { uuid } => uuid,
280 PromptId::EditWorkflow => return None,
281 },
282 title: prompt_metadata.title.map(|title| title.to_string()),
283 contents,
284 }),
285 Err(err) => {
286 this.update(cx, |_, cx| {
287 cx.emit(RulesLoadingError {
288 message: format!("{err:?}").into(),
289 });
290 })
291 .ok();
292 None
293 }
294 })
295 .collect::<Vec<_>>();
296
297 this.update(cx, |this, _cx| {
298 *this.project_context.0.borrow_mut() =
299 Some(ProjectContext::new(worktrees, default_user_rules));
300 })
301 .ok();
302 })
303 }
304
305 fn load_worktree_info_for_system_prompt(
306 worktree: Entity<Worktree>,
307 project: Entity<Project>,
308 cx: &mut App,
309 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
310 let tree = worktree.read(cx);
311 let root_name = tree.root_name().into();
312 let abs_path = tree.abs_path();
313
314 let mut context = WorktreeContext {
315 root_name,
316 abs_path,
317 rules_file: None,
318 };
319
320 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
321 let Some(rules_task) = rules_task else {
322 return Task::ready((context, None));
323 };
324
325 cx.spawn(async move |_| {
326 let (rules_file, rules_file_error) = match rules_task.await {
327 Ok(rules_file) => (Some(rules_file), None),
328 Err(err) => (
329 None,
330 Some(RulesLoadingError {
331 message: format!("{err}").into(),
332 }),
333 ),
334 };
335 context.rules_file = rules_file;
336 (context, rules_file_error)
337 })
338 }
339
340 fn load_worktree_rules_file(
341 worktree: Entity<Worktree>,
342 project: Entity<Project>,
343 cx: &mut App,
344 ) -> Option<Task<Result<RulesFileContext>>> {
345 let worktree = worktree.read(cx);
346 let worktree_id = worktree.id();
347 let selected_rules_file = RULES_FILE_NAMES
348 .into_iter()
349 .filter_map(|name| {
350 worktree
351 .entry_for_path(name)
352 .filter(|entry| entry.is_file())
353 .map(|entry| entry.path.clone())
354 })
355 .next();
356
357 // Note that Cline supports `.clinerules` being a directory, but that is not currently
358 // supported. This doesn't seem to occur often in GitHub repositories.
359 selected_rules_file.map(|path_in_worktree| {
360 let project_path = ProjectPath {
361 worktree_id,
362 path: path_in_worktree.clone(),
363 };
364 let buffer_task =
365 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
366 let rope_task = cx.spawn(async move |cx| {
367 buffer_task.await?.read_with(cx, |buffer, cx| {
368 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
369 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
370 })?
371 });
372 // Build a string from the rope on a background thread.
373 cx.background_spawn(async move {
374 let (project_entry_id, rope) = rope_task.await?;
375 anyhow::Ok(RulesFileContext {
376 path_in_worktree,
377 text: rope.to_string().trim().to_string(),
378 project_entry_id: project_entry_id.to_usize(),
379 })
380 })
381 })
382 }
383
384 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
385 &self.prompt_store
386 }
387
388 pub fn tools(&self) -> Entity<ToolWorkingSet> {
389 self.tools.clone()
390 }
391
392 /// Returns the number of threads.
393 pub fn thread_count(&self) -> usize {
394 self.threads.len()
395 }
396
397 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
398 // ordering is from "ORDER BY" in `list_threads`
399 self.threads.iter()
400 }
401
402 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
403 cx.new(|cx| {
404 Thread::new(
405 self.project.clone(),
406 self.tools.clone(),
407 self.prompt_builder.clone(),
408 self.project_context.clone(),
409 cx,
410 )
411 })
412 }
413
414 pub fn create_thread_from_serialized(
415 &mut self,
416 serialized: SerializedThread,
417 cx: &mut Context<Self>,
418 ) -> Entity<Thread> {
419 cx.new(|cx| {
420 Thread::deserialize(
421 ThreadId::new(),
422 serialized,
423 self.project.clone(),
424 self.tools.clone(),
425 self.prompt_builder.clone(),
426 self.project_context.clone(),
427 None,
428 cx,
429 )
430 })
431 }
432
433 pub fn open_thread(
434 &self,
435 id: &ThreadId,
436 window: &mut Window,
437 cx: &mut Context<Self>,
438 ) -> Task<Result<Entity<Thread>>> {
439 let id = id.clone();
440 let database_future = ThreadsDatabase::global_future(cx);
441 let this = cx.weak_entity();
442 window.spawn(cx, async move |cx| {
443 let database = database_future.await.map_err(|err| anyhow!(err))?;
444 let thread = database
445 .try_find_thread(id.clone())
446 .await?
447 .with_context(|| format!("no thread found with ID: {id:?}"))?;
448
449 let thread = this.update_in(cx, |this, window, cx| {
450 cx.new(|cx| {
451 Thread::deserialize(
452 id.clone(),
453 thread,
454 this.project.clone(),
455 this.tools.clone(),
456 this.prompt_builder.clone(),
457 this.project_context.clone(),
458 Some(window),
459 cx,
460 )
461 })
462 })?;
463
464 Ok(thread)
465 })
466 }
467
468 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
469 let (metadata, serialized_thread) =
470 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
471
472 let database_future = ThreadsDatabase::global_future(cx);
473 cx.spawn(async move |this, cx| {
474 let serialized_thread = serialized_thread.await?;
475 let database = database_future.await.map_err(|err| anyhow!(err))?;
476 database.save_thread(metadata, serialized_thread).await?;
477
478 this.update(cx, |this, cx| this.reload(cx))?.await
479 })
480 }
481
482 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
483 let id = id.clone();
484 let database_future = ThreadsDatabase::global_future(cx);
485 cx.spawn(async move |this, cx| {
486 let database = database_future.await.map_err(|err| anyhow!(err))?;
487 database.delete_thread(id.clone()).await?;
488
489 this.update(cx, |this, cx| {
490 this.threads.retain(|thread| thread.id != id);
491 cx.notify();
492 })
493 })
494 }
495
496 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
497 let database_future = ThreadsDatabase::global_future(cx);
498 cx.spawn(async move |this, cx| {
499 let threads = database_future
500 .await
501 .map_err(|err| anyhow!(err))?
502 .list_threads()
503 .await?;
504
505 this.update(cx, |this, cx| {
506 this.threads = threads;
507 cx.notify();
508 })
509 })
510 }
511
512 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
513 let context_server_store = self.project.read(cx).context_server_store();
514 cx.subscribe(&context_server_store, Self::handle_context_server_event)
515 .detach();
516
517 // Check for any servers that were already running before the handler was registered
518 for server in context_server_store.read(cx).running_servers() {
519 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
520 }
521 }
522
523 fn handle_context_server_event(
524 &mut self,
525 context_server_store: Entity<ContextServerStore>,
526 event: &project::context_server_store::Event,
527 cx: &mut Context<Self>,
528 ) {
529 let tool_working_set = self.tools.clone();
530 match event {
531 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
532 match status {
533 ContextServerStatus::Starting => {}
534 ContextServerStatus::Running => {
535 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
536 }
537 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
538 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
539 tool_working_set.update(cx, |tool_working_set, _| {
540 tool_working_set.remove(&tool_ids);
541 });
542 }
543 }
544 }
545 }
546 }
547 }
548
549 fn load_context_server_tools(
550 &self,
551 server_id: ContextServerId,
552 context_server_store: Entity<ContextServerStore>,
553 cx: &mut Context<Self>,
554 ) {
555 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
556 return;
557 };
558 let tool_working_set = self.tools.clone();
559 cx.spawn(async move |this, cx| {
560 let Some(protocol) = server.client() else {
561 return;
562 };
563
564 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
565 if let Some(response) = protocol
566 .request::<context_server::types::requests::ListTools>(())
567 .await
568 .log_err()
569 {
570 let tool_ids = tool_working_set
571 .update(cx, |tool_working_set, _| {
572 response
573 .tools
574 .into_iter()
575 .map(|tool| {
576 log::info!("registering context server tool: {:?}", tool.name);
577 tool_working_set.insert(Arc::new(ContextServerTool::new(
578 context_server_store.clone(),
579 server.id(),
580 tool,
581 )))
582 })
583 .collect::<Vec<_>>()
584 })
585 .log_err();
586
587 if let Some(tool_ids) = tool_ids {
588 this.update(cx, |this, _| {
589 this.context_server_tool_ids.insert(server_id, tool_ids);
590 })
591 .log_err();
592 }
593 }
594 }
595 })
596 .detach();
597 }
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize)]
601pub struct SerializedThreadMetadata {
602 pub id: ThreadId,
603 pub summary: SharedString,
604 pub updated_at: DateTime<Utc>,
605}
606
607#[derive(Serialize, Deserialize, Debug, PartialEq)]
608pub struct SerializedThread {
609 pub version: String,
610 pub summary: SharedString,
611 pub updated_at: DateTime<Utc>,
612 pub messages: Vec<SerializedMessage>,
613 #[serde(default)]
614 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
615 #[serde(default)]
616 pub cumulative_token_usage: TokenUsage,
617 #[serde(default)]
618 pub request_token_usage: Vec<TokenUsage>,
619 #[serde(default)]
620 pub detailed_summary_state: DetailedSummaryState,
621 #[serde(default)]
622 pub exceeded_window_error: Option<ExceededWindowError>,
623 #[serde(default)]
624 pub model: Option<SerializedLanguageModel>,
625 #[serde(default)]
626 pub completion_mode: Option<CompletionMode>,
627 #[serde(default)]
628 pub tool_use_limit_reached: bool,
629 #[serde(default)]
630 pub profile: Option<AgentProfileId>,
631}
632
633#[derive(Serialize, Deserialize, Debug, PartialEq)]
634pub struct SerializedLanguageModel {
635 pub provider: String,
636 pub model: String,
637}
638
639impl SerializedThread {
640 pub const VERSION: &'static str = "0.2.0";
641
642 pub fn from_json(json: &[u8]) -> Result<Self> {
643 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
644 match saved_thread_json.get("version") {
645 Some(serde_json::Value::String(version)) => match version.as_str() {
646 SerializedThreadV0_1_0::VERSION => {
647 let saved_thread =
648 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
649 Ok(saved_thread.upgrade())
650 }
651 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
652 saved_thread_json,
653 )?),
654 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
655 },
656 None => {
657 let saved_thread =
658 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
659 Ok(saved_thread.upgrade())
660 }
661 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
662 }
663 }
664}
665
666#[derive(Serialize, Deserialize, Debug)]
667pub struct SerializedThreadV0_1_0(
668 // The structure did not change, so we are reusing the latest SerializedThread.
669 // When making the next version, make sure this points to SerializedThreadV0_2_0
670 SerializedThread,
671);
672
673impl SerializedThreadV0_1_0 {
674 pub const VERSION: &'static str = "0.1.0";
675
676 pub fn upgrade(self) -> SerializedThread {
677 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
678
679 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
680
681 for message in self.0.messages {
682 if message.role == Role::User && !message.tool_results.is_empty() {
683 if let Some(last_message) = messages.last_mut() {
684 debug_assert!(last_message.role == Role::Assistant);
685
686 last_message.tool_results = message.tool_results;
687 continue;
688 }
689 }
690
691 messages.push(message);
692 }
693
694 SerializedThread {
695 messages,
696 version: SerializedThread::VERSION.to_string(),
697 ..self.0
698 }
699 }
700}
701
702#[derive(Debug, Serialize, Deserialize, PartialEq)]
703pub struct SerializedMessage {
704 pub id: MessageId,
705 pub role: Role,
706 #[serde(default)]
707 pub segments: Vec<SerializedMessageSegment>,
708 #[serde(default)]
709 pub tool_uses: Vec<SerializedToolUse>,
710 #[serde(default)]
711 pub tool_results: Vec<SerializedToolResult>,
712 #[serde(default)]
713 pub context: String,
714 #[serde(default)]
715 pub creases: Vec<SerializedCrease>,
716 #[serde(default)]
717 pub is_hidden: bool,
718}
719
720#[derive(Debug, Serialize, Deserialize, PartialEq)]
721#[serde(tag = "type")]
722pub enum SerializedMessageSegment {
723 #[serde(rename = "text")]
724 Text {
725 text: String,
726 },
727 #[serde(rename = "thinking")]
728 Thinking {
729 text: String,
730 #[serde(skip_serializing_if = "Option::is_none")]
731 signature: Option<String>,
732 },
733 RedactedThinking {
734 data: String,
735 },
736}
737
738#[derive(Debug, Serialize, Deserialize, PartialEq)]
739pub struct SerializedToolUse {
740 pub id: LanguageModelToolUseId,
741 pub name: SharedString,
742 pub input: serde_json::Value,
743}
744
745#[derive(Debug, Serialize, Deserialize, PartialEq)]
746pub struct SerializedToolResult {
747 pub tool_use_id: LanguageModelToolUseId,
748 pub is_error: bool,
749 pub content: LanguageModelToolResultContent,
750 pub output: Option<serde_json::Value>,
751}
752
753#[derive(Serialize, Deserialize)]
754struct LegacySerializedThread {
755 pub summary: SharedString,
756 pub updated_at: DateTime<Utc>,
757 pub messages: Vec<LegacySerializedMessage>,
758 #[serde(default)]
759 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
760}
761
762impl LegacySerializedThread {
763 pub fn upgrade(self) -> SerializedThread {
764 SerializedThread {
765 version: SerializedThread::VERSION.to_string(),
766 summary: self.summary,
767 updated_at: self.updated_at,
768 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
769 initial_project_snapshot: self.initial_project_snapshot,
770 cumulative_token_usage: TokenUsage::default(),
771 request_token_usage: Vec::new(),
772 detailed_summary_state: DetailedSummaryState::default(),
773 exceeded_window_error: None,
774 model: None,
775 completion_mode: None,
776 tool_use_limit_reached: false,
777 profile: None,
778 }
779 }
780}
781
782#[derive(Debug, Serialize, Deserialize)]
783struct LegacySerializedMessage {
784 pub id: MessageId,
785 pub role: Role,
786 pub text: String,
787 #[serde(default)]
788 pub tool_uses: Vec<SerializedToolUse>,
789 #[serde(default)]
790 pub tool_results: Vec<SerializedToolResult>,
791}
792
793impl LegacySerializedMessage {
794 fn upgrade(self) -> SerializedMessage {
795 SerializedMessage {
796 id: self.id,
797 role: self.role,
798 segments: vec![SerializedMessageSegment::Text { text: self.text }],
799 tool_uses: self.tool_uses,
800 tool_results: self.tool_results,
801 context: String::new(),
802 creases: Vec::new(),
803 is_hidden: false,
804 }
805 }
806}
807
808#[derive(Debug, Serialize, Deserialize, PartialEq)]
809pub struct SerializedCrease {
810 pub start: usize,
811 pub end: usize,
812 pub icon_path: SharedString,
813 pub label: SharedString,
814}
815
816struct GlobalThreadsDatabase(
817 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
818);
819
820impl Global for GlobalThreadsDatabase {}
821
822pub(crate) struct ThreadsDatabase {
823 executor: BackgroundExecutor,
824 connection: Arc<Mutex<Connection>>,
825}
826
827impl ThreadsDatabase {
828 fn connection(&self) -> Arc<Mutex<Connection>> {
829 self.connection.clone()
830 }
831
832 const COMPRESSION_LEVEL: i32 = 3;
833}
834
835impl Bind for ThreadId {
836 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
837 self.to_string().bind(statement, start_index)
838 }
839}
840
841impl Column for ThreadId {
842 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
843 let (id_str, next_index) = String::column(statement, start_index)?;
844 Ok((ThreadId::from(id_str.as_str()), next_index))
845 }
846}
847
848impl ThreadsDatabase {
849 fn global_future(
850 cx: &mut App,
851 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
852 GlobalThreadsDatabase::global(cx).0.clone()
853 }
854
855 fn init(cx: &mut App) {
856 let executor = cx.background_executor().clone();
857 let database_future = executor
858 .spawn({
859 let executor = executor.clone();
860 let threads_dir = paths::data_dir().join("threads");
861 async move { ThreadsDatabase::new(threads_dir, executor) }
862 })
863 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
864 .boxed()
865 .shared();
866
867 cx.set_global(GlobalThreadsDatabase(database_future));
868 }
869
870 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
871 std::fs::create_dir_all(&threads_dir)?;
872
873 let sqlite_path = threads_dir.join("threads.db");
874 let mdb_path = threads_dir.join("threads-db.1.mdb");
875
876 let needs_migration_from_heed = mdb_path.exists();
877
878 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
879
880 connection.exec(indoc! {"
881 CREATE TABLE IF NOT EXISTS threads (
882 id TEXT PRIMARY KEY,
883 summary TEXT NOT NULL,
884 updated_at TEXT NOT NULL,
885 data_type TEXT NOT NULL,
886 data BLOB NOT NULL
887 )
888 "})?()
889 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
890
891 let db = Self {
892 executor: executor.clone(),
893 connection: Arc::new(Mutex::new(connection)),
894 };
895
896 if needs_migration_from_heed {
897 let db_connection = db.connection();
898 let executor_clone = executor.clone();
899 executor
900 .spawn(async move {
901 log::info!("Starting threads.db migration");
902 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
903 std::fs::remove_dir_all(mdb_path)?;
904 log::info!("threads.db migrated to sqlite");
905 Ok::<(), anyhow::Error>(())
906 })
907 .detach();
908 }
909
910 Ok(db)
911 }
912
913 // Remove this migration after 2025-09-01
914 fn migrate_from_heed(
915 mdb_path: &Path,
916 connection: Arc<Mutex<Connection>>,
917 _executor: BackgroundExecutor,
918 ) -> Result<()> {
919 use heed::types::SerdeBincode;
920 struct SerializedThreadHeed(SerializedThread);
921
922 impl heed::BytesEncode<'_> for SerializedThreadHeed {
923 type EItem = SerializedThreadHeed;
924
925 fn bytes_encode(
926 item: &Self::EItem,
927 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
928 serde_json::to_vec(&item.0)
929 .map(std::borrow::Cow::Owned)
930 .map_err(Into::into)
931 }
932 }
933
934 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
935 type DItem = SerializedThreadHeed;
936
937 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
938 SerializedThread::from_json(bytes)
939 .map(SerializedThreadHeed)
940 .map_err(Into::into)
941 }
942 }
943
944 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
945
946 let env = unsafe {
947 heed::EnvOpenOptions::new()
948 .map_size(ONE_GB_IN_BYTES)
949 .max_dbs(1)
950 .open(mdb_path)?
951 };
952
953 let txn = env.write_txn()?;
954 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
955 .open_database(&txn, Some("threads"))?
956 .ok_or_else(|| anyhow!("threads database not found"))?;
957
958 for result in threads.iter(&txn)? {
959 let (thread_id, thread_heed) = result?;
960 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
961 }
962
963 Ok(())
964 }
965
966 fn save_thread_sync(
967 connection: &Arc<Mutex<Connection>>,
968 id: ThreadId,
969 thread: SerializedThread,
970 ) -> Result<()> {
971 let json_data = serde_json::to_string(&thread)?;
972 let summary = thread.summary.to_string();
973 let updated_at = thread.updated_at.to_rfc3339();
974
975 let connection = connection.lock().unwrap();
976
977 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
978 let data_type = DataType::Zstd;
979 let data = compressed;
980
981 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
982 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
983 "})?;
984
985 insert((id, summary, updated_at, data_type, data))?;
986
987 Ok(())
988 }
989
990 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
991 let connection = self.connection.clone();
992
993 self.executor.spawn(async move {
994 let connection = connection.lock().unwrap();
995 let mut select =
996 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
997 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
998 "})?;
999
1000 let rows = select(())?;
1001 let mut threads = Vec::new();
1002
1003 for (id, summary, updated_at) in rows {
1004 threads.push(SerializedThreadMetadata {
1005 id,
1006 summary: summary.into(),
1007 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1008 });
1009 }
1010
1011 Ok(threads)
1012 })
1013 }
1014
1015 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1016 let connection = self.connection.clone();
1017
1018 self.executor.spawn(async move {
1019 let connection = connection.lock().unwrap();
1020 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1021 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1022 "})?;
1023
1024 let rows = select(id)?;
1025 if let Some((data_type, data)) = rows.into_iter().next() {
1026 let json_data = match data_type {
1027 DataType::Zstd => {
1028 let decompressed = zstd::decode_all(&data[..])?;
1029 String::from_utf8(decompressed)?
1030 }
1031 DataType::Json => String::from_utf8(data)?,
1032 };
1033
1034 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1035 Ok(Some(thread))
1036 } else {
1037 Ok(None)
1038 }
1039 })
1040 }
1041
1042 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1043 let connection = self.connection.clone();
1044
1045 self.executor
1046 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1047 }
1048
1049 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1050 let connection = self.connection.clone();
1051
1052 self.executor.spawn(async move {
1053 let connection = connection.lock().unwrap();
1054
1055 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1056 DELETE FROM threads WHERE id = ?
1057 "})?;
1058
1059 delete(id)?;
1060
1061 Ok(())
1062 })
1063 }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::*;
1069 use crate::thread::{DetailedSummaryState, MessageId};
1070 use chrono::Utc;
1071 use language_model::{Role, TokenUsage};
1072 use pretty_assertions::assert_eq;
1073
1074 #[test]
1075 fn test_legacy_serialized_thread_upgrade() {
1076 let updated_at = Utc::now();
1077 let legacy_thread = LegacySerializedThread {
1078 summary: "Test conversation".into(),
1079 updated_at,
1080 messages: vec![LegacySerializedMessage {
1081 id: MessageId(1),
1082 role: Role::User,
1083 text: "Hello, world!".to_string(),
1084 tool_uses: vec![],
1085 tool_results: vec![],
1086 }],
1087 initial_project_snapshot: None,
1088 };
1089
1090 let upgraded = legacy_thread.upgrade();
1091
1092 assert_eq!(
1093 upgraded,
1094 SerializedThread {
1095 summary: "Test conversation".into(),
1096 updated_at,
1097 messages: vec![SerializedMessage {
1098 id: MessageId(1),
1099 role: Role::User,
1100 segments: vec![SerializedMessageSegment::Text {
1101 text: "Hello, world!".to_string()
1102 }],
1103 tool_uses: vec![],
1104 tool_results: vec![],
1105 context: "".to_string(),
1106 creases: vec![],
1107 is_hidden: false
1108 }],
1109 version: SerializedThread::VERSION.to_string(),
1110 initial_project_snapshot: None,
1111 cumulative_token_usage: TokenUsage::default(),
1112 request_token_usage: vec![],
1113 detailed_summary_state: DetailedSummaryState::default(),
1114 exceeded_window_error: None,
1115 model: None,
1116 completion_mode: None,
1117 tool_use_limit_reached: false,
1118 profile: None
1119 }
1120 )
1121 }
1122
1123 #[test]
1124 fn test_serialized_threadv0_1_0_upgrade() {
1125 let updated_at = Utc::now();
1126 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1127 summary: "Test conversation".into(),
1128 updated_at,
1129 messages: vec![
1130 SerializedMessage {
1131 id: MessageId(1),
1132 role: Role::User,
1133 segments: vec![SerializedMessageSegment::Text {
1134 text: "Use tool_1".to_string(),
1135 }],
1136 tool_uses: vec![],
1137 tool_results: vec![],
1138 context: "".to_string(),
1139 creases: vec![],
1140 is_hidden: false,
1141 },
1142 SerializedMessage {
1143 id: MessageId(2),
1144 role: Role::Assistant,
1145 segments: vec![SerializedMessageSegment::Text {
1146 text: "I want to use a tool".to_string(),
1147 }],
1148 tool_uses: vec![SerializedToolUse {
1149 id: "abc".into(),
1150 name: "tool_1".into(),
1151 input: serde_json::Value::Null,
1152 }],
1153 tool_results: vec![],
1154 context: "".to_string(),
1155 creases: vec![],
1156 is_hidden: false,
1157 },
1158 SerializedMessage {
1159 id: MessageId(1),
1160 role: Role::User,
1161 segments: vec![SerializedMessageSegment::Text {
1162 text: "Here is the tool result".to_string(),
1163 }],
1164 tool_uses: vec![],
1165 tool_results: vec![SerializedToolResult {
1166 tool_use_id: "abc".into(),
1167 is_error: false,
1168 content: LanguageModelToolResultContent::Text("abcdef".into()),
1169 output: Some(serde_json::Value::Null),
1170 }],
1171 context: "".to_string(),
1172 creases: vec![],
1173 is_hidden: false,
1174 },
1175 ],
1176 version: SerializedThreadV0_1_0::VERSION.to_string(),
1177 initial_project_snapshot: None,
1178 cumulative_token_usage: TokenUsage::default(),
1179 request_token_usage: vec![],
1180 detailed_summary_state: DetailedSummaryState::default(),
1181 exceeded_window_error: None,
1182 model: None,
1183 completion_mode: None,
1184 tool_use_limit_reached: false,
1185 profile: None,
1186 });
1187 let upgraded = thread_v0_1_0.upgrade();
1188
1189 assert_eq!(
1190 upgraded,
1191 SerializedThread {
1192 summary: "Test conversation".into(),
1193 updated_at,
1194 messages: vec![
1195 SerializedMessage {
1196 id: MessageId(1),
1197 role: Role::User,
1198 segments: vec![SerializedMessageSegment::Text {
1199 text: "Use tool_1".to_string()
1200 }],
1201 tool_uses: vec![],
1202 tool_results: vec![],
1203 context: "".to_string(),
1204 creases: vec![],
1205 is_hidden: false
1206 },
1207 SerializedMessage {
1208 id: MessageId(2),
1209 role: Role::Assistant,
1210 segments: vec![SerializedMessageSegment::Text {
1211 text: "I want to use a tool".to_string(),
1212 }],
1213 tool_uses: vec![SerializedToolUse {
1214 id: "abc".into(),
1215 name: "tool_1".into(),
1216 input: serde_json::Value::Null,
1217 }],
1218 tool_results: vec![SerializedToolResult {
1219 tool_use_id: "abc".into(),
1220 is_error: false,
1221 content: LanguageModelToolResultContent::Text("abcdef".into()),
1222 output: Some(serde_json::Value::Null),
1223 }],
1224 context: "".to_string(),
1225 creases: vec![],
1226 is_hidden: false,
1227 },
1228 ],
1229 version: SerializedThread::VERSION.to_string(),
1230 initial_project_snapshot: None,
1231 cumulative_token_usage: TokenUsage::default(),
1232 request_token_usage: vec![],
1233 detailed_summary_state: DetailedSummaryState::default(),
1234 exceeded_window_error: None,
1235 model: None,
1236 completion_mode: None,
1237 tool_use_limit_reached: false,
1238 profile: None
1239 }
1240 )
1241 }
1242}