1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use fs::{Fs, RemoveOptions};
14use futures::{
15 FutureExt as _, StreamExt as _,
16 channel::{mpsc, oneshot},
17 future::{self, BoxFuture, Shared},
18};
19use gpui::{
20 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
21 Subscription, Task, Window, prelude::*,
22};
23use indoc::indoc;
24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
25use project::context_server_store::{ContextServerStatus, ContextServerStore};
26use project::{Project, ProjectItem, ProjectPath, Worktree};
27use prompt_store::{
28 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
29 UserRulesContext, WorktreeContext,
30};
31use serde::{Deserialize, Serialize};
32use sqlez::{
33 bindable::{Bind, Column},
34 connection::Connection,
35 statement::Statement,
36};
37use std::{
38 cell::{Ref, RefCell},
39 path::{Path, PathBuf},
40 rc::Rc,
41 sync::{Arc, LazyLock, Mutex},
42};
43use util::{ResultExt as _, rel_path::RelPath};
44
45use zed_env_vars::ZED_STATELESS;
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum DataType {
49 #[serde(rename = "json")]
50 Json,
51 #[serde(rename = "zstd")]
52 Zstd,
53}
54
55impl Bind for DataType {
56 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
57 let value = match self {
58 DataType::Json => "json",
59 DataType::Zstd => "zstd",
60 };
61 value.bind(statement, start_index)
62 }
63}
64
65impl Column for DataType {
66 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
67 let (value, next_index) = String::column(statement, start_index)?;
68 let data_type = match value.as_str() {
69 "json" => DataType::Json,
70 "zstd" => DataType::Zstd,
71 _ => anyhow::bail!("Unknown data type: {}", value),
72 };
73 Ok((data_type, next_index))
74 }
75}
76
77static RULES_FILE_NAMES: LazyLock<[&RelPath; 9]> = LazyLock::new(|| {
78 [
79 RelPath::unix(".rules").unwrap(),
80 RelPath::unix(".cursorrules").unwrap(),
81 RelPath::unix(".windsurfrules").unwrap(),
82 RelPath::unix(".clinerules").unwrap(),
83 RelPath::unix(".github/copilot-instructions.md").unwrap(),
84 RelPath::unix("CLAUDE.md").unwrap(),
85 RelPath::unix("AGENT.md").unwrap(),
86 RelPath::unix("AGENTS.md").unwrap(),
87 RelPath::unix("GEMINI.md").unwrap(),
88 ]
89});
90
91pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
92 ThreadsDatabase::init(fs, cx);
93}
94
95/// A system prompt shared by all threads created by this ThreadStore
96#[derive(Clone, Default)]
97pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
98
99impl SharedProjectContext {
100 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
101 self.0.borrow()
102 }
103}
104
105pub type TextThreadStore = assistant_context::ContextStore;
106
107pub struct ThreadStore {
108 project: Entity<Project>,
109 tools: Entity<ToolWorkingSet>,
110 prompt_builder: Arc<PromptBuilder>,
111 prompt_store: Option<Entity<PromptStore>>,
112 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
113 threads: Vec<SerializedThreadMetadata>,
114 project_context: SharedProjectContext,
115 reload_system_prompt_tx: mpsc::Sender<()>,
116 _reload_system_prompt_task: Task<()>,
117 _subscriptions: Vec<Subscription>,
118}
119
120pub struct RulesLoadingError {
121 pub message: SharedString,
122}
123
124impl EventEmitter<RulesLoadingError> for ThreadStore {}
125
126impl ThreadStore {
127 pub fn load(
128 project: Entity<Project>,
129 tools: Entity<ToolWorkingSet>,
130 prompt_store: Option<Entity<PromptStore>>,
131 prompt_builder: Arc<PromptBuilder>,
132 cx: &mut App,
133 ) -> Task<Result<Entity<Self>>> {
134 cx.spawn(async move |cx| {
135 let (thread_store, ready_rx) = cx.update(|cx| {
136 let mut option_ready_rx = None;
137 let thread_store = cx.new(|cx| {
138 let (thread_store, ready_rx) =
139 Self::new(project, tools, prompt_builder, prompt_store, cx);
140 option_ready_rx = Some(ready_rx);
141 thread_store
142 });
143 (thread_store, option_ready_rx.take().unwrap())
144 })?;
145 ready_rx.await?;
146 Ok(thread_store)
147 })
148 }
149
150 fn new(
151 project: Entity<Project>,
152 tools: Entity<ToolWorkingSet>,
153 prompt_builder: Arc<PromptBuilder>,
154 prompt_store: Option<Entity<PromptStore>>,
155 cx: &mut Context<Self>,
156 ) -> (Self, oneshot::Receiver<()>) {
157 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
158
159 if let Some(prompt_store) = prompt_store.as_ref() {
160 subscriptions.push(cx.subscribe(
161 prompt_store,
162 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
163 this.enqueue_system_prompt_reload();
164 },
165 ))
166 }
167
168 // This channel and task prevent concurrent and redundant loading of the system prompt.
169 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
170 let (ready_tx, ready_rx) = oneshot::channel();
171 let mut ready_tx = Some(ready_tx);
172 let reload_system_prompt_task = cx.spawn({
173 let prompt_store = prompt_store.clone();
174 async move |thread_store, cx| {
175 loop {
176 let Some(reload_task) = thread_store
177 .update(cx, |thread_store, cx| {
178 thread_store.reload_system_prompt(prompt_store.clone(), cx)
179 })
180 .ok()
181 else {
182 return;
183 };
184 reload_task.await;
185 if let Some(ready_tx) = ready_tx.take() {
186 ready_tx.send(()).ok();
187 }
188 reload_system_prompt_rx.next().await;
189 }
190 }
191 });
192
193 let this = Self {
194 project,
195 tools,
196 prompt_builder,
197 prompt_store,
198 context_server_tool_ids: HashMap::default(),
199 threads: Vec::new(),
200 project_context: SharedProjectContext::default(),
201 reload_system_prompt_tx,
202 _reload_system_prompt_task: reload_system_prompt_task,
203 _subscriptions: subscriptions,
204 };
205 this.register_context_server_handlers(cx);
206 this.reload(cx).detach_and_log_err(cx);
207 (this, ready_rx)
208 }
209
210 #[cfg(any(test, feature = "test-support"))]
211 pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
212 Self {
213 project,
214 tools: cx.new(|_| ToolWorkingSet::default()),
215 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
216 prompt_store: None,
217 context_server_tool_ids: HashMap::default(),
218 threads: Vec::new(),
219 project_context: SharedProjectContext::default(),
220 reload_system_prompt_tx: mpsc::channel(0).0,
221 _reload_system_prompt_task: Task::ready(()),
222 _subscriptions: vec![],
223 }
224 }
225
226 fn handle_project_event(
227 &mut self,
228 _project: Entity<Project>,
229 event: &project::Event,
230 _cx: &mut Context<Self>,
231 ) {
232 match event {
233 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
234 self.enqueue_system_prompt_reload();
235 }
236 project::Event::WorktreeUpdatedEntries(_, items) => {
237 if items
238 .iter()
239 .any(|(path, _, _)| RULES_FILE_NAMES.iter().any(|name| path.as_ref() == *name))
240 {
241 self.enqueue_system_prompt_reload();
242 }
243 }
244 _ => {}
245 }
246 }
247
248 fn enqueue_system_prompt_reload(&mut self) {
249 self.reload_system_prompt_tx.try_send(()).ok();
250 }
251
252 // Note that this should only be called from `reload_system_prompt_task`.
253 fn reload_system_prompt(
254 &self,
255 prompt_store: Option<Entity<PromptStore>>,
256 cx: &mut Context<Self>,
257 ) -> Task<()> {
258 let worktrees = self
259 .project
260 .read(cx)
261 .visible_worktrees(cx)
262 .collect::<Vec<_>>();
263 let worktree_tasks = worktrees
264 .into_iter()
265 .map(|worktree| {
266 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
267 })
268 .collect::<Vec<_>>();
269 let default_user_rules_task = match prompt_store {
270 None => Task::ready(vec![]),
271 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
272 let prompts = prompt_store.default_prompt_metadata();
273 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
274 let contents = prompt_store.load(prompt_metadata.id, cx);
275 async move { (contents.await, prompt_metadata) }
276 });
277 cx.background_spawn(future::join_all(load_tasks))
278 }),
279 };
280
281 cx.spawn(async move |this, cx| {
282 let (worktrees, default_user_rules) =
283 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
284
285 let worktrees = worktrees
286 .into_iter()
287 .map(|(worktree, rules_error)| {
288 if let Some(rules_error) = rules_error {
289 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
290 }
291 worktree
292 })
293 .collect::<Vec<_>>();
294
295 let default_user_rules = default_user_rules
296 .into_iter()
297 .flat_map(|(contents, prompt_metadata)| match contents {
298 Ok(contents) => Some(UserRulesContext {
299 uuid: match prompt_metadata.id {
300 PromptId::User { uuid } => uuid,
301 PromptId::EditWorkflow => return None,
302 },
303 title: prompt_metadata.title.map(|title| title.to_string()),
304 contents,
305 }),
306 Err(err) => {
307 this.update(cx, |_, cx| {
308 cx.emit(RulesLoadingError {
309 message: format!("{err:?}").into(),
310 });
311 })
312 .ok();
313 None
314 }
315 })
316 .collect::<Vec<_>>();
317
318 this.update(cx, |this, _cx| {
319 *this.project_context.0.borrow_mut() =
320 Some(ProjectContext::new(worktrees, default_user_rules));
321 })
322 .ok();
323 })
324 }
325
326 fn load_worktree_info_for_system_prompt(
327 worktree: Entity<Worktree>,
328 project: Entity<Project>,
329 cx: &mut App,
330 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
331 let tree = worktree.read(cx);
332 let root_name = tree.root_name_str().into();
333 let abs_path = tree.abs_path();
334
335 let mut context = WorktreeContext {
336 root_name,
337 abs_path,
338 rules_file: None,
339 };
340
341 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
342 let Some(rules_task) = rules_task else {
343 return Task::ready((context, None));
344 };
345
346 cx.spawn(async move |_| {
347 let (rules_file, rules_file_error) = match rules_task.await {
348 Ok(rules_file) => (Some(rules_file), None),
349 Err(err) => (
350 None,
351 Some(RulesLoadingError {
352 message: format!("{err}").into(),
353 }),
354 ),
355 };
356 context.rules_file = rules_file;
357 (context, rules_file_error)
358 })
359 }
360
361 fn load_worktree_rules_file(
362 worktree: Entity<Worktree>,
363 project: Entity<Project>,
364 cx: &mut App,
365 ) -> Option<Task<Result<RulesFileContext>>> {
366 let worktree = worktree.read(cx);
367 let worktree_id = worktree.id();
368 let selected_rules_file = RULES_FILE_NAMES
369 .into_iter()
370 .filter_map(|name| {
371 worktree
372 .entry_for_path(name)
373 .filter(|entry| entry.is_file())
374 .map(|entry| entry.path.clone())
375 })
376 .next();
377
378 // Note that Cline supports `.clinerules` being a directory, but that is not currently
379 // supported. This doesn't seem to occur often in GitHub repositories.
380 selected_rules_file.map(|path_in_worktree| {
381 let project_path = ProjectPath {
382 worktree_id,
383 path: path_in_worktree.clone(),
384 };
385 let buffer_task =
386 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
387 let rope_task = cx.spawn(async move |cx| {
388 buffer_task.await?.read_with(cx, |buffer, cx| {
389 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
390 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
391 })?
392 });
393 // Build a string from the rope on a background thread.
394 cx.background_spawn(async move {
395 let (project_entry_id, rope) = rope_task.await?;
396 anyhow::Ok(RulesFileContext {
397 path_in_worktree,
398 text: rope.to_string().trim().to_string(),
399 project_entry_id: project_entry_id.to_usize(),
400 })
401 })
402 })
403 }
404
405 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
406 &self.prompt_store
407 }
408
409 pub fn tools(&self) -> Entity<ToolWorkingSet> {
410 self.tools.clone()
411 }
412
413 /// Returns the number of threads.
414 pub fn thread_count(&self) -> usize {
415 self.threads.len()
416 }
417
418 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
419 // ordering is from "ORDER BY" in `list_threads`
420 self.threads.iter()
421 }
422
423 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
424 cx.new(|cx| {
425 Thread::new(
426 self.project.clone(),
427 self.tools.clone(),
428 self.prompt_builder.clone(),
429 self.project_context.clone(),
430 cx,
431 )
432 })
433 }
434
435 pub fn create_thread_from_serialized(
436 &mut self,
437 serialized: SerializedThread,
438 cx: &mut Context<Self>,
439 ) -> Entity<Thread> {
440 cx.new(|cx| {
441 Thread::deserialize(
442 ThreadId::new(),
443 serialized,
444 self.project.clone(),
445 self.tools.clone(),
446 self.prompt_builder.clone(),
447 self.project_context.clone(),
448 None,
449 cx,
450 )
451 })
452 }
453
454 pub fn open_thread(
455 &self,
456 id: &ThreadId,
457 window: &mut Window,
458 cx: &mut Context<Self>,
459 ) -> Task<Result<Entity<Thread>>> {
460 let id = id.clone();
461 let database_future = ThreadsDatabase::global_future(cx);
462 let this = cx.weak_entity();
463 window.spawn(cx, async move |cx| {
464 let database = database_future.await.map_err(|err| anyhow!(err))?;
465 let thread = database
466 .try_find_thread(id.clone())
467 .await?
468 .with_context(|| format!("no thread found with ID: {id:?}"))?;
469
470 let thread = this.update_in(cx, |this, window, cx| {
471 cx.new(|cx| {
472 Thread::deserialize(
473 id.clone(),
474 thread,
475 this.project.clone(),
476 this.tools.clone(),
477 this.prompt_builder.clone(),
478 this.project_context.clone(),
479 Some(window),
480 cx,
481 )
482 })
483 })?;
484
485 Ok(thread)
486 })
487 }
488
489 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
490 let (metadata, serialized_thread) =
491 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
492
493 let database_future = ThreadsDatabase::global_future(cx);
494 cx.spawn(async move |this, cx| {
495 let serialized_thread = serialized_thread.await?;
496 let database = database_future.await.map_err(|err| anyhow!(err))?;
497 database.save_thread(metadata, serialized_thread).await?;
498
499 this.update(cx, |this, cx| this.reload(cx))?.await
500 })
501 }
502
503 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
504 let id = id.clone();
505 let database_future = ThreadsDatabase::global_future(cx);
506 cx.spawn(async move |this, cx| {
507 let database = database_future.await.map_err(|err| anyhow!(err))?;
508 database.delete_thread(id.clone()).await?;
509
510 this.update(cx, |this, cx| {
511 this.threads.retain(|thread| thread.id != id);
512 cx.notify();
513 })
514 })
515 }
516
517 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
518 let database_future = ThreadsDatabase::global_future(cx);
519 cx.spawn(async move |this, cx| {
520 let threads = database_future
521 .await
522 .map_err(|err| anyhow!(err))?
523 .list_threads()
524 .await?;
525
526 this.update(cx, |this, cx| {
527 this.threads = threads;
528 cx.notify();
529 })
530 })
531 }
532
533 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
534 let context_server_store = self.project.read(cx).context_server_store();
535 cx.subscribe(&context_server_store, Self::handle_context_server_event)
536 .detach();
537
538 // Check for any servers that were already running before the handler was registered
539 for server in context_server_store.read(cx).running_servers() {
540 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
541 }
542 }
543
544 fn handle_context_server_event(
545 &mut self,
546 context_server_store: Entity<ContextServerStore>,
547 event: &project::context_server_store::Event,
548 cx: &mut Context<Self>,
549 ) {
550 let tool_working_set = self.tools.clone();
551 match event {
552 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
553 match status {
554 ContextServerStatus::Starting => {}
555 ContextServerStatus::Running => {
556 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
557 }
558 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
559 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
560 tool_working_set.update(cx, |tool_working_set, cx| {
561 tool_working_set.remove(&tool_ids, cx);
562 });
563 }
564 }
565 }
566 }
567 }
568 }
569
570 fn load_context_server_tools(
571 &self,
572 server_id: ContextServerId,
573 context_server_store: Entity<ContextServerStore>,
574 cx: &mut Context<Self>,
575 ) {
576 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
577 return;
578 };
579 let tool_working_set = self.tools.clone();
580 cx.spawn(async move |this, cx| {
581 let Some(protocol) = server.client() else {
582 return;
583 };
584
585 if protocol.capable(context_server::protocol::ServerCapability::Tools)
586 && let Some(response) = protocol
587 .request::<context_server::types::requests::ListTools>(())
588 .await
589 .log_err()
590 {
591 let tool_ids = tool_working_set
592 .update(cx, |tool_working_set, cx| {
593 tool_working_set.extend(
594 response.tools.into_iter().map(|tool| {
595 Arc::new(ContextServerTool::new(
596 context_server_store.clone(),
597 server.id(),
598 tool,
599 )) as Arc<dyn Tool>
600 }),
601 cx,
602 )
603 })
604 .log_err();
605
606 if let Some(tool_ids) = tool_ids {
607 this.update(cx, |this, _| {
608 this.context_server_tool_ids.insert(server_id, tool_ids);
609 })
610 .log_err();
611 }
612 }
613 })
614 .detach();
615 }
616}
617
618#[derive(Debug, Clone, Serialize, Deserialize)]
619pub struct SerializedThreadMetadata {
620 pub id: ThreadId,
621 pub summary: SharedString,
622 pub updated_at: DateTime<Utc>,
623}
624
625#[derive(Serialize, Deserialize, Debug, PartialEq)]
626pub struct SerializedThread {
627 pub version: String,
628 pub summary: SharedString,
629 pub updated_at: DateTime<Utc>,
630 pub messages: Vec<SerializedMessage>,
631 #[serde(default)]
632 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
633 #[serde(default)]
634 pub cumulative_token_usage: TokenUsage,
635 #[serde(default)]
636 pub request_token_usage: Vec<TokenUsage>,
637 #[serde(default)]
638 pub detailed_summary_state: DetailedSummaryState,
639 #[serde(default)]
640 pub exceeded_window_error: Option<ExceededWindowError>,
641 #[serde(default)]
642 pub model: Option<SerializedLanguageModel>,
643 #[serde(default)]
644 pub completion_mode: Option<CompletionMode>,
645 #[serde(default)]
646 pub tool_use_limit_reached: bool,
647 #[serde(default)]
648 pub profile: Option<AgentProfileId>,
649}
650
651#[derive(Serialize, Deserialize, Debug, PartialEq)]
652pub struct SerializedLanguageModel {
653 pub provider: String,
654 pub model: String,
655}
656
657impl SerializedThread {
658 pub const VERSION: &'static str = "0.2.0";
659
660 pub fn from_json(json: &[u8]) -> Result<Self> {
661 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
662 match saved_thread_json.get("version") {
663 Some(serde_json::Value::String(version)) => match version.as_str() {
664 SerializedThreadV0_1_0::VERSION => {
665 let saved_thread =
666 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
667 Ok(saved_thread.upgrade())
668 }
669 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
670 saved_thread_json,
671 )?),
672 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
673 },
674 None => {
675 let saved_thread =
676 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
677 Ok(saved_thread.upgrade())
678 }
679 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
680 }
681 }
682}
683
684#[derive(Serialize, Deserialize, Debug)]
685pub struct SerializedThreadV0_1_0(
686 // The structure did not change, so we are reusing the latest SerializedThread.
687 // When making the next version, make sure this points to SerializedThreadV0_2_0
688 SerializedThread,
689);
690
691impl SerializedThreadV0_1_0 {
692 pub const VERSION: &'static str = "0.1.0";
693
694 pub fn upgrade(self) -> SerializedThread {
695 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
696
697 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
698
699 for message in self.0.messages {
700 if message.role == Role::User
701 && !message.tool_results.is_empty()
702 && let Some(last_message) = messages.last_mut()
703 {
704 debug_assert!(last_message.role == Role::Assistant);
705
706 last_message.tool_results = message.tool_results;
707 continue;
708 }
709
710 messages.push(message);
711 }
712
713 SerializedThread {
714 messages,
715 version: SerializedThread::VERSION.to_string(),
716 ..self.0
717 }
718 }
719}
720
721#[derive(Debug, Serialize, Deserialize, PartialEq)]
722pub struct SerializedMessage {
723 pub id: MessageId,
724 pub role: Role,
725 #[serde(default)]
726 pub segments: Vec<SerializedMessageSegment>,
727 #[serde(default)]
728 pub tool_uses: Vec<SerializedToolUse>,
729 #[serde(default)]
730 pub tool_results: Vec<SerializedToolResult>,
731 #[serde(default)]
732 pub context: String,
733 #[serde(default)]
734 pub creases: Vec<SerializedCrease>,
735 #[serde(default)]
736 pub is_hidden: bool,
737}
738
739#[derive(Debug, Serialize, Deserialize, PartialEq)]
740#[serde(tag = "type")]
741pub enum SerializedMessageSegment {
742 #[serde(rename = "text")]
743 Text {
744 text: String,
745 },
746 #[serde(rename = "thinking")]
747 Thinking {
748 text: String,
749 #[serde(skip_serializing_if = "Option::is_none")]
750 signature: Option<String>,
751 },
752 RedactedThinking {
753 data: String,
754 },
755}
756
757#[derive(Debug, Serialize, Deserialize, PartialEq)]
758pub struct SerializedToolUse {
759 pub id: LanguageModelToolUseId,
760 pub name: SharedString,
761 pub input: serde_json::Value,
762}
763
764#[derive(Debug, Serialize, Deserialize, PartialEq)]
765pub struct SerializedToolResult {
766 pub tool_use_id: LanguageModelToolUseId,
767 pub is_error: bool,
768 pub content: LanguageModelToolResultContent,
769 pub output: Option<serde_json::Value>,
770}
771
772#[derive(Serialize, Deserialize)]
773struct LegacySerializedThread {
774 pub summary: SharedString,
775 pub updated_at: DateTime<Utc>,
776 pub messages: Vec<LegacySerializedMessage>,
777 #[serde(default)]
778 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
779}
780
781impl LegacySerializedThread {
782 pub fn upgrade(self) -> SerializedThread {
783 SerializedThread {
784 version: SerializedThread::VERSION.to_string(),
785 summary: self.summary,
786 updated_at: self.updated_at,
787 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
788 initial_project_snapshot: self.initial_project_snapshot,
789 cumulative_token_usage: TokenUsage::default(),
790 request_token_usage: Vec::new(),
791 detailed_summary_state: DetailedSummaryState::default(),
792 exceeded_window_error: None,
793 model: None,
794 completion_mode: None,
795 tool_use_limit_reached: false,
796 profile: None,
797 }
798 }
799}
800
801#[derive(Debug, Serialize, Deserialize)]
802struct LegacySerializedMessage {
803 pub id: MessageId,
804 pub role: Role,
805 pub text: String,
806 #[serde(default)]
807 pub tool_uses: Vec<SerializedToolUse>,
808 #[serde(default)]
809 pub tool_results: Vec<SerializedToolResult>,
810}
811
812impl LegacySerializedMessage {
813 fn upgrade(self) -> SerializedMessage {
814 SerializedMessage {
815 id: self.id,
816 role: self.role,
817 segments: vec![SerializedMessageSegment::Text { text: self.text }],
818 tool_uses: self.tool_uses,
819 tool_results: self.tool_results,
820 context: String::new(),
821 creases: Vec::new(),
822 is_hidden: false,
823 }
824 }
825}
826
827#[derive(Debug, Serialize, Deserialize, PartialEq)]
828pub struct SerializedCrease {
829 pub start: usize,
830 pub end: usize,
831 pub icon_path: SharedString,
832 pub label: SharedString,
833}
834
835struct GlobalThreadsDatabase(
836 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
837);
838
839impl Global for GlobalThreadsDatabase {}
840
841pub(crate) struct ThreadsDatabase {
842 executor: BackgroundExecutor,
843 connection: Arc<Mutex<Connection>>,
844}
845
846impl ThreadsDatabase {
847 fn connection(&self) -> Arc<Mutex<Connection>> {
848 self.connection.clone()
849 }
850
851 const COMPRESSION_LEVEL: i32 = 3;
852}
853
854impl Bind for ThreadId {
855 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
856 self.to_string().bind(statement, start_index)
857 }
858}
859
860impl Column for ThreadId {
861 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
862 let (id_str, next_index) = String::column(statement, start_index)?;
863 Ok((ThreadId::from(id_str.as_str()), next_index))
864 }
865}
866
867impl ThreadsDatabase {
868 fn global_future(
869 cx: &mut App,
870 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
871 GlobalThreadsDatabase::global(cx).0.clone()
872 }
873
874 fn init(fs: Arc<dyn Fs>, cx: &mut App) {
875 let executor = cx.background_executor().clone();
876 let database_future = executor
877 .spawn({
878 let executor = executor.clone();
879 let threads_dir = paths::data_dir().join("threads");
880 async move { ThreadsDatabase::new(fs, threads_dir, executor).await }
881 })
882 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
883 .boxed()
884 .shared();
885
886 cx.set_global(GlobalThreadsDatabase(database_future));
887 }
888
889 pub async fn new(
890 fs: Arc<dyn Fs>,
891 threads_dir: PathBuf,
892 executor: BackgroundExecutor,
893 ) -> Result<Self> {
894 fs.create_dir(&threads_dir).await?;
895
896 let sqlite_path = threads_dir.join("threads.db");
897 let mdb_path = threads_dir.join("threads-db.1.mdb");
898
899 let needs_migration_from_heed = fs.is_file(&mdb_path).await;
900
901 let connection = if *ZED_STATELESS {
902 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
903 } else if cfg!(any(feature = "test-support", test)) {
904 // rust stores the name of the test on the current thread.
905 // We use this to automatically create a database that will
906 // be shared within the test (for the test_retrieve_old_thread)
907 // but not with concurrent tests.
908 let thread = std::thread::current();
909 let test_name = thread.name();
910 Connection::open_memory(Some(&format!(
911 "THREAD_FALLBACK_{}",
912 test_name.unwrap_or_default()
913 )))
914 } else {
915 Connection::open_file(&sqlite_path.to_string_lossy())
916 };
917
918 connection.exec(indoc! {"
919 CREATE TABLE IF NOT EXISTS threads (
920 id TEXT PRIMARY KEY,
921 summary TEXT NOT NULL,
922 updated_at TEXT NOT NULL,
923 data_type TEXT NOT NULL,
924 data BLOB NOT NULL
925 )
926 "})?()
927 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
928
929 let db = Self {
930 executor: executor.clone(),
931 connection: Arc::new(Mutex::new(connection)),
932 };
933
934 if needs_migration_from_heed {
935 let db_connection = db.connection();
936 let executor_clone = executor.clone();
937 executor
938 .spawn(async move {
939 log::info!("Starting threads.db migration");
940 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
941 fs.remove_dir(
942 &mdb_path,
943 RemoveOptions {
944 recursive: true,
945 ignore_if_not_exists: true,
946 },
947 )
948 .await?;
949 log::info!("threads.db migrated to sqlite");
950 Ok::<(), anyhow::Error>(())
951 })
952 .detach();
953 }
954
955 Ok(db)
956 }
957
958 // Remove this migration after 2025-09-01
959 fn migrate_from_heed(
960 mdb_path: &Path,
961 connection: Arc<Mutex<Connection>>,
962 _executor: BackgroundExecutor,
963 ) -> Result<()> {
964 use heed::types::SerdeBincode;
965 struct SerializedThreadHeed(SerializedThread);
966
967 impl heed::BytesEncode<'_> for SerializedThreadHeed {
968 type EItem = SerializedThreadHeed;
969
970 fn bytes_encode(
971 item: &Self::EItem,
972 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
973 serde_json::to_vec(&item.0)
974 .map(std::borrow::Cow::Owned)
975 .map_err(Into::into)
976 }
977 }
978
979 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
980 type DItem = SerializedThreadHeed;
981
982 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
983 SerializedThread::from_json(bytes)
984 .map(SerializedThreadHeed)
985 .map_err(Into::into)
986 }
987 }
988
989 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
990
991 let env = unsafe {
992 heed::EnvOpenOptions::new()
993 .map_size(ONE_GB_IN_BYTES)
994 .max_dbs(1)
995 .open(mdb_path)?
996 };
997
998 let txn = env.write_txn()?;
999 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1000 .open_database(&txn, Some("threads"))?
1001 .ok_or_else(|| anyhow!("threads database not found"))?;
1002
1003 for result in threads.iter(&txn)? {
1004 let (thread_id, thread_heed) = result?;
1005 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1006 }
1007
1008 Ok(())
1009 }
1010
1011 fn save_thread_sync(
1012 connection: &Arc<Mutex<Connection>>,
1013 id: ThreadId,
1014 thread: SerializedThread,
1015 ) -> Result<()> {
1016 let json_data = serde_json::to_string(&thread)?;
1017 let summary = thread.summary.to_string();
1018 let updated_at = thread.updated_at.to_rfc3339();
1019
1020 let connection = connection.lock().unwrap();
1021
1022 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1023 let data_type = DataType::Zstd;
1024 let data = compressed;
1025
1026 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1027 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1028 "})?;
1029
1030 insert((id, summary, updated_at, data_type, data))?;
1031
1032 Ok(())
1033 }
1034
1035 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1036 let connection = self.connection.clone();
1037
1038 self.executor.spawn(async move {
1039 let connection = connection.lock().unwrap();
1040 let mut select =
1041 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1042 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1043 "})?;
1044
1045 let rows = select(())?;
1046 let mut threads = Vec::new();
1047
1048 for (id, summary, updated_at) in rows {
1049 threads.push(SerializedThreadMetadata {
1050 id,
1051 summary: summary.into(),
1052 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1053 });
1054 }
1055
1056 Ok(threads)
1057 })
1058 }
1059
1060 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1061 let connection = self.connection.clone();
1062
1063 self.executor.spawn(async move {
1064 let connection = connection.lock().unwrap();
1065 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1066 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1067 "})?;
1068
1069 let rows = select(id)?;
1070 if let Some((data_type, data)) = rows.into_iter().next() {
1071 let json_data = match data_type {
1072 DataType::Zstd => {
1073 let decompressed = zstd::decode_all(&data[..])?;
1074 String::from_utf8(decompressed)?
1075 }
1076 DataType::Json => String::from_utf8(data)?,
1077 };
1078
1079 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1080 Ok(Some(thread))
1081 } else {
1082 Ok(None)
1083 }
1084 })
1085 }
1086
1087 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1088 let connection = self.connection.clone();
1089
1090 self.executor
1091 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1092 }
1093
1094 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1095 let connection = self.connection.clone();
1096
1097 self.executor.spawn(async move {
1098 let connection = connection.lock().unwrap();
1099
1100 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1101 DELETE FROM threads WHERE id = ?
1102 "})?;
1103
1104 delete(id)?;
1105
1106 Ok(())
1107 })
1108 }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114 use crate::thread::{DetailedSummaryState, MessageId};
1115 use chrono::Utc;
1116 use language_model::{Role, TokenUsage};
1117 use pretty_assertions::assert_eq;
1118
1119 #[test]
1120 fn test_legacy_serialized_thread_upgrade() {
1121 let updated_at = Utc::now();
1122 let legacy_thread = LegacySerializedThread {
1123 summary: "Test conversation".into(),
1124 updated_at,
1125 messages: vec![LegacySerializedMessage {
1126 id: MessageId(1),
1127 role: Role::User,
1128 text: "Hello, world!".to_string(),
1129 tool_uses: vec![],
1130 tool_results: vec![],
1131 }],
1132 initial_project_snapshot: None,
1133 };
1134
1135 let upgraded = legacy_thread.upgrade();
1136
1137 assert_eq!(
1138 upgraded,
1139 SerializedThread {
1140 summary: "Test conversation".into(),
1141 updated_at,
1142 messages: vec![SerializedMessage {
1143 id: MessageId(1),
1144 role: Role::User,
1145 segments: vec![SerializedMessageSegment::Text {
1146 text: "Hello, world!".to_string()
1147 }],
1148 tool_uses: vec![],
1149 tool_results: vec![],
1150 context: "".to_string(),
1151 creases: vec![],
1152 is_hidden: false
1153 }],
1154 version: SerializedThread::VERSION.to_string(),
1155 initial_project_snapshot: None,
1156 cumulative_token_usage: TokenUsage::default(),
1157 request_token_usage: vec![],
1158 detailed_summary_state: DetailedSummaryState::default(),
1159 exceeded_window_error: None,
1160 model: None,
1161 completion_mode: None,
1162 tool_use_limit_reached: false,
1163 profile: None
1164 }
1165 )
1166 }
1167
1168 #[test]
1169 fn test_serialized_threadv0_1_0_upgrade() {
1170 let updated_at = Utc::now();
1171 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1172 summary: "Test conversation".into(),
1173 updated_at,
1174 messages: vec![
1175 SerializedMessage {
1176 id: MessageId(1),
1177 role: Role::User,
1178 segments: vec![SerializedMessageSegment::Text {
1179 text: "Use tool_1".to_string(),
1180 }],
1181 tool_uses: vec![],
1182 tool_results: vec![],
1183 context: "".to_string(),
1184 creases: vec![],
1185 is_hidden: false,
1186 },
1187 SerializedMessage {
1188 id: MessageId(2),
1189 role: Role::Assistant,
1190 segments: vec![SerializedMessageSegment::Text {
1191 text: "I want to use a tool".to_string(),
1192 }],
1193 tool_uses: vec![SerializedToolUse {
1194 id: "abc".into(),
1195 name: "tool_1".into(),
1196 input: serde_json::Value::Null,
1197 }],
1198 tool_results: vec![],
1199 context: "".to_string(),
1200 creases: vec![],
1201 is_hidden: false,
1202 },
1203 SerializedMessage {
1204 id: MessageId(1),
1205 role: Role::User,
1206 segments: vec![SerializedMessageSegment::Text {
1207 text: "Here is the tool result".to_string(),
1208 }],
1209 tool_uses: vec![],
1210 tool_results: vec![SerializedToolResult {
1211 tool_use_id: "abc".into(),
1212 is_error: false,
1213 content: LanguageModelToolResultContent::Text("abcdef".into()),
1214 output: Some(serde_json::Value::Null),
1215 }],
1216 context: "".to_string(),
1217 creases: vec![],
1218 is_hidden: false,
1219 },
1220 ],
1221 version: SerializedThreadV0_1_0::VERSION.to_string(),
1222 initial_project_snapshot: None,
1223 cumulative_token_usage: TokenUsage::default(),
1224 request_token_usage: vec![],
1225 detailed_summary_state: DetailedSummaryState::default(),
1226 exceeded_window_error: None,
1227 model: None,
1228 completion_mode: None,
1229 tool_use_limit_reached: false,
1230 profile: None,
1231 });
1232 let upgraded = thread_v0_1_0.upgrade();
1233
1234 assert_eq!(
1235 upgraded,
1236 SerializedThread {
1237 summary: "Test conversation".into(),
1238 updated_at,
1239 messages: vec![
1240 SerializedMessage {
1241 id: MessageId(1),
1242 role: Role::User,
1243 segments: vec![SerializedMessageSegment::Text {
1244 text: "Use tool_1".to_string()
1245 }],
1246 tool_uses: vec![],
1247 tool_results: vec![],
1248 context: "".to_string(),
1249 creases: vec![],
1250 is_hidden: false
1251 },
1252 SerializedMessage {
1253 id: MessageId(2),
1254 role: Role::Assistant,
1255 segments: vec![SerializedMessageSegment::Text {
1256 text: "I want to use a tool".to_string(),
1257 }],
1258 tool_uses: vec![SerializedToolUse {
1259 id: "abc".into(),
1260 name: "tool_1".into(),
1261 input: serde_json::Value::Null,
1262 }],
1263 tool_results: vec![SerializedToolResult {
1264 tool_use_id: "abc".into(),
1265 is_error: false,
1266 content: LanguageModelToolResultContent::Text("abcdef".into()),
1267 output: Some(serde_json::Value::Null),
1268 }],
1269 context: "".to_string(),
1270 creases: vec![],
1271 is_hidden: false,
1272 },
1273 ],
1274 version: SerializedThread::VERSION.to_string(),
1275 initial_project_snapshot: None,
1276 cumulative_token_usage: TokenUsage::default(),
1277 request_token_usage: vec![],
1278 detailed_summary_state: DetailedSummaryState::default(),
1279 exceeded_window_error: None,
1280 model: None,
1281 completion_mode: None,
1282 tool_use_limit_reached: false,
1283 profile: None
1284 }
1285 )
1286 }
1287}