1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::{
14 FutureExt as _, StreamExt as _,
15 channel::{mpsc, oneshot},
16 future::{self, BoxFuture, Shared},
17};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, Window, prelude::*,
21};
22use indoc::indoc;
23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
24use project::context_server_store::{ContextServerStatus, ContextServerStore};
25use project::{Project, ProjectItem, ProjectPath, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use sqlez::{
32 bindable::{Bind, Column},
33 connection::Connection,
34 statement::Statement,
35};
36use std::{
37 cell::{Ref, RefCell},
38 path::{Path, PathBuf},
39 rc::Rc,
40 sync::{Arc, Mutex},
41};
42use util::ResultExt as _;
43
44use zed_env_vars::ZED_STATELESS;
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
47pub enum DataType {
48 #[serde(rename = "json")]
49 Json,
50 #[serde(rename = "zstd")]
51 Zstd,
52}
53
54impl Bind for DataType {
55 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
56 let value = match self {
57 DataType::Json => "json",
58 DataType::Zstd => "zstd",
59 };
60 value.bind(statement, start_index)
61 }
62}
63
64impl Column for DataType {
65 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
66 let (value, next_index) = String::column(statement, start_index)?;
67 let data_type = match value.as_str() {
68 "json" => DataType::Json,
69 "zstd" => DataType::Zstd,
70 _ => anyhow::bail!("Unknown data type: {}", value),
71 };
72 Ok((data_type, next_index))
73 }
74}
75
76const RULES_FILE_NAMES: [&str; 9] = [
77 ".rules",
78 ".cursorrules",
79 ".windsurfrules",
80 ".clinerules",
81 ".github/copilot-instructions.md",
82 "CLAUDE.md",
83 "AGENT.md",
84 "AGENTS.md",
85 "GEMINI.md",
86];
87
88pub fn init(cx: &mut App) {
89 ThreadsDatabase::init(cx);
90}
91
92/// A system prompt shared by all threads created by this ThreadStore
93#[derive(Clone, Default)]
94pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
95
96impl SharedProjectContext {
97 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
98 self.0.borrow()
99 }
100}
101
102pub type TextThreadStore = assistant_context::ContextStore;
103
104pub struct ThreadStore {
105 project: Entity<Project>,
106 tools: Entity<ToolWorkingSet>,
107 prompt_builder: Arc<PromptBuilder>,
108 prompt_store: Option<Entity<PromptStore>>,
109 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
110 threads: Vec<SerializedThreadMetadata>,
111 project_context: SharedProjectContext,
112 reload_system_prompt_tx: mpsc::Sender<()>,
113 _reload_system_prompt_task: Task<()>,
114 _subscriptions: Vec<Subscription>,
115}
116
117pub struct RulesLoadingError {
118 pub message: SharedString,
119}
120
121impl EventEmitter<RulesLoadingError> for ThreadStore {}
122
123impl ThreadStore {
124 pub fn load(
125 project: Entity<Project>,
126 tools: Entity<ToolWorkingSet>,
127 prompt_store: Option<Entity<PromptStore>>,
128 prompt_builder: Arc<PromptBuilder>,
129 cx: &mut App,
130 ) -> Task<Result<Entity<Self>>> {
131 cx.spawn(async move |cx| {
132 let (thread_store, ready_rx) = cx.update(|cx| {
133 let mut option_ready_rx = None;
134 let thread_store = cx.new(|cx| {
135 let (thread_store, ready_rx) =
136 Self::new(project, tools, prompt_builder, prompt_store, cx);
137 option_ready_rx = Some(ready_rx);
138 thread_store
139 });
140 (thread_store, option_ready_rx.take().unwrap())
141 })?;
142 ready_rx.await?;
143 Ok(thread_store)
144 })
145 }
146
147 fn new(
148 project: Entity<Project>,
149 tools: Entity<ToolWorkingSet>,
150 prompt_builder: Arc<PromptBuilder>,
151 prompt_store: Option<Entity<PromptStore>>,
152 cx: &mut Context<Self>,
153 ) -> (Self, oneshot::Receiver<()>) {
154 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
155
156 if let Some(prompt_store) = prompt_store.as_ref() {
157 subscriptions.push(cx.subscribe(
158 prompt_store,
159 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
160 this.enqueue_system_prompt_reload();
161 },
162 ))
163 }
164
165 // This channel and task prevent concurrent and redundant loading of the system prompt.
166 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
167 let (ready_tx, ready_rx) = oneshot::channel();
168 let mut ready_tx = Some(ready_tx);
169 let reload_system_prompt_task = cx.spawn({
170 let prompt_store = prompt_store.clone();
171 async move |thread_store, cx| {
172 loop {
173 let Some(reload_task) = thread_store
174 .update(cx, |thread_store, cx| {
175 thread_store.reload_system_prompt(prompt_store.clone(), cx)
176 })
177 .ok()
178 else {
179 return;
180 };
181 reload_task.await;
182 if let Some(ready_tx) = ready_tx.take() {
183 ready_tx.send(()).ok();
184 }
185 reload_system_prompt_rx.next().await;
186 }
187 }
188 });
189
190 let this = Self {
191 project,
192 tools,
193 prompt_builder,
194 prompt_store,
195 context_server_tool_ids: HashMap::default(),
196 threads: Vec::new(),
197 project_context: SharedProjectContext::default(),
198 reload_system_prompt_tx,
199 _reload_system_prompt_task: reload_system_prompt_task,
200 _subscriptions: subscriptions,
201 };
202 this.register_context_server_handlers(cx);
203 this.reload(cx).detach_and_log_err(cx);
204 (this, ready_rx)
205 }
206
207 #[cfg(any(test, feature = "test-support"))]
208 pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
209 Self {
210 project,
211 tools: cx.new(|_| ToolWorkingSet::default()),
212 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
213 prompt_store: None,
214 context_server_tool_ids: HashMap::default(),
215 threads: Vec::new(),
216 project_context: SharedProjectContext::default(),
217 reload_system_prompt_tx: mpsc::channel(0).0,
218 _reload_system_prompt_task: Task::ready(()),
219 _subscriptions: vec![],
220 }
221 }
222
223 fn handle_project_event(
224 &mut self,
225 _project: Entity<Project>,
226 event: &project::Event,
227 _cx: &mut Context<Self>,
228 ) {
229 match event {
230 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
231 self.enqueue_system_prompt_reload();
232 }
233 project::Event::WorktreeUpdatedEntries(_, items) => {
234 if items.iter().any(|(path, _, _)| {
235 RULES_FILE_NAMES
236 .iter()
237 .any(|name| path.as_ref() == Path::new(name))
238 }) {
239 self.enqueue_system_prompt_reload();
240 }
241 }
242 _ => {}
243 }
244 }
245
246 fn enqueue_system_prompt_reload(&mut self) {
247 self.reload_system_prompt_tx.try_send(()).ok();
248 }
249
250 // Note that this should only be called from `reload_system_prompt_task`.
251 fn reload_system_prompt(
252 &self,
253 prompt_store: Option<Entity<PromptStore>>,
254 cx: &mut Context<Self>,
255 ) -> Task<()> {
256 let worktrees = self
257 .project
258 .read(cx)
259 .visible_worktrees(cx)
260 .collect::<Vec<_>>();
261 let worktree_tasks = worktrees
262 .into_iter()
263 .map(|worktree| {
264 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
265 })
266 .collect::<Vec<_>>();
267 let default_user_rules_task = match prompt_store {
268 None => Task::ready(vec![]),
269 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
270 let prompts = prompt_store.default_prompt_metadata();
271 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
272 let contents = prompt_store.load(prompt_metadata.id, cx);
273 async move { (contents.await, prompt_metadata) }
274 });
275 cx.background_spawn(future::join_all(load_tasks))
276 }),
277 };
278
279 cx.spawn(async move |this, cx| {
280 let (worktrees, default_user_rules) =
281 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
282
283 let worktrees = worktrees
284 .into_iter()
285 .map(|(worktree, rules_error)| {
286 if let Some(rules_error) = rules_error {
287 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
288 }
289 worktree
290 })
291 .collect::<Vec<_>>();
292
293 let default_user_rules = default_user_rules
294 .into_iter()
295 .flat_map(|(contents, prompt_metadata)| match contents {
296 Ok(contents) => Some(UserRulesContext {
297 uuid: match prompt_metadata.id {
298 PromptId::User { uuid } => uuid,
299 PromptId::EditWorkflow => return None,
300 },
301 title: prompt_metadata.title.map(|title| title.to_string()),
302 contents,
303 }),
304 Err(err) => {
305 this.update(cx, |_, cx| {
306 cx.emit(RulesLoadingError {
307 message: format!("{err:?}").into(),
308 });
309 })
310 .ok();
311 None
312 }
313 })
314 .collect::<Vec<_>>();
315
316 this.update(cx, |this, _cx| {
317 *this.project_context.0.borrow_mut() =
318 Some(ProjectContext::new(worktrees, default_user_rules));
319 })
320 .ok();
321 })
322 }
323
324 fn load_worktree_info_for_system_prompt(
325 worktree: Entity<Worktree>,
326 project: Entity<Project>,
327 cx: &mut App,
328 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
329 let tree = worktree.read(cx);
330 let root_name = tree.root_name().into();
331 let abs_path = tree.abs_path();
332
333 let mut context = WorktreeContext {
334 root_name,
335 abs_path,
336 rules_file: None,
337 };
338
339 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
340 let Some(rules_task) = rules_task else {
341 return Task::ready((context, None));
342 };
343
344 cx.spawn(async move |_| {
345 let (rules_file, rules_file_error) = match rules_task.await {
346 Ok(rules_file) => (Some(rules_file), None),
347 Err(err) => (
348 None,
349 Some(RulesLoadingError {
350 message: format!("{err}").into(),
351 }),
352 ),
353 };
354 context.rules_file = rules_file;
355 (context, rules_file_error)
356 })
357 }
358
359 fn load_worktree_rules_file(
360 worktree: Entity<Worktree>,
361 project: Entity<Project>,
362 cx: &mut App,
363 ) -> Option<Task<Result<RulesFileContext>>> {
364 let worktree = worktree.read(cx);
365 let worktree_id = worktree.id();
366 let selected_rules_file = RULES_FILE_NAMES
367 .into_iter()
368 .filter_map(|name| {
369 worktree
370 .entry_for_path(name)
371 .filter(|entry| entry.is_file())
372 .map(|entry| entry.path.clone())
373 })
374 .next();
375
376 // Note that Cline supports `.clinerules` being a directory, but that is not currently
377 // supported. This doesn't seem to occur often in GitHub repositories.
378 selected_rules_file.map(|path_in_worktree| {
379 let project_path = ProjectPath {
380 worktree_id,
381 path: path_in_worktree.clone(),
382 };
383 let buffer_task =
384 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
385 let rope_task = cx.spawn(async move |cx| {
386 buffer_task.await?.read_with(cx, |buffer, cx| {
387 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
388 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
389 })?
390 });
391 // Build a string from the rope on a background thread.
392 cx.background_spawn(async move {
393 let (project_entry_id, rope) = rope_task.await?;
394 anyhow::Ok(RulesFileContext {
395 path_in_worktree,
396 text: rope.to_string().trim().to_string(),
397 project_entry_id: project_entry_id.to_usize(),
398 })
399 })
400 })
401 }
402
403 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
404 &self.prompt_store
405 }
406
407 pub fn tools(&self) -> Entity<ToolWorkingSet> {
408 self.tools.clone()
409 }
410
411 /// Returns the number of threads.
412 pub fn thread_count(&self) -> usize {
413 self.threads.len()
414 }
415
416 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
417 // ordering is from "ORDER BY" in `list_threads`
418 self.threads.iter()
419 }
420
421 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
422 cx.new(|cx| {
423 Thread::new(
424 self.project.clone(),
425 self.tools.clone(),
426 self.prompt_builder.clone(),
427 self.project_context.clone(),
428 cx,
429 )
430 })
431 }
432
433 pub fn create_thread_from_serialized(
434 &mut self,
435 serialized: SerializedThread,
436 cx: &mut Context<Self>,
437 ) -> Entity<Thread> {
438 cx.new(|cx| {
439 Thread::deserialize(
440 ThreadId::new(),
441 serialized,
442 self.project.clone(),
443 self.tools.clone(),
444 self.prompt_builder.clone(),
445 self.project_context.clone(),
446 None,
447 cx,
448 )
449 })
450 }
451
452 pub fn open_thread(
453 &self,
454 id: &ThreadId,
455 window: &mut Window,
456 cx: &mut Context<Self>,
457 ) -> Task<Result<Entity<Thread>>> {
458 let id = id.clone();
459 let database_future = ThreadsDatabase::global_future(cx);
460 let this = cx.weak_entity();
461 window.spawn(cx, async move |cx| {
462 let database = database_future.await.map_err(|err| anyhow!(err))?;
463 let thread = database
464 .try_find_thread(id.clone())
465 .await?
466 .with_context(|| format!("no thread found with ID: {id:?}"))?;
467
468 let thread = this.update_in(cx, |this, window, cx| {
469 cx.new(|cx| {
470 Thread::deserialize(
471 id.clone(),
472 thread,
473 this.project.clone(),
474 this.tools.clone(),
475 this.prompt_builder.clone(),
476 this.project_context.clone(),
477 Some(window),
478 cx,
479 )
480 })
481 })?;
482
483 Ok(thread)
484 })
485 }
486
487 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
488 let (metadata, serialized_thread) =
489 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
490
491 let database_future = ThreadsDatabase::global_future(cx);
492 cx.spawn(async move |this, cx| {
493 let serialized_thread = serialized_thread.await?;
494 let database = database_future.await.map_err(|err| anyhow!(err))?;
495 database.save_thread(metadata, serialized_thread).await?;
496
497 this.update(cx, |this, cx| this.reload(cx))?.await
498 })
499 }
500
501 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
502 let id = id.clone();
503 let database_future = ThreadsDatabase::global_future(cx);
504 cx.spawn(async move |this, cx| {
505 let database = database_future.await.map_err(|err| anyhow!(err))?;
506 database.delete_thread(id.clone()).await?;
507
508 this.update(cx, |this, cx| {
509 this.threads.retain(|thread| thread.id != id);
510 cx.notify();
511 })
512 })
513 }
514
515 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
516 let database_future = ThreadsDatabase::global_future(cx);
517 cx.spawn(async move |this, cx| {
518 let threads = database_future
519 .await
520 .map_err(|err| anyhow!(err))?
521 .list_threads()
522 .await?;
523
524 this.update(cx, |this, cx| {
525 this.threads = threads;
526 cx.notify();
527 })
528 })
529 }
530
531 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
532 let context_server_store = self.project.read(cx).context_server_store();
533 cx.subscribe(&context_server_store, Self::handle_context_server_event)
534 .detach();
535
536 // Check for any servers that were already running before the handler was registered
537 for server in context_server_store.read(cx).running_servers() {
538 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
539 }
540 }
541
542 fn handle_context_server_event(
543 &mut self,
544 context_server_store: Entity<ContextServerStore>,
545 event: &project::context_server_store::Event,
546 cx: &mut Context<Self>,
547 ) {
548 let tool_working_set = self.tools.clone();
549 match event {
550 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
551 match status {
552 ContextServerStatus::Starting => {}
553 ContextServerStatus::Running => {
554 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
555 }
556 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
557 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
558 tool_working_set.update(cx, |tool_working_set, cx| {
559 tool_working_set.remove(&tool_ids, cx);
560 });
561 }
562 }
563 }
564 }
565 }
566 }
567
568 fn load_context_server_tools(
569 &self,
570 server_id: ContextServerId,
571 context_server_store: Entity<ContextServerStore>,
572 cx: &mut Context<Self>,
573 ) {
574 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
575 return;
576 };
577 let tool_working_set = self.tools.clone();
578 cx.spawn(async move |this, cx| {
579 let Some(protocol) = server.client() else {
580 return;
581 };
582
583 if protocol.capable(context_server::protocol::ServerCapability::Tools)
584 && let Some(response) = protocol
585 .request::<context_server::types::requests::ListTools>(())
586 .await
587 .log_err()
588 {
589 let tool_ids = tool_working_set
590 .update(cx, |tool_working_set, cx| {
591 tool_working_set.extend(
592 response.tools.into_iter().map(|tool| {
593 Arc::new(ContextServerTool::new(
594 context_server_store.clone(),
595 server.id(),
596 tool,
597 )) as Arc<dyn Tool>
598 }),
599 cx,
600 )
601 })
602 .log_err();
603
604 if let Some(tool_ids) = tool_ids {
605 this.update(cx, |this, _| {
606 this.context_server_tool_ids.insert(server_id, tool_ids);
607 })
608 .log_err();
609 }
610 }
611 })
612 .detach();
613 }
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct SerializedThreadMetadata {
618 pub id: ThreadId,
619 pub summary: SharedString,
620 pub updated_at: DateTime<Utc>,
621}
622
623#[derive(Serialize, Deserialize, Debug, PartialEq)]
624pub struct SerializedThread {
625 pub version: String,
626 pub summary: SharedString,
627 pub updated_at: DateTime<Utc>,
628 pub messages: Vec<SerializedMessage>,
629 #[serde(default)]
630 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
631 #[serde(default)]
632 pub cumulative_token_usage: TokenUsage,
633 #[serde(default)]
634 pub request_token_usage: Vec<TokenUsage>,
635 #[serde(default)]
636 pub detailed_summary_state: DetailedSummaryState,
637 #[serde(default)]
638 pub exceeded_window_error: Option<ExceededWindowError>,
639 #[serde(default)]
640 pub model: Option<SerializedLanguageModel>,
641 #[serde(default)]
642 pub completion_mode: Option<CompletionMode>,
643 #[serde(default)]
644 pub tool_use_limit_reached: bool,
645 #[serde(default)]
646 pub profile: Option<AgentProfileId>,
647}
648
649#[derive(Serialize, Deserialize, Debug, PartialEq)]
650pub struct SerializedLanguageModel {
651 pub provider: String,
652 pub model: String,
653}
654
655impl SerializedThread {
656 pub const VERSION: &'static str = "0.2.0";
657
658 pub fn from_json(json: &[u8]) -> Result<Self> {
659 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
660 match saved_thread_json.get("version") {
661 Some(serde_json::Value::String(version)) => match version.as_str() {
662 SerializedThreadV0_1_0::VERSION => {
663 let saved_thread =
664 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
665 Ok(saved_thread.upgrade())
666 }
667 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
668 saved_thread_json,
669 )?),
670 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
671 },
672 None => {
673 let saved_thread =
674 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
675 Ok(saved_thread.upgrade())
676 }
677 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
678 }
679 }
680}
681
682#[derive(Serialize, Deserialize, Debug)]
683pub struct SerializedThreadV0_1_0(
684 // The structure did not change, so we are reusing the latest SerializedThread.
685 // When making the next version, make sure this points to SerializedThreadV0_2_0
686 SerializedThread,
687);
688
689impl SerializedThreadV0_1_0 {
690 pub const VERSION: &'static str = "0.1.0";
691
692 pub fn upgrade(self) -> SerializedThread {
693 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
694
695 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
696
697 for message in self.0.messages {
698 if message.role == Role::User
699 && !message.tool_results.is_empty()
700 && let Some(last_message) = messages.last_mut()
701 {
702 debug_assert!(last_message.role == Role::Assistant);
703
704 last_message.tool_results = message.tool_results;
705 continue;
706 }
707
708 messages.push(message);
709 }
710
711 SerializedThread {
712 messages,
713 version: SerializedThread::VERSION.to_string(),
714 ..self.0
715 }
716 }
717}
718
719#[derive(Debug, Serialize, Deserialize, PartialEq)]
720pub struct SerializedMessage {
721 pub id: MessageId,
722 pub role: Role,
723 #[serde(default)]
724 pub segments: Vec<SerializedMessageSegment>,
725 #[serde(default)]
726 pub tool_uses: Vec<SerializedToolUse>,
727 #[serde(default)]
728 pub tool_results: Vec<SerializedToolResult>,
729 #[serde(default)]
730 pub context: String,
731 #[serde(default)]
732 pub creases: Vec<SerializedCrease>,
733 #[serde(default)]
734 pub is_hidden: bool,
735}
736
737#[derive(Debug, Serialize, Deserialize, PartialEq)]
738#[serde(tag = "type")]
739pub enum SerializedMessageSegment {
740 #[serde(rename = "text")]
741 Text {
742 text: String,
743 },
744 #[serde(rename = "thinking")]
745 Thinking {
746 text: String,
747 #[serde(skip_serializing_if = "Option::is_none")]
748 signature: Option<String>,
749 },
750 RedactedThinking {
751 data: String,
752 },
753}
754
755#[derive(Debug, Serialize, Deserialize, PartialEq)]
756pub struct SerializedToolUse {
757 pub id: LanguageModelToolUseId,
758 pub name: SharedString,
759 pub input: serde_json::Value,
760}
761
762#[derive(Debug, Serialize, Deserialize, PartialEq)]
763pub struct SerializedToolResult {
764 pub tool_use_id: LanguageModelToolUseId,
765 pub is_error: bool,
766 pub content: LanguageModelToolResultContent,
767 pub output: Option<serde_json::Value>,
768}
769
770#[derive(Serialize, Deserialize)]
771struct LegacySerializedThread {
772 pub summary: SharedString,
773 pub updated_at: DateTime<Utc>,
774 pub messages: Vec<LegacySerializedMessage>,
775 #[serde(default)]
776 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
777}
778
779impl LegacySerializedThread {
780 pub fn upgrade(self) -> SerializedThread {
781 SerializedThread {
782 version: SerializedThread::VERSION.to_string(),
783 summary: self.summary,
784 updated_at: self.updated_at,
785 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
786 initial_project_snapshot: self.initial_project_snapshot,
787 cumulative_token_usage: TokenUsage::default(),
788 request_token_usage: Vec::new(),
789 detailed_summary_state: DetailedSummaryState::default(),
790 exceeded_window_error: None,
791 model: None,
792 completion_mode: None,
793 tool_use_limit_reached: false,
794 profile: None,
795 }
796 }
797}
798
799#[derive(Debug, Serialize, Deserialize)]
800struct LegacySerializedMessage {
801 pub id: MessageId,
802 pub role: Role,
803 pub text: String,
804 #[serde(default)]
805 pub tool_uses: Vec<SerializedToolUse>,
806 #[serde(default)]
807 pub tool_results: Vec<SerializedToolResult>,
808}
809
810impl LegacySerializedMessage {
811 fn upgrade(self) -> SerializedMessage {
812 SerializedMessage {
813 id: self.id,
814 role: self.role,
815 segments: vec![SerializedMessageSegment::Text { text: self.text }],
816 tool_uses: self.tool_uses,
817 tool_results: self.tool_results,
818 context: String::new(),
819 creases: Vec::new(),
820 is_hidden: false,
821 }
822 }
823}
824
825#[derive(Debug, Serialize, Deserialize, PartialEq)]
826pub struct SerializedCrease {
827 pub start: usize,
828 pub end: usize,
829 pub icon_path: SharedString,
830 pub label: SharedString,
831}
832
833struct GlobalThreadsDatabase(
834 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
835);
836
837impl Global for GlobalThreadsDatabase {}
838
839pub(crate) struct ThreadsDatabase {
840 executor: BackgroundExecutor,
841 connection: Arc<Mutex<Connection>>,
842}
843
844impl ThreadsDatabase {
845 fn connection(&self) -> Arc<Mutex<Connection>> {
846 self.connection.clone()
847 }
848
849 const COMPRESSION_LEVEL: i32 = 3;
850}
851
852impl Bind for ThreadId {
853 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
854 self.to_string().bind(statement, start_index)
855 }
856}
857
858impl Column for ThreadId {
859 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
860 let (id_str, next_index) = String::column(statement, start_index)?;
861 Ok((ThreadId::from(id_str.as_str()), next_index))
862 }
863}
864
865impl ThreadsDatabase {
866 fn global_future(
867 cx: &mut App,
868 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
869 GlobalThreadsDatabase::global(cx).0.clone()
870 }
871
872 fn init(cx: &mut App) {
873 let executor = cx.background_executor().clone();
874 let database_future = executor
875 .spawn({
876 let executor = executor.clone();
877 let threads_dir = paths::data_dir().join("threads");
878 async move { ThreadsDatabase::new(threads_dir, executor) }
879 })
880 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
881 .boxed()
882 .shared();
883
884 cx.set_global(GlobalThreadsDatabase(database_future));
885 }
886
887 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
888 std::fs::create_dir_all(&threads_dir)?;
889
890 let sqlite_path = threads_dir.join("threads.db");
891 let mdb_path = threads_dir.join("threads-db.1.mdb");
892
893 let needs_migration_from_heed = mdb_path.exists();
894
895 let connection = if *ZED_STATELESS {
896 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
897 } else if cfg!(any(feature = "test-support", test)) {
898 // rust stores the name of the test on the current thread.
899 // We use this to automatically create a database that will
900 // be shared within the test (for the test_retrieve_old_thread)
901 // but not with concurrent tests.
902 let thread = std::thread::current();
903 let test_name = thread.name();
904 Connection::open_memory(Some(&format!(
905 "THREAD_FALLBACK_{}",
906 test_name.unwrap_or_default()
907 )))
908 } else {
909 Connection::open_file(&sqlite_path.to_string_lossy())
910 };
911
912 connection.exec(indoc! {"
913 CREATE TABLE IF NOT EXISTS threads (
914 id TEXT PRIMARY KEY,
915 summary TEXT NOT NULL,
916 updated_at TEXT NOT NULL,
917 data_type TEXT NOT NULL,
918 data BLOB NOT NULL
919 )
920 "})?()
921 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
922
923 let db = Self {
924 executor: executor.clone(),
925 connection: Arc::new(Mutex::new(connection)),
926 };
927
928 if needs_migration_from_heed {
929 let db_connection = db.connection();
930 let executor_clone = executor.clone();
931 executor
932 .spawn(async move {
933 log::info!("Starting threads.db migration");
934 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
935 std::fs::remove_dir_all(mdb_path)?;
936 log::info!("threads.db migrated to sqlite");
937 Ok::<(), anyhow::Error>(())
938 })
939 .detach();
940 }
941
942 Ok(db)
943 }
944
945 // Remove this migration after 2025-09-01
946 fn migrate_from_heed(
947 mdb_path: &Path,
948 connection: Arc<Mutex<Connection>>,
949 _executor: BackgroundExecutor,
950 ) -> Result<()> {
951 use heed::types::SerdeBincode;
952 struct SerializedThreadHeed(SerializedThread);
953
954 impl heed::BytesEncode<'_> for SerializedThreadHeed {
955 type EItem = SerializedThreadHeed;
956
957 fn bytes_encode(
958 item: &Self::EItem,
959 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
960 serde_json::to_vec(&item.0)
961 .map(std::borrow::Cow::Owned)
962 .map_err(Into::into)
963 }
964 }
965
966 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
967 type DItem = SerializedThreadHeed;
968
969 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
970 SerializedThread::from_json(bytes)
971 .map(SerializedThreadHeed)
972 .map_err(Into::into)
973 }
974 }
975
976 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
977
978 let env = unsafe {
979 heed::EnvOpenOptions::new()
980 .map_size(ONE_GB_IN_BYTES)
981 .max_dbs(1)
982 .open(mdb_path)?
983 };
984
985 let txn = env.write_txn()?;
986 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
987 .open_database(&txn, Some("threads"))?
988 .ok_or_else(|| anyhow!("threads database not found"))?;
989
990 for result in threads.iter(&txn)? {
991 let (thread_id, thread_heed) = result?;
992 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
993 }
994
995 Ok(())
996 }
997
998 fn save_thread_sync(
999 connection: &Arc<Mutex<Connection>>,
1000 id: ThreadId,
1001 thread: SerializedThread,
1002 ) -> Result<()> {
1003 let json_data = serde_json::to_string(&thread)?;
1004 let summary = thread.summary.to_string();
1005 let updated_at = thread.updated_at.to_rfc3339();
1006
1007 let connection = connection.lock().unwrap();
1008
1009 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1010 let data_type = DataType::Zstd;
1011 let data = compressed;
1012
1013 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1014 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1015 "})?;
1016
1017 insert((id, summary, updated_at, data_type, data))?;
1018
1019 Ok(())
1020 }
1021
1022 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1023 let connection = self.connection.clone();
1024
1025 self.executor.spawn(async move {
1026 let connection = connection.lock().unwrap();
1027 let mut select =
1028 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1029 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1030 "})?;
1031
1032 let rows = select(())?;
1033 let mut threads = Vec::new();
1034
1035 for (id, summary, updated_at) in rows {
1036 threads.push(SerializedThreadMetadata {
1037 id,
1038 summary: summary.into(),
1039 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1040 });
1041 }
1042
1043 Ok(threads)
1044 })
1045 }
1046
1047 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1048 let connection = self.connection.clone();
1049
1050 self.executor.spawn(async move {
1051 let connection = connection.lock().unwrap();
1052 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1053 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1054 "})?;
1055
1056 let rows = select(id)?;
1057 if let Some((data_type, data)) = rows.into_iter().next() {
1058 let json_data = match data_type {
1059 DataType::Zstd => {
1060 let decompressed = zstd::decode_all(&data[..])?;
1061 String::from_utf8(decompressed)?
1062 }
1063 DataType::Json => String::from_utf8(data)?,
1064 };
1065
1066 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1067 Ok(Some(thread))
1068 } else {
1069 Ok(None)
1070 }
1071 })
1072 }
1073
1074 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1075 let connection = self.connection.clone();
1076
1077 self.executor
1078 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1079 }
1080
1081 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1082 let connection = self.connection.clone();
1083
1084 self.executor.spawn(async move {
1085 let connection = connection.lock().unwrap();
1086
1087 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1088 DELETE FROM threads WHERE id = ?
1089 "})?;
1090
1091 delete(id)?;
1092
1093 Ok(())
1094 })
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use super::*;
1101 use crate::thread::{DetailedSummaryState, MessageId};
1102 use chrono::Utc;
1103 use language_model::{Role, TokenUsage};
1104 use pretty_assertions::assert_eq;
1105
1106 #[test]
1107 fn test_legacy_serialized_thread_upgrade() {
1108 let updated_at = Utc::now();
1109 let legacy_thread = LegacySerializedThread {
1110 summary: "Test conversation".into(),
1111 updated_at,
1112 messages: vec![LegacySerializedMessage {
1113 id: MessageId(1),
1114 role: Role::User,
1115 text: "Hello, world!".to_string(),
1116 tool_uses: vec![],
1117 tool_results: vec![],
1118 }],
1119 initial_project_snapshot: None,
1120 };
1121
1122 let upgraded = legacy_thread.upgrade();
1123
1124 assert_eq!(
1125 upgraded,
1126 SerializedThread {
1127 summary: "Test conversation".into(),
1128 updated_at,
1129 messages: vec![SerializedMessage {
1130 id: MessageId(1),
1131 role: Role::User,
1132 segments: vec![SerializedMessageSegment::Text {
1133 text: "Hello, world!".to_string()
1134 }],
1135 tool_uses: vec![],
1136 tool_results: vec![],
1137 context: "".to_string(),
1138 creases: vec![],
1139 is_hidden: false
1140 }],
1141 version: SerializedThread::VERSION.to_string(),
1142 initial_project_snapshot: None,
1143 cumulative_token_usage: TokenUsage::default(),
1144 request_token_usage: vec![],
1145 detailed_summary_state: DetailedSummaryState::default(),
1146 exceeded_window_error: None,
1147 model: None,
1148 completion_mode: None,
1149 tool_use_limit_reached: false,
1150 profile: None
1151 }
1152 )
1153 }
1154
1155 #[test]
1156 fn test_serialized_threadv0_1_0_upgrade() {
1157 let updated_at = Utc::now();
1158 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1159 summary: "Test conversation".into(),
1160 updated_at,
1161 messages: vec![
1162 SerializedMessage {
1163 id: MessageId(1),
1164 role: Role::User,
1165 segments: vec![SerializedMessageSegment::Text {
1166 text: "Use tool_1".to_string(),
1167 }],
1168 tool_uses: vec![],
1169 tool_results: vec![],
1170 context: "".to_string(),
1171 creases: vec![],
1172 is_hidden: false,
1173 },
1174 SerializedMessage {
1175 id: MessageId(2),
1176 role: Role::Assistant,
1177 segments: vec![SerializedMessageSegment::Text {
1178 text: "I want to use a tool".to_string(),
1179 }],
1180 tool_uses: vec![SerializedToolUse {
1181 id: "abc".into(),
1182 name: "tool_1".into(),
1183 input: serde_json::Value::Null,
1184 }],
1185 tool_results: vec![],
1186 context: "".to_string(),
1187 creases: vec![],
1188 is_hidden: false,
1189 },
1190 SerializedMessage {
1191 id: MessageId(1),
1192 role: Role::User,
1193 segments: vec![SerializedMessageSegment::Text {
1194 text: "Here is the tool result".to_string(),
1195 }],
1196 tool_uses: vec![],
1197 tool_results: vec![SerializedToolResult {
1198 tool_use_id: "abc".into(),
1199 is_error: false,
1200 content: LanguageModelToolResultContent::Text("abcdef".into()),
1201 output: Some(serde_json::Value::Null),
1202 }],
1203 context: "".to_string(),
1204 creases: vec![],
1205 is_hidden: false,
1206 },
1207 ],
1208 version: SerializedThreadV0_1_0::VERSION.to_string(),
1209 initial_project_snapshot: None,
1210 cumulative_token_usage: TokenUsage::default(),
1211 request_token_usage: vec![],
1212 detailed_summary_state: DetailedSummaryState::default(),
1213 exceeded_window_error: None,
1214 model: None,
1215 completion_mode: None,
1216 tool_use_limit_reached: false,
1217 profile: None,
1218 });
1219 let upgraded = thread_v0_1_0.upgrade();
1220
1221 assert_eq!(
1222 upgraded,
1223 SerializedThread {
1224 summary: "Test conversation".into(),
1225 updated_at,
1226 messages: vec![
1227 SerializedMessage {
1228 id: MessageId(1),
1229 role: Role::User,
1230 segments: vec![SerializedMessageSegment::Text {
1231 text: "Use tool_1".to_string()
1232 }],
1233 tool_uses: vec![],
1234 tool_results: vec![],
1235 context: "".to_string(),
1236 creases: vec![],
1237 is_hidden: false
1238 },
1239 SerializedMessage {
1240 id: MessageId(2),
1241 role: Role::Assistant,
1242 segments: vec![SerializedMessageSegment::Text {
1243 text: "I want to use a tool".to_string(),
1244 }],
1245 tool_uses: vec![SerializedToolUse {
1246 id: "abc".into(),
1247 name: "tool_1".into(),
1248 input: serde_json::Value::Null,
1249 }],
1250 tool_results: vec![SerializedToolResult {
1251 tool_use_id: "abc".into(),
1252 is_error: false,
1253 content: LanguageModelToolResultContent::Text("abcdef".into()),
1254 output: Some(serde_json::Value::Null),
1255 }],
1256 context: "".to_string(),
1257 creases: vec![],
1258 is_hidden: false,
1259 },
1260 ],
1261 version: SerializedThread::VERSION.to_string(),
1262 initial_project_snapshot: None,
1263 cumulative_token_usage: TokenUsage::default(),
1264 request_token_usage: vec![],
1265 detailed_summary_state: DetailedSummaryState::default(),
1266 exceeded_window_error: None,
1267 model: None,
1268 completion_mode: None,
1269 tool_use_limit_reached: false,
1270 profile: None
1271 }
1272 )
1273 }
1274}