1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use fs::{Fs, RemoveOptions};
14use futures::{
15 FutureExt as _, StreamExt as _,
16 channel::{mpsc, oneshot},
17 future::{self, BoxFuture, Shared},
18};
19use gpui::{
20 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
21 Subscription, Task, Window, prelude::*,
22};
23use indoc::indoc;
24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
25use project::context_server_store::{ContextServerStatus, ContextServerStore};
26use project::{Project, ProjectItem, ProjectPath, Worktree};
27use prompt_store::{
28 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
29 UserRulesContext, WorktreeContext,
30};
31use serde::{Deserialize, Serialize};
32use sqlez::{
33 bindable::{Bind, Column},
34 connection::Connection,
35 statement::Statement,
36};
37use std::{
38 cell::{Ref, RefCell},
39 path::{Path, PathBuf},
40 rc::Rc,
41 sync::{Arc, Mutex},
42};
43use util::{ResultExt as _, rel_path::RelPath};
44
45use zed_env_vars::ZED_STATELESS;
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum DataType {
49 #[serde(rename = "json")]
50 Json,
51 #[serde(rename = "zstd")]
52 Zstd,
53}
54
55impl Bind for DataType {
56 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
57 let value = match self {
58 DataType::Json => "json",
59 DataType::Zstd => "zstd",
60 };
61 value.bind(statement, start_index)
62 }
63}
64
65impl Column for DataType {
66 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
67 let (value, next_index) = String::column(statement, start_index)?;
68 let data_type = match value.as_str() {
69 "json" => DataType::Json,
70 "zstd" => DataType::Zstd,
71 _ => anyhow::bail!("Unknown data type: {}", value),
72 };
73 Ok((data_type, next_index))
74 }
75}
76
77const RULES_FILE_NAMES: [&str; 9] = [
78 ".rules",
79 ".cursorrules",
80 ".windsurfrules",
81 ".clinerules",
82 ".github/copilot-instructions.md",
83 "CLAUDE.md",
84 "AGENT.md",
85 "AGENTS.md",
86 "GEMINI.md",
87];
88
89pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
90 ThreadsDatabase::init(fs, cx);
91}
92
93/// A system prompt shared by all threads created by this ThreadStore
94#[derive(Clone, Default)]
95pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
96
97impl SharedProjectContext {
98 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
99 self.0.borrow()
100 }
101}
102
103pub type TextThreadStore = assistant_context::ContextStore;
104
105pub struct ThreadStore {
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
111 threads: Vec<SerializedThreadMetadata>,
112 project_context: SharedProjectContext,
113 reload_system_prompt_tx: mpsc::Sender<()>,
114 _reload_system_prompt_task: Task<()>,
115 _subscriptions: Vec<Subscription>,
116}
117
118pub struct RulesLoadingError {
119 pub message: SharedString,
120}
121
122impl EventEmitter<RulesLoadingError> for ThreadStore {}
123
124impl ThreadStore {
125 pub fn load(
126 project: Entity<Project>,
127 tools: Entity<ToolWorkingSet>,
128 prompt_store: Option<Entity<PromptStore>>,
129 prompt_builder: Arc<PromptBuilder>,
130 cx: &mut App,
131 ) -> Task<Result<Entity<Self>>> {
132 cx.spawn(async move |cx| {
133 let (thread_store, ready_rx) = cx.update(|cx| {
134 let mut option_ready_rx = None;
135 let thread_store = cx.new(|cx| {
136 let (thread_store, ready_rx) =
137 Self::new(project, tools, prompt_builder, prompt_store, cx);
138 option_ready_rx = Some(ready_rx);
139 thread_store
140 });
141 (thread_store, option_ready_rx.take().unwrap())
142 })?;
143 ready_rx.await?;
144 Ok(thread_store)
145 })
146 }
147
148 fn new(
149 project: Entity<Project>,
150 tools: Entity<ToolWorkingSet>,
151 prompt_builder: Arc<PromptBuilder>,
152 prompt_store: Option<Entity<PromptStore>>,
153 cx: &mut Context<Self>,
154 ) -> (Self, oneshot::Receiver<()>) {
155 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
156
157 if let Some(prompt_store) = prompt_store.as_ref() {
158 subscriptions.push(cx.subscribe(
159 prompt_store,
160 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
161 this.enqueue_system_prompt_reload();
162 },
163 ))
164 }
165
166 // This channel and task prevent concurrent and redundant loading of the system prompt.
167 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
168 let (ready_tx, ready_rx) = oneshot::channel();
169 let mut ready_tx = Some(ready_tx);
170 let reload_system_prompt_task = cx.spawn({
171 let prompt_store = prompt_store.clone();
172 async move |thread_store, cx| {
173 loop {
174 let Some(reload_task) = thread_store
175 .update(cx, |thread_store, cx| {
176 thread_store.reload_system_prompt(prompt_store.clone(), cx)
177 })
178 .ok()
179 else {
180 return;
181 };
182 reload_task.await;
183 if let Some(ready_tx) = ready_tx.take() {
184 ready_tx.send(()).ok();
185 }
186 reload_system_prompt_rx.next().await;
187 }
188 }
189 });
190
191 let this = Self {
192 project,
193 tools,
194 prompt_builder,
195 prompt_store,
196 context_server_tool_ids: HashMap::default(),
197 threads: Vec::new(),
198 project_context: SharedProjectContext::default(),
199 reload_system_prompt_tx,
200 _reload_system_prompt_task: reload_system_prompt_task,
201 _subscriptions: subscriptions,
202 };
203 this.register_context_server_handlers(cx);
204 this.reload(cx).detach_and_log_err(cx);
205 (this, ready_rx)
206 }
207
208 #[cfg(any(test, feature = "test-support"))]
209 pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
210 Self {
211 project,
212 tools: cx.new(|_| ToolWorkingSet::default()),
213 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
214 prompt_store: None,
215 context_server_tool_ids: HashMap::default(),
216 threads: Vec::new(),
217 project_context: SharedProjectContext::default(),
218 reload_system_prompt_tx: mpsc::channel(0).0,
219 _reload_system_prompt_task: Task::ready(()),
220 _subscriptions: vec![],
221 }
222 }
223
224 fn handle_project_event(
225 &mut self,
226 _project: Entity<Project>,
227 event: &project::Event,
228 _cx: &mut Context<Self>,
229 ) {
230 match event {
231 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
232 self.enqueue_system_prompt_reload();
233 }
234 project::Event::WorktreeUpdatedEntries(_, items) => {
235 if items.iter().any(|(path, _, _)| {
236 RULES_FILE_NAMES
237 .iter()
238 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
239 }) {
240 self.enqueue_system_prompt_reload();
241 }
242 }
243 _ => {}
244 }
245 }
246
247 fn enqueue_system_prompt_reload(&mut self) {
248 self.reload_system_prompt_tx.try_send(()).ok();
249 }
250
251 // Note that this should only be called from `reload_system_prompt_task`.
252 fn reload_system_prompt(
253 &self,
254 prompt_store: Option<Entity<PromptStore>>,
255 cx: &mut Context<Self>,
256 ) -> Task<()> {
257 let worktrees = self
258 .project
259 .read(cx)
260 .visible_worktrees(cx)
261 .collect::<Vec<_>>();
262 let worktree_tasks = worktrees
263 .into_iter()
264 .map(|worktree| {
265 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
266 })
267 .collect::<Vec<_>>();
268 let default_user_rules_task = match prompt_store {
269 None => Task::ready(vec![]),
270 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
271 let prompts = prompt_store.default_prompt_metadata();
272 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
273 let contents = prompt_store.load(prompt_metadata.id, cx);
274 async move { (contents.await, prompt_metadata) }
275 });
276 cx.background_spawn(future::join_all(load_tasks))
277 }),
278 };
279
280 cx.spawn(async move |this, cx| {
281 let (worktrees, default_user_rules) =
282 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
283
284 let worktrees = worktrees
285 .into_iter()
286 .map(|(worktree, rules_error)| {
287 if let Some(rules_error) = rules_error {
288 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
289 }
290 worktree
291 })
292 .collect::<Vec<_>>();
293
294 let default_user_rules = default_user_rules
295 .into_iter()
296 .flat_map(|(contents, prompt_metadata)| match contents {
297 Ok(contents) => Some(UserRulesContext {
298 uuid: match prompt_metadata.id {
299 PromptId::User { uuid } => uuid,
300 PromptId::EditWorkflow => return None,
301 },
302 title: prompt_metadata.title.map(|title| title.to_string()),
303 contents,
304 }),
305 Err(err) => {
306 this.update(cx, |_, cx| {
307 cx.emit(RulesLoadingError {
308 message: format!("{err:?}").into(),
309 });
310 })
311 .ok();
312 None
313 }
314 })
315 .collect::<Vec<_>>();
316
317 this.update(cx, |this, _cx| {
318 *this.project_context.0.borrow_mut() =
319 Some(ProjectContext::new(worktrees, default_user_rules));
320 })
321 .ok();
322 })
323 }
324
325 fn load_worktree_info_for_system_prompt(
326 worktree: Entity<Worktree>,
327 project: Entity<Project>,
328 cx: &mut App,
329 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
330 let tree = worktree.read(cx);
331 let root_name = tree.root_name_str().into();
332 let abs_path = tree.abs_path();
333
334 let mut context = WorktreeContext {
335 root_name,
336 abs_path,
337 rules_file: None,
338 };
339
340 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
341 let Some(rules_task) = rules_task else {
342 return Task::ready((context, None));
343 };
344
345 cx.spawn(async move |_| {
346 let (rules_file, rules_file_error) = match rules_task.await {
347 Ok(rules_file) => (Some(rules_file), None),
348 Err(err) => (
349 None,
350 Some(RulesLoadingError {
351 message: format!("{err}").into(),
352 }),
353 ),
354 };
355 context.rules_file = rules_file;
356 (context, rules_file_error)
357 })
358 }
359
360 fn load_worktree_rules_file(
361 worktree: Entity<Worktree>,
362 project: Entity<Project>,
363 cx: &mut App,
364 ) -> Option<Task<Result<RulesFileContext>>> {
365 let worktree = worktree.read(cx);
366 let worktree_id = worktree.id();
367 let selected_rules_file = RULES_FILE_NAMES
368 .into_iter()
369 .filter_map(|name| {
370 worktree
371 .entry_for_path(RelPath::unix(name).unwrap())
372 .filter(|entry| entry.is_file())
373 .map(|entry| entry.path.clone())
374 })
375 .next();
376
377 // Note that Cline supports `.clinerules` being a directory, but that is not currently
378 // supported. This doesn't seem to occur often in GitHub repositories.
379 selected_rules_file.map(|path_in_worktree| {
380 let project_path = ProjectPath {
381 worktree_id,
382 path: path_in_worktree.clone(),
383 };
384 let buffer_task =
385 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
386 let rope_task = cx.spawn(async move |cx| {
387 buffer_task.await?.read_with(cx, |buffer, cx| {
388 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
389 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
390 })?
391 });
392 // Build a string from the rope on a background thread.
393 cx.background_spawn(async move {
394 let (project_entry_id, rope) = rope_task.await?;
395 anyhow::Ok(RulesFileContext {
396 path_in_worktree,
397 text: rope.to_string().trim().to_string(),
398 project_entry_id: project_entry_id.to_usize(),
399 })
400 })
401 })
402 }
403
404 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
405 &self.prompt_store
406 }
407
408 pub fn tools(&self) -> Entity<ToolWorkingSet> {
409 self.tools.clone()
410 }
411
412 /// Returns the number of threads.
413 pub fn thread_count(&self) -> usize {
414 self.threads.len()
415 }
416
417 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
418 // ordering is from "ORDER BY" in `list_threads`
419 self.threads.iter()
420 }
421
422 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
423 cx.new(|cx| {
424 Thread::new(
425 self.project.clone(),
426 self.tools.clone(),
427 self.prompt_builder.clone(),
428 self.project_context.clone(),
429 cx,
430 )
431 })
432 }
433
434 pub fn create_thread_from_serialized(
435 &mut self,
436 serialized: SerializedThread,
437 cx: &mut Context<Self>,
438 ) -> Entity<Thread> {
439 cx.new(|cx| {
440 Thread::deserialize(
441 ThreadId::new(),
442 serialized,
443 self.project.clone(),
444 self.tools.clone(),
445 self.prompt_builder.clone(),
446 self.project_context.clone(),
447 None,
448 cx,
449 )
450 })
451 }
452
453 pub fn open_thread(
454 &self,
455 id: &ThreadId,
456 window: &mut Window,
457 cx: &mut Context<Self>,
458 ) -> Task<Result<Entity<Thread>>> {
459 let id = id.clone();
460 let database_future = ThreadsDatabase::global_future(cx);
461 let this = cx.weak_entity();
462 window.spawn(cx, async move |cx| {
463 let database = database_future.await.map_err(|err| anyhow!(err))?;
464 let thread = database
465 .try_find_thread(id.clone())
466 .await?
467 .with_context(|| format!("no thread found with ID: {id:?}"))?;
468
469 let thread = this.update_in(cx, |this, window, cx| {
470 cx.new(|cx| {
471 Thread::deserialize(
472 id.clone(),
473 thread,
474 this.project.clone(),
475 this.tools.clone(),
476 this.prompt_builder.clone(),
477 this.project_context.clone(),
478 Some(window),
479 cx,
480 )
481 })
482 })?;
483
484 Ok(thread)
485 })
486 }
487
488 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
489 let (metadata, serialized_thread) =
490 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
491
492 let database_future = ThreadsDatabase::global_future(cx);
493 cx.spawn(async move |this, cx| {
494 let serialized_thread = serialized_thread.await?;
495 let database = database_future.await.map_err(|err| anyhow!(err))?;
496 database.save_thread(metadata, serialized_thread).await?;
497
498 this.update(cx, |this, cx| this.reload(cx))?.await
499 })
500 }
501
502 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
503 let id = id.clone();
504 let database_future = ThreadsDatabase::global_future(cx);
505 cx.spawn(async move |this, cx| {
506 let database = database_future.await.map_err(|err| anyhow!(err))?;
507 database.delete_thread(id.clone()).await?;
508
509 this.update(cx, |this, cx| {
510 this.threads.retain(|thread| thread.id != id);
511 cx.notify();
512 })
513 })
514 }
515
516 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
517 let database_future = ThreadsDatabase::global_future(cx);
518 cx.spawn(async move |this, cx| {
519 let threads = database_future
520 .await
521 .map_err(|err| anyhow!(err))?
522 .list_threads()
523 .await?;
524
525 this.update(cx, |this, cx| {
526 this.threads = threads;
527 cx.notify();
528 })
529 })
530 }
531
532 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
533 let context_server_store = self.project.read(cx).context_server_store();
534 cx.subscribe(&context_server_store, Self::handle_context_server_event)
535 .detach();
536
537 // Check for any servers that were already running before the handler was registered
538 for server in context_server_store.read(cx).running_servers() {
539 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
540 }
541 }
542
543 fn handle_context_server_event(
544 &mut self,
545 context_server_store: Entity<ContextServerStore>,
546 event: &project::context_server_store::Event,
547 cx: &mut Context<Self>,
548 ) {
549 let tool_working_set = self.tools.clone();
550 match event {
551 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
552 match status {
553 ContextServerStatus::Starting => {}
554 ContextServerStatus::Running => {
555 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
556 }
557 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
558 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
559 tool_working_set.update(cx, |tool_working_set, cx| {
560 tool_working_set.remove(&tool_ids, cx);
561 });
562 }
563 }
564 }
565 }
566 }
567 }
568
569 fn load_context_server_tools(
570 &self,
571 server_id: ContextServerId,
572 context_server_store: Entity<ContextServerStore>,
573 cx: &mut Context<Self>,
574 ) {
575 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
576 return;
577 };
578 let tool_working_set = self.tools.clone();
579 cx.spawn(async move |this, cx| {
580 let Some(protocol) = server.client() else {
581 return;
582 };
583
584 if protocol.capable(context_server::protocol::ServerCapability::Tools)
585 && let Some(response) = protocol
586 .request::<context_server::types::requests::ListTools>(())
587 .await
588 .log_err()
589 {
590 let tool_ids = tool_working_set
591 .update(cx, |tool_working_set, cx| {
592 tool_working_set.extend(
593 response.tools.into_iter().map(|tool| {
594 Arc::new(ContextServerTool::new(
595 context_server_store.clone(),
596 server.id(),
597 tool,
598 )) as Arc<dyn Tool>
599 }),
600 cx,
601 )
602 })
603 .log_err();
604
605 if let Some(tool_ids) = tool_ids {
606 this.update(cx, |this, _| {
607 this.context_server_tool_ids.insert(server_id, tool_ids);
608 })
609 .log_err();
610 }
611 }
612 })
613 .detach();
614 }
615}
616
617#[derive(Debug, Clone, Serialize, Deserialize)]
618pub struct SerializedThreadMetadata {
619 pub id: ThreadId,
620 pub summary: SharedString,
621 pub updated_at: DateTime<Utc>,
622}
623
624#[derive(Serialize, Deserialize, Debug, PartialEq)]
625pub struct SerializedThread {
626 pub version: String,
627 pub summary: SharedString,
628 pub updated_at: DateTime<Utc>,
629 pub messages: Vec<SerializedMessage>,
630 #[serde(default)]
631 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
632 #[serde(default)]
633 pub cumulative_token_usage: TokenUsage,
634 #[serde(default)]
635 pub request_token_usage: Vec<TokenUsage>,
636 #[serde(default)]
637 pub detailed_summary_state: DetailedSummaryState,
638 #[serde(default)]
639 pub exceeded_window_error: Option<ExceededWindowError>,
640 #[serde(default)]
641 pub model: Option<SerializedLanguageModel>,
642 #[serde(default)]
643 pub completion_mode: Option<CompletionMode>,
644 #[serde(default)]
645 pub tool_use_limit_reached: bool,
646 #[serde(default)]
647 pub profile: Option<AgentProfileId>,
648}
649
650#[derive(Serialize, Deserialize, Debug, PartialEq)]
651pub struct SerializedLanguageModel {
652 pub provider: String,
653 pub model: String,
654}
655
656impl SerializedThread {
657 pub const VERSION: &'static str = "0.2.0";
658
659 pub fn from_json(json: &[u8]) -> Result<Self> {
660 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
661 match saved_thread_json.get("version") {
662 Some(serde_json::Value::String(version)) => match version.as_str() {
663 SerializedThreadV0_1_0::VERSION => {
664 let saved_thread =
665 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
666 Ok(saved_thread.upgrade())
667 }
668 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
669 saved_thread_json,
670 )?),
671 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
672 },
673 None => {
674 let saved_thread =
675 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
676 Ok(saved_thread.upgrade())
677 }
678 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
679 }
680 }
681}
682
683#[derive(Serialize, Deserialize, Debug)]
684pub struct SerializedThreadV0_1_0(
685 // The structure did not change, so we are reusing the latest SerializedThread.
686 // When making the next version, make sure this points to SerializedThreadV0_2_0
687 SerializedThread,
688);
689
690impl SerializedThreadV0_1_0 {
691 pub const VERSION: &'static str = "0.1.0";
692
693 pub fn upgrade(self) -> SerializedThread {
694 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
695
696 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
697
698 for message in self.0.messages {
699 if message.role == Role::User
700 && !message.tool_results.is_empty()
701 && let Some(last_message) = messages.last_mut()
702 {
703 debug_assert!(last_message.role == Role::Assistant);
704
705 last_message.tool_results = message.tool_results;
706 continue;
707 }
708
709 messages.push(message);
710 }
711
712 SerializedThread {
713 messages,
714 version: SerializedThread::VERSION.to_string(),
715 ..self.0
716 }
717 }
718}
719
720#[derive(Debug, Serialize, Deserialize, PartialEq)]
721pub struct SerializedMessage {
722 pub id: MessageId,
723 pub role: Role,
724 #[serde(default)]
725 pub segments: Vec<SerializedMessageSegment>,
726 #[serde(default)]
727 pub tool_uses: Vec<SerializedToolUse>,
728 #[serde(default)]
729 pub tool_results: Vec<SerializedToolResult>,
730 #[serde(default)]
731 pub context: String,
732 #[serde(default)]
733 pub creases: Vec<SerializedCrease>,
734 #[serde(default)]
735 pub is_hidden: bool,
736}
737
738#[derive(Debug, Serialize, Deserialize, PartialEq)]
739#[serde(tag = "type")]
740pub enum SerializedMessageSegment {
741 #[serde(rename = "text")]
742 Text {
743 text: String,
744 },
745 #[serde(rename = "thinking")]
746 Thinking {
747 text: String,
748 #[serde(skip_serializing_if = "Option::is_none")]
749 signature: Option<String>,
750 },
751 RedactedThinking {
752 data: String,
753 },
754}
755
756#[derive(Debug, Serialize, Deserialize, PartialEq)]
757pub struct SerializedToolUse {
758 pub id: LanguageModelToolUseId,
759 pub name: SharedString,
760 pub input: serde_json::Value,
761}
762
763#[derive(Debug, Serialize, Deserialize, PartialEq)]
764pub struct SerializedToolResult {
765 pub tool_use_id: LanguageModelToolUseId,
766 pub is_error: bool,
767 pub content: LanguageModelToolResultContent,
768 pub output: Option<serde_json::Value>,
769}
770
771#[derive(Serialize, Deserialize)]
772struct LegacySerializedThread {
773 pub summary: SharedString,
774 pub updated_at: DateTime<Utc>,
775 pub messages: Vec<LegacySerializedMessage>,
776 #[serde(default)]
777 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
778}
779
780impl LegacySerializedThread {
781 pub fn upgrade(self) -> SerializedThread {
782 SerializedThread {
783 version: SerializedThread::VERSION.to_string(),
784 summary: self.summary,
785 updated_at: self.updated_at,
786 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
787 initial_project_snapshot: self.initial_project_snapshot,
788 cumulative_token_usage: TokenUsage::default(),
789 request_token_usage: Vec::new(),
790 detailed_summary_state: DetailedSummaryState::default(),
791 exceeded_window_error: None,
792 model: None,
793 completion_mode: None,
794 tool_use_limit_reached: false,
795 profile: None,
796 }
797 }
798}
799
800#[derive(Debug, Serialize, Deserialize)]
801struct LegacySerializedMessage {
802 pub id: MessageId,
803 pub role: Role,
804 pub text: String,
805 #[serde(default)]
806 pub tool_uses: Vec<SerializedToolUse>,
807 #[serde(default)]
808 pub tool_results: Vec<SerializedToolResult>,
809}
810
811impl LegacySerializedMessage {
812 fn upgrade(self) -> SerializedMessage {
813 SerializedMessage {
814 id: self.id,
815 role: self.role,
816 segments: vec![SerializedMessageSegment::Text { text: self.text }],
817 tool_uses: self.tool_uses,
818 tool_results: self.tool_results,
819 context: String::new(),
820 creases: Vec::new(),
821 is_hidden: false,
822 }
823 }
824}
825
826#[derive(Debug, Serialize, Deserialize, PartialEq)]
827pub struct SerializedCrease {
828 pub start: usize,
829 pub end: usize,
830 pub icon_path: SharedString,
831 pub label: SharedString,
832}
833
834struct GlobalThreadsDatabase(
835 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
836);
837
838impl Global for GlobalThreadsDatabase {}
839
840pub(crate) struct ThreadsDatabase {
841 executor: BackgroundExecutor,
842 connection: Arc<Mutex<Connection>>,
843}
844
845impl ThreadsDatabase {
846 fn connection(&self) -> Arc<Mutex<Connection>> {
847 self.connection.clone()
848 }
849
850 const COMPRESSION_LEVEL: i32 = 3;
851}
852
853impl Bind for ThreadId {
854 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
855 self.to_string().bind(statement, start_index)
856 }
857}
858
859impl Column for ThreadId {
860 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
861 let (id_str, next_index) = String::column(statement, start_index)?;
862 Ok((ThreadId::from(id_str.as_str()), next_index))
863 }
864}
865
866impl ThreadsDatabase {
867 fn global_future(
868 cx: &mut App,
869 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
870 GlobalThreadsDatabase::global(cx).0.clone()
871 }
872
873 fn init(fs: Arc<dyn Fs>, cx: &mut App) {
874 let executor = cx.background_executor().clone();
875 let database_future = executor
876 .spawn({
877 let executor = executor.clone();
878 let threads_dir = paths::data_dir().join("threads");
879 async move { ThreadsDatabase::new(fs, threads_dir, executor).await }
880 })
881 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
882 .boxed()
883 .shared();
884
885 cx.set_global(GlobalThreadsDatabase(database_future));
886 }
887
888 pub async fn new(
889 fs: Arc<dyn Fs>,
890 threads_dir: PathBuf,
891 executor: BackgroundExecutor,
892 ) -> Result<Self> {
893 fs.create_dir(&threads_dir).await?;
894
895 let sqlite_path = threads_dir.join("threads.db");
896 let mdb_path = threads_dir.join("threads-db.1.mdb");
897
898 let needs_migration_from_heed = fs.is_file(&mdb_path).await;
899
900 let connection = if *ZED_STATELESS {
901 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
902 } else if cfg!(any(feature = "test-support", test)) {
903 // rust stores the name of the test on the current thread.
904 // We use this to automatically create a database that will
905 // be shared within the test (for the test_retrieve_old_thread)
906 // but not with concurrent tests.
907 let thread = std::thread::current();
908 let test_name = thread.name();
909 Connection::open_memory(Some(&format!(
910 "THREAD_FALLBACK_{}",
911 test_name.unwrap_or_default()
912 )))
913 } else {
914 Connection::open_file(&sqlite_path.to_string_lossy())
915 };
916
917 connection.exec(indoc! {"
918 CREATE TABLE IF NOT EXISTS threads (
919 id TEXT PRIMARY KEY,
920 summary TEXT NOT NULL,
921 updated_at TEXT NOT NULL,
922 data_type TEXT NOT NULL,
923 data BLOB NOT NULL
924 )
925 "})?()
926 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
927
928 let db = Self {
929 executor: executor.clone(),
930 connection: Arc::new(Mutex::new(connection)),
931 };
932
933 if needs_migration_from_heed {
934 let db_connection = db.connection();
935 let executor_clone = executor.clone();
936 executor
937 .spawn(async move {
938 log::info!("Starting threads.db migration");
939 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
940 fs.remove_dir(
941 &mdb_path,
942 RemoveOptions {
943 recursive: true,
944 ignore_if_not_exists: true,
945 },
946 )
947 .await?;
948 log::info!("threads.db migrated to sqlite");
949 Ok::<(), anyhow::Error>(())
950 })
951 .detach();
952 }
953
954 Ok(db)
955 }
956
957 // Remove this migration after 2025-09-01
958 fn migrate_from_heed(
959 mdb_path: &Path,
960 connection: Arc<Mutex<Connection>>,
961 _executor: BackgroundExecutor,
962 ) -> Result<()> {
963 use heed::types::SerdeBincode;
964 struct SerializedThreadHeed(SerializedThread);
965
966 impl heed::BytesEncode<'_> for SerializedThreadHeed {
967 type EItem = SerializedThreadHeed;
968
969 fn bytes_encode(
970 item: &Self::EItem,
971 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
972 serde_json::to_vec(&item.0)
973 .map(std::borrow::Cow::Owned)
974 .map_err(Into::into)
975 }
976 }
977
978 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
979 type DItem = SerializedThreadHeed;
980
981 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
982 SerializedThread::from_json(bytes)
983 .map(SerializedThreadHeed)
984 .map_err(Into::into)
985 }
986 }
987
988 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
989
990 let env = unsafe {
991 heed::EnvOpenOptions::new()
992 .map_size(ONE_GB_IN_BYTES)
993 .max_dbs(1)
994 .open(mdb_path)?
995 };
996
997 let txn = env.write_txn()?;
998 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
999 .open_database(&txn, Some("threads"))?
1000 .ok_or_else(|| anyhow!("threads database not found"))?;
1001
1002 for result in threads.iter(&txn)? {
1003 let (thread_id, thread_heed) = result?;
1004 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1005 }
1006
1007 Ok(())
1008 }
1009
1010 fn save_thread_sync(
1011 connection: &Arc<Mutex<Connection>>,
1012 id: ThreadId,
1013 thread: SerializedThread,
1014 ) -> Result<()> {
1015 let json_data = serde_json::to_string(&thread)?;
1016 let summary = thread.summary.to_string();
1017 let updated_at = thread.updated_at.to_rfc3339();
1018
1019 let connection = connection.lock().unwrap();
1020
1021 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1022 let data_type = DataType::Zstd;
1023 let data = compressed;
1024
1025 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1026 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1027 "})?;
1028
1029 insert((id, summary, updated_at, data_type, data))?;
1030
1031 Ok(())
1032 }
1033
1034 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1035 let connection = self.connection.clone();
1036
1037 self.executor.spawn(async move {
1038 let connection = connection.lock().unwrap();
1039 let mut select =
1040 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1041 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1042 "})?;
1043
1044 let rows = select(())?;
1045 let mut threads = Vec::new();
1046
1047 for (id, summary, updated_at) in rows {
1048 threads.push(SerializedThreadMetadata {
1049 id,
1050 summary: summary.into(),
1051 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1052 });
1053 }
1054
1055 Ok(threads)
1056 })
1057 }
1058
1059 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1060 let connection = self.connection.clone();
1061
1062 self.executor.spawn(async move {
1063 let connection = connection.lock().unwrap();
1064 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1065 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1066 "})?;
1067
1068 let rows = select(id)?;
1069 if let Some((data_type, data)) = rows.into_iter().next() {
1070 let json_data = match data_type {
1071 DataType::Zstd => {
1072 let decompressed = zstd::decode_all(&data[..])?;
1073 String::from_utf8(decompressed)?
1074 }
1075 DataType::Json => String::from_utf8(data)?,
1076 };
1077
1078 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1079 Ok(Some(thread))
1080 } else {
1081 Ok(None)
1082 }
1083 })
1084 }
1085
1086 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1087 let connection = self.connection.clone();
1088
1089 self.executor
1090 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1091 }
1092
1093 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1094 let connection = self.connection.clone();
1095
1096 self.executor.spawn(async move {
1097 let connection = connection.lock().unwrap();
1098
1099 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1100 DELETE FROM threads WHERE id = ?
1101 "})?;
1102
1103 delete(id)?;
1104
1105 Ok(())
1106 })
1107 }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use super::*;
1113 use crate::thread::{DetailedSummaryState, MessageId};
1114 use chrono::Utc;
1115 use language_model::{Role, TokenUsage};
1116 use pretty_assertions::assert_eq;
1117
1118 #[test]
1119 fn test_legacy_serialized_thread_upgrade() {
1120 let updated_at = Utc::now();
1121 let legacy_thread = LegacySerializedThread {
1122 summary: "Test conversation".into(),
1123 updated_at,
1124 messages: vec![LegacySerializedMessage {
1125 id: MessageId(1),
1126 role: Role::User,
1127 text: "Hello, world!".to_string(),
1128 tool_uses: vec![],
1129 tool_results: vec![],
1130 }],
1131 initial_project_snapshot: None,
1132 };
1133
1134 let upgraded = legacy_thread.upgrade();
1135
1136 assert_eq!(
1137 upgraded,
1138 SerializedThread {
1139 summary: "Test conversation".into(),
1140 updated_at,
1141 messages: vec![SerializedMessage {
1142 id: MessageId(1),
1143 role: Role::User,
1144 segments: vec![SerializedMessageSegment::Text {
1145 text: "Hello, world!".to_string()
1146 }],
1147 tool_uses: vec![],
1148 tool_results: vec![],
1149 context: "".to_string(),
1150 creases: vec![],
1151 is_hidden: false
1152 }],
1153 version: SerializedThread::VERSION.to_string(),
1154 initial_project_snapshot: None,
1155 cumulative_token_usage: TokenUsage::default(),
1156 request_token_usage: vec![],
1157 detailed_summary_state: DetailedSummaryState::default(),
1158 exceeded_window_error: None,
1159 model: None,
1160 completion_mode: None,
1161 tool_use_limit_reached: false,
1162 profile: None
1163 }
1164 )
1165 }
1166
1167 #[test]
1168 fn test_serialized_threadv0_1_0_upgrade() {
1169 let updated_at = Utc::now();
1170 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1171 summary: "Test conversation".into(),
1172 updated_at,
1173 messages: vec![
1174 SerializedMessage {
1175 id: MessageId(1),
1176 role: Role::User,
1177 segments: vec![SerializedMessageSegment::Text {
1178 text: "Use tool_1".to_string(),
1179 }],
1180 tool_uses: vec![],
1181 tool_results: vec![],
1182 context: "".to_string(),
1183 creases: vec![],
1184 is_hidden: false,
1185 },
1186 SerializedMessage {
1187 id: MessageId(2),
1188 role: Role::Assistant,
1189 segments: vec![SerializedMessageSegment::Text {
1190 text: "I want to use a tool".to_string(),
1191 }],
1192 tool_uses: vec![SerializedToolUse {
1193 id: "abc".into(),
1194 name: "tool_1".into(),
1195 input: serde_json::Value::Null,
1196 }],
1197 tool_results: vec![],
1198 context: "".to_string(),
1199 creases: vec![],
1200 is_hidden: false,
1201 },
1202 SerializedMessage {
1203 id: MessageId(1),
1204 role: Role::User,
1205 segments: vec![SerializedMessageSegment::Text {
1206 text: "Here is the tool result".to_string(),
1207 }],
1208 tool_uses: vec![],
1209 tool_results: vec![SerializedToolResult {
1210 tool_use_id: "abc".into(),
1211 is_error: false,
1212 content: LanguageModelToolResultContent::Text("abcdef".into()),
1213 output: Some(serde_json::Value::Null),
1214 }],
1215 context: "".to_string(),
1216 creases: vec![],
1217 is_hidden: false,
1218 },
1219 ],
1220 version: SerializedThreadV0_1_0::VERSION.to_string(),
1221 initial_project_snapshot: None,
1222 cumulative_token_usage: TokenUsage::default(),
1223 request_token_usage: vec![],
1224 detailed_summary_state: DetailedSummaryState::default(),
1225 exceeded_window_error: None,
1226 model: None,
1227 completion_mode: None,
1228 tool_use_limit_reached: false,
1229 profile: None,
1230 });
1231 let upgraded = thread_v0_1_0.upgrade();
1232
1233 assert_eq!(
1234 upgraded,
1235 SerializedThread {
1236 summary: "Test conversation".into(),
1237 updated_at,
1238 messages: vec![
1239 SerializedMessage {
1240 id: MessageId(1),
1241 role: Role::User,
1242 segments: vec![SerializedMessageSegment::Text {
1243 text: "Use tool_1".to_string()
1244 }],
1245 tool_uses: vec![],
1246 tool_results: vec![],
1247 context: "".to_string(),
1248 creases: vec![],
1249 is_hidden: false
1250 },
1251 SerializedMessage {
1252 id: MessageId(2),
1253 role: Role::Assistant,
1254 segments: vec![SerializedMessageSegment::Text {
1255 text: "I want to use a tool".to_string(),
1256 }],
1257 tool_uses: vec![SerializedToolUse {
1258 id: "abc".into(),
1259 name: "tool_1".into(),
1260 input: serde_json::Value::Null,
1261 }],
1262 tool_results: vec![SerializedToolResult {
1263 tool_use_id: "abc".into(),
1264 is_error: false,
1265 content: LanguageModelToolResultContent::Text("abcdef".into()),
1266 output: Some(serde_json::Value::Null),
1267 }],
1268 context: "".to_string(),
1269 creases: vec![],
1270 is_hidden: false,
1271 },
1272 ],
1273 version: SerializedThread::VERSION.to_string(),
1274 initial_project_snapshot: None,
1275 cumulative_token_usage: TokenUsage::default(),
1276 request_token_usage: vec![],
1277 detailed_summary_state: DetailedSummaryState::default(),
1278 exceeded_window_error: None,
1279 model: None,
1280 completion_mode: None,
1281 tool_use_limit_reached: false,
1282 profile: None
1283 }
1284 )
1285 }
1286}