1use crate::{
2 MessageId, ThreadId,
3 context_server_tool::ContextServerTool,
4 thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread},
5};
6use agent_settings::{AgentProfileId, CompletionMode};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ToolId, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::HashMap;
11use context_server::ContextServerId;
12use futures::{
13 FutureExt as _, StreamExt as _,
14 channel::{mpsc, oneshot},
15 future::{self, BoxFuture, Shared},
16};
17use gpui::{
18 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
19 Subscription, Task, Window, prelude::*,
20};
21use indoc::indoc;
22use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use sqlez::{
31 bindable::{Bind, Column},
32 connection::Connection,
33 statement::Statement,
34};
35use std::{
36 cell::{Ref, RefCell},
37 path::{Path, PathBuf},
38 rc::Rc,
39 sync::{Arc, Mutex},
40};
41use util::ResultExt as _;
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum DataType {
45 #[serde(rename = "json")]
46 Json,
47 #[serde(rename = "zstd")]
48 Zstd,
49}
50
51impl Bind for DataType {
52 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
53 let value = match self {
54 DataType::Json => "json",
55 DataType::Zstd => "zstd",
56 };
57 value.bind(statement, start_index)
58 }
59}
60
61impl Column for DataType {
62 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
63 let (value, next_index) = String::column(statement, start_index)?;
64 let data_type = match value.as_str() {
65 "json" => DataType::Json,
66 "zstd" => DataType::Zstd,
67 _ => anyhow::bail!("Unknown data type: {}", value),
68 };
69 Ok((data_type, next_index))
70 }
71}
72
73const RULES_FILE_NAMES: [&'static str; 9] = [
74 ".rules",
75 ".cursorrules",
76 ".windsurfrules",
77 ".clinerules",
78 ".github/copilot-instructions.md",
79 "CLAUDE.md",
80 "AGENT.md",
81 "AGENTS.md",
82 "GEMINI.md",
83];
84
85pub fn init(cx: &mut App) {
86 ThreadsDatabase::init(cx);
87}
88
89/// A system prompt shared by all threads created by this ThreadStore
90#[derive(Clone, Default)]
91pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
92
93impl SharedProjectContext {
94 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
95 self.0.borrow()
96 }
97}
98
99pub type TextThreadStore = assistant_context::ContextStore;
100
101pub struct ThreadStore {
102 project: Entity<Project>,
103 tools: Entity<ToolWorkingSet>,
104 prompt_builder: Arc<PromptBuilder>,
105 prompt_store: Option<Entity<PromptStore>>,
106 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
107 threads: Vec<SerializedThreadMetadata>,
108 project_context: SharedProjectContext,
109 reload_system_prompt_tx: mpsc::Sender<()>,
110 _reload_system_prompt_task: Task<()>,
111 _subscriptions: Vec<Subscription>,
112}
113
114pub struct RulesLoadingError {
115 pub message: SharedString,
116}
117
118impl EventEmitter<RulesLoadingError> for ThreadStore {}
119
120impl ThreadStore {
121 pub fn load(
122 project: Entity<Project>,
123 tools: Entity<ToolWorkingSet>,
124 prompt_store: Option<Entity<PromptStore>>,
125 prompt_builder: Arc<PromptBuilder>,
126 cx: &mut App,
127 ) -> Task<Result<Entity<Self>>> {
128 cx.spawn(async move |cx| {
129 let (thread_store, ready_rx) = cx.update(|cx| {
130 let mut option_ready_rx = None;
131 let thread_store = cx.new(|cx| {
132 let (thread_store, ready_rx) =
133 Self::new(project, tools, prompt_builder, prompt_store, cx);
134 option_ready_rx = Some(ready_rx);
135 thread_store
136 });
137 (thread_store, option_ready_rx.take().unwrap())
138 })?;
139 ready_rx.await?;
140 Ok(thread_store)
141 })
142 }
143
144 fn new(
145 project: Entity<Project>,
146 tools: Entity<ToolWorkingSet>,
147 prompt_builder: Arc<PromptBuilder>,
148 prompt_store: Option<Entity<PromptStore>>,
149 cx: &mut Context<Self>,
150 ) -> (Self, oneshot::Receiver<()>) {
151 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
152
153 if let Some(prompt_store) = prompt_store.as_ref() {
154 subscriptions.push(cx.subscribe(
155 prompt_store,
156 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
157 this.enqueue_system_prompt_reload();
158 },
159 ))
160 }
161
162 // This channel and task prevent concurrent and redundant loading of the system prompt.
163 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
164 let (ready_tx, ready_rx) = oneshot::channel();
165 let mut ready_tx = Some(ready_tx);
166 let reload_system_prompt_task = cx.spawn({
167 let prompt_store = prompt_store.clone();
168 async move |thread_store, cx| {
169 loop {
170 let Some(reload_task) = thread_store
171 .update(cx, |thread_store, cx| {
172 thread_store.reload_system_prompt(prompt_store.clone(), cx)
173 })
174 .ok()
175 else {
176 return;
177 };
178 reload_task.await;
179 if let Some(ready_tx) = ready_tx.take() {
180 ready_tx.send(()).ok();
181 }
182 reload_system_prompt_rx.next().await;
183 }
184 }
185 });
186
187 let this = Self {
188 project,
189 tools,
190 prompt_builder,
191 prompt_store,
192 context_server_tool_ids: HashMap::default(),
193 threads: Vec::new(),
194 project_context: SharedProjectContext::default(),
195 reload_system_prompt_tx,
196 _reload_system_prompt_task: reload_system_prompt_task,
197 _subscriptions: subscriptions,
198 };
199 this.register_context_server_handlers(cx);
200 this.reload(cx).detach_and_log_err(cx);
201 (this, ready_rx)
202 }
203
204 fn handle_project_event(
205 &mut self,
206 _project: Entity<Project>,
207 event: &project::Event,
208 _cx: &mut Context<Self>,
209 ) {
210 match event {
211 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
212 self.enqueue_system_prompt_reload();
213 }
214 project::Event::WorktreeUpdatedEntries(_, items) => {
215 if items.iter().any(|(path, _, _)| {
216 RULES_FILE_NAMES
217 .iter()
218 .any(|name| path.as_ref() == Path::new(name))
219 }) {
220 self.enqueue_system_prompt_reload();
221 }
222 }
223 _ => {}
224 }
225 }
226
227 fn enqueue_system_prompt_reload(&mut self) {
228 self.reload_system_prompt_tx.try_send(()).ok();
229 }
230
231 // Note that this should only be called from `reload_system_prompt_task`.
232 fn reload_system_prompt(
233 &self,
234 prompt_store: Option<Entity<PromptStore>>,
235 cx: &mut Context<Self>,
236 ) -> Task<()> {
237 let worktrees = self
238 .project
239 .read(cx)
240 .visible_worktrees(cx)
241 .collect::<Vec<_>>();
242 let worktree_tasks = worktrees
243 .into_iter()
244 .map(|worktree| {
245 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
246 })
247 .collect::<Vec<_>>();
248 let default_user_rules_task = match prompt_store {
249 None => Task::ready(vec![]),
250 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
251 let prompts = prompt_store.default_prompt_metadata();
252 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
253 let contents = prompt_store.load(prompt_metadata.id, cx);
254 async move { (contents.await, prompt_metadata) }
255 });
256 cx.background_spawn(future::join_all(load_tasks))
257 }),
258 };
259
260 cx.spawn(async move |this, cx| {
261 let (worktrees, default_user_rules) =
262 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
263
264 let worktrees = worktrees
265 .into_iter()
266 .map(|(worktree, rules_error)| {
267 if let Some(rules_error) = rules_error {
268 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
269 }
270 worktree
271 })
272 .collect::<Vec<_>>();
273
274 let default_user_rules = default_user_rules
275 .into_iter()
276 .flat_map(|(contents, prompt_metadata)| match contents {
277 Ok(contents) => Some(UserRulesContext {
278 uuid: match prompt_metadata.id {
279 PromptId::User { uuid } => uuid,
280 PromptId::EditWorkflow => return None,
281 },
282 title: prompt_metadata.title.map(|title| title.to_string()),
283 contents,
284 }),
285 Err(err) => {
286 this.update(cx, |_, cx| {
287 cx.emit(RulesLoadingError {
288 message: format!("{err:?}").into(),
289 });
290 })
291 .ok();
292 None
293 }
294 })
295 .collect::<Vec<_>>();
296
297 this.update(cx, |this, _cx| {
298 *this.project_context.0.borrow_mut() =
299 Some(ProjectContext::new(worktrees, default_user_rules));
300 })
301 .ok();
302 })
303 }
304
305 fn load_worktree_info_for_system_prompt(
306 worktree: Entity<Worktree>,
307 project: Entity<Project>,
308 cx: &mut App,
309 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
310 let tree = worktree.read(cx);
311 let root_name = tree.root_name().into();
312 let abs_path = tree.abs_path();
313
314 let mut context = WorktreeContext {
315 root_name,
316 abs_path,
317 rules_file: None,
318 };
319
320 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
321 let Some(rules_task) = rules_task else {
322 return Task::ready((context, None));
323 };
324
325 cx.spawn(async move |_| {
326 let (rules_file, rules_file_error) = match rules_task.await {
327 Ok(rules_file) => (Some(rules_file), None),
328 Err(err) => (
329 None,
330 Some(RulesLoadingError {
331 message: format!("{err}").into(),
332 }),
333 ),
334 };
335 context.rules_file = rules_file;
336 (context, rules_file_error)
337 })
338 }
339
340 fn load_worktree_rules_file(
341 worktree: Entity<Worktree>,
342 project: Entity<Project>,
343 cx: &mut App,
344 ) -> Option<Task<Result<RulesFileContext>>> {
345 let worktree = worktree.read(cx);
346 let worktree_id = worktree.id();
347 let selected_rules_file = RULES_FILE_NAMES
348 .into_iter()
349 .filter_map(|name| {
350 worktree
351 .entry_for_path(name)
352 .filter(|entry| entry.is_file())
353 .map(|entry| entry.path.clone())
354 })
355 .next();
356
357 // Note that Cline supports `.clinerules` being a directory, but that is not currently
358 // supported. This doesn't seem to occur often in GitHub repositories.
359 selected_rules_file.map(|path_in_worktree| {
360 let project_path = ProjectPath {
361 worktree_id,
362 path: path_in_worktree.clone(),
363 };
364 let buffer_task =
365 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
366 let rope_task = cx.spawn(async move |cx| {
367 buffer_task.await?.read_with(cx, |buffer, cx| {
368 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
369 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
370 })?
371 });
372 // Build a string from the rope on a background thread.
373 cx.background_spawn(async move {
374 let (project_entry_id, rope) = rope_task.await?;
375 anyhow::Ok(RulesFileContext {
376 path_in_worktree,
377 text: rope.to_string().trim().to_string(),
378 project_entry_id: project_entry_id.to_usize(),
379 })
380 })
381 })
382 }
383
384 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
385 &self.prompt_store
386 }
387
388 pub fn tools(&self) -> Entity<ToolWorkingSet> {
389 self.tools.clone()
390 }
391
392 /// Returns the number of threads.
393 pub fn thread_count(&self) -> usize {
394 self.threads.len()
395 }
396
397 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
398 // ordering is from "ORDER BY" in `list_threads`
399 self.threads.iter()
400 }
401
402 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
403 todo!()
404 // cx.new(|cx| {
405 // Thread::new(
406 // self.project.clone(),
407 // self.tools.clone(),
408 // self.prompt_builder.clone(),
409 // self.project_context.clone(),
410 // cx,
411 // )
412 // })
413 }
414
415 pub fn open_thread(
416 &self,
417 id: &ThreadId,
418 window: &mut Window,
419 cx: &mut Context<Self>,
420 ) -> Task<Result<Entity<Thread>>> {
421 let id = id.clone();
422 let database_future = ThreadsDatabase::global_future(cx);
423 let this = cx.weak_entity();
424 window.spawn(cx, async move |cx| {
425 let database = database_future.await.map_err(|err| anyhow!(err))?;
426 let thread = database
427 .try_find_thread(id.clone())
428 .await?
429 .with_context(|| format!("no thread found with ID: {id:?}"))?;
430
431 todo!();
432 // let thread = this.update_in(cx, |this, window, cx| {
433 // cx.new(|cx| {
434 // Thread::deserialize(
435 // id.clone(),
436 // thread,
437 // this.project.clone(),
438 // this.tools.clone(),
439 // this.prompt_builder.clone(),
440 // this.project_context.clone(),
441 // Some(window),
442 // cx,
443 // )
444 // })
445 // })?;
446 // Ok(thread)
447 })
448 }
449
450 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
451 todo!()
452 // let (metadata, serialized_thread) =
453 // thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
454
455 // let database_future = ThreadsDatabase::global_future(cx);
456 // cx.spawn(async move |this, cx| {
457 // let serialized_thread = serialized_thread.await?;
458 // let database = database_future.await.map_err(|err| anyhow!(err))?;
459 // database.save_thread(metadata, serialized_thread).await?;
460
461 // this.update(cx, |this, cx| this.reload(cx))?.await
462 // })
463 }
464
465 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
466 todo!()
467 // let id = id.clone();
468 // let database_future = ThreadsDatabase::global_future(cx);
469 // cx.spawn(async move |this, cx| {
470 // let database = database_future.await.map_err(|err| anyhow!(err))?;
471 // database.delete_thread(id.clone()).await?;
472
473 // this.update(cx, |this, cx| {
474 // this.threads.retain(|thread| thread.id != id);
475 // cx.notify();
476 // })
477 // })
478 }
479
480 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
481 let database_future = ThreadsDatabase::global_future(cx);
482 cx.spawn(async move |this, cx| {
483 let threads = database_future
484 .await
485 .map_err(|err| anyhow!(err))?
486 .list_threads()
487 .await?;
488
489 this.update(cx, |this, cx| {
490 this.threads = threads;
491 cx.notify();
492 })
493 })
494 }
495
496 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
497 let context_server_store = self.project.read(cx).context_server_store();
498 cx.subscribe(&context_server_store, Self::handle_context_server_event)
499 .detach();
500
501 // Check for any servers that were already running before the handler was registered
502 for server in context_server_store.read(cx).running_servers() {
503 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
504 }
505 }
506
507 fn handle_context_server_event(
508 &mut self,
509 context_server_store: Entity<ContextServerStore>,
510 event: &project::context_server_store::Event,
511 cx: &mut Context<Self>,
512 ) {
513 let tool_working_set = self.tools.clone();
514 match event {
515 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
516 match status {
517 ContextServerStatus::Starting => {}
518 ContextServerStatus::Running => {
519 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
520 }
521 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
522 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
523 tool_working_set.update(cx, |tool_working_set, _| {
524 tool_working_set.remove(&tool_ids);
525 });
526 }
527 }
528 }
529 }
530 }
531 }
532
533 fn load_context_server_tools(
534 &self,
535 server_id: ContextServerId,
536 context_server_store: Entity<ContextServerStore>,
537 cx: &mut Context<Self>,
538 ) {
539 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
540 return;
541 };
542 let tool_working_set = self.tools.clone();
543 cx.spawn(async move |this, cx| {
544 let Some(protocol) = server.client() else {
545 return;
546 };
547
548 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
549 if let Some(response) = protocol
550 .request::<context_server::types::requests::ListTools>(())
551 .await
552 .log_err()
553 {
554 let tool_ids = tool_working_set
555 .update(cx, |tool_working_set, _| {
556 response
557 .tools
558 .into_iter()
559 .map(|tool| {
560 log::info!("registering context server tool: {:?}", tool.name);
561 tool_working_set.insert(Arc::new(ContextServerTool::new(
562 context_server_store.clone(),
563 server.id(),
564 tool,
565 )))
566 })
567 .collect::<Vec<_>>()
568 })
569 .log_err();
570
571 if let Some(tool_ids) = tool_ids {
572 this.update(cx, |this, _| {
573 this.context_server_tool_ids.insert(server_id, tool_ids);
574 })
575 .log_err();
576 }
577 }
578 }
579 })
580 .detach();
581 }
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct SerializedThreadMetadata {
586 pub id: ThreadId,
587 pub summary: SharedString,
588 pub updated_at: DateTime<Utc>,
589}
590
591#[derive(Serialize, Deserialize, Debug, PartialEq)]
592pub struct SerializedThread {
593 pub version: String,
594 pub summary: SharedString,
595 pub updated_at: DateTime<Utc>,
596 pub messages: Vec<SerializedMessage>,
597 #[serde(default)]
598 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
599 #[serde(default)]
600 pub cumulative_token_usage: TokenUsage,
601 #[serde(default)]
602 pub request_token_usage: Vec<TokenUsage>,
603 #[serde(default)]
604 pub detailed_summary_state: DetailedSummaryState,
605 #[serde(default)]
606 pub exceeded_window_error: Option<ExceededWindowError>,
607 #[serde(default)]
608 pub model: Option<SerializedLanguageModel>,
609 #[serde(default)]
610 pub completion_mode: Option<CompletionMode>,
611 #[serde(default)]
612 pub tool_use_limit_reached: bool,
613 #[serde(default)]
614 pub profile: Option<AgentProfileId>,
615}
616
617#[derive(Serialize, Deserialize, Debug, PartialEq)]
618pub struct SerializedLanguageModel {
619 pub provider: String,
620 pub model: String,
621}
622
623impl SerializedThread {
624 pub const VERSION: &'static str = "0.2.0";
625
626 pub fn from_json(json: &[u8]) -> Result<Self> {
627 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
628 match saved_thread_json.get("version") {
629 Some(serde_json::Value::String(version)) => match version.as_str() {
630 SerializedThreadV0_1_0::VERSION => {
631 let saved_thread =
632 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
633 Ok(saved_thread.upgrade())
634 }
635 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
636 saved_thread_json,
637 )?),
638 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
639 },
640 None => {
641 let saved_thread =
642 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
643 Ok(saved_thread.upgrade())
644 }
645 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
646 }
647 }
648}
649
650#[derive(Serialize, Deserialize, Debug)]
651pub struct SerializedThreadV0_1_0(
652 // The structure did not change, so we are reusing the latest SerializedThread.
653 // When making the next version, make sure this points to SerializedThreadV0_2_0
654 SerializedThread,
655);
656
657impl SerializedThreadV0_1_0 {
658 pub const VERSION: &'static str = "0.1.0";
659
660 pub fn upgrade(self) -> SerializedThread {
661 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
662
663 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
664
665 for message in self.0.messages {
666 if message.role == Role::User && !message.tool_results.is_empty() {
667 if let Some(last_message) = messages.last_mut() {
668 debug_assert!(last_message.role == Role::Assistant);
669
670 last_message.tool_results = message.tool_results;
671 continue;
672 }
673 }
674
675 messages.push(message);
676 }
677
678 SerializedThread {
679 messages,
680 version: SerializedThread::VERSION.to_string(),
681 ..self.0
682 }
683 }
684}
685
686#[derive(Debug, Serialize, Deserialize, PartialEq)]
687pub struct SerializedMessage {
688 pub id: MessageId,
689 pub role: Role,
690 #[serde(default)]
691 pub segments: Vec<SerializedMessageSegment>,
692 #[serde(default)]
693 pub tool_uses: Vec<SerializedToolUse>,
694 #[serde(default)]
695 pub tool_results: Vec<SerializedToolResult>,
696 #[serde(default)]
697 pub context: String,
698 #[serde(default)]
699 pub creases: Vec<SerializedCrease>,
700 #[serde(default)]
701 pub is_hidden: bool,
702}
703
704#[derive(Debug, Serialize, Deserialize, PartialEq)]
705#[serde(tag = "type")]
706pub enum SerializedMessageSegment {
707 #[serde(rename = "text")]
708 Text {
709 text: String,
710 },
711 #[serde(rename = "thinking")]
712 Thinking {
713 text: String,
714 #[serde(skip_serializing_if = "Option::is_none")]
715 signature: Option<String>,
716 },
717 RedactedThinking {
718 data: String,
719 },
720}
721
722#[derive(Debug, Serialize, Deserialize, PartialEq)]
723pub struct SerializedToolUse {
724 pub id: LanguageModelToolUseId,
725 pub name: SharedString,
726 pub input: serde_json::Value,
727}
728
729#[derive(Debug, Serialize, Deserialize, PartialEq)]
730pub struct SerializedToolResult {
731 pub tool_use_id: LanguageModelToolUseId,
732 pub is_error: bool,
733 pub content: LanguageModelToolResultContent,
734 pub output: Option<serde_json::Value>,
735}
736
737#[derive(Serialize, Deserialize)]
738struct LegacySerializedThread {
739 pub summary: SharedString,
740 pub updated_at: DateTime<Utc>,
741 pub messages: Vec<LegacySerializedMessage>,
742 #[serde(default)]
743 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
744}
745
746impl LegacySerializedThread {
747 pub fn upgrade(self) -> SerializedThread {
748 SerializedThread {
749 version: SerializedThread::VERSION.to_string(),
750 summary: self.summary,
751 updated_at: self.updated_at,
752 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
753 initial_project_snapshot: self.initial_project_snapshot,
754 cumulative_token_usage: TokenUsage::default(),
755 request_token_usage: Vec::new(),
756 detailed_summary_state: DetailedSummaryState::default(),
757 exceeded_window_error: None,
758 model: None,
759 completion_mode: None,
760 tool_use_limit_reached: false,
761 profile: None,
762 }
763 }
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767struct LegacySerializedMessage {
768 pub id: MessageId,
769 pub role: Role,
770 pub text: String,
771 #[serde(default)]
772 pub tool_uses: Vec<SerializedToolUse>,
773 #[serde(default)]
774 pub tool_results: Vec<SerializedToolResult>,
775}
776
777impl LegacySerializedMessage {
778 fn upgrade(self) -> SerializedMessage {
779 SerializedMessage {
780 id: self.id,
781 role: self.role,
782 segments: vec![SerializedMessageSegment::Text { text: self.text }],
783 tool_uses: self.tool_uses,
784 tool_results: self.tool_results,
785 context: String::new(),
786 creases: Vec::new(),
787 is_hidden: false,
788 }
789 }
790}
791
792#[derive(Debug, Serialize, Deserialize, PartialEq)]
793pub struct SerializedCrease {
794 pub start: usize,
795 pub end: usize,
796 pub icon_path: SharedString,
797 pub label: SharedString,
798}
799
800struct GlobalThreadsDatabase(
801 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
802);
803
804impl Global for GlobalThreadsDatabase {}
805
806pub(crate) struct ThreadsDatabase {
807 executor: BackgroundExecutor,
808 connection: Arc<Mutex<Connection>>,
809}
810
811impl ThreadsDatabase {
812 fn connection(&self) -> Arc<Mutex<Connection>> {
813 self.connection.clone()
814 }
815
816 const COMPRESSION_LEVEL: i32 = 3;
817}
818
819impl Bind for ThreadId {
820 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
821 self.to_string().bind(statement, start_index)
822 }
823}
824
825impl Column for ThreadId {
826 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
827 let (id_str, next_index) = String::column(statement, start_index)?;
828 Ok((ThreadId::from(id_str.as_str()), next_index))
829 }
830}
831
832impl ThreadsDatabase {
833 fn global_future(
834 cx: &mut App,
835 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
836 GlobalThreadsDatabase::global(cx).0.clone()
837 }
838
839 fn init(cx: &mut App) {
840 let executor = cx.background_executor().clone();
841 let database_future = executor
842 .spawn({
843 let executor = executor.clone();
844 let threads_dir = paths::data_dir().join("threads");
845 async move { ThreadsDatabase::new(threads_dir, executor) }
846 })
847 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
848 .boxed()
849 .shared();
850
851 cx.set_global(GlobalThreadsDatabase(database_future));
852 }
853
854 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
855 std::fs::create_dir_all(&threads_dir)?;
856
857 let sqlite_path = threads_dir.join("threads.db");
858 let mdb_path = threads_dir.join("threads-db.1.mdb");
859
860 let needs_migration_from_heed = mdb_path.exists();
861
862 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
863
864 connection.exec(indoc! {"
865 CREATE TABLE IF NOT EXISTS threads (
866 id TEXT PRIMARY KEY,
867 summary TEXT NOT NULL,
868 updated_at TEXT NOT NULL,
869 data_type TEXT NOT NULL,
870 data BLOB NOT NULL
871 )
872 "})?()
873 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
874
875 let db = Self {
876 executor: executor.clone(),
877 connection: Arc::new(Mutex::new(connection)),
878 };
879
880 if needs_migration_from_heed {
881 let db_connection = db.connection();
882 let executor_clone = executor.clone();
883 executor
884 .spawn(async move {
885 log::info!("Starting threads.db migration");
886 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
887 std::fs::remove_dir_all(mdb_path)?;
888 log::info!("threads.db migrated to sqlite");
889 Ok::<(), anyhow::Error>(())
890 })
891 .detach();
892 }
893
894 Ok(db)
895 }
896
897 // Remove this migration after 2025-09-01
898 fn migrate_from_heed(
899 mdb_path: &Path,
900 connection: Arc<Mutex<Connection>>,
901 _executor: BackgroundExecutor,
902 ) -> Result<()> {
903 use heed::types::SerdeBincode;
904 struct SerializedThreadHeed(SerializedThread);
905
906 impl heed::BytesEncode<'_> for SerializedThreadHeed {
907 type EItem = SerializedThreadHeed;
908
909 fn bytes_encode(
910 item: &Self::EItem,
911 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
912 serde_json::to_vec(&item.0)
913 .map(std::borrow::Cow::Owned)
914 .map_err(Into::into)
915 }
916 }
917
918 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
919 type DItem = SerializedThreadHeed;
920
921 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
922 SerializedThread::from_json(bytes)
923 .map(SerializedThreadHeed)
924 .map_err(Into::into)
925 }
926 }
927
928 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
929
930 let env = unsafe {
931 heed::EnvOpenOptions::new()
932 .map_size(ONE_GB_IN_BYTES)
933 .max_dbs(1)
934 .open(mdb_path)?
935 };
936
937 let txn = env.write_txn()?;
938 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
939 .open_database(&txn, Some("threads"))?
940 .ok_or_else(|| anyhow!("threads database not found"))?;
941
942 for result in threads.iter(&txn)? {
943 let (thread_id, thread_heed) = result?;
944 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
945 }
946
947 Ok(())
948 }
949
950 fn save_thread_sync(
951 connection: &Arc<Mutex<Connection>>,
952 id: ThreadId,
953 thread: SerializedThread,
954 ) -> Result<()> {
955 let json_data = serde_json::to_string(&thread)?;
956 let summary = thread.summary.to_string();
957 let updated_at = thread.updated_at.to_rfc3339();
958
959 let connection = connection.lock().unwrap();
960
961 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
962 let data_type = DataType::Zstd;
963 let data = compressed;
964
965 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
966 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
967 "})?;
968
969 insert((id, summary, updated_at, data_type, data))?;
970
971 Ok(())
972 }
973
974 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
975 let connection = self.connection.clone();
976
977 self.executor.spawn(async move {
978 let connection = connection.lock().unwrap();
979 let mut select =
980 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
981 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
982 "})?;
983
984 let rows = select(())?;
985 let mut threads = Vec::new();
986
987 for (id, summary, updated_at) in rows {
988 threads.push(SerializedThreadMetadata {
989 id,
990 summary: summary.into(),
991 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
992 });
993 }
994
995 Ok(threads)
996 })
997 }
998
999 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1000 let connection = self.connection.clone();
1001
1002 self.executor.spawn(async move {
1003 let connection = connection.lock().unwrap();
1004 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1005 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1006 "})?;
1007
1008 let rows = select(id)?;
1009 if let Some((data_type, data)) = rows.into_iter().next() {
1010 let json_data = match data_type {
1011 DataType::Zstd => {
1012 let decompressed = zstd::decode_all(&data[..])?;
1013 String::from_utf8(decompressed)?
1014 }
1015 DataType::Json => String::from_utf8(data)?,
1016 };
1017
1018 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1019 Ok(Some(thread))
1020 } else {
1021 Ok(None)
1022 }
1023 })
1024 }
1025
1026 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1027 let connection = self.connection.clone();
1028
1029 self.executor
1030 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1031 }
1032
1033 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1034 let connection = self.connection.clone();
1035
1036 self.executor.spawn(async move {
1037 let connection = connection.lock().unwrap();
1038
1039 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1040 DELETE FROM threads WHERE id = ?
1041 "})?;
1042
1043 delete(id)?;
1044
1045 Ok(())
1046 })
1047 }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052 use super::*;
1053 use crate::{MessageId, thread::DetailedSummaryState};
1054 use chrono::Utc;
1055 use language_model::{Role, TokenUsage};
1056 use pretty_assertions::assert_eq;
1057
1058 #[test]
1059 fn test_legacy_serialized_thread_upgrade() {
1060 let updated_at = Utc::now();
1061 let legacy_thread = LegacySerializedThread {
1062 summary: "Test conversation".into(),
1063 updated_at,
1064 messages: vec![LegacySerializedMessage {
1065 id: MessageId(1),
1066 role: Role::User,
1067 text: "Hello, world!".to_string(),
1068 tool_uses: vec![],
1069 tool_results: vec![],
1070 }],
1071 initial_project_snapshot: None,
1072 };
1073
1074 let upgraded = legacy_thread.upgrade();
1075
1076 assert_eq!(
1077 upgraded,
1078 SerializedThread {
1079 summary: "Test conversation".into(),
1080 updated_at,
1081 messages: vec![SerializedMessage {
1082 id: MessageId(1),
1083 role: Role::User,
1084 segments: vec![SerializedMessageSegment::Text {
1085 text: "Hello, world!".to_string()
1086 }],
1087 tool_uses: vec![],
1088 tool_results: vec![],
1089 context: "".to_string(),
1090 creases: vec![],
1091 is_hidden: false
1092 }],
1093 version: SerializedThread::VERSION.to_string(),
1094 initial_project_snapshot: None,
1095 cumulative_token_usage: TokenUsage::default(),
1096 request_token_usage: vec![],
1097 detailed_summary_state: DetailedSummaryState::default(),
1098 exceeded_window_error: None,
1099 model: None,
1100 completion_mode: None,
1101 tool_use_limit_reached: false,
1102 profile: None
1103 }
1104 )
1105 }
1106
1107 #[test]
1108 fn test_serialized_threadv0_1_0_upgrade() {
1109 let updated_at = Utc::now();
1110 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1111 summary: "Test conversation".into(),
1112 updated_at,
1113 messages: vec![
1114 SerializedMessage {
1115 id: MessageId(1),
1116 role: Role::User,
1117 segments: vec![SerializedMessageSegment::Text {
1118 text: "Use tool_1".to_string(),
1119 }],
1120 tool_uses: vec![],
1121 tool_results: vec![],
1122 context: "".to_string(),
1123 creases: vec![],
1124 is_hidden: false,
1125 },
1126 SerializedMessage {
1127 id: MessageId(2),
1128 role: Role::Assistant,
1129 segments: vec![SerializedMessageSegment::Text {
1130 text: "I want to use a tool".to_string(),
1131 }],
1132 tool_uses: vec![SerializedToolUse {
1133 id: "abc".into(),
1134 name: "tool_1".into(),
1135 input: serde_json::Value::Null,
1136 }],
1137 tool_results: vec![],
1138 context: "".to_string(),
1139 creases: vec![],
1140 is_hidden: false,
1141 },
1142 SerializedMessage {
1143 id: MessageId(1),
1144 role: Role::User,
1145 segments: vec![SerializedMessageSegment::Text {
1146 text: "Here is the tool result".to_string(),
1147 }],
1148 tool_uses: vec![],
1149 tool_results: vec![SerializedToolResult {
1150 tool_use_id: "abc".into(),
1151 is_error: false,
1152 content: LanguageModelToolResultContent::Text("abcdef".into()),
1153 output: Some(serde_json::Value::Null),
1154 }],
1155 context: "".to_string(),
1156 creases: vec![],
1157 is_hidden: false,
1158 },
1159 ],
1160 version: SerializedThreadV0_1_0::VERSION.to_string(),
1161 initial_project_snapshot: None,
1162 cumulative_token_usage: TokenUsage::default(),
1163 request_token_usage: vec![],
1164 detailed_summary_state: DetailedSummaryState::default(),
1165 exceeded_window_error: None,
1166 model: None,
1167 completion_mode: None,
1168 tool_use_limit_reached: false,
1169 profile: None,
1170 });
1171 let upgraded = thread_v0_1_0.upgrade();
1172
1173 assert_eq!(
1174 upgraded,
1175 SerializedThread {
1176 summary: "Test conversation".into(),
1177 updated_at,
1178 messages: vec![
1179 SerializedMessage {
1180 id: MessageId(1),
1181 role: Role::User,
1182 segments: vec![SerializedMessageSegment::Text {
1183 text: "Use tool_1".to_string()
1184 }],
1185 tool_uses: vec![],
1186 tool_results: vec![],
1187 context: "".to_string(),
1188 creases: vec![],
1189 is_hidden: false
1190 },
1191 SerializedMessage {
1192 id: MessageId(2),
1193 role: Role::Assistant,
1194 segments: vec![SerializedMessageSegment::Text {
1195 text: "I want to use a tool".to_string(),
1196 }],
1197 tool_uses: vec![SerializedToolUse {
1198 id: "abc".into(),
1199 name: "tool_1".into(),
1200 input: serde_json::Value::Null,
1201 }],
1202 tool_results: vec![SerializedToolResult {
1203 tool_use_id: "abc".into(),
1204 is_error: false,
1205 content: LanguageModelToolResultContent::Text("abcdef".into()),
1206 output: Some(serde_json::Value::Null),
1207 }],
1208 context: "".to_string(),
1209 creases: vec![],
1210 is_hidden: false,
1211 },
1212 ],
1213 version: SerializedThread::VERSION.to_string(),
1214 initial_project_snapshot: None,
1215 cumulative_token_usage: TokenUsage::default(),
1216 request_token_usage: vec![],
1217 detailed_summary_state: DetailedSummaryState::default(),
1218 exceeded_window_error: None,
1219 model: None,
1220 completion_mode: None,
1221 tool_use_limit_reached: false,
1222 profile: None
1223 }
1224 )
1225 }
1226}