1use crate::{
2 context_server_tool::ContextServerTool,
3 thread::{
4 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
5 },
6};
7use agent_settings::{AgentProfileId, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::{
14 FutureExt as _, StreamExt as _,
15 channel::{mpsc, oneshot},
16 future::{self, BoxFuture, Shared},
17};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, Window, prelude::*,
21};
22use indoc::indoc;
23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
24use project::context_server_store::{ContextServerStatus, ContextServerStore};
25use project::{Project, ProjectItem, ProjectPath, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use sqlez::{
32 bindable::{Bind, Column},
33 connection::Connection,
34 statement::Statement,
35};
36use std::{
37 cell::{Ref, RefCell},
38 path::{Path, PathBuf},
39 rc::Rc,
40 sync::{Arc, Mutex},
41};
42use util::ResultExt as _;
43
44pub static ZED_STATELESS: std::sync::LazyLock<bool> =
45 std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum DataType {
49 #[serde(rename = "json")]
50 Json,
51 #[serde(rename = "zstd")]
52 Zstd,
53}
54
55impl Bind for DataType {
56 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
57 let value = match self {
58 DataType::Json => "json",
59 DataType::Zstd => "zstd",
60 };
61 value.bind(statement, start_index)
62 }
63}
64
65impl Column for DataType {
66 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
67 let (value, next_index) = String::column(statement, start_index)?;
68 let data_type = match value.as_str() {
69 "json" => DataType::Json,
70 "zstd" => DataType::Zstd,
71 _ => anyhow::bail!("Unknown data type: {}", value),
72 };
73 Ok((data_type, next_index))
74 }
75}
76
77const RULES_FILE_NAMES: [&'static str; 9] = [
78 ".rules",
79 ".cursorrules",
80 ".windsurfrules",
81 ".clinerules",
82 ".github/copilot-instructions.md",
83 "CLAUDE.md",
84 "AGENT.md",
85 "AGENTS.md",
86 "GEMINI.md",
87];
88
89pub fn init(cx: &mut App) {
90 ThreadsDatabase::init(cx);
91}
92
93/// A system prompt shared by all threads created by this ThreadStore
94#[derive(Clone, Default)]
95pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
96
97impl SharedProjectContext {
98 pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
99 self.0.borrow()
100 }
101}
102
103pub type TextThreadStore = assistant_context::ContextStore;
104
105pub struct ThreadStore {
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
111 threads: Vec<SerializedThreadMetadata>,
112 project_context: SharedProjectContext,
113 reload_system_prompt_tx: mpsc::Sender<()>,
114 _reload_system_prompt_task: Task<()>,
115 _subscriptions: Vec<Subscription>,
116}
117
118pub struct RulesLoadingError {
119 pub message: SharedString,
120}
121
122impl EventEmitter<RulesLoadingError> for ThreadStore {}
123
124impl ThreadStore {
125 pub fn load(
126 project: Entity<Project>,
127 tools: Entity<ToolWorkingSet>,
128 prompt_store: Option<Entity<PromptStore>>,
129 prompt_builder: Arc<PromptBuilder>,
130 cx: &mut App,
131 ) -> Task<Result<Entity<Self>>> {
132 cx.spawn(async move |cx| {
133 let (thread_store, ready_rx) = cx.update(|cx| {
134 let mut option_ready_rx = None;
135 let thread_store = cx.new(|cx| {
136 let (thread_store, ready_rx) =
137 Self::new(project, tools, prompt_builder, prompt_store, cx);
138 option_ready_rx = Some(ready_rx);
139 thread_store
140 });
141 (thread_store, option_ready_rx.take().unwrap())
142 })?;
143 ready_rx.await?;
144 Ok(thread_store)
145 })
146 }
147
148 fn new(
149 project: Entity<Project>,
150 tools: Entity<ToolWorkingSet>,
151 prompt_builder: Arc<PromptBuilder>,
152 prompt_store: Option<Entity<PromptStore>>,
153 cx: &mut Context<Self>,
154 ) -> (Self, oneshot::Receiver<()>) {
155 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
156
157 if let Some(prompt_store) = prompt_store.as_ref() {
158 subscriptions.push(cx.subscribe(
159 prompt_store,
160 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
161 this.enqueue_system_prompt_reload();
162 },
163 ))
164 }
165
166 // This channel and task prevent concurrent and redundant loading of the system prompt.
167 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
168 let (ready_tx, ready_rx) = oneshot::channel();
169 let mut ready_tx = Some(ready_tx);
170 let reload_system_prompt_task = cx.spawn({
171 let prompt_store = prompt_store.clone();
172 async move |thread_store, cx| {
173 loop {
174 let Some(reload_task) = thread_store
175 .update(cx, |thread_store, cx| {
176 thread_store.reload_system_prompt(prompt_store.clone(), cx)
177 })
178 .ok()
179 else {
180 return;
181 };
182 reload_task.await;
183 if let Some(ready_tx) = ready_tx.take() {
184 ready_tx.send(()).ok();
185 }
186 reload_system_prompt_rx.next().await;
187 }
188 }
189 });
190
191 let this = Self {
192 project,
193 tools,
194 prompt_builder,
195 prompt_store,
196 context_server_tool_ids: HashMap::default(),
197 threads: Vec::new(),
198 project_context: SharedProjectContext::default(),
199 reload_system_prompt_tx,
200 _reload_system_prompt_task: reload_system_prompt_task,
201 _subscriptions: subscriptions,
202 };
203 this.register_context_server_handlers(cx);
204 this.reload(cx).detach_and_log_err(cx);
205 (this, ready_rx)
206 }
207
208 fn handle_project_event(
209 &mut self,
210 _project: Entity<Project>,
211 event: &project::Event,
212 _cx: &mut Context<Self>,
213 ) {
214 match event {
215 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
216 self.enqueue_system_prompt_reload();
217 }
218 project::Event::WorktreeUpdatedEntries(_, items) => {
219 if items.iter().any(|(path, _, _)| {
220 RULES_FILE_NAMES
221 .iter()
222 .any(|name| path.as_ref() == Path::new(name))
223 }) {
224 self.enqueue_system_prompt_reload();
225 }
226 }
227 _ => {}
228 }
229 }
230
231 fn enqueue_system_prompt_reload(&mut self) {
232 self.reload_system_prompt_tx.try_send(()).ok();
233 }
234
235 // Note that this should only be called from `reload_system_prompt_task`.
236 fn reload_system_prompt(
237 &self,
238 prompt_store: Option<Entity<PromptStore>>,
239 cx: &mut Context<Self>,
240 ) -> Task<()> {
241 let worktrees = self
242 .project
243 .read(cx)
244 .visible_worktrees(cx)
245 .collect::<Vec<_>>();
246 let worktree_tasks = worktrees
247 .into_iter()
248 .map(|worktree| {
249 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
250 })
251 .collect::<Vec<_>>();
252 let default_user_rules_task = match prompt_store {
253 None => Task::ready(vec![]),
254 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
255 let prompts = prompt_store.default_prompt_metadata();
256 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
257 let contents = prompt_store.load(prompt_metadata.id, cx);
258 async move { (contents.await, prompt_metadata) }
259 });
260 cx.background_spawn(future::join_all(load_tasks))
261 }),
262 };
263
264 cx.spawn(async move |this, cx| {
265 let (worktrees, default_user_rules) =
266 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
267
268 let worktrees = worktrees
269 .into_iter()
270 .map(|(worktree, rules_error)| {
271 if let Some(rules_error) = rules_error {
272 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
273 }
274 worktree
275 })
276 .collect::<Vec<_>>();
277
278 let default_user_rules = default_user_rules
279 .into_iter()
280 .flat_map(|(contents, prompt_metadata)| match contents {
281 Ok(contents) => Some(UserRulesContext {
282 uuid: match prompt_metadata.id {
283 PromptId::User { uuid } => uuid,
284 PromptId::EditWorkflow => return None,
285 },
286 title: prompt_metadata.title.map(|title| title.to_string()),
287 contents,
288 }),
289 Err(err) => {
290 this.update(cx, |_, cx| {
291 cx.emit(RulesLoadingError {
292 message: format!("{err:?}").into(),
293 });
294 })
295 .ok();
296 None
297 }
298 })
299 .collect::<Vec<_>>();
300
301 this.update(cx, |this, _cx| {
302 *this.project_context.0.borrow_mut() =
303 Some(ProjectContext::new(worktrees, default_user_rules));
304 })
305 .ok();
306 })
307 }
308
309 fn load_worktree_info_for_system_prompt(
310 worktree: Entity<Worktree>,
311 project: Entity<Project>,
312 cx: &mut App,
313 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
314 let tree = worktree.read(cx);
315 let root_name = tree.root_name().into();
316 let abs_path = tree.abs_path();
317
318 let mut context = WorktreeContext {
319 root_name,
320 abs_path,
321 rules_file: None,
322 };
323
324 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
325 let Some(rules_task) = rules_task else {
326 return Task::ready((context, None));
327 };
328
329 cx.spawn(async move |_| {
330 let (rules_file, rules_file_error) = match rules_task.await {
331 Ok(rules_file) => (Some(rules_file), None),
332 Err(err) => (
333 None,
334 Some(RulesLoadingError {
335 message: format!("{err}").into(),
336 }),
337 ),
338 };
339 context.rules_file = rules_file;
340 (context, rules_file_error)
341 })
342 }
343
344 fn load_worktree_rules_file(
345 worktree: Entity<Worktree>,
346 project: Entity<Project>,
347 cx: &mut App,
348 ) -> Option<Task<Result<RulesFileContext>>> {
349 let worktree = worktree.read(cx);
350 let worktree_id = worktree.id();
351 let selected_rules_file = RULES_FILE_NAMES
352 .into_iter()
353 .filter_map(|name| {
354 worktree
355 .entry_for_path(name)
356 .filter(|entry| entry.is_file())
357 .map(|entry| entry.path.clone())
358 })
359 .next();
360
361 // Note that Cline supports `.clinerules` being a directory, but that is not currently
362 // supported. This doesn't seem to occur often in GitHub repositories.
363 selected_rules_file.map(|path_in_worktree| {
364 let project_path = ProjectPath {
365 worktree_id,
366 path: path_in_worktree.clone(),
367 };
368 let buffer_task =
369 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
370 let rope_task = cx.spawn(async move |cx| {
371 buffer_task.await?.read_with(cx, |buffer, cx| {
372 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
373 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
374 })?
375 });
376 // Build a string from the rope on a background thread.
377 cx.background_spawn(async move {
378 let (project_entry_id, rope) = rope_task.await?;
379 anyhow::Ok(RulesFileContext {
380 path_in_worktree,
381 text: rope.to_string().trim().to_string(),
382 project_entry_id: project_entry_id.to_usize(),
383 })
384 })
385 })
386 }
387
388 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
389 &self.prompt_store
390 }
391
392 pub fn tools(&self) -> Entity<ToolWorkingSet> {
393 self.tools.clone()
394 }
395
396 /// Returns the number of threads.
397 pub fn thread_count(&self) -> usize {
398 self.threads.len()
399 }
400
401 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
402 // ordering is from "ORDER BY" in `list_threads`
403 self.threads.iter()
404 }
405
406 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
407 cx.new(|cx| {
408 Thread::new(
409 self.project.clone(),
410 self.tools.clone(),
411 self.prompt_builder.clone(),
412 self.project_context.clone(),
413 cx,
414 )
415 })
416 }
417
418 pub fn create_thread_from_serialized(
419 &mut self,
420 serialized: SerializedThread,
421 cx: &mut Context<Self>,
422 ) -> Entity<Thread> {
423 cx.new(|cx| {
424 Thread::deserialize(
425 ThreadId::new(),
426 serialized,
427 self.project.clone(),
428 self.tools.clone(),
429 self.prompt_builder.clone(),
430 self.project_context.clone(),
431 None,
432 cx,
433 )
434 })
435 }
436
437 pub fn open_thread(
438 &self,
439 id: &ThreadId,
440 window: &mut Window,
441 cx: &mut Context<Self>,
442 ) -> Task<Result<Entity<Thread>>> {
443 let id = id.clone();
444 let database_future = ThreadsDatabase::global_future(cx);
445 let this = cx.weak_entity();
446 window.spawn(cx, async move |cx| {
447 let database = database_future.await.map_err(|err| anyhow!(err))?;
448 let thread = database
449 .try_find_thread(id.clone())
450 .await?
451 .with_context(|| format!("no thread found with ID: {id:?}"))?;
452
453 let thread = this.update_in(cx, |this, window, cx| {
454 cx.new(|cx| {
455 Thread::deserialize(
456 id.clone(),
457 thread,
458 this.project.clone(),
459 this.tools.clone(),
460 this.prompt_builder.clone(),
461 this.project_context.clone(),
462 Some(window),
463 cx,
464 )
465 })
466 })?;
467
468 Ok(thread)
469 })
470 }
471
472 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
473 let (metadata, serialized_thread) =
474 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
475
476 let database_future = ThreadsDatabase::global_future(cx);
477 cx.spawn(async move |this, cx| {
478 let serialized_thread = serialized_thread.await?;
479 let database = database_future.await.map_err(|err| anyhow!(err))?;
480 database.save_thread(metadata, serialized_thread).await?;
481
482 this.update(cx, |this, cx| this.reload(cx))?.await
483 })
484 }
485
486 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
487 let id = id.clone();
488 let database_future = ThreadsDatabase::global_future(cx);
489 cx.spawn(async move |this, cx| {
490 let database = database_future.await.map_err(|err| anyhow!(err))?;
491 database.delete_thread(id.clone()).await?;
492
493 this.update(cx, |this, cx| {
494 this.threads.retain(|thread| thread.id != id);
495 cx.notify();
496 })
497 })
498 }
499
500 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
501 let database_future = ThreadsDatabase::global_future(cx);
502 cx.spawn(async move |this, cx| {
503 let threads = database_future
504 .await
505 .map_err(|err| anyhow!(err))?
506 .list_threads()
507 .await?;
508
509 this.update(cx, |this, cx| {
510 this.threads = threads;
511 cx.notify();
512 })
513 })
514 }
515
516 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
517 let context_server_store = self.project.read(cx).context_server_store();
518 cx.subscribe(&context_server_store, Self::handle_context_server_event)
519 .detach();
520
521 // Check for any servers that were already running before the handler was registered
522 for server in context_server_store.read(cx).running_servers() {
523 self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
524 }
525 }
526
527 fn handle_context_server_event(
528 &mut self,
529 context_server_store: Entity<ContextServerStore>,
530 event: &project::context_server_store::Event,
531 cx: &mut Context<Self>,
532 ) {
533 let tool_working_set = self.tools.clone();
534 match event {
535 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
536 match status {
537 ContextServerStatus::Starting => {}
538 ContextServerStatus::Running => {
539 self.load_context_server_tools(server_id.clone(), context_server_store, cx);
540 }
541 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
542 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
543 tool_working_set.update(cx, |tool_working_set, cx| {
544 tool_working_set.remove(&tool_ids, cx);
545 });
546 }
547 }
548 }
549 }
550 }
551 }
552
553 fn load_context_server_tools(
554 &self,
555 server_id: ContextServerId,
556 context_server_store: Entity<ContextServerStore>,
557 cx: &mut Context<Self>,
558 ) {
559 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
560 return;
561 };
562 let tool_working_set = self.tools.clone();
563 cx.spawn(async move |this, cx| {
564 let Some(protocol) = server.client() else {
565 return;
566 };
567
568 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
569 if let Some(response) = protocol
570 .request::<context_server::types::requests::ListTools>(())
571 .await
572 .log_err()
573 {
574 let tool_ids = tool_working_set
575 .update(cx, |tool_working_set, cx| {
576 tool_working_set.extend(
577 response.tools.into_iter().map(|tool| {
578 Arc::new(ContextServerTool::new(
579 context_server_store.clone(),
580 server.id(),
581 tool,
582 )) as Arc<dyn Tool>
583 }),
584 cx,
585 )
586 })
587 .log_err();
588
589 if let Some(tool_ids) = tool_ids {
590 this.update(cx, |this, _| {
591 this.context_server_tool_ids.insert(server_id, tool_ids);
592 })
593 .log_err();
594 }
595 }
596 }
597 })
598 .detach();
599 }
600}
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct SerializedThreadMetadata {
604 pub id: ThreadId,
605 pub summary: SharedString,
606 pub updated_at: DateTime<Utc>,
607}
608
609#[derive(Serialize, Deserialize, Debug, PartialEq)]
610pub struct SerializedThread {
611 pub version: String,
612 pub summary: SharedString,
613 pub updated_at: DateTime<Utc>,
614 pub messages: Vec<SerializedMessage>,
615 #[serde(default)]
616 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
617 #[serde(default)]
618 pub cumulative_token_usage: TokenUsage,
619 #[serde(default)]
620 pub request_token_usage: Vec<TokenUsage>,
621 #[serde(default)]
622 pub detailed_summary_state: DetailedSummaryState,
623 #[serde(default)]
624 pub exceeded_window_error: Option<ExceededWindowError>,
625 #[serde(default)]
626 pub model: Option<SerializedLanguageModel>,
627 #[serde(default)]
628 pub completion_mode: Option<CompletionMode>,
629 #[serde(default)]
630 pub tool_use_limit_reached: bool,
631 #[serde(default)]
632 pub profile: Option<AgentProfileId>,
633}
634
635#[derive(Serialize, Deserialize, Debug, PartialEq)]
636pub struct SerializedLanguageModel {
637 pub provider: String,
638 pub model: String,
639}
640
641impl SerializedThread {
642 pub const VERSION: &'static str = "0.2.0";
643
644 pub fn from_json(json: &[u8]) -> Result<Self> {
645 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
646 match saved_thread_json.get("version") {
647 Some(serde_json::Value::String(version)) => match version.as_str() {
648 SerializedThreadV0_1_0::VERSION => {
649 let saved_thread =
650 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
651 Ok(saved_thread.upgrade())
652 }
653 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
654 saved_thread_json,
655 )?),
656 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
657 },
658 None => {
659 let saved_thread =
660 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
661 Ok(saved_thread.upgrade())
662 }
663 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
664 }
665 }
666}
667
668#[derive(Serialize, Deserialize, Debug)]
669pub struct SerializedThreadV0_1_0(
670 // The structure did not change, so we are reusing the latest SerializedThread.
671 // When making the next version, make sure this points to SerializedThreadV0_2_0
672 SerializedThread,
673);
674
675impl SerializedThreadV0_1_0 {
676 pub const VERSION: &'static str = "0.1.0";
677
678 pub fn upgrade(self) -> SerializedThread {
679 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
680
681 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
682
683 for message in self.0.messages {
684 if message.role == Role::User && !message.tool_results.is_empty() {
685 if let Some(last_message) = messages.last_mut() {
686 debug_assert!(last_message.role == Role::Assistant);
687
688 last_message.tool_results = message.tool_results;
689 continue;
690 }
691 }
692
693 messages.push(message);
694 }
695
696 SerializedThread {
697 messages,
698 version: SerializedThread::VERSION.to_string(),
699 ..self.0
700 }
701 }
702}
703
704#[derive(Debug, Serialize, Deserialize, PartialEq)]
705pub struct SerializedMessage {
706 pub id: MessageId,
707 pub role: Role,
708 #[serde(default)]
709 pub segments: Vec<SerializedMessageSegment>,
710 #[serde(default)]
711 pub tool_uses: Vec<SerializedToolUse>,
712 #[serde(default)]
713 pub tool_results: Vec<SerializedToolResult>,
714 #[serde(default)]
715 pub context: String,
716 #[serde(default)]
717 pub creases: Vec<SerializedCrease>,
718 #[serde(default)]
719 pub is_hidden: bool,
720}
721
722#[derive(Debug, Serialize, Deserialize, PartialEq)]
723#[serde(tag = "type")]
724pub enum SerializedMessageSegment {
725 #[serde(rename = "text")]
726 Text {
727 text: String,
728 },
729 #[serde(rename = "thinking")]
730 Thinking {
731 text: String,
732 #[serde(skip_serializing_if = "Option::is_none")]
733 signature: Option<String>,
734 },
735 RedactedThinking {
736 data: String,
737 },
738}
739
740#[derive(Debug, Serialize, Deserialize, PartialEq)]
741pub struct SerializedToolUse {
742 pub id: LanguageModelToolUseId,
743 pub name: SharedString,
744 pub input: serde_json::Value,
745}
746
747#[derive(Debug, Serialize, Deserialize, PartialEq)]
748pub struct SerializedToolResult {
749 pub tool_use_id: LanguageModelToolUseId,
750 pub is_error: bool,
751 pub content: LanguageModelToolResultContent,
752 pub output: Option<serde_json::Value>,
753}
754
755#[derive(Serialize, Deserialize)]
756struct LegacySerializedThread {
757 pub summary: SharedString,
758 pub updated_at: DateTime<Utc>,
759 pub messages: Vec<LegacySerializedMessage>,
760 #[serde(default)]
761 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
762}
763
764impl LegacySerializedThread {
765 pub fn upgrade(self) -> SerializedThread {
766 SerializedThread {
767 version: SerializedThread::VERSION.to_string(),
768 summary: self.summary,
769 updated_at: self.updated_at,
770 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
771 initial_project_snapshot: self.initial_project_snapshot,
772 cumulative_token_usage: TokenUsage::default(),
773 request_token_usage: Vec::new(),
774 detailed_summary_state: DetailedSummaryState::default(),
775 exceeded_window_error: None,
776 model: None,
777 completion_mode: None,
778 tool_use_limit_reached: false,
779 profile: None,
780 }
781 }
782}
783
784#[derive(Debug, Serialize, Deserialize)]
785struct LegacySerializedMessage {
786 pub id: MessageId,
787 pub role: Role,
788 pub text: String,
789 #[serde(default)]
790 pub tool_uses: Vec<SerializedToolUse>,
791 #[serde(default)]
792 pub tool_results: Vec<SerializedToolResult>,
793}
794
795impl LegacySerializedMessage {
796 fn upgrade(self) -> SerializedMessage {
797 SerializedMessage {
798 id: self.id,
799 role: self.role,
800 segments: vec![SerializedMessageSegment::Text { text: self.text }],
801 tool_uses: self.tool_uses,
802 tool_results: self.tool_results,
803 context: String::new(),
804 creases: Vec::new(),
805 is_hidden: false,
806 }
807 }
808}
809
810#[derive(Debug, Serialize, Deserialize, PartialEq)]
811pub struct SerializedCrease {
812 pub start: usize,
813 pub end: usize,
814 pub icon_path: SharedString,
815 pub label: SharedString,
816}
817
818struct GlobalThreadsDatabase(
819 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
820);
821
822impl Global for GlobalThreadsDatabase {}
823
824pub(crate) struct ThreadsDatabase {
825 executor: BackgroundExecutor,
826 connection: Arc<Mutex<Connection>>,
827}
828
829impl ThreadsDatabase {
830 fn connection(&self) -> Arc<Mutex<Connection>> {
831 self.connection.clone()
832 }
833
834 const COMPRESSION_LEVEL: i32 = 3;
835}
836
837impl Bind for ThreadId {
838 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
839 self.to_string().bind(statement, start_index)
840 }
841}
842
843impl Column for ThreadId {
844 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
845 let (id_str, next_index) = String::column(statement, start_index)?;
846 Ok((ThreadId::from(id_str.as_str()), next_index))
847 }
848}
849
850impl ThreadsDatabase {
851 fn global_future(
852 cx: &mut App,
853 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
854 GlobalThreadsDatabase::global(cx).0.clone()
855 }
856
857 fn init(cx: &mut App) {
858 let executor = cx.background_executor().clone();
859 let database_future = executor
860 .spawn({
861 let executor = executor.clone();
862 let threads_dir = paths::data_dir().join("threads");
863 async move { ThreadsDatabase::new(threads_dir, executor) }
864 })
865 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
866 .boxed()
867 .shared();
868
869 cx.set_global(GlobalThreadsDatabase(database_future));
870 }
871
872 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
873 std::fs::create_dir_all(&threads_dir)?;
874
875 let sqlite_path = threads_dir.join("threads.db");
876 let mdb_path = threads_dir.join("threads-db.1.mdb");
877
878 let needs_migration_from_heed = mdb_path.exists();
879
880 let connection = if *ZED_STATELESS {
881 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
882 } else {
883 Connection::open_file(&sqlite_path.to_string_lossy())
884 };
885
886 connection.exec(indoc! {"
887 CREATE TABLE IF NOT EXISTS threads (
888 id TEXT PRIMARY KEY,
889 summary TEXT NOT NULL,
890 updated_at TEXT NOT NULL,
891 data_type TEXT NOT NULL,
892 data BLOB NOT NULL
893 )
894 "})?()
895 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
896
897 let db = Self {
898 executor: executor.clone(),
899 connection: Arc::new(Mutex::new(connection)),
900 };
901
902 if needs_migration_from_heed {
903 let db_connection = db.connection();
904 let executor_clone = executor.clone();
905 executor
906 .spawn(async move {
907 log::info!("Starting threads.db migration");
908 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
909 std::fs::remove_dir_all(mdb_path)?;
910 log::info!("threads.db migrated to sqlite");
911 Ok::<(), anyhow::Error>(())
912 })
913 .detach();
914 }
915
916 Ok(db)
917 }
918
919 // Remove this migration after 2025-09-01
920 fn migrate_from_heed(
921 mdb_path: &Path,
922 connection: Arc<Mutex<Connection>>,
923 _executor: BackgroundExecutor,
924 ) -> Result<()> {
925 use heed::types::SerdeBincode;
926 struct SerializedThreadHeed(SerializedThread);
927
928 impl heed::BytesEncode<'_> for SerializedThreadHeed {
929 type EItem = SerializedThreadHeed;
930
931 fn bytes_encode(
932 item: &Self::EItem,
933 ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
934 serde_json::to_vec(&item.0)
935 .map(std::borrow::Cow::Owned)
936 .map_err(Into::into)
937 }
938 }
939
940 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
941 type DItem = SerializedThreadHeed;
942
943 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
944 SerializedThread::from_json(bytes)
945 .map(SerializedThreadHeed)
946 .map_err(Into::into)
947 }
948 }
949
950 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
951
952 let env = unsafe {
953 heed::EnvOpenOptions::new()
954 .map_size(ONE_GB_IN_BYTES)
955 .max_dbs(1)
956 .open(mdb_path)?
957 };
958
959 let txn = env.write_txn()?;
960 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
961 .open_database(&txn, Some("threads"))?
962 .ok_or_else(|| anyhow!("threads database not found"))?;
963
964 for result in threads.iter(&txn)? {
965 let (thread_id, thread_heed) = result?;
966 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
967 }
968
969 Ok(())
970 }
971
972 fn save_thread_sync(
973 connection: &Arc<Mutex<Connection>>,
974 id: ThreadId,
975 thread: SerializedThread,
976 ) -> Result<()> {
977 let json_data = serde_json::to_string(&thread)?;
978 let summary = thread.summary.to_string();
979 let updated_at = thread.updated_at.to_rfc3339();
980
981 let connection = connection.lock().unwrap();
982
983 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
984 let data_type = DataType::Zstd;
985 let data = compressed;
986
987 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
988 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
989 "})?;
990
991 insert((id, summary, updated_at, data_type, data))?;
992
993 Ok(())
994 }
995
996 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
997 let connection = self.connection.clone();
998
999 self.executor.spawn(async move {
1000 let connection = connection.lock().unwrap();
1001 let mut select =
1002 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1003 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1004 "})?;
1005
1006 let rows = select(())?;
1007 let mut threads = Vec::new();
1008
1009 for (id, summary, updated_at) in rows {
1010 threads.push(SerializedThreadMetadata {
1011 id,
1012 summary: summary.into(),
1013 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1014 });
1015 }
1016
1017 Ok(threads)
1018 })
1019 }
1020
1021 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1022 let connection = self.connection.clone();
1023
1024 self.executor.spawn(async move {
1025 let connection = connection.lock().unwrap();
1026 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1027 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1028 "})?;
1029
1030 let rows = select(id)?;
1031 if let Some((data_type, data)) = rows.into_iter().next() {
1032 let json_data = match data_type {
1033 DataType::Zstd => {
1034 let decompressed = zstd::decode_all(&data[..])?;
1035 String::from_utf8(decompressed)?
1036 }
1037 DataType::Json => String::from_utf8(data)?,
1038 };
1039
1040 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1041 Ok(Some(thread))
1042 } else {
1043 Ok(None)
1044 }
1045 })
1046 }
1047
1048 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1049 let connection = self.connection.clone();
1050
1051 self.executor
1052 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1053 }
1054
1055 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1056 let connection = self.connection.clone();
1057
1058 self.executor.spawn(async move {
1059 let connection = connection.lock().unwrap();
1060
1061 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1062 DELETE FROM threads WHERE id = ?
1063 "})?;
1064
1065 delete(id)?;
1066
1067 Ok(())
1068 })
1069 }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074 use super::*;
1075 use crate::thread::{DetailedSummaryState, MessageId};
1076 use chrono::Utc;
1077 use language_model::{Role, TokenUsage};
1078 use pretty_assertions::assert_eq;
1079
1080 #[test]
1081 fn test_legacy_serialized_thread_upgrade() {
1082 let updated_at = Utc::now();
1083 let legacy_thread = LegacySerializedThread {
1084 summary: "Test conversation".into(),
1085 updated_at,
1086 messages: vec![LegacySerializedMessage {
1087 id: MessageId(1),
1088 role: Role::User,
1089 text: "Hello, world!".to_string(),
1090 tool_uses: vec![],
1091 tool_results: vec![],
1092 }],
1093 initial_project_snapshot: None,
1094 };
1095
1096 let upgraded = legacy_thread.upgrade();
1097
1098 assert_eq!(
1099 upgraded,
1100 SerializedThread {
1101 summary: "Test conversation".into(),
1102 updated_at,
1103 messages: vec![SerializedMessage {
1104 id: MessageId(1),
1105 role: Role::User,
1106 segments: vec![SerializedMessageSegment::Text {
1107 text: "Hello, world!".to_string()
1108 }],
1109 tool_uses: vec![],
1110 tool_results: vec![],
1111 context: "".to_string(),
1112 creases: vec![],
1113 is_hidden: false
1114 }],
1115 version: SerializedThread::VERSION.to_string(),
1116 initial_project_snapshot: None,
1117 cumulative_token_usage: TokenUsage::default(),
1118 request_token_usage: vec![],
1119 detailed_summary_state: DetailedSummaryState::default(),
1120 exceeded_window_error: None,
1121 model: None,
1122 completion_mode: None,
1123 tool_use_limit_reached: false,
1124 profile: None
1125 }
1126 )
1127 }
1128
1129 #[test]
1130 fn test_serialized_threadv0_1_0_upgrade() {
1131 let updated_at = Utc::now();
1132 let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1133 summary: "Test conversation".into(),
1134 updated_at,
1135 messages: vec![
1136 SerializedMessage {
1137 id: MessageId(1),
1138 role: Role::User,
1139 segments: vec![SerializedMessageSegment::Text {
1140 text: "Use tool_1".to_string(),
1141 }],
1142 tool_uses: vec![],
1143 tool_results: vec![],
1144 context: "".to_string(),
1145 creases: vec![],
1146 is_hidden: false,
1147 },
1148 SerializedMessage {
1149 id: MessageId(2),
1150 role: Role::Assistant,
1151 segments: vec![SerializedMessageSegment::Text {
1152 text: "I want to use a tool".to_string(),
1153 }],
1154 tool_uses: vec![SerializedToolUse {
1155 id: "abc".into(),
1156 name: "tool_1".into(),
1157 input: serde_json::Value::Null,
1158 }],
1159 tool_results: vec![],
1160 context: "".to_string(),
1161 creases: vec![],
1162 is_hidden: false,
1163 },
1164 SerializedMessage {
1165 id: MessageId(1),
1166 role: Role::User,
1167 segments: vec![SerializedMessageSegment::Text {
1168 text: "Here is the tool result".to_string(),
1169 }],
1170 tool_uses: vec![],
1171 tool_results: vec![SerializedToolResult {
1172 tool_use_id: "abc".into(),
1173 is_error: false,
1174 content: LanguageModelToolResultContent::Text("abcdef".into()),
1175 output: Some(serde_json::Value::Null),
1176 }],
1177 context: "".to_string(),
1178 creases: vec![],
1179 is_hidden: false,
1180 },
1181 ],
1182 version: SerializedThreadV0_1_0::VERSION.to_string(),
1183 initial_project_snapshot: None,
1184 cumulative_token_usage: TokenUsage::default(),
1185 request_token_usage: vec![],
1186 detailed_summary_state: DetailedSummaryState::default(),
1187 exceeded_window_error: None,
1188 model: None,
1189 completion_mode: None,
1190 tool_use_limit_reached: false,
1191 profile: None,
1192 });
1193 let upgraded = thread_v0_1_0.upgrade();
1194
1195 assert_eq!(
1196 upgraded,
1197 SerializedThread {
1198 summary: "Test conversation".into(),
1199 updated_at,
1200 messages: vec![
1201 SerializedMessage {
1202 id: MessageId(1),
1203 role: Role::User,
1204 segments: vec![SerializedMessageSegment::Text {
1205 text: "Use tool_1".to_string()
1206 }],
1207 tool_uses: vec![],
1208 tool_results: vec![],
1209 context: "".to_string(),
1210 creases: vec![],
1211 is_hidden: false
1212 },
1213 SerializedMessage {
1214 id: MessageId(2),
1215 role: Role::Assistant,
1216 segments: vec![SerializedMessageSegment::Text {
1217 text: "I want to use a tool".to_string(),
1218 }],
1219 tool_uses: vec![SerializedToolUse {
1220 id: "abc".into(),
1221 name: "tool_1".into(),
1222 input: serde_json::Value::Null,
1223 }],
1224 tool_results: vec![SerializedToolResult {
1225 tool_use_id: "abc".into(),
1226 is_error: false,
1227 content: LanguageModelToolResultContent::Text("abcdef".into()),
1228 output: Some(serde_json::Value::Null),
1229 }],
1230 context: "".to_string(),
1231 creases: vec![],
1232 is_hidden: false,
1233 },
1234 ],
1235 version: SerializedThread::VERSION.to_string(),
1236 initial_project_snapshot: None,
1237 cumulative_token_usage: TokenUsage::default(),
1238 request_token_usage: vec![],
1239 detailed_summary_state: DetailedSummaryState::default(),
1240 exceeded_window_error: None,
1241 model: None,
1242 completion_mode: None,
1243 tool_use_limit_reached: false,
1244 profile: None
1245 }
1246 )
1247 }
1248}