1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::{
14 FutureExt as _, StreamExt as _,
15 channel::{mpsc, oneshot},
16 future::{self, BoxFuture, Shared},
17};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, Window, prelude::*,
21};
22use indoc::indoc;
23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
24use project::context_server_store::{ContextServerStatus, ContextServerStore};
25use project::{Project, ProjectItem, ProjectPath, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use sqlez::{
32 bindable::{Bind, Column},
33 connection::Connection,
34 statement::Statement,
35};
36use std::{
37 cell::{Ref, RefCell},
38 path::{Path, PathBuf},
39 rc::Rc,
40 sync::{Arc, Mutex},
41};
42use util::ResultExt as _;
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum DataType {
46 #[serde(rename = "json")]
47 Json,
48 #[serde(rename = "zstd")]
49 Zstd,
50}
51
52impl Bind for DataType {
53 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
54 let value = match self {
55 DataType::Json => "json",
56 DataType::Zstd => "zstd",
57 };
58 value.bind(statement, start_index)
59 }
60}
61
62impl Column for DataType {
63 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
64 let (value, next_index) = String::column(statement, start_index)?;
65 let data_type = match value.as_str() {
66 "json" => DataType::Json,
67 "zstd" => DataType::Zstd,
68 _ => anyhow::bail!("Unknown data type: {}", value),
69 };
70 Ok((data_type, next_index))
71 }
72}
73
74const RULES_FILE_NAMES: [&'static str; 9] = [
75 ".rules",
76 ".cursorrules",
77 ".windsurfrules",
78 ".clinerules",
79 ".github/copilot-instructions.md",
80 "CLAUDE.md",
81 "AGENT.md",
82 "AGENTS.md",
83 "GEMINI.md",
84];
85
86pub fn init(cx: &mut App) {
87 ThreadsDatabase::init(cx);
88}
89
90/// A system prompt shared by all threads created by this ThreadStore
91#[derive(Clone, Default)]
92pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
93
94impl SharedProjectContext {
95 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
96 self.0.borrow()
97 }
98}
99
100pub type TextThreadStore = assistant_context::ContextStore;
101
102pub struct ThreadStore {
103 project: Entity<Project>,
104 tools: Entity<ToolWorkingSet>,
105 prompt_builder: Arc<PromptBuilder>,
106 prompt_store: Option<Entity<PromptStore>>,
107 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
108 threads: Vec<SerializedThreadMetadata>,
109 project_context: SharedProjectContext,
110 reload_system_prompt_tx: mpsc::Sender<()>,
111 _reload_system_prompt_task: Task<()>,
112 _subscriptions: Vec<Subscription>,
113}
114
115pub struct RulesLoadingError {
116 pub message: SharedString,
117}
118
119impl EventEmitter<RulesLoadingError> for ThreadStore {}
120
121impl ThreadStore {
122 pub fn load(
123 project: Entity<Project>,
124 tools: Entity<ToolWorkingSet>,
125 prompt_store: Option<Entity<PromptStore>>,
126 prompt_builder: Arc<PromptBuilder>,
127 cx: &mut App,
128 ) -> Task<Result<Entity<Self>>> {
129 cx.spawn(async move |cx| {
130 let (thread_store, ready_rx) = cx.update(|cx| {
131 let mut option_ready_rx = None;
132 let thread_store = cx.new(|cx| {
133 let (thread_store, ready_rx) =
134 Self::new(project, tools, prompt_builder, prompt_store, cx);
135 option_ready_rx = Some(ready_rx);
136 thread_store
137 });
138 (thread_store, option_ready_rx.take().unwrap())
139 })?;
140 ready_rx.await?;
141 Ok(thread_store)
142 })
143 }
144
145 fn new(
146 project: Entity<Project>,
147 tools: Entity<ToolWorkingSet>,
148 prompt_builder: Arc<PromptBuilder>,
149 prompt_store: Option<Entity<PromptStore>>,
150 cx: &mut Context<Self>,
151 ) -> (Self, oneshot::Receiver<()>) {
152 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
153
154 if let Some(prompt_store) = prompt_store.as_ref() {
155 subscriptions.push(cx.subscribe(
156 prompt_store,
157 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
158 this.enqueue_system_prompt_reload();
159 },
160 ))
161 }
162
163 // This channel and task prevent concurrent and redundant loading of the system prompt.
164 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
165 let (ready_tx, ready_rx) = oneshot::channel();
166 let mut ready_tx = Some(ready_tx);
167 let reload_system_prompt_task = cx.spawn({
168 let prompt_store = prompt_store.clone();
169 async move |thread_store, cx| {
170 loop {
171 let Some(reload_task) = thread_store
172 .update(cx, |thread_store, cx| {
173 thread_store.reload_system_prompt(prompt_store.clone(), cx)
174 })
175 .ok()
176 else {
177 return;
178 };
179 reload_task.await;
180 if let Some(ready_tx) = ready_tx.take() {
181 ready_tx.send(()).ok();
182 }
183 reload_system_prompt_rx.next().await;
184 }
185 }
186 });
187
188 let this = Self {
189 project,
190 tools,
191 prompt_builder,
192 prompt_store,
193 context_server_tool_ids: HashMap::default(),
194 threads: Vec::new(),
195 project_context: SharedProjectContext::default(),
196 reload_system_prompt_tx,
197 _reload_system_prompt_task: reload_system_prompt_task,
198 _subscriptions: subscriptions,
199 };
200 this.register_context_server_handlers(cx);
201 this.reload(cx).detach_and_log_err(cx);
202 (this, ready_rx)
203 }
204
205 fn handle_project_event(
206 &mut self,
207 _project: Entity<Project>,
208 event: &project::Event,
209 _cx: &mut Context<Self>,
210 ) {
211 match event {
212 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
213 self.enqueue_system_prompt_reload();
214 }
215 project::Event::WorktreeUpdatedEntries(_, items) => {
216 if items.iter().any(|(path, _, _)| {
217 RULES_FILE_NAMES
218 .iter()
219 .any(|name| path.as_ref() == Path::new(name))
220 }) {
221 self.enqueue_system_prompt_reload();
222 }
223 }
224 _ => {}
225 }
226 }
227
228 fn enqueue_system_prompt_reload(&mut self) {
229 self.reload_system_prompt_tx.try_send(()).ok();
230 }
231
232 // Note that this should only be called from `reload_system_prompt_task`.
233 fn reload_system_prompt(
234 &self,
235 prompt_store: Option<Entity<PromptStore>>,
236 cx: &mut Context<Self>,
237 ) -> Task<()> {
238 let worktrees = self
239 .project
240 .read(cx)
241 .visible_worktrees(cx)
242 .collect::<Vec<_>>();
243 let worktree_tasks = worktrees
244 .into_iter()
245 .map(|worktree| {
246 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
247 })
248 .collect::<Vec<_>>();
249 let default_user_rules_task = match prompt_store {
250 None => Task::ready(vec![]),
251 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
252 let prompts = prompt_store.default_prompt_metadata();
253 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
254 let contents = prompt_store.load(prompt_metadata.id, cx);
255 async move { (contents.await, prompt_metadata) }
256 });
257 cx.background_spawn(future::join_all(load_tasks))
258 }),
259 };
260
261 cx.spawn(async move |this, cx| {
262 let (worktrees, default_user_rules) =
263 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
264
265 let worktrees = worktrees
266 .into_iter()
267 .map(|(worktree, rules_error)| {
268 if let Some(rules_error) = rules_error {
269 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
270 }
271 worktree
272 })
273 .collect::<Vec<_>>();
274
275 let default_user_rules = default_user_rules
276 .into_iter()
277 .flat_map(|(contents, prompt_metadata)| match contents {
278 Ok(contents) => Some(UserRulesContext {
279 uuid: match prompt_metadata.id {
280 PromptId::User { uuid } => uuid,
281 PromptId::EditWorkflow => return None,
282 },
283 title: prompt_metadata.title.map(|title| title.to_string()),
284 contents,
285 }),
286 Err(err) => {
287 this.update(cx, |_, cx| {
288 cx.emit(RulesLoadingError {
289 message: format!("{err:?}").into(),
290 });
291 })
292 .ok();
293 None
294 }
295 })
296 .collect::<Vec<_>>();
297
298 this.update(cx, |this, _cx| {
299 *this.project_context.0.borrow_mut() =
300 Some(ProjectContext::new(worktrees, default_user_rules));
301 })
302 .ok();
303 })
304 }
305
306 fn load_worktree_info_for_system_prompt(
307 worktree: Entity<Worktree>,
308 project: Entity<Project>,
309 cx: &mut App,
310 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
311 let tree = worktree.read(cx);
312 let root_name = tree.root_name().into();
313 let abs_path = tree.abs_path();
314
315 let mut context = WorktreeContext {
316 root_name,
317 abs_path,
318 rules_file: None,
319 };
320
321 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
322 let Some(rules_task) = rules_task else {
323 return Task::ready((context, None));
324 };
325
326 cx.spawn(async move |_| {
327 let (rules_file, rules_file_error) = match rules_task.await {
328 Ok(rules_file) => (Some(rules_file), None),
329 Err(err) => (
330 None,
331 Some(RulesLoadingError {
332 message: format!("{err}").into(),
333 }),
334 ),
335 };
336 context.rules_file = rules_file;
337 (context, rules_file_error)
338 })
339 }
340
341 fn load_worktree_rules_file(
342 worktree: Entity<Worktree>,
343 project: Entity<Project>,
344 cx: &mut App,
345 ) -> Option<Task<Result<RulesFileContext>>> {
346 let worktree = worktree.read(cx);
347 let worktree_id = worktree.id();
348 let selected_rules_file = RULES_FILE_NAMES
349 .into_iter()
350 .filter_map(|name| {
351 worktree
352 .entry_for_path(name)
353 .filter(|entry| entry.is_file())
354 .map(|entry| entry.path.clone())
355 })
356 .next();
357
358 // Note that Cline supports `.clinerules` being a directory, but that is not currently
359 // supported. This doesn't seem to occur often in GitHub repositories.
360 selected_rules_file.map(|path_in_worktree| {
361 let project_path = ProjectPath {
362 worktree_id,
363 path: path_in_worktree.clone(),
364 };
365 let buffer_task =
366 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
367 let rope_task = cx.spawn(async move |cx| {
368 buffer_task.await?.read_with(cx, |buffer, cx| {
369 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
370 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
371 })?
372 });
373 // Build a string from the rope on a background thread.
374 cx.background_spawn(async move {
375 let (project_entry_id, rope) = rope_task.await?;
376 anyhow::Ok(RulesFileContext {
377 path_in_worktree,
378 text: rope.to_string().trim().to_string(),
379 project_entry_id: project_entry_id.to_usize(),
380 })
381 })
382 })
383 }
384
385 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
386 &self.prompt_store
387 }
388
389 pub fn tools(&self) -> Entity<ToolWorkingSet> {
390 self.tools.clone()
391 }
392
393 /// Returns the number of threads.
394 pub fn thread_count(&self) -> usize {
395 self.threads.len()
396 }
397
398 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
399 // ordering is from "ORDER BY" in `list_threads`
400 self.threads.iter()
401 }
402
403 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
404 cx.new(|cx| {
405 Thread::new(
406 self.project.clone(),
407 self.tools.clone(),
408 self.prompt_builder.clone(),
409 self.project_context.clone(),
410 cx,
411 )
412 })
413 }
414
415 pub fn create_thread_from_serialized(
416 &mut self,
417 serialized: SerializedThread,
418 cx: &mut Context<Self>,
419 ) -> Entity<Thread> {
420 cx.new(|cx| {
421 Thread::deserialize(
422 ThreadId::new(),
423 serialized,
424 self.project.clone(),
425 self.tools.clone(),
426 self.prompt_builder.clone(),
427 self.project_context.clone(),
428 None,
429 cx,
430 )
431 })
432 }
433
434 pub fn open_thread(
435 &self,
436 id: &ThreadId,
437 window: &mut Window,
438 cx: &mut Context<Self>,
439 ) -> Task<Result<Entity<Thread>>> {
440 let id = id.clone();
441 let database_future = ThreadsDatabase::global_future(cx);
442 let this = cx.weak_entity();
443 window.spawn(cx, async move |cx| {
444 let database = database_future.await.map_err(|err| anyhow!(err))?;
445 let thread = database
446 .try_find_thread(id.clone())
447 .await?
448 .with_context(|| format!("no thread found with ID: {id:?}"))?;
449
450 let thread = this.update_in(cx, |this, window, cx| {
451 cx.new(|cx| {
452 Thread::deserialize(
453 id.clone(),
454 thread,
455 this.project.clone(),
456 this.tools.clone(),
457 this.prompt_builder.clone(),
458 this.project_context.clone(),
459 Some(window),
460 cx,
461 )
462 })
463 })?;
464
465 Ok(thread)
466 })
467 }
468
469 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
470 let (metadata, serialized_thread) =
471 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
472
473 let database_future = ThreadsDatabase::global_future(cx);
474 cx.spawn(async move |this, cx| {
475 let serialized_thread = serialized_thread.await?;
476 let database = database_future.await.map_err(|err| anyhow!(err))?;
477 database.save_thread(metadata, serialized_thread).await?;
478
479 this.update(cx, |this, cx| this.reload(cx))?.await
480 })
481 }
482
483 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
484 let id = id.clone();
485 let database_future = ThreadsDatabase::global_future(cx);
486 cx.spawn(async move |this, cx| {
487 let database = database_future.await.map_err(|err| anyhow!(err))?;
488 database.delete_thread(id.clone()).await?;
489
490 this.update(cx, |this, cx| {
491 this.threads.retain(|thread| thread.id != id);
492 cx.notify();
493 })
494 })
495 }
496
497 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
498 let database_future = ThreadsDatabase::global_future(cx);
499 cx.spawn(async move |this, cx| {
500 let threads = database_future
501 .await
502 .map_err(|err| anyhow!(err))?
503 .list_threads()
504 .await?;
505
506 this.update(cx, |this, cx| {
507 this.threads = threads;
508 cx.notify();
509 })
510 })
511 }
512
513 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
514 let context_server_store = self.project.read(cx).context_server_store();
515 cx.subscribe(&context_server_store, Self::handle_context_server_event)
516 .detach();
517
518 // Check for any servers that were already running before the handler was registered
519 for server in context_server_store.read(cx).running_servers() {
520 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
521 }
522 }
523
524 fn handle_context_server_event(
525 &mut self,
526 context_server_store: Entity<ContextServerStore>,
527 event: &project::context_server_store::Event,
528 cx: &mut Context<Self>,
529 ) {
530 let tool_working_set = self.tools.clone();
531 match event {
532 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
533 match status {
534 ContextServerStatus::Starting => {}
535 ContextServerStatus::Running => {
536 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
537 }
538 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
539 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
540 tool_working_set.update(cx, |tool_working_set, _| {
541 tool_working_set.remove(&tool_ids);
542 });
543 }
544 }
545 }
546 }
547 }
548 }
549
550 fn load_context_server_tools(
551 &self,
552 server_id: ContextServerId,
553 context_server_store: Entity<ContextServerStore>,
554 cx: &mut Context<Self>,
555 ) {
556 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
557 return;
558 };
559 let tool_working_set = self.tools.clone();
560 cx.spawn(async move |this, cx| {
561 let Some(protocol) = server.client() else {
562 return;
563 };
564
565 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
566 if let Some(response) = protocol
567 .request::<context_server::types::requests::ListTools>(())
568 .await
569 .log_err()
570 {
571 let tool_ids = tool_working_set
572 .update(cx, |tool_working_set, _| {
573 response
574 .tools
575 .into_iter()
576 .map(|tool| {
577 log::info!("registering context server tool: {:?}", tool.name);
578 tool_working_set.insert(Arc::new(ContextServerTool::new(
579 context_server_store.clone(),
580 server.id(),
581 tool,
582 )))
583 })
584 .collect::<Vec<_>>()
585 })
586 .log_err();
587
588 if let Some(tool_ids) = tool_ids {
589 this.update(cx, |this, _| {
590 this.context_server_tool_ids.insert(server_id, tool_ids);
591 })
592 .log_err();
593 }
594 }
595 }
596 })
597 .detach();
598 }
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize)]
602pub struct SerializedThreadMetadata {
603 pub id: ThreadId,
604 pub summary: SharedString,
605 pub updated_at: DateTime<Utc>,
606}
607
608#[derive(Serialize, Deserialize, Debug, PartialEq)]
609pub struct SerializedThread {
610 pub version: String,
611 pub summary: SharedString,
612 pub updated_at: DateTime<Utc>,
613 pub messages: Vec<SerializedMessage>,
614 #[serde(default)]
615 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
616 #[serde(default)]
617 pub cumulative_token_usage: TokenUsage,
618 #[serde(default)]
619 pub request_token_usage: Vec<TokenUsage>,
620 #[serde(default)]
621 pub detailed_summary_state: DetailedSummaryState,
622 #[serde(default)]
623 pub exceeded_window_error: Option<ExceededWindowError>,
624 #[serde(default)]
625 pub model: Option<SerializedLanguageModel>,
626 #[serde(default)]
627 pub completion_mode: Option<CompletionMode>,
628 #[serde(default)]
629 pub tool_use_limit_reached: bool,
630 #[serde(default)]
631 pub profile: Option<AgentProfileId>,
632}
633
634#[derive(Serialize, Deserialize, Debug, PartialEq)]
635pub struct SerializedLanguageModel {
636 pub provider: String,
637 pub model: String,
638}
639
640impl SerializedThread {
641 pub const VERSION: &'static str = "0.2.0";
642
643 pub fn from_json(json: &[u8]) -> Result<Self> {
644 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
645 match saved_thread_json.get("version") {
646 Some(serde_json::Value::String(version)) => match version.as_str() {
647 SerializedThreadV0_1_0::VERSION => {
648 let saved_thread =
649 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
650 Ok(saved_thread.upgrade())
651 }
652 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
653 saved_thread_json,
654 )?),
655 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
656 },
657 None => {
658 let saved_thread =
659 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
660 Ok(saved_thread.upgrade())
661 }
662 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
663 }
664 }
665}
666
667#[derive(Serialize, Deserialize, Debug)]
668pub struct SerializedThreadV0_1_0(
669 // The structure did not change, so we are reusing the latest SerializedThread.
670 // When making the next version, make sure this points to SerializedThreadV0_2_0
671 SerializedThread,
672);
673
674impl SerializedThreadV0_1_0 {
675 pub const VERSION: &'static str = "0.1.0";
676
677 pub fn upgrade(self) -> SerializedThread {
678 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
679
680 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
681
682 for message in self.0.messages {
683 if message.role == Role::User && !message.tool_results.is_empty() {
684 if let Some(last_message) = messages.last_mut() {
685 debug_assert!(last_message.role == Role::Assistant);
686
687 last_message.tool_results = message.tool_results;
688 continue;
689 }
690 }
691
692 messages.push(message);
693 }
694
695 SerializedThread {
696 messages,
697 version: SerializedThread::VERSION.to_string(),
698 ..self.0
699 }
700 }
701}
702
703#[derive(Debug, Serialize, Deserialize, PartialEq)]
704pub struct SerializedMessage {
705 pub id: MessageId,
706 pub role: Role,
707 #[serde(default)]
708 pub segments: Vec<SerializedMessageSegment>,
709 #[serde(default)]
710 pub tool_uses: Vec<SerializedToolUse>,
711 #[serde(default)]
712 pub tool_results: Vec<SerializedToolResult>,
713 #[serde(default)]
714 pub context: String,
715 #[serde(default)]
716 pub creases: Vec<SerializedCrease>,
717 #[serde(default)]
718 pub is_hidden: bool,
719}
720
721#[derive(Debug, Serialize, Deserialize, PartialEq)]
722#[serde(tag = "type")]
723pub enum SerializedMessageSegment {
724 #[serde(rename = "text")]
725 Text {
726 text: String,
727 },
728 #[serde(rename = "thinking")]
729 Thinking {
730 text: String,
731 #[serde(skip_serializing_if = "Option::is_none")]
732 signature: Option<String>,
733 },
734 RedactedThinking {
735 data: String,
736 },
737}
738
739#[derive(Debug, Serialize, Deserialize, PartialEq)]
740pub struct SerializedToolUse {
741 pub id: LanguageModelToolUseId,
742 pub name: SharedString,
743 pub input: serde_json::Value,
744}
745
746#[derive(Debug, Serialize, Deserialize, PartialEq)]
747pub struct SerializedToolResult {
748 pub tool_use_id: LanguageModelToolUseId,
749 pub is_error: bool,
750 pub content: LanguageModelToolResultContent,
751 pub output: Option<serde_json::Value>,
752}
753
754#[derive(Serialize, Deserialize)]
755struct LegacySerializedThread {
756 pub summary: SharedString,
757 pub updated_at: DateTime<Utc>,
758 pub messages: Vec<LegacySerializedMessage>,
759 #[serde(default)]
760 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
761}
762
763impl LegacySerializedThread {
764 pub fn upgrade(self) -> SerializedThread {
765 SerializedThread {
766 version: SerializedThread::VERSION.to_string(),
767 summary: self.summary,
768 updated_at: self.updated_at,
769 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
770 initial_project_snapshot: self.initial_project_snapshot,
771 cumulative_token_usage: TokenUsage::default(),
772 request_token_usage: Vec::new(),
773 detailed_summary_state: DetailedSummaryState::default(),
774 exceeded_window_error: None,
775 model: None,
776 completion_mode: None,
777 tool_use_limit_reached: false,
778 profile: None,
779 }
780 }
781}
782
783#[derive(Debug, Serialize, Deserialize)]
784struct LegacySerializedMessage {
785 pub id: MessageId,
786 pub role: Role,
787 pub text: String,
788 #[serde(default)]
789 pub tool_uses: Vec<SerializedToolUse>,
790 #[serde(default)]
791 pub tool_results: Vec<SerializedToolResult>,
792}
793
794impl LegacySerializedMessage {
795 fn upgrade(self) -> SerializedMessage {
796 SerializedMessage {
797 id: self.id,
798 role: self.role,
799 segments: vec![SerializedMessageSegment::Text { text: self.text }],
800 tool_uses: self.tool_uses,
801 tool_results: self.tool_results,
802 context: String::new(),
803 creases: Vec::new(),
804 is_hidden: false,
805 }
806 }
807}
808
809#[derive(Debug, Serialize, Deserialize, PartialEq)]
810pub struct SerializedCrease {
811 pub start: usize,
812 pub end: usize,
813 pub icon_path: SharedString,
814 pub label: SharedString,
815}
816
817struct GlobalThreadsDatabase(
818 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
819);
820
821impl Global for GlobalThreadsDatabase {}
822
823pub(crate) struct ThreadsDatabase {
824 executor: BackgroundExecutor,
825 connection: Arc<Mutex<Connection>>,
826}
827
828impl ThreadsDatabase {
829 fn connection(&self) -> Arc<Mutex<Connection>> {
830 self.connection.clone()
831 }
832
833 const COMPRESSION_LEVEL: i32 = 3;
834}
835
836impl Bind for ThreadId {
837 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
838 self.to_string().bind(statement, start_index)
839 }
840}
841
842impl Column for ThreadId {
843 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
844 let (id_str, next_index) = String::column(statement, start_index)?;
845 Ok((ThreadId::from(id_str.as_str()), next_index))
846 }
847}
848
849impl ThreadsDatabase {
850 fn global_future(
851 cx: &mut App,
852 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
853 GlobalThreadsDatabase::global(cx).0.clone()
854 }
855
856 fn init(cx: &mut App) {
857 let executor = cx.background_executor().clone();
858 let database_future = executor
859 .spawn({
860 let executor = executor.clone();
861 let threads_dir = paths::data_dir().join("threads");
862 async move { ThreadsDatabase::new(threads_dir, executor) }
863 })
864 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
865 .boxed()
866 .shared();
867
868 cx.set_global(GlobalThreadsDatabase(database_future));
869 }
870
871 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
872 std::fs::create_dir_all(&threads_dir)?;
873
874 let sqlite_path = threads_dir.join("threads.db");
875 let mdb_path = threads_dir.join("threads-db.1.mdb");
876
877 let needs_migration_from_heed = mdb_path.exists();
878
879 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
880
881 connection.exec(indoc! {"
882 CREATE TABLE IF NOT EXISTS threads (
883 id TEXT PRIMARY KEY,
884 summary TEXT NOT NULL,
885 updated_at TEXT NOT NULL,
886 data_type TEXT NOT NULL,
887 data BLOB NOT NULL
888 )
889 "})?()
890 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
891
892 let db = Self {
893 executor: executor.clone(),
894 connection: Arc::new(Mutex::new(connection)),
895 };
896
897 if needs_migration_from_heed {
898 let db_connection = db.connection();
899 let executor_clone = executor.clone();
900 executor
901 .spawn(async move {
902 log::info!("Starting threads.db migration");
903 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
904 std::fs::remove_dir_all(mdb_path)?;
905 log::info!("threads.db migrated to sqlite");
906 Ok::<(), anyhow::Error>(())
907 })
908 .detach();
909 }
910
911 Ok(db)
912 }
913
914 // Remove this migration after 2025-09-01
915 fn migrate_from_heed(
916 mdb_path: &Path,
917 connection: Arc<Mutex<Connection>>,
918 _executor: BackgroundExecutor,
919 ) -> Result<()> {
920 use heed::types::SerdeBincode;
921 struct SerializedThreadHeed(SerializedThread);
922
923 impl heed::BytesEncode<'_> for SerializedThreadHeed {
924 type EItem = SerializedThreadHeed;
925
926 fn bytes_encode(
927 item: &Self::EItem,
928 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
929 serde_json::to_vec(&item.0)
930 .map(std::borrow::Cow::Owned)
931 .map_err(Into::into)
932 }
933 }
934
935 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
936 type DItem = SerializedThreadHeed;
937
938 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
939 SerializedThread::from_json(bytes)
940 .map(SerializedThreadHeed)
941 .map_err(Into::into)
942 }
943 }
944
945 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
946
947 let env = unsafe {
948 heed::EnvOpenOptions::new()
949 .map_size(ONE_GB_IN_BYTES)
950 .max_dbs(1)
951 .open(mdb_path)?
952 };
953
954 let txn = env.write_txn()?;
955 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
956 .open_database(&txn, Some("threads"))?
957 .ok_or_else(|| anyhow!("threads database not found"))?;
958
959 for result in threads.iter(&txn)? {
960 let (thread_id, thread_heed) = result?;
961 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
962 }
963
964 Ok(())
965 }
966
967 fn save_thread_sync(
968 connection: &Arc<Mutex<Connection>>,
969 id: ThreadId,
970 thread: SerializedThread,
971 ) -> Result<()> {
972 let json_data = serde_json::to_string(&thread)?;
973 let summary = thread.summary.to_string();
974 let updated_at = thread.updated_at.to_rfc3339();
975
976 let connection = connection.lock().unwrap();
977
978 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
979 let data_type = DataType::Zstd;
980 let data = compressed;
981
982 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
983 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
984 "})?;
985
986 insert((id, summary, updated_at, data_type, data))?;
987
988 Ok(())
989 }
990
991 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
992 let connection = self.connection.clone();
993
994 self.executor.spawn(async move {
995 let connection = connection.lock().unwrap();
996 let mut select =
997 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
998 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
999 "})?;
1000
1001 let rows = select(())?;
1002 let mut threads = Vec::new();
1003
1004 for (id, summary, updated_at) in rows {
1005 threads.push(SerializedThreadMetadata {
1006 id,
1007 summary: summary.into(),
1008 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1009 });
1010 }
1011
1012 Ok(threads)
1013 })
1014 }
1015
1016 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1017 let connection = self.connection.clone();
1018
1019 self.executor.spawn(async move {
1020 let connection = connection.lock().unwrap();
1021 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1022 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1023 "})?;
1024
1025 let rows = select(id)?;
1026 if let Some((data_type, data)) = rows.into_iter().next() {
1027 let json_data = match data_type {
1028 DataType::Zstd => {
1029 let decompressed = zstd::decode_all(&data[..])?;
1030 String::from_utf8(decompressed)?
1031 }
1032 DataType::Json => String::from_utf8(data)?,
1033 };
1034
1035 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1036 Ok(Some(thread))
1037 } else {
1038 Ok(None)
1039 }
1040 })
1041 }
1042
1043 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1044 let connection = self.connection.clone();
1045
1046 self.executor
1047 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1048 }
1049
1050 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1051 let connection = self.connection.clone();
1052
1053 self.executor.spawn(async move {
1054 let connection = connection.lock().unwrap();
1055
1056 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1057 DELETE FROM threads WHERE id = ?
1058 "})?;
1059
1060 delete(id)?;
1061
1062 Ok(())
1063 })
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use crate::thread::{DetailedSummaryState, MessageId};
1071 use chrono::Utc;
1072 use language_model::{Role, TokenUsage};
1073 use pretty_assertions::assert_eq;
1074
1075 #[test]
1076 fn test_legacy_serialized_thread_upgrade() {
1077 let updated_at = Utc::now();
1078 let legacy_thread = LegacySerializedThread {
1079 summary: "Test conversation".into(),
1080 updated_at,
1081 messages: vec![LegacySerializedMessage {
1082 id: MessageId(1),
1083 role: Role::User,
1084 text: "Hello, world!".to_string(),
1085 tool_uses: vec![],
1086 tool_results: vec![],
1087 }],
1088 initial_project_snapshot: None,
1089 };
1090
1091 let upgraded = legacy_thread.upgrade();
1092
1093 assert_eq!(
1094 upgraded,
1095 SerializedThread {
1096 summary: "Test conversation".into(),
1097 updated_at,
1098 messages: vec![SerializedMessage {
1099 id: MessageId(1),
1100 role: Role::User,
1101 segments: vec![SerializedMessageSegment::Text {
1102 text: "Hello, world!".to_string()
1103 }],
1104 tool_uses: vec![],
1105 tool_results: vec![],
1106 context: "".to_string(),
1107 creases: vec![],
1108 is_hidden: false
1109 }],
1110 version: SerializedThread::VERSION.to_string(),
1111 initial_project_snapshot: None,
1112 cumulative_token_usage: TokenUsage::default(),
1113 request_token_usage: vec![],
1114 detailed_summary_state: DetailedSummaryState::default(),
1115 exceeded_window_error: None,
1116 model: None,
1117 completion_mode: None,
1118 tool_use_limit_reached: false,
1119 profile: None
1120 }
1121 )
1122 }
1123
1124 #[test]
1125 fn test_serialized_threadv0_1_0_upgrade() {
1126 let updated_at = Utc::now();
1127 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1128 summary: "Test conversation".into(),
1129 updated_at,
1130 messages: vec![
1131 SerializedMessage {
1132 id: MessageId(1),
1133 role: Role::User,
1134 segments: vec![SerializedMessageSegment::Text {
1135 text: "Use tool_1".to_string(),
1136 }],
1137 tool_uses: vec![],
1138 tool_results: vec![],
1139 context: "".to_string(),
1140 creases: vec![],
1141 is_hidden: false,
1142 },
1143 SerializedMessage {
1144 id: MessageId(2),
1145 role: Role::Assistant,
1146 segments: vec![SerializedMessageSegment::Text {
1147 text: "I want to use a tool".to_string(),
1148 }],
1149 tool_uses: vec![SerializedToolUse {
1150 id: "abc".into(),
1151 name: "tool_1".into(),
1152 input: serde_json::Value::Null,
1153 }],
1154 tool_results: vec![],
1155 context: "".to_string(),
1156 creases: vec![],
1157 is_hidden: false,
1158 },
1159 SerializedMessage {
1160 id: MessageId(1),
1161 role: Role::User,
1162 segments: vec![SerializedMessageSegment::Text {
1163 text: "Here is the tool result".to_string(),
1164 }],
1165 tool_uses: vec![],
1166 tool_results: vec![SerializedToolResult {
1167 tool_use_id: "abc".into(),
1168 is_error: false,
1169 content: LanguageModelToolResultContent::Text("abcdef".into()),
1170 output: Some(serde_json::Value::Null),
1171 }],
1172 context: "".to_string(),
1173 creases: vec![],
1174 is_hidden: false,
1175 },
1176 ],
1177 version: SerializedThreadV0_1_0::VERSION.to_string(),
1178 initial_project_snapshot: None,
1179 cumulative_token_usage: TokenUsage::default(),
1180 request_token_usage: vec![],
1181 detailed_summary_state: DetailedSummaryState::default(),
1182 exceeded_window_error: None,
1183 model: None,
1184 completion_mode: None,
1185 tool_use_limit_reached: false,
1186 profile: None,
1187 });
1188 let upgraded = thread_v0_1_0.upgrade();
1189
1190 assert_eq!(
1191 upgraded,
1192 SerializedThread {
1193 summary: "Test conversation".into(),
1194 updated_at,
1195 messages: vec![
1196 SerializedMessage {
1197 id: MessageId(1),
1198 role: Role::User,
1199 segments: vec![SerializedMessageSegment::Text {
1200 text: "Use tool_1".to_string()
1201 }],
1202 tool_uses: vec![],
1203 tool_results: vec![],
1204 context: "".to_string(),
1205 creases: vec![],
1206 is_hidden: false
1207 },
1208 SerializedMessage {
1209 id: MessageId(2),
1210 role: Role::Assistant,
1211 segments: vec![SerializedMessageSegment::Text {
1212 text: "I want to use a tool".to_string(),
1213 }],
1214 tool_uses: vec![SerializedToolUse {
1215 id: "abc".into(),
1216 name: "tool_1".into(),
1217 input: serde_json::Value::Null,
1218 }],
1219 tool_results: vec![SerializedToolResult {
1220 tool_use_id: "abc".into(),
1221 is_error: false,
1222 content: LanguageModelToolResultContent::Text("abcdef".into()),
1223 output: Some(serde_json::Value::Null),
1224 }],
1225 context: "".to_string(),
1226 creases: vec![],
1227 is_hidden: false,
1228 },
1229 ],
1230 version: SerializedThread::VERSION.to_string(),
1231 initial_project_snapshot: None,
1232 cumulative_token_usage: TokenUsage::default(),
1233 request_token_usage: vec![],
1234 detailed_summary_state: DetailedSummaryState::default(),
1235 exceeded_window_error: None,
1236 model: None,
1237 completion_mode: None,
1238 tool_use_limit_reached: false,
1239 profile: None
1240 }
1241 )
1242 }
1243}