1use std::cell::{Ref, RefCell};
2use std::path::{Path, PathBuf};
3use std::rc::Rc;
4use std::sync::{Arc, Mutex};
5
6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::HashMap;
11use context_server::ContextServerId;
12use futures::channel::{mpsc, oneshot};
13use futures::future::{self, BoxFuture, Shared};
14use futures::{FutureExt as _, StreamExt as _};
15use gpui::{
16 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
17 Subscription, Task, prelude::*,
18};
19
20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
21use project::context_server_store::{ContextServerStatus, ContextServerStore};
22use project::{Project, ProjectItem, ProjectPath, Worktree};
23use prompt_store::{
24 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
25 UserRulesContext, WorktreeContext,
26};
27use serde::{Deserialize, Serialize};
28use settings::{Settings as _, SettingsStore};
29use ui::Window;
30use util::ResultExt as _;
31
32use crate::context_server_tool::ContextServerTool;
33use crate::thread::{
34 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
35};
36use indoc::indoc;
37use sqlez::{
38 bindable::{Bind, Column},
39 connection::Connection,
40 statement::Statement,
41};
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum DataType {
45 #[serde(rename = "json")]
46 Json,
47 #[serde(rename = "zstd")]
48 Zstd,
49}
50
51impl Bind for DataType {
52 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
53 let value = match self {
54 DataType::Json => "json",
55 DataType::Zstd => "zstd",
56 };
57 value.bind(statement, start_index)
58 }
59}
60
61impl Column for DataType {
62 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
63 let (value, next_index) = String::column(statement, start_index)?;
64 let data_type = match value.as_str() {
65 "json" => DataType::Json,
66 "zstd" => DataType::Zstd,
67 _ => anyhow::bail!("Unknown data type: {}", value),
68 };
69 Ok((data_type, next_index))
70 }
71}
72
73const RULES_FILE_NAMES: [&'static str; 6] = [
74 ".rules",
75 ".cursorrules",
76 ".windsurfrules",
77 ".clinerules",
78 ".github/copilot-instructions.md",
79 "CLAUDE.md",
80];
81
82pub fn init(cx: &mut App) {
83 ThreadsDatabase::init(cx);
84}
85
86/// A system prompt shared by all threads created by this ThreadStore
87#[derive(Clone, Default)]
88pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
89
90impl SharedProjectContext {
91 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
92 self.0.borrow()
93 }
94}
95
96pub type TextThreadStore = assistant_context_editor::ContextStore;
97
98pub struct ThreadStore {
99 project: Entity<Project>,
100 tools: Entity<ToolWorkingSet>,
101 prompt_builder: Arc<PromptBuilder>,
102 prompt_store: Option<Entity<PromptStore>>,
103 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
104 threads: Vec<SerializedThreadMetadata>,
105 project_context: SharedProjectContext,
106 reload_system_prompt_tx: mpsc::Sender<()>,
107 _reload_system_prompt_task: Task<()>,
108 _subscriptions: Vec<Subscription>,
109}
110
111pub struct RulesLoadingError {
112 pub message: SharedString,
113}
114
115impl EventEmitter<RulesLoadingError> for ThreadStore {}
116
117impl ThreadStore {
118 pub fn load(
119 project: Entity<Project>,
120 tools: Entity<ToolWorkingSet>,
121 prompt_store: Option<Entity<PromptStore>>,
122 prompt_builder: Arc<PromptBuilder>,
123 cx: &mut App,
124 ) -> Task<Result<Entity<Self>>> {
125 cx.spawn(async move |cx| {
126 let (thread_store, ready_rx) = cx.update(|cx| {
127 let mut option_ready_rx = None;
128 let thread_store = cx.new(|cx| {
129 let (thread_store, ready_rx) =
130 Self::new(project, tools, prompt_builder, prompt_store, cx);
131 option_ready_rx = Some(ready_rx);
132 thread_store
133 });
134 (thread_store, option_ready_rx.take().unwrap())
135 })?;
136 ready_rx.await?;
137 Ok(thread_store)
138 })
139 }
140
141 fn new(
142 project: Entity<Project>,
143 tools: Entity<ToolWorkingSet>,
144 prompt_builder: Arc<PromptBuilder>,
145 prompt_store: Option<Entity<PromptStore>>,
146 cx: &mut Context<Self>,
147 ) -> (Self, oneshot::Receiver<()>) {
148 let mut subscriptions = vec![
149 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
150 this.load_default_profile(cx);
151 }),
152 cx.subscribe(&project, Self::handle_project_event),
153 ];
154
155 if let Some(prompt_store) = prompt_store.as_ref() {
156 subscriptions.push(cx.subscribe(
157 prompt_store,
158 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
159 this.enqueue_system_prompt_reload();
160 },
161 ))
162 }
163
164 // This channel and task prevent concurrent and redundant loading of the system prompt.
165 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
166 let (ready_tx, ready_rx) = oneshot::channel();
167 let mut ready_tx = Some(ready_tx);
168 let reload_system_prompt_task = cx.spawn({
169 let prompt_store = prompt_store.clone();
170 async move |thread_store, cx| {
171 loop {
172 let Some(reload_task) = thread_store
173 .update(cx, |thread_store, cx| {
174 thread_store.reload_system_prompt(prompt_store.clone(), cx)
175 })
176 .ok()
177 else {
178 return;
179 };
180 reload_task.await;
181 if let Some(ready_tx) = ready_tx.take() {
182 ready_tx.send(()).ok();
183 }
184 reload_system_prompt_rx.next().await;
185 }
186 }
187 });
188
189 let this = Self {
190 project,
191 tools,
192 prompt_builder,
193 prompt_store,
194 context_server_tool_ids: HashMap::default(),
195 threads: Vec::new(),
196 project_context: SharedProjectContext::default(),
197 reload_system_prompt_tx,
198 _reload_system_prompt_task: reload_system_prompt_task,
199 _subscriptions: subscriptions,
200 };
201 this.load_default_profile(cx);
202 this.register_context_server_handlers(cx);
203 this.reload(cx).detach_and_log_err(cx);
204 (this, ready_rx)
205 }
206
207 fn handle_project_event(
208 &mut self,
209 _project: Entity<Project>,
210 event: &project::Event,
211 _cx: &mut Context<Self>,
212 ) {
213 match event {
214 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
215 self.enqueue_system_prompt_reload();
216 }
217 project::Event::WorktreeUpdatedEntries(_, items) => {
218 if items.iter().any(|(path, _, _)| {
219 RULES_FILE_NAMES
220 .iter()
221 .any(|name| path.as_ref() == Path::new(name))
222 }) {
223 self.enqueue_system_prompt_reload();
224 }
225 }
226 _ => {}
227 }
228 }
229
230 fn enqueue_system_prompt_reload(&mut self) {
231 self.reload_system_prompt_tx.try_send(()).ok();
232 }
233
234 // Note that this should only be called from `reload_system_prompt_task`.
235 fn reload_system_prompt(
236 &self,
237 prompt_store: Option<Entity<PromptStore>>,
238 cx: &mut Context<Self>,
239 ) -> Task<()> {
240 let worktrees = self
241 .project
242 .read(cx)
243 .visible_worktrees(cx)
244 .collect::<Vec<_>>();
245 let worktree_tasks = worktrees
246 .into_iter()
247 .map(|worktree| {
248 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
249 })
250 .collect::<Vec<_>>();
251 let default_user_rules_task = match prompt_store {
252 None => Task::ready(vec![]),
253 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
254 let prompts = prompt_store.default_prompt_metadata();
255 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
256 let contents = prompt_store.load(prompt_metadata.id, cx);
257 async move { (contents.await, prompt_metadata) }
258 });
259 cx.background_spawn(future::join_all(load_tasks))
260 }),
261 };
262
263 cx.spawn(async move |this, cx| {
264 let (worktrees, default_user_rules) =
265 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
266
267 let worktrees = worktrees
268 .into_iter()
269 .map(|(worktree, rules_error)| {
270 if let Some(rules_error) = rules_error {
271 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
272 }
273 worktree
274 })
275 .collect::<Vec<_>>();
276
277 let default_user_rules = default_user_rules
278 .into_iter()
279 .flat_map(|(contents, prompt_metadata)| match contents {
280 Ok(contents) => Some(UserRulesContext {
281 uuid: match prompt_metadata.id {
282 PromptId::User { uuid } => uuid,
283 PromptId::EditWorkflow => return None,
284 },
285 title: prompt_metadata.title.map(|title| title.to_string()),
286 contents,
287 }),
288 Err(err) => {
289 this.update(cx, |_, cx| {
290 cx.emit(RulesLoadingError {
291 message: format!("{err:?}").into(),
292 });
293 })
294 .ok();
295 None
296 }
297 })
298 .collect::<Vec<_>>();
299
300 this.update(cx, |this, _cx| {
301 *this.project_context.0.borrow_mut() =
302 Some(ProjectContext::new(worktrees, default_user_rules));
303 })
304 .ok();
305 })
306 }
307
308 fn load_worktree_info_for_system_prompt(
309 worktree: Entity<Worktree>,
310 project: Entity<Project>,
311 cx: &mut App,
312 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
313 let root_name = worktree.read(cx).root_name().into();
314
315 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
316 let Some(rules_task) = rules_task else {
317 return Task::ready((
318 WorktreeContext {
319 root_name,
320 rules_file: None,
321 },
322 None,
323 ));
324 };
325
326 cx.spawn(async move |_| {
327 let (rules_file, rules_file_error) = match rules_task.await {
328 Ok(rules_file) => (Some(rules_file), None),
329 Err(err) => (
330 None,
331 Some(RulesLoadingError {
332 message: format!("{err}").into(),
333 }),
334 ),
335 };
336 let worktree_info = WorktreeContext {
337 root_name,
338 rules_file,
339 };
340 (worktree_info, rules_file_error)
341 })
342 }
343
344 fn load_worktree_rules_file(
345 worktree: Entity<Worktree>,
346 project: Entity<Project>,
347 cx: &mut App,
348 ) -> Option<Task<Result<RulesFileContext>>> {
349 let worktree_ref = worktree.read(cx);
350 let worktree_id = worktree_ref.id();
351 let selected_rules_file = RULES_FILE_NAMES
352 .into_iter()
353 .filter_map(|name| {
354 worktree_ref
355 .entry_for_path(name)
356 .filter(|entry| entry.is_file())
357 .map(|entry| entry.path.clone())
358 })
359 .next();
360
361 // Note that Cline supports `.clinerules` being a directory, but that is not currently
362 // supported. This doesn't seem to occur often in GitHub repositories.
363 selected_rules_file.map(|path_in_worktree| {
364 let project_path = ProjectPath {
365 worktree_id,
366 path: path_in_worktree.clone(),
367 };
368 let buffer_task =
369 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
370 let rope_task = cx.spawn(async move |cx| {
371 buffer_task.await?.read_with(cx, |buffer, cx| {
372 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
373 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
374 })?
375 });
376 // Build a string from the rope on a background thread.
377 cx.background_spawn(async move {
378 let (project_entry_id, rope) = rope_task.await?;
379 anyhow::Ok(RulesFileContext {
380 path_in_worktree,
381 text: rope.to_string().trim().to_string(),
382 project_entry_id: project_entry_id.to_usize(),
383 })
384 })
385 })
386 }
387
388 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
389 &self.prompt_store
390 }
391
392 pub fn tools(&self) -> Entity<ToolWorkingSet> {
393 self.tools.clone()
394 }
395
396 /// Returns the number of threads.
397 pub fn thread_count(&self) -> usize {
398 self.threads.len()
399 }
400
401 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
402 self.threads.iter()
403 }
404
405 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
406 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
407 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
408 threads
409 }
410
411 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
412 cx.new(|cx| {
413 Thread::new(
414 self.project.clone(),
415 self.tools.clone(),
416 self.prompt_builder.clone(),
417 self.project_context.clone(),
418 cx,
419 )
420 })
421 }
422
423 pub fn create_thread_from_serialized(
424 &mut self,
425 serialized: SerializedThread,
426 cx: &mut Context<Self>,
427 ) -> Entity<Thread> {
428 cx.new(|cx| {
429 Thread::deserialize(
430 ThreadId::new(),
431 serialized,
432 self.project.clone(),
433 self.tools.clone(),
434 self.prompt_builder.clone(),
435 self.project_context.clone(),
436 None,
437 cx,
438 )
439 })
440 }
441
442 pub fn open_thread(
443 &self,
444 id: &ThreadId,
445 window: &mut Window,
446 cx: &mut Context<Self>,
447 ) -> Task<Result<Entity<Thread>>> {
448 let id = id.clone();
449 let database_future = ThreadsDatabase::global_future(cx);
450 let this = cx.weak_entity();
451 window.spawn(cx, async move |cx| {
452 let database = database_future.await.map_err(|err| anyhow!(err))?;
453 let thread = database
454 .try_find_thread(id.clone())
455 .await?
456 .with_context(|| format!("no thread found with ID: {id:?}"))?;
457
458 let thread = this.update_in(cx, |this, window, cx| {
459 cx.new(|cx| {
460 Thread::deserialize(
461 id.clone(),
462 thread,
463 this.project.clone(),
464 this.tools.clone(),
465 this.prompt_builder.clone(),
466 this.project_context.clone(),
467 Some(window),
468 cx,
469 )
470 })
471 })?;
472
473 Ok(thread)
474 })
475 }
476
477 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
478 let (metadata, serialized_thread) =
479 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
480
481 let database_future = ThreadsDatabase::global_future(cx);
482 cx.spawn(async move |this, cx| {
483 let serialized_thread = serialized_thread.await?;
484 let database = database_future.await.map_err(|err| anyhow!(err))?;
485 database.save_thread(metadata, serialized_thread).await?;
486
487 this.update(cx, |this, cx| this.reload(cx))?.await
488 })
489 }
490
491 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
492 let id = id.clone();
493 let database_future = ThreadsDatabase::global_future(cx);
494 cx.spawn(async move |this, cx| {
495 let database = database_future.await.map_err(|err| anyhow!(err))?;
496 database.delete_thread(id.clone()).await?;
497
498 this.update(cx, |this, cx| {
499 this.threads.retain(|thread| thread.id != id);
500 cx.notify();
501 })
502 })
503 }
504
505 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
506 let database_future = ThreadsDatabase::global_future(cx);
507 cx.spawn(async move |this, cx| {
508 let threads = database_future
509 .await
510 .map_err(|err| anyhow!(err))?
511 .list_threads()
512 .await?;
513
514 this.update(cx, |this, cx| {
515 this.threads = threads;
516 cx.notify();
517 })
518 })
519 }
520
521 fn load_default_profile(&self, cx: &mut Context<Self>) {
522 let assistant_settings = AgentSettings::get_global(cx);
523
524 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
525 }
526
527 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
528 let assistant_settings = AgentSettings::get_global(cx);
529
530 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
531 self.load_profile(profile.clone(), cx);
532 }
533 }
534
535 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
536 self.tools.update(cx, |tools, cx| {
537 tools.disable_all_tools(cx);
538 tools.enable(
539 ToolSource::Native,
540 &profile
541 .tools
542 .into_iter()
543 .filter_map(|(tool, enabled)| enabled.then(|| tool))
544 .collect::<Vec<_>>(),
545 cx,
546 );
547 });
548
549 if profile.enable_all_context_servers {
550 for context_server_id in self
551 .project
552 .read(cx)
553 .context_server_store()
554 .read(cx)
555 .all_server_ids()
556 {
557 self.tools.update(cx, |tools, cx| {
558 tools.enable_source(
559 ToolSource::ContextServer {
560 id: context_server_id.0.into(),
561 },
562 cx,
563 );
564 });
565 }
566 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
567 for (context_server_id, preset) in profile.context_servers {
568 self.tools.update(cx, |tools, cx| {
569 tools.disable(
570 ToolSource::ContextServer {
571 id: context_server_id.into(),
572 },
573 &preset
574 .tools
575 .into_iter()
576 .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
577 .collect::<Vec<_>>(),
578 cx,
579 )
580 })
581 }
582 } else {
583 for (context_server_id, preset) in profile.context_servers {
584 self.tools.update(cx, |tools, cx| {
585 tools.enable(
586 ToolSource::ContextServer {
587 id: context_server_id.into(),
588 },
589 &preset
590 .tools
591 .into_iter()
592 .filter_map(|(tool, enabled)| enabled.then(|| tool))
593 .collect::<Vec<_>>(),
594 cx,
595 )
596 })
597 }
598 }
599 }
600
601 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
602 cx.subscribe(
603 &self.project.read(cx).context_server_store(),
604 Self::handle_context_server_event,
605 )
606 .detach();
607 }
608
609 fn handle_context_server_event(
610 &mut self,
611 context_server_store: Entity<ContextServerStore>,
612 event: &project::context_server_store::Event,
613 cx: &mut Context<Self>,
614 ) {
615 let tool_working_set = self.tools.clone();
616 match event {
617 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
618 match status {
619 ContextServerStatus::Running => {
620 if let Some(server) =
621 context_server_store.read(cx).get_running_server(server_id)
622 {
623 let context_server_manager = context_server_store.clone();
624 cx.spawn({
625 let server = server.clone();
626 let server_id = server_id.clone();
627 async move |this, cx| {
628 let Some(protocol) = server.client() else {
629 return;
630 };
631
632 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
633 if let Some(tools) = protocol.list_tools().await.log_err() {
634 let tool_ids = tool_working_set
635 .update(cx, |tool_working_set, _| {
636 tools
637 .tools
638 .into_iter()
639 .map(|tool| {
640 log::info!(
641 "registering context server tool: {:?}",
642 tool.name
643 );
644 tool_working_set.insert(Arc::new(
645 ContextServerTool::new(
646 context_server_manager.clone(),
647 server.id(),
648 tool,
649 ),
650 ))
651 })
652 .collect::<Vec<_>>()
653 })
654 .log_err();
655
656 if let Some(tool_ids) = tool_ids {
657 this.update(cx, |this, cx| {
658 this.context_server_tool_ids
659 .insert(server_id, tool_ids);
660 this.load_default_profile(cx);
661 })
662 .log_err();
663 }
664 }
665 }
666 }
667 })
668 .detach();
669 }
670 }
671 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
672 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
673 tool_working_set.update(cx, |tool_working_set, _| {
674 tool_working_set.remove(&tool_ids);
675 });
676 self.load_default_profile(cx);
677 }
678 }
679 _ => {}
680 }
681 }
682 }
683 }
684}
685
686#[derive(Debug, Clone, Serialize, Deserialize)]
687pub struct SerializedThreadMetadata {
688 pub id: ThreadId,
689 pub summary: SharedString,
690 pub updated_at: DateTime<Utc>,
691}
692
693#[derive(Serialize, Deserialize, Debug)]
694pub struct SerializedThread {
695 pub version: String,
696 pub summary: SharedString,
697 pub updated_at: DateTime<Utc>,
698 pub messages: Vec<SerializedMessage>,
699 #[serde(default)]
700 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
701 #[serde(default)]
702 pub cumulative_token_usage: TokenUsage,
703 #[serde(default)]
704 pub request_token_usage: Vec<TokenUsage>,
705 #[serde(default)]
706 pub detailed_summary_state: DetailedSummaryState,
707 #[serde(default)]
708 pub exceeded_window_error: Option<ExceededWindowError>,
709 #[serde(default)]
710 pub model: Option<SerializedLanguageModel>,
711 #[serde(default)]
712 pub completion_mode: Option<CompletionMode>,
713 #[serde(default)]
714 pub tool_use_limit_reached: bool,
715}
716
717#[derive(Serialize, Deserialize, Debug)]
718pub struct SerializedLanguageModel {
719 pub provider: String,
720 pub model: String,
721}
722
723impl SerializedThread {
724 pub const VERSION: &'static str = "0.2.0";
725
726 pub fn from_json(json: &[u8]) -> Result<Self> {
727 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
728 match saved_thread_json.get("version") {
729 Some(serde_json::Value::String(version)) => match version.as_str() {
730 SerializedThreadV0_1_0::VERSION => {
731 let saved_thread =
732 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
733 Ok(saved_thread.upgrade())
734 }
735 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
736 saved_thread_json,
737 )?),
738 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
739 },
740 None => {
741 let saved_thread =
742 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
743 Ok(saved_thread.upgrade())
744 }
745 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
746 }
747 }
748}
749
750#[derive(Serialize, Deserialize, Debug)]
751pub struct SerializedThreadV0_1_0(
752 // The structure did not change, so we are reusing the latest SerializedThread.
753 // When making the next version, make sure this points to SerializedThreadV0_2_0
754 SerializedThread,
755);
756
757impl SerializedThreadV0_1_0 {
758 pub const VERSION: &'static str = "0.1.0";
759
760 pub fn upgrade(self) -> SerializedThread {
761 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
762
763 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
764
765 for message in self.0.messages {
766 if message.role == Role::User && !message.tool_results.is_empty() {
767 if let Some(last_message) = messages.last_mut() {
768 debug_assert!(last_message.role == Role::Assistant);
769
770 last_message.tool_results = message.tool_results;
771 continue;
772 }
773 }
774
775 messages.push(message);
776 }
777
778 SerializedThread { messages, ..self.0 }
779 }
780}
781
782#[derive(Debug, Serialize, Deserialize)]
783pub struct SerializedMessage {
784 pub id: MessageId,
785 pub role: Role,
786 #[serde(default)]
787 pub segments: Vec<SerializedMessageSegment>,
788 #[serde(default)]
789 pub tool_uses: Vec<SerializedToolUse>,
790 #[serde(default)]
791 pub tool_results: Vec<SerializedToolResult>,
792 #[serde(default)]
793 pub context: String,
794 #[serde(default)]
795 pub creases: Vec<SerializedCrease>,
796 #[serde(default)]
797 pub is_hidden: bool,
798}
799
800#[derive(Debug, Serialize, Deserialize)]
801#[serde(tag = "type")]
802pub enum SerializedMessageSegment {
803 #[serde(rename = "text")]
804 Text {
805 text: String,
806 },
807 #[serde(rename = "thinking")]
808 Thinking {
809 text: String,
810 #[serde(skip_serializing_if = "Option::is_none")]
811 signature: Option<String>,
812 },
813 RedactedThinking {
814 data: Vec<u8>,
815 },
816}
817
818#[derive(Debug, Serialize, Deserialize)]
819pub struct SerializedToolUse {
820 pub id: LanguageModelToolUseId,
821 pub name: SharedString,
822 pub input: serde_json::Value,
823}
824
825#[derive(Debug, Serialize, Deserialize)]
826pub struct SerializedToolResult {
827 pub tool_use_id: LanguageModelToolUseId,
828 pub is_error: bool,
829 pub content: LanguageModelToolResultContent,
830 pub output: Option<serde_json::Value>,
831}
832
833#[derive(Serialize, Deserialize)]
834struct LegacySerializedThread {
835 pub summary: SharedString,
836 pub updated_at: DateTime<Utc>,
837 pub messages: Vec<LegacySerializedMessage>,
838 #[serde(default)]
839 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
840}
841
842impl LegacySerializedThread {
843 pub fn upgrade(self) -> SerializedThread {
844 SerializedThread {
845 version: SerializedThread::VERSION.to_string(),
846 summary: self.summary,
847 updated_at: self.updated_at,
848 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
849 initial_project_snapshot: self.initial_project_snapshot,
850 cumulative_token_usage: TokenUsage::default(),
851 request_token_usage: Vec::new(),
852 detailed_summary_state: DetailedSummaryState::default(),
853 exceeded_window_error: None,
854 model: None,
855 completion_mode: None,
856 tool_use_limit_reached: false,
857 }
858 }
859}
860
861#[derive(Debug, Serialize, Deserialize)]
862struct LegacySerializedMessage {
863 pub id: MessageId,
864 pub role: Role,
865 pub text: String,
866 #[serde(default)]
867 pub tool_uses: Vec<SerializedToolUse>,
868 #[serde(default)]
869 pub tool_results: Vec<SerializedToolResult>,
870}
871
872impl LegacySerializedMessage {
873 fn upgrade(self) -> SerializedMessage {
874 SerializedMessage {
875 id: self.id,
876 role: self.role,
877 segments: vec![SerializedMessageSegment::Text { text: self.text }],
878 tool_uses: self.tool_uses,
879 tool_results: self.tool_results,
880 context: String::new(),
881 creases: Vec::new(),
882 is_hidden: false,
883 }
884 }
885}
886
887#[derive(Debug, Serialize, Deserialize)]
888pub struct SerializedCrease {
889 pub start: usize,
890 pub end: usize,
891 pub icon_path: SharedString,
892 pub label: SharedString,
893}
894
895struct GlobalThreadsDatabase(
896 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
897);
898
899impl Global for GlobalThreadsDatabase {}
900
901pub(crate) struct ThreadsDatabase {
902 executor: BackgroundExecutor,
903 connection: Arc<Mutex<Connection>>,
904}
905
906impl ThreadsDatabase {
907 fn connection(&self) -> Arc<Mutex<Connection>> {
908 self.connection.clone()
909 }
910
911 const COMPRESSION_LEVEL: i32 = 3;
912}
913
914impl Bind for ThreadId {
915 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
916 self.to_string().bind(statement, start_index)
917 }
918}
919
920impl Column for ThreadId {
921 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
922 let (id_str, next_index) = String::column(statement, start_index)?;
923 Ok((ThreadId::from(id_str.as_str()), next_index))
924 }
925}
926
927impl ThreadsDatabase {
928 fn global_future(
929 cx: &mut App,
930 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
931 GlobalThreadsDatabase::global(cx).0.clone()
932 }
933
934 fn init(cx: &mut App) {
935 let executor = cx.background_executor().clone();
936 let database_future = executor
937 .spawn({
938 let executor = executor.clone();
939 let threads_dir = paths::data_dir().join("threads");
940 async move { ThreadsDatabase::new(threads_dir, executor) }
941 })
942 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
943 .boxed()
944 .shared();
945
946 cx.set_global(GlobalThreadsDatabase(database_future));
947 }
948
949 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
950 std::fs::create_dir_all(&threads_dir)?;
951
952 let sqlite_path = threads_dir.join("threads.db");
953 let mdb_path = threads_dir.join("threads-db.1.mdb");
954
955 let needs_migration_from_heed = mdb_path.exists();
956
957 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
958
959 connection.exec(indoc! {"
960 CREATE TABLE IF NOT EXISTS threads (
961 id TEXT PRIMARY KEY,
962 summary TEXT NOT NULL,
963 updated_at TEXT NOT NULL,
964 data_type TEXT NOT NULL,
965 data BLOB NOT NULL
966 )
967 "})?()
968 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
969
970 let db = Self {
971 executor: executor.clone(),
972 connection: Arc::new(Mutex::new(connection)),
973 };
974
975 if needs_migration_from_heed {
976 let db_connection = db.connection();
977 let executor_clone = executor.clone();
978 executor
979 .spawn(async move {
980 log::info!("Starting threads.db migration");
981 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
982 std::fs::remove_dir_all(mdb_path)?;
983 log::info!("threads.db migrated to sqlite");
984 Ok::<(), anyhow::Error>(())
985 })
986 .detach();
987 }
988
989 Ok(db)
990 }
991
992 // Remove this migration after 2025-09-01
993 fn migrate_from_heed(
994 mdb_path: &Path,
995 connection: Arc<Mutex<Connection>>,
996 _executor: BackgroundExecutor,
997 ) -> Result<()> {
998 use heed::types::SerdeBincode;
999 struct SerializedThreadHeed(SerializedThread);
1000
1001 impl heed::BytesEncode<'_> for SerializedThreadHeed {
1002 type EItem = SerializedThreadHeed;
1003
1004 fn bytes_encode(
1005 item: &Self::EItem,
1006 ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1007 serde_json::to_vec(&item.0)
1008 .map(std::borrow::Cow::Owned)
1009 .map_err(Into::into)
1010 }
1011 }
1012
1013 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1014 type DItem = SerializedThreadHeed;
1015
1016 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1017 SerializedThread::from_json(bytes)
1018 .map(SerializedThreadHeed)
1019 .map_err(Into::into)
1020 }
1021 }
1022
1023 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1024
1025 let env = unsafe {
1026 heed::EnvOpenOptions::new()
1027 .map_size(ONE_GB_IN_BYTES)
1028 .max_dbs(1)
1029 .open(mdb_path)?
1030 };
1031
1032 let txn = env.write_txn()?;
1033 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1034 .open_database(&txn, Some("threads"))?
1035 .ok_or_else(|| anyhow!("threads database not found"))?;
1036
1037 for result in threads.iter(&txn)? {
1038 let (thread_id, thread_heed) = result?;
1039 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1040 }
1041
1042 Ok(())
1043 }
1044
1045 fn save_thread_sync(
1046 connection: &Arc<Mutex<Connection>>,
1047 id: ThreadId,
1048 thread: SerializedThread,
1049 ) -> Result<()> {
1050 let json_data = serde_json::to_string(&thread)?;
1051 let summary = thread.summary.to_string();
1052 let updated_at = thread.updated_at.to_rfc3339();
1053
1054 let connection = connection.lock().unwrap();
1055
1056 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1057 let data_type = DataType::Zstd;
1058 let data = compressed;
1059
1060 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1061 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1062 "})?;
1063
1064 insert((id, summary, updated_at, data_type, data))?;
1065
1066 Ok(())
1067 }
1068
1069 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1070 let connection = self.connection.clone();
1071
1072 self.executor.spawn(async move {
1073 let connection = connection.lock().unwrap();
1074 let mut select =
1075 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1076 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1077 "})?;
1078
1079 let rows = select(())?;
1080 let mut threads = Vec::new();
1081
1082 for (id, summary, updated_at) in rows {
1083 threads.push(SerializedThreadMetadata {
1084 id,
1085 summary: summary.into(),
1086 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1087 });
1088 }
1089
1090 Ok(threads)
1091 })
1092 }
1093
1094 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1095 let connection = self.connection.clone();
1096
1097 self.executor.spawn(async move {
1098 let connection = connection.lock().unwrap();
1099 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1100 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1101 "})?;
1102
1103 let rows = select(id)?;
1104 if let Some((data_type, data)) = rows.into_iter().next() {
1105 let json_data = match data_type {
1106 DataType::Zstd => {
1107 let decompressed = zstd::decode_all(&data[..])?;
1108 String::from_utf8(decompressed)?
1109 }
1110 DataType::Json => String::from_utf8(data)?,
1111 };
1112
1113 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1114 Ok(Some(thread))
1115 } else {
1116 Ok(None)
1117 }
1118 })
1119 }
1120
1121 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1122 let connection = self.connection.clone();
1123
1124 self.executor
1125 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1126 }
1127
1128 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1129 let connection = self.connection.clone();
1130
1131 self.executor.spawn(async move {
1132 let connection = connection.lock().unwrap();
1133
1134 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1135 DELETE FROM threads WHERE id = ?
1136 "})?;
1137
1138 delete(id)?;
1139
1140 Ok(())
1141 })
1142 }
1143}