1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use client::CloudUserStore;
12use collections::HashMap;
13use context_server::ContextServerId;
14use futures::{
15 FutureExt as _, StreamExt as _,
16 channel::{mpsc, oneshot},
17 future::{self, BoxFuture, Shared},
18};
19use gpui::{
20 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
21 Subscription, Task, Window, prelude::*,
22};
23use indoc::indoc;
24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
25use project::context_server_store::{ContextServerStatus, ContextServerStore};
26use project::{Project, ProjectItem, ProjectPath, Worktree};
27use prompt_store::{
28 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
29 UserRulesContext, WorktreeContext,
30};
31use serde::{Deserialize, Serialize};
32use sqlez::{
33 bindable::{Bind, Column},
34 connection::Connection,
35 statement::Statement,
36};
37use std::{
38 cell::{Ref, RefCell},
39 path::{Path, PathBuf},
40 rc::Rc,
41 sync::{Arc, Mutex},
42};
43use util::ResultExt as _;
44
45pub static ZED_STATELESS: std::sync::LazyLock<bool> =
46 std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
47
48#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49pub enum DataType {
50 #[serde(rename = "json")]
51 Json,
52 #[serde(rename = "zstd")]
53 Zstd,
54}
55
56impl Bind for DataType {
57 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
58 let value = match self {
59 DataType::Json => "json",
60 DataType::Zstd => "zstd",
61 };
62 value.bind(statement, start_index)
63 }
64}
65
66impl Column for DataType {
67 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
68 let (value, next_index) = String::column(statement, start_index)?;
69 let data_type = match value.as_str() {
70 "json" => DataType::Json,
71 "zstd" => DataType::Zstd,
72 _ => anyhow::bail!("Unknown data type: {}", value),
73 };
74 Ok((data_type, next_index))
75 }
76}
77
78const RULES_FILE_NAMES: [&'static str; 9] = [
79 ".rules",
80 ".cursorrules",
81 ".windsurfrules",
82 ".clinerules",
83 ".github/copilot-instructions.md",
84 "CLAUDE.md",
85 "AGENT.md",
86 "AGENTS.md",
87 "GEMINI.md",
88];
89
90pub fn init(cx: &mut App) {
91 ThreadsDatabase::init(cx);
92}
93
94/// A system prompt shared by all threads created by this ThreadStore
95#[derive(Clone, Default)]
96pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
97
98impl SharedProjectContext {
99 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
100 self.0.borrow()
101 }
102}
103
104pub type TextThreadStore = assistant_context::ContextStore;
105
106pub struct ThreadStore {
107 project: Entity<Project>,
108 cloud_user_store: Entity<CloudUserStore>,
109 tools: Entity<ToolWorkingSet>,
110 prompt_builder: Arc<PromptBuilder>,
111 prompt_store: Option<Entity<PromptStore>>,
112 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
113 threads: Vec<SerializedThreadMetadata>,
114 project_context: SharedProjectContext,
115 reload_system_prompt_tx: mpsc::Sender<()>,
116 _reload_system_prompt_task: Task<()>,
117 _subscriptions: Vec<Subscription>,
118}
119
120pub struct RulesLoadingError {
121 pub message: SharedString,
122}
123
124impl EventEmitter<RulesLoadingError> for ThreadStore {}
125
126impl ThreadStore {
127 pub fn load(
128 project: Entity<Project>,
129 cloud_user_store: Entity<CloudUserStore>,
130 tools: Entity<ToolWorkingSet>,
131 prompt_store: Option<Entity<PromptStore>>,
132 prompt_builder: Arc<PromptBuilder>,
133 cx: &mut App,
134 ) -> Task<Result<Entity<Self>>> {
135 cx.spawn(async move |cx| {
136 let (thread_store, ready_rx) = cx.update(|cx| {
137 let mut option_ready_rx = None;
138 let thread_store = cx.new(|cx| {
139 let (thread_store, ready_rx) = Self::new(
140 project,
141 cloud_user_store,
142 tools,
143 prompt_builder,
144 prompt_store,
145 cx,
146 );
147 option_ready_rx = Some(ready_rx);
148 thread_store
149 });
150 (thread_store, option_ready_rx.take().unwrap())
151 })?;
152 ready_rx.await?;
153 Ok(thread_store)
154 })
155 }
156
157 fn new(
158 project: Entity<Project>,
159 cloud_user_store: Entity<CloudUserStore>,
160 tools: Entity<ToolWorkingSet>,
161 prompt_builder: Arc<PromptBuilder>,
162 prompt_store: Option<Entity<PromptStore>>,
163 cx: &mut Context<Self>,
164 ) -> (Self, oneshot::Receiver<()>) {
165 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
166
167 if let Some(prompt_store) = prompt_store.as_ref() {
168 subscriptions.push(cx.subscribe(
169 prompt_store,
170 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
171 this.enqueue_system_prompt_reload();
172 },
173 ))
174 }
175
176 // This channel and task prevent concurrent and redundant loading of the system prompt.
177 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
178 let (ready_tx, ready_rx) = oneshot::channel();
179 let mut ready_tx = Some(ready_tx);
180 let reload_system_prompt_task = cx.spawn({
181 let prompt_store = prompt_store.clone();
182 async move |thread_store, cx| {
183 loop {
184 let Some(reload_task) = thread_store
185 .update(cx, |thread_store, cx| {
186 thread_store.reload_system_prompt(prompt_store.clone(), cx)
187 })
188 .ok()
189 else {
190 return;
191 };
192 reload_task.await;
193 if let Some(ready_tx) = ready_tx.take() {
194 ready_tx.send(()).ok();
195 }
196 reload_system_prompt_rx.next().await;
197 }
198 }
199 });
200
201 let this = Self {
202 project,
203 cloud_user_store,
204 tools,
205 prompt_builder,
206 prompt_store,
207 context_server_tool_ids: HashMap::default(),
208 threads: Vec::new(),
209 project_context: SharedProjectContext::default(),
210 reload_system_prompt_tx,
211 _reload_system_prompt_task: reload_system_prompt_task,
212 _subscriptions: subscriptions,
213 };
214 this.register_context_server_handlers(cx);
215 this.reload(cx).detach_and_log_err(cx);
216 (this, ready_rx)
217 }
218
219 fn handle_project_event(
220 &mut self,
221 _project: Entity<Project>,
222 event: &project::Event,
223 _cx: &mut Context<Self>,
224 ) {
225 match event {
226 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
227 self.enqueue_system_prompt_reload();
228 }
229 project::Event::WorktreeUpdatedEntries(_, items) => {
230 if items.iter().any(|(path, _, _)| {
231 RULES_FILE_NAMES
232 .iter()
233 .any(|name| path.as_ref() == Path::new(name))
234 }) {
235 self.enqueue_system_prompt_reload();
236 }
237 }
238 _ => {}
239 }
240 }
241
242 fn enqueue_system_prompt_reload(&mut self) {
243 self.reload_system_prompt_tx.try_send(()).ok();
244 }
245
246 // Note that this should only be called from `reload_system_prompt_task`.
247 fn reload_system_prompt(
248 &self,
249 prompt_store: Option<Entity<PromptStore>>,
250 cx: &mut Context<Self>,
251 ) -> Task<()> {
252 let worktrees = self
253 .project
254 .read(cx)
255 .visible_worktrees(cx)
256 .collect::<Vec<_>>();
257 let worktree_tasks = worktrees
258 .into_iter()
259 .map(|worktree| {
260 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
261 })
262 .collect::<Vec<_>>();
263 let default_user_rules_task = match prompt_store {
264 None => Task::ready(vec![]),
265 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
266 let prompts = prompt_store.default_prompt_metadata();
267 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
268 let contents = prompt_store.load(prompt_metadata.id, cx);
269 async move { (contents.await, prompt_metadata) }
270 });
271 cx.background_spawn(future::join_all(load_tasks))
272 }),
273 };
274
275 cx.spawn(async move |this, cx| {
276 let (worktrees, default_user_rules) =
277 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
278
279 let worktrees = worktrees
280 .into_iter()
281 .map(|(worktree, rules_error)| {
282 if let Some(rules_error) = rules_error {
283 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
284 }
285 worktree
286 })
287 .collect::<Vec<_>>();
288
289 let default_user_rules = default_user_rules
290 .into_iter()
291 .flat_map(|(contents, prompt_metadata)| match contents {
292 Ok(contents) => Some(UserRulesContext {
293 uuid: match prompt_metadata.id {
294 PromptId::User { uuid } => uuid,
295 PromptId::EditWorkflow => return None,
296 },
297 title: prompt_metadata.title.map(|title| title.to_string()),
298 contents,
299 }),
300 Err(err) => {
301 this.update(cx, |_, cx| {
302 cx.emit(RulesLoadingError {
303 message: format!("{err:?}").into(),
304 });
305 })
306 .ok();
307 None
308 }
309 })
310 .collect::<Vec<_>>();
311
312 this.update(cx, |this, _cx| {
313 *this.project_context.0.borrow_mut() =
314 Some(ProjectContext::new(worktrees, default_user_rules));
315 })
316 .ok();
317 })
318 }
319
320 fn load_worktree_info_for_system_prompt(
321 worktree: Entity<Worktree>,
322 project: Entity<Project>,
323 cx: &mut App,
324 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
325 let tree = worktree.read(cx);
326 let root_name = tree.root_name().into();
327 let abs_path = tree.abs_path();
328
329 let mut context = WorktreeContext {
330 root_name,
331 abs_path,
332 rules_file: None,
333 };
334
335 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
336 let Some(rules_task) = rules_task else {
337 return Task::ready((context, None));
338 };
339
340 cx.spawn(async move |_| {
341 let (rules_file, rules_file_error) = match rules_task.await {
342 Ok(rules_file) => (Some(rules_file), None),
343 Err(err) => (
344 None,
345 Some(RulesLoadingError {
346 message: format!("{err}").into(),
347 }),
348 ),
349 };
350 context.rules_file = rules_file;
351 (context, rules_file_error)
352 })
353 }
354
355 fn load_worktree_rules_file(
356 worktree: Entity<Worktree>,
357 project: Entity<Project>,
358 cx: &mut App,
359 ) -> Option<Task<Result<RulesFileContext>>> {
360 let worktree = worktree.read(cx);
361 let worktree_id = worktree.id();
362 let selected_rules_file = RULES_FILE_NAMES
363 .into_iter()
364 .filter_map(|name| {
365 worktree
366 .entry_for_path(name)
367 .filter(|entry| entry.is_file())
368 .map(|entry| entry.path.clone())
369 })
370 .next();
371
372 // Note that Cline supports `.clinerules` being a directory, but that is not currently
373 // supported. This doesn't seem to occur often in GitHub repositories.
374 selected_rules_file.map(|path_in_worktree| {
375 let project_path = ProjectPath {
376 worktree_id,
377 path: path_in_worktree.clone(),
378 };
379 let buffer_task =
380 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
381 let rope_task = cx.spawn(async move |cx| {
382 buffer_task.await?.read_with(cx, |buffer, cx| {
383 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
384 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
385 })?
386 });
387 // Build a string from the rope on a background thread.
388 cx.background_spawn(async move {
389 let (project_entry_id, rope) = rope_task.await?;
390 anyhow::Ok(RulesFileContext {
391 path_in_worktree,
392 text: rope.to_string().trim().to_string(),
393 project_entry_id: project_entry_id.to_usize(),
394 })
395 })
396 })
397 }
398
399 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
400 &self.prompt_store
401 }
402
403 pub fn tools(&self) -> Entity<ToolWorkingSet> {
404 self.tools.clone()
405 }
406
407 /// Returns the number of threads.
408 pub fn thread_count(&self) -> usize {
409 self.threads.len()
410 }
411
412 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
413 // ordering is from "ORDER BY" in `list_threads`
414 self.threads.iter()
415 }
416
417 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
418 cx.new(|cx| {
419 Thread::new(
420 self.project.clone(),
421 self.cloud_user_store.clone(),
422 self.tools.clone(),
423 self.prompt_builder.clone(),
424 self.project_context.clone(),
425 cx,
426 )
427 })
428 }
429
430 pub fn create_thread_from_serialized(
431 &mut self,
432 serialized: SerializedThread,
433 cx: &mut Context<Self>,
434 ) -> Entity<Thread> {
435 cx.new(|cx| {
436 Thread::deserialize(
437 ThreadId::new(),
438 serialized,
439 self.project.clone(),
440 self.cloud_user_store.clone(),
441 self.tools.clone(),
442 self.prompt_builder.clone(),
443 self.project_context.clone(),
444 None,
445 cx,
446 )
447 })
448 }
449
450 pub fn open_thread(
451 &self,
452 id: &ThreadId,
453 window: &mut Window,
454 cx: &mut Context<Self>,
455 ) -> Task<Result<Entity<Thread>>> {
456 let id = id.clone();
457 let database_future = ThreadsDatabase::global_future(cx);
458 let this = cx.weak_entity();
459 window.spawn(cx, async move |cx| {
460 let database = database_future.await.map_err(|err| anyhow!(err))?;
461 let thread = database
462 .try_find_thread(id.clone())
463 .await?
464 .with_context(|| format!("no thread found with ID: {id:?}"))?;
465
466 let thread = this.update_in(cx, |this, window, cx| {
467 cx.new(|cx| {
468 Thread::deserialize(
469 id.clone(),
470 thread,
471 this.project.clone(),
472 this.cloud_user_store.clone(),
473 this.tools.clone(),
474 this.prompt_builder.clone(),
475 this.project_context.clone(),
476 Some(window),
477 cx,
478 )
479 })
480 })?;
481
482 Ok(thread)
483 })
484 }
485
486 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
487 let (metadata, serialized_thread) =
488 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
489
490 let database_future = ThreadsDatabase::global_future(cx);
491 cx.spawn(async move |this, cx| {
492 let serialized_thread = serialized_thread.await?;
493 let database = database_future.await.map_err(|err| anyhow!(err))?;
494 database.save_thread(metadata, serialized_thread).await?;
495
496 this.update(cx, |this, cx| this.reload(cx))?.await
497 })
498 }
499
500 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
501 let id = id.clone();
502 let database_future = ThreadsDatabase::global_future(cx);
503 cx.spawn(async move |this, cx| {
504 let database = database_future.await.map_err(|err| anyhow!(err))?;
505 database.delete_thread(id.clone()).await?;
506
507 this.update(cx, |this, cx| {
508 this.threads.retain(|thread| thread.id != id);
509 cx.notify();
510 })
511 })
512 }
513
514 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
515 let database_future = ThreadsDatabase::global_future(cx);
516 cx.spawn(async move |this, cx| {
517 let threads = database_future
518 .await
519 .map_err(|err| anyhow!(err))?
520 .list_threads()
521 .await?;
522
523 this.update(cx, |this, cx| {
524 this.threads = threads;
525 cx.notify();
526 })
527 })
528 }
529
530 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
531 let context_server_store = self.project.read(cx).context_server_store();
532 cx.subscribe(&context_server_store, Self::handle_context_server_event)
533 .detach();
534
535 // Check for any servers that were already running before the handler was registered
536 for server in context_server_store.read(cx).running_servers() {
537 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
538 }
539 }
540
541 fn handle_context_server_event(
542 &mut self,
543 context_server_store: Entity<ContextServerStore>,
544 event: &project::context_server_store::Event,
545 cx: &mut Context<Self>,
546 ) {
547 let tool_working_set = self.tools.clone();
548 match event {
549 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
550 match status {
551 ContextServerStatus::Starting => {}
552 ContextServerStatus::Running => {
553 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
554 }
555 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
556 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
557 tool_working_set.update(cx, |tool_working_set, cx| {
558 tool_working_set.remove(&tool_ids, cx);
559 });
560 }
561 }
562 }
563 }
564 }
565 }
566
567 fn load_context_server_tools(
568 &self,
569 server_id: ContextServerId,
570 context_server_store: Entity<ContextServerStore>,
571 cx: &mut Context<Self>,
572 ) {
573 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
574 return;
575 };
576 let tool_working_set = self.tools.clone();
577 cx.spawn(async move |this, cx| {
578 let Some(protocol) = server.client() else {
579 return;
580 };
581
582 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
583 if let Some(response) = protocol
584 .request::<context_server::types::requests::ListTools>(())
585 .await
586 .log_err()
587 {
588 let tool_ids = tool_working_set
589 .update(cx, |tool_working_set, cx| {
590 tool_working_set.extend(
591 response.tools.into_iter().map(|tool| {
592 Arc::new(ContextServerTool::new(
593 context_server_store.clone(),
594 server.id(),
595 tool,
596 )) as Arc<dyn Tool>
597 }),
598 cx,
599 )
600 })
601 .log_err();
602
603 if let Some(tool_ids) = tool_ids {
604 this.update(cx, |this, _| {
605 this.context_server_tool_ids.insert(server_id, tool_ids);
606 })
607 .log_err();
608 }
609 }
610 }
611 })
612 .detach();
613 }
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct SerializedThreadMetadata {
618 pub id: ThreadId,
619 pub summary: SharedString,
620 pub updated_at: DateTime<Utc>,
621}
622
623#[derive(Serialize, Deserialize, Debug, PartialEq)]
624pub struct SerializedThread {
625 pub version: String,
626 pub summary: SharedString,
627 pub updated_at: DateTime<Utc>,
628 pub messages: Vec<SerializedMessage>,
629 #[serde(default)]
630 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
631 #[serde(default)]
632 pub cumulative_token_usage: TokenUsage,
633 #[serde(default)]
634 pub request_token_usage: Vec<TokenUsage>,
635 #[serde(default)]
636 pub detailed_summary_state: DetailedSummaryState,
637 #[serde(default)]
638 pub exceeded_window_error: Option<ExceededWindowError>,
639 #[serde(default)]
640 pub model: Option<SerializedLanguageModel>,
641 #[serde(default)]
642 pub completion_mode: Option<CompletionMode>,
643 #[serde(default)]
644 pub tool_use_limit_reached: bool,
645 #[serde(default)]
646 pub profile: Option<AgentProfileId>,
647}
648
649#[derive(Serialize, Deserialize, Debug, PartialEq)]
650pub struct SerializedLanguageModel {
651 pub provider: String,
652 pub model: String,
653}
654
655impl SerializedThread {
656 pub const VERSION: &'static str = "0.2.0";
657
658 pub fn from_json(json: &[u8]) -> Result<Self> {
659 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
660 match saved_thread_json.get("version") {
661 Some(serde_json::Value::String(version)) => match version.as_str() {
662 SerializedThreadV0_1_0::VERSION => {
663 let saved_thread =
664 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
665 Ok(saved_thread.upgrade())
666 }
667 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
668 saved_thread_json,
669 )?),
670 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
671 },
672 None => {
673 let saved_thread =
674 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
675 Ok(saved_thread.upgrade())
676 }
677 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
678 }
679 }
680}
681
682#[derive(Serialize, Deserialize, Debug)]
683pub struct SerializedThreadV0_1_0(
684 // The structure did not change, so we are reusing the latest SerializedThread.
685 // When making the next version, make sure this points to SerializedThreadV0_2_0
686 SerializedThread,
687);
688
689impl SerializedThreadV0_1_0 {
690 pub const VERSION: &'static str = "0.1.0";
691
692 pub fn upgrade(self) -> SerializedThread {
693 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
694
695 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
696
697 for message in self.0.messages {
698 if message.role == Role::User && !message.tool_results.is_empty() {
699 if let Some(last_message) = messages.last_mut() {
700 debug_assert!(last_message.role == Role::Assistant);
701
702 last_message.tool_results = message.tool_results;
703 continue;
704 }
705 }
706
707 messages.push(message);
708 }
709
710 SerializedThread {
711 messages,
712 version: SerializedThread::VERSION.to_string(),
713 ..self.0
714 }
715 }
716}
717
718#[derive(Debug, Serialize, Deserialize, PartialEq)]
719pub struct SerializedMessage {
720 pub id: MessageId,
721 pub role: Role,
722 #[serde(default)]
723 pub segments: Vec<SerializedMessageSegment>,
724 #[serde(default)]
725 pub tool_uses: Vec<SerializedToolUse>,
726 #[serde(default)]
727 pub tool_results: Vec<SerializedToolResult>,
728 #[serde(default)]
729 pub context: String,
730 #[serde(default)]
731 pub creases: Vec<SerializedCrease>,
732 #[serde(default)]
733 pub is_hidden: bool,
734}
735
736#[derive(Debug, Serialize, Deserialize, PartialEq)]
737#[serde(tag = "type")]
738pub enum SerializedMessageSegment {
739 #[serde(rename = "text")]
740 Text {
741 text: String,
742 },
743 #[serde(rename = "thinking")]
744 Thinking {
745 text: String,
746 #[serde(skip_serializing_if = "Option::is_none")]
747 signature: Option<String>,
748 },
749 RedactedThinking {
750 data: String,
751 },
752}
753
754#[derive(Debug, Serialize, Deserialize, PartialEq)]
755pub struct SerializedToolUse {
756 pub id: LanguageModelToolUseId,
757 pub name: SharedString,
758 pub input: serde_json::Value,
759}
760
761#[derive(Debug, Serialize, Deserialize, PartialEq)]
762pub struct SerializedToolResult {
763 pub tool_use_id: LanguageModelToolUseId,
764 pub is_error: bool,
765 pub content: LanguageModelToolResultContent,
766 pub output: Option<serde_json::Value>,
767}
768
769#[derive(Serialize, Deserialize)]
770struct LegacySerializedThread {
771 pub summary: SharedString,
772 pub updated_at: DateTime<Utc>,
773 pub messages: Vec<LegacySerializedMessage>,
774 #[serde(default)]
775 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
776}
777
778impl LegacySerializedThread {
779 pub fn upgrade(self) -> SerializedThread {
780 SerializedThread {
781 version: SerializedThread::VERSION.to_string(),
782 summary: self.summary,
783 updated_at: self.updated_at,
784 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
785 initial_project_snapshot: self.initial_project_snapshot,
786 cumulative_token_usage: TokenUsage::default(),
787 request_token_usage: Vec::new(),
788 detailed_summary_state: DetailedSummaryState::default(),
789 exceeded_window_error: None,
790 model: None,
791 completion_mode: None,
792 tool_use_limit_reached: false,
793 profile: None,
794 }
795 }
796}
797
798#[derive(Debug, Serialize, Deserialize)]
799struct LegacySerializedMessage {
800 pub id: MessageId,
801 pub role: Role,
802 pub text: String,
803 #[serde(default)]
804 pub tool_uses: Vec<SerializedToolUse>,
805 #[serde(default)]
806 pub tool_results: Vec<SerializedToolResult>,
807}
808
809impl LegacySerializedMessage {
810 fn upgrade(self) -> SerializedMessage {
811 SerializedMessage {
812 id: self.id,
813 role: self.role,
814 segments: vec![SerializedMessageSegment::Text { text: self.text }],
815 tool_uses: self.tool_uses,
816 tool_results: self.tool_results,
817 context: String::new(),
818 creases: Vec::new(),
819 is_hidden: false,
820 }
821 }
822}
823
824#[derive(Debug, Serialize, Deserialize, PartialEq)]
825pub struct SerializedCrease {
826 pub start: usize,
827 pub end: usize,
828 pub icon_path: SharedString,
829 pub label: SharedString,
830}
831
832struct GlobalThreadsDatabase(
833 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
834);
835
836impl Global for GlobalThreadsDatabase {}
837
838pub(crate) struct ThreadsDatabase {
839 executor: BackgroundExecutor,
840 connection: Arc<Mutex<Connection>>,
841}
842
843impl ThreadsDatabase {
844 fn connection(&self) -> Arc<Mutex<Connection>> {
845 self.connection.clone()
846 }
847
848 const COMPRESSION_LEVEL: i32 = 3;
849}
850
851impl Bind for ThreadId {
852 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
853 self.to_string().bind(statement, start_index)
854 }
855}
856
857impl Column for ThreadId {
858 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
859 let (id_str, next_index) = String::column(statement, start_index)?;
860 Ok((ThreadId::from(id_str.as_str()), next_index))
861 }
862}
863
864impl ThreadsDatabase {
865 fn global_future(
866 cx: &mut App,
867 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
868 GlobalThreadsDatabase::global(cx).0.clone()
869 }
870
871 fn init(cx: &mut App) {
872 let executor = cx.background_executor().clone();
873 let database_future = executor
874 .spawn({
875 let executor = executor.clone();
876 let threads_dir = paths::data_dir().join("threads");
877 async move { ThreadsDatabase::new(threads_dir, executor) }
878 })
879 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
880 .boxed()
881 .shared();
882
883 cx.set_global(GlobalThreadsDatabase(database_future));
884 }
885
886 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
887 std::fs::create_dir_all(&threads_dir)?;
888
889 let sqlite_path = threads_dir.join("threads.db");
890 let mdb_path = threads_dir.join("threads-db.1.mdb");
891
892 let needs_migration_from_heed = mdb_path.exists();
893
894 let connection = if *ZED_STATELESS {
895 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
896 } else {
897 Connection::open_file(&sqlite_path.to_string_lossy())
898 };
899
900 connection.exec(indoc! {"
901 CREATE TABLE IF NOT EXISTS threads (
902 id TEXT PRIMARY KEY,
903 summary TEXT NOT NULL,
904 updated_at TEXT NOT NULL,
905 data_type TEXT NOT NULL,
906 data BLOB NOT NULL
907 )
908 "})?()
909 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
910
911 let db = Self {
912 executor: executor.clone(),
913 connection: Arc::new(Mutex::new(connection)),
914 };
915
916 if needs_migration_from_heed {
917 let db_connection = db.connection();
918 let executor_clone = executor.clone();
919 executor
920 .spawn(async move {
921 log::info!("Starting threads.db migration");
922 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
923 std::fs::remove_dir_all(mdb_path)?;
924 log::info!("threads.db migrated to sqlite");
925 Ok::<(), anyhow::Error>(())
926 })
927 .detach();
928 }
929
930 Ok(db)
931 }
932
933 // Remove this migration after 2025-09-01
934 fn migrate_from_heed(
935 mdb_path: &Path,
936 connection: Arc<Mutex<Connection>>,
937 _executor: BackgroundExecutor,
938 ) -> Result<()> {
939 use heed::types::SerdeBincode;
940 struct SerializedThreadHeed(SerializedThread);
941
942 impl heed::BytesEncode<'_> for SerializedThreadHeed {
943 type EItem = SerializedThreadHeed;
944
945 fn bytes_encode(
946 item: &Self::EItem,
947 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
948 serde_json::to_vec(&item.0)
949 .map(std::borrow::Cow::Owned)
950 .map_err(Into::into)
951 }
952 }
953
954 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
955 type DItem = SerializedThreadHeed;
956
957 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
958 SerializedThread::from_json(bytes)
959 .map(SerializedThreadHeed)
960 .map_err(Into::into)
961 }
962 }
963
964 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
965
966 let env = unsafe {
967 heed::EnvOpenOptions::new()
968 .map_size(ONE_GB_IN_BYTES)
969 .max_dbs(1)
970 .open(mdb_path)?
971 };
972
973 let txn = env.write_txn()?;
974 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
975 .open_database(&txn, Some("threads"))?
976 .ok_or_else(|| anyhow!("threads database not found"))?;
977
978 for result in threads.iter(&txn)? {
979 let (thread_id, thread_heed) = result?;
980 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
981 }
982
983 Ok(())
984 }
985
986 fn save_thread_sync(
987 connection: &Arc<Mutex<Connection>>,
988 id: ThreadId,
989 thread: SerializedThread,
990 ) -> Result<()> {
991 let json_data = serde_json::to_string(&thread)?;
992 let summary = thread.summary.to_string();
993 let updated_at = thread.updated_at.to_rfc3339();
994
995 let connection = connection.lock().unwrap();
996
997 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
998 let data_type = DataType::Zstd;
999 let data = compressed;
1000
1001 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1002 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1003 "})?;
1004
1005 insert((id, summary, updated_at, data_type, data))?;
1006
1007 Ok(())
1008 }
1009
1010 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1011 let connection = self.connection.clone();
1012
1013 self.executor.spawn(async move {
1014 let connection = connection.lock().unwrap();
1015 let mut select =
1016 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1017 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1018 "})?;
1019
1020 let rows = select(())?;
1021 let mut threads = Vec::new();
1022
1023 for (id, summary, updated_at) in rows {
1024 threads.push(SerializedThreadMetadata {
1025 id,
1026 summary: summary.into(),
1027 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1028 });
1029 }
1030
1031 Ok(threads)
1032 })
1033 }
1034
1035 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1036 let connection = self.connection.clone();
1037
1038 self.executor.spawn(async move {
1039 let connection = connection.lock().unwrap();
1040 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1041 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1042 "})?;
1043
1044 let rows = select(id)?;
1045 if let Some((data_type, data)) = rows.into_iter().next() {
1046 let json_data = match data_type {
1047 DataType::Zstd => {
1048 let decompressed = zstd::decode_all(&data[..])?;
1049 String::from_utf8(decompressed)?
1050 }
1051 DataType::Json => String::from_utf8(data)?,
1052 };
1053
1054 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1055 Ok(Some(thread))
1056 } else {
1057 Ok(None)
1058 }
1059 })
1060 }
1061
1062 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1063 let connection = self.connection.clone();
1064
1065 self.executor
1066 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1067 }
1068
1069 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1070 let connection = self.connection.clone();
1071
1072 self.executor.spawn(async move {
1073 let connection = connection.lock().unwrap();
1074
1075 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1076 DELETE FROM threads WHERE id = ?
1077 "})?;
1078
1079 delete(id)?;
1080
1081 Ok(())
1082 })
1083 }
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088 use super::*;
1089 use crate::thread::{DetailedSummaryState, MessageId};
1090 use chrono::Utc;
1091 use language_model::{Role, TokenUsage};
1092 use pretty_assertions::assert_eq;
1093
1094 #[test]
1095 fn test_legacy_serialized_thread_upgrade() {
1096 let updated_at = Utc::now();
1097 let legacy_thread = LegacySerializedThread {
1098 summary: "Test conversation".into(),
1099 updated_at,
1100 messages: vec![LegacySerializedMessage {
1101 id: MessageId(1),
1102 role: Role::User,
1103 text: "Hello, world!".to_string(),
1104 tool_uses: vec![],
1105 tool_results: vec![],
1106 }],
1107 initial_project_snapshot: None,
1108 };
1109
1110 let upgraded = legacy_thread.upgrade();
1111
1112 assert_eq!(
1113 upgraded,
1114 SerializedThread {
1115 summary: "Test conversation".into(),
1116 updated_at,
1117 messages: vec![SerializedMessage {
1118 id: MessageId(1),
1119 role: Role::User,
1120 segments: vec![SerializedMessageSegment::Text {
1121 text: "Hello, world!".to_string()
1122 }],
1123 tool_uses: vec![],
1124 tool_results: vec![],
1125 context: "".to_string(),
1126 creases: vec![],
1127 is_hidden: false
1128 }],
1129 version: SerializedThread::VERSION.to_string(),
1130 initial_project_snapshot: None,
1131 cumulative_token_usage: TokenUsage::default(),
1132 request_token_usage: vec![],
1133 detailed_summary_state: DetailedSummaryState::default(),
1134 exceeded_window_error: None,
1135 model: None,
1136 completion_mode: None,
1137 tool_use_limit_reached: false,
1138 profile: None
1139 }
1140 )
1141 }
1142
1143 #[test]
1144 fn test_serialized_threadv0_1_0_upgrade() {
1145 let updated_at = Utc::now();
1146 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1147 summary: "Test conversation".into(),
1148 updated_at,
1149 messages: vec![
1150 SerializedMessage {
1151 id: MessageId(1),
1152 role: Role::User,
1153 segments: vec![SerializedMessageSegment::Text {
1154 text: "Use tool_1".to_string(),
1155 }],
1156 tool_uses: vec![],
1157 tool_results: vec![],
1158 context: "".to_string(),
1159 creases: vec![],
1160 is_hidden: false,
1161 },
1162 SerializedMessage {
1163 id: MessageId(2),
1164 role: Role::Assistant,
1165 segments: vec![SerializedMessageSegment::Text {
1166 text: "I want to use a tool".to_string(),
1167 }],
1168 tool_uses: vec![SerializedToolUse {
1169 id: "abc".into(),
1170 name: "tool_1".into(),
1171 input: serde_json::Value::Null,
1172 }],
1173 tool_results: vec![],
1174 context: "".to_string(),
1175 creases: vec![],
1176 is_hidden: false,
1177 },
1178 SerializedMessage {
1179 id: MessageId(1),
1180 role: Role::User,
1181 segments: vec![SerializedMessageSegment::Text {
1182 text: "Here is the tool result".to_string(),
1183 }],
1184 tool_uses: vec![],
1185 tool_results: vec![SerializedToolResult {
1186 tool_use_id: "abc".into(),
1187 is_error: false,
1188 content: LanguageModelToolResultContent::Text("abcdef".into()),
1189 output: Some(serde_json::Value::Null),
1190 }],
1191 context: "".to_string(),
1192 creases: vec![],
1193 is_hidden: false,
1194 },
1195 ],
1196 version: SerializedThreadV0_1_0::VERSION.to_string(),
1197 initial_project_snapshot: None,
1198 cumulative_token_usage: TokenUsage::default(),
1199 request_token_usage: vec![],
1200 detailed_summary_state: DetailedSummaryState::default(),
1201 exceeded_window_error: None,
1202 model: None,
1203 completion_mode: None,
1204 tool_use_limit_reached: false,
1205 profile: None,
1206 });
1207 let upgraded = thread_v0_1_0.upgrade();
1208
1209 assert_eq!(
1210 upgraded,
1211 SerializedThread {
1212 summary: "Test conversation".into(),
1213 updated_at,
1214 messages: vec![
1215 SerializedMessage {
1216 id: MessageId(1),
1217 role: Role::User,
1218 segments: vec![SerializedMessageSegment::Text {
1219 text: "Use tool_1".to_string()
1220 }],
1221 tool_uses: vec![],
1222 tool_results: vec![],
1223 context: "".to_string(),
1224 creases: vec![],
1225 is_hidden: false
1226 },
1227 SerializedMessage {
1228 id: MessageId(2),
1229 role: Role::Assistant,
1230 segments: vec![SerializedMessageSegment::Text {
1231 text: "I want to use a tool".to_string(),
1232 }],
1233 tool_uses: vec![SerializedToolUse {
1234 id: "abc".into(),
1235 name: "tool_1".into(),
1236 input: serde_json::Value::Null,
1237 }],
1238 tool_results: vec![SerializedToolResult {
1239 tool_use_id: "abc".into(),
1240 is_error: false,
1241 content: LanguageModelToolResultContent::Text("abcdef".into()),
1242 output: Some(serde_json::Value::Null),
1243 }],
1244 context: "".to_string(),
1245 creases: vec![],
1246 is_hidden: false,
1247 },
1248 ],
1249 version: SerializedThread::VERSION.to_string(),
1250 initial_project_snapshot: None,
1251 cumulative_token_usage: TokenUsage::default(),
1252 request_token_usage: vec![],
1253 detailed_summary_state: DetailedSummaryState::default(),
1254 exceeded_window_error: None,
1255 model: None,
1256 completion_mode: None,
1257 tool_use_limit_reached: false,
1258 profile: None
1259 }
1260 )
1261 }
1262}