db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use zed_env_vars::ZED_STATELESS;
 22
 23pub type DbMessage = crate::Message;
 24pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 25pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 26
 27#[derive(Debug, Clone, Serialize, Deserialize)]
 28pub struct DbThreadMetadata {
 29    pub id: acp::SessionId,
 30    pub parent_session_id: Option<acp::SessionId>,
 31    #[serde(alias = "summary")]
 32    pub title: SharedString,
 33    pub updated_at: DateTime<Utc>,
 34}
 35
 36#[derive(Debug, Serialize, Deserialize)]
 37pub struct DbThread {
 38    pub title: SharedString,
 39    pub messages: Vec<DbMessage>,
 40    pub updated_at: DateTime<Utc>,
 41    #[serde(default)]
 42    pub detailed_summary: Option<SharedString>,
 43    #[serde(default)]
 44    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 45    #[serde(default)]
 46    pub cumulative_token_usage: language_model::TokenUsage,
 47    #[serde(default)]
 48    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 49    #[serde(default)]
 50    pub model: Option<DbLanguageModel>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53    #[serde(default)]
 54    pub imported: bool,
 55    #[serde(default)]
 56    pub subagent_context: Option<crate::SubagentContext>,
 57    #[serde(default)]
 58    pub speed: Option<Speed>,
 59    #[serde(default)]
 60    pub thinking_enabled: bool,
 61    #[serde(default)]
 62    pub thinking_effort: Option<String>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct SharedThread {
 67    pub title: SharedString,
 68    pub messages: Vec<DbMessage>,
 69    pub updated_at: DateTime<Utc>,
 70    #[serde(default)]
 71    pub model: Option<DbLanguageModel>,
 72    pub version: String,
 73}
 74
 75impl SharedThread {
 76    pub const VERSION: &'static str = "1.0.0";
 77
 78    pub fn from_db_thread(thread: &DbThread) -> Self {
 79        Self {
 80            title: thread.title.clone(),
 81            messages: thread.messages.clone(),
 82            updated_at: thread.updated_at,
 83            model: thread.model.clone(),
 84            version: Self::VERSION.to_string(),
 85        }
 86    }
 87
 88    pub fn to_db_thread(self) -> DbThread {
 89        DbThread {
 90            title: format!("🔗 {}", self.title).into(),
 91            messages: self.messages,
 92            updated_at: self.updated_at,
 93            detailed_summary: None,
 94            initial_project_snapshot: None,
 95            cumulative_token_usage: Default::default(),
 96            request_token_usage: Default::default(),
 97            model: self.model,
 98            profile: None,
 99            imported: true,
100            subagent_context: None,
101            speed: None,
102            thinking_enabled: false,
103            thinking_effort: None,
104        }
105    }
106
107    pub fn to_bytes(&self) -> Result<Vec<u8>> {
108        const COMPRESSION_LEVEL: i32 = 3;
109        let json = serde_json::to_vec(self)?;
110        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
111        Ok(compressed)
112    }
113
114    pub fn from_bytes(data: &[u8]) -> Result<Self> {
115        let decompressed = zstd::decode_all(data)?;
116        Ok(serde_json::from_slice(&decompressed)?)
117    }
118}
119
120impl DbThread {
121    pub const VERSION: &'static str = "0.3.0";
122
123    pub fn from_json(json: &[u8]) -> Result<Self> {
124        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
125        match saved_thread_json.get("version") {
126            Some(serde_json::Value::String(version)) => match version.as_str() {
127                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
128                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
129                    json,
130                )?),
131            },
132            _ => {
133                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
134            }
135        }
136    }
137
138    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
139        let mut messages = Vec::new();
140        let mut request_token_usage = HashMap::default();
141
142        let mut last_user_message_id = None;
143        for (ix, msg) in thread.messages.into_iter().enumerate() {
144            let message = match msg.role {
145                language_model::Role::User => {
146                    let mut content = Vec::new();
147
148                    // Convert segments to content
149                    for segment in msg.segments {
150                        match segment {
151                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
152                                content.push(UserMessageContent::Text(text));
153                            }
154                            crate::legacy_thread::SerializedMessageSegment::Thinking {
155                                text,
156                                ..
157                            } => {
158                                // User messages don't have thinking segments, but handle gracefully
159                                content.push(UserMessageContent::Text(text));
160                            }
161                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
162                                ..
163                            } => {
164                                // User messages don't have redacted thinking, skip.
165                            }
166                        }
167                    }
168
169                    // If no content was added, add context as text if available
170                    if content.is_empty() && !msg.context.is_empty() {
171                        content.push(UserMessageContent::Text(msg.context));
172                    }
173
174                    let id = UserMessageId::new();
175                    last_user_message_id = Some(id.clone());
176
177                    crate::Message::User(UserMessage {
178                        // MessageId from old format can't be meaningfully converted, so generate a new one
179                        id,
180                        content,
181                    })
182                }
183                language_model::Role::Assistant => {
184                    let mut content = Vec::new();
185
186                    // Convert segments to content
187                    for segment in msg.segments {
188                        match segment {
189                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
190                                content.push(AgentMessageContent::Text(text));
191                            }
192                            crate::legacy_thread::SerializedMessageSegment::Thinking {
193                                text,
194                                signature,
195                            } => {
196                                content.push(AgentMessageContent::Thinking { text, signature });
197                            }
198                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
199                                data,
200                            } => {
201                                content.push(AgentMessageContent::RedactedThinking(data));
202                            }
203                        }
204                    }
205
206                    // Convert tool uses
207                    let mut tool_names_by_id = HashMap::default();
208                    for tool_use in msg.tool_uses {
209                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
210                        content.push(AgentMessageContent::ToolUse(
211                            language_model::LanguageModelToolUse {
212                                id: tool_use.id,
213                                name: tool_use.name.into(),
214                                raw_input: serde_json::to_string(&tool_use.input)
215                                    .unwrap_or_default(),
216                                input: tool_use.input,
217                                is_input_complete: true,
218                                thought_signature: None,
219                            },
220                        ));
221                    }
222
223                    // Convert tool results
224                    let mut tool_results = IndexMap::default();
225                    for tool_result in msg.tool_results {
226                        let name = tool_names_by_id
227                            .remove(&tool_result.tool_use_id)
228                            .unwrap_or_else(|| SharedString::from("unknown"));
229                        tool_results.insert(
230                            tool_result.tool_use_id.clone(),
231                            language_model::LanguageModelToolResult {
232                                tool_use_id: tool_result.tool_use_id,
233                                tool_name: name.into(),
234                                is_error: tool_result.is_error,
235                                content: tool_result.content,
236                                output: tool_result.output,
237                            },
238                        );
239                    }
240
241                    if let Some(last_user_message_id) = &last_user_message_id
242                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
243                    {
244                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
245                    }
246
247                    crate::Message::Agent(AgentMessage {
248                        content,
249                        tool_results,
250                        reasoning_details: None,
251                    })
252                }
253                language_model::Role::System => {
254                    // Skip system messages as they're not supported in the new format
255                    continue;
256                }
257            };
258
259            messages.push(message);
260        }
261
262        Ok(Self {
263            title: thread.summary,
264            messages,
265            updated_at: thread.updated_at,
266            detailed_summary: match thread.detailed_summary_state {
267                crate::legacy_thread::DetailedSummaryState::NotGenerated
268                | crate::legacy_thread::DetailedSummaryState::Generating => None,
269                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
270            },
271            initial_project_snapshot: thread.initial_project_snapshot,
272            cumulative_token_usage: thread.cumulative_token_usage,
273            request_token_usage,
274            model: thread.model,
275            profile: thread.profile,
276            imported: false,
277            subagent_context: None,
278            speed: None,
279            thinking_enabled: false,
280            thinking_effort: None,
281        })
282    }
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
286pub enum DataType {
287    #[serde(rename = "json")]
288    Json,
289    #[serde(rename = "zstd")]
290    Zstd,
291}
292
293impl Bind for DataType {
294    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
295        let value = match self {
296            DataType::Json => "json",
297            DataType::Zstd => "zstd",
298        };
299        value.bind(statement, start_index)
300    }
301}
302
303impl Column for DataType {
304    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
305        let (value, next_index) = String::column(statement, start_index)?;
306        let data_type = match value.as_str() {
307            "json" => DataType::Json,
308            "zstd" => DataType::Zstd,
309            _ => anyhow::bail!("Unknown data type: {}", value),
310        };
311        Ok((data_type, next_index))
312    }
313}
314
315pub(crate) struct ThreadsDatabase {
316    executor: BackgroundExecutor,
317    connection: Arc<Mutex<Connection>>,
318}
319
320struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
321
322impl Global for GlobalThreadsDatabase {}
323
324impl ThreadsDatabase {
325    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
326        if cx.has_global::<GlobalThreadsDatabase>() {
327            return cx.global::<GlobalThreadsDatabase>().0.clone();
328        }
329        let executor = cx.background_executor().clone();
330        let task = executor
331            .spawn({
332                let executor = executor.clone();
333                async move {
334                    match ThreadsDatabase::new(executor) {
335                        Ok(db) => Ok(Arc::new(db)),
336                        Err(err) => Err(Arc::new(err)),
337                    }
338                }
339            })
340            .shared();
341
342        cx.set_global(GlobalThreadsDatabase(task.clone()));
343        task
344    }
345
346    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
347        let connection = if *ZED_STATELESS {
348            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
349        } else if cfg!(any(feature = "test-support", test)) {
350            // rust stores the name of the test on the current thread.
351            // We use this to automatically create a database that will
352            // be shared within the test (for the test_retrieve_old_thread)
353            // but not with concurrent tests.
354            let thread = std::thread::current();
355            let test_name = thread.name();
356            Connection::open_memory(Some(&format!(
357                "THREAD_FALLBACK_{}",
358                test_name.unwrap_or_default()
359            )))
360        } else {
361            let threads_dir = paths::data_dir().join("threads");
362            std::fs::create_dir_all(&threads_dir)?;
363            let sqlite_path = threads_dir.join("threads.db");
364            Connection::open_file(&sqlite_path.to_string_lossy())
365        };
366
367        connection.exec(indoc! {"
368            CREATE TABLE IF NOT EXISTS threads (
369                id TEXT PRIMARY KEY,
370                summary TEXT NOT NULL,
371                updated_at TEXT NOT NULL,
372                data_type TEXT NOT NULL,
373                data BLOB NOT NULL
374            )
375        "})?()
376        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
377
378        if let Ok(mut s) = connection.exec(indoc! {"
379            ALTER TABLE threads ADD COLUMN parent_id TEXT
380        "})
381        {
382            s().ok();
383        }
384
385        let db = Self {
386            executor,
387            connection: Arc::new(Mutex::new(connection)),
388        };
389
390        Ok(db)
391    }
392
393    fn save_thread_sync(
394        connection: &Arc<Mutex<Connection>>,
395        id: acp::SessionId,
396        thread: DbThread,
397    ) -> Result<()> {
398        const COMPRESSION_LEVEL: i32 = 3;
399
400        #[derive(Serialize)]
401        struct SerializedThread {
402            #[serde(flatten)]
403            thread: DbThread,
404            version: &'static str,
405        }
406
407        let title = thread.title.to_string();
408        let updated_at = thread.updated_at.to_rfc3339();
409        let parent_id = thread
410            .subagent_context
411            .as_ref()
412            .map(|ctx| ctx.parent_thread_id.0.clone());
413        let json_data = serde_json::to_string(&SerializedThread {
414            thread,
415            version: DbThread::VERSION,
416        })?;
417
418        let connection = connection.lock();
419
420        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
421        let data_type = DataType::Zstd;
422        let data = compressed;
423
424        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, String, String, DataType, Vec<u8>)>(indoc! {"
425            INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?)
426        "})?;
427
428        insert((id.0, parent_id, title, updated_at, data_type, data))?;
429
430        Ok(())
431    }
432
433    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
434        let connection = self.connection.clone();
435
436        self.executor.spawn(async move {
437            let connection = connection.lock();
438
439            let mut select = connection
440                .select_bound::<(), (Arc<str>, Option<Arc<str>>, String, String)>(indoc! {"
441                SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC
442            "})?;
443
444            let rows = select(())?;
445            let mut threads = Vec::new();
446
447            for (id, parent_id, summary, updated_at) in rows {
448                threads.push(DbThreadMetadata {
449                    id: acp::SessionId::new(id),
450                    parent_session_id: parent_id.map(acp::SessionId::new),
451                    title: summary.into(),
452                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
453                });
454            }
455
456            Ok(threads)
457        })
458    }
459
460    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
461        let connection = self.connection.clone();
462
463        self.executor.spawn(async move {
464            let connection = connection.lock();
465            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
466                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
467            "})?;
468
469            let rows = select(id.0)?;
470            if let Some((data_type, data)) = rows.into_iter().next() {
471                let json_data = match data_type {
472                    DataType::Zstd => {
473                        let decompressed = zstd::decode_all(&data[..])?;
474                        String::from_utf8(decompressed)?
475                    }
476                    DataType::Json => String::from_utf8(data)?,
477                };
478                let thread = DbThread::from_json(json_data.as_bytes())?;
479                Ok(Some(thread))
480            } else {
481                Ok(None)
482            }
483        })
484    }
485
486    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
487        let connection = self.connection.clone();
488
489        self.executor
490            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
491    }
492
493    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
494        let connection = self.connection.clone();
495
496        self.executor.spawn(async move {
497            let connection = connection.lock();
498
499            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
500                DELETE FROM threads WHERE id = ?
501            "})?;
502
503            delete(id.0)?;
504
505            Ok(())
506        })
507    }
508
509    pub fn delete_threads(&self) -> Task<Result<()>> {
510        let connection = self.connection.clone();
511
512        self.executor.spawn(async move {
513            let connection = connection.lock();
514
515            let mut delete = connection.exec_bound::<()>(indoc! {"
516                DELETE FROM threads
517            "})?;
518
519            delete(())?;
520
521            Ok(())
522        })
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use chrono::{DateTime, TimeZone, Utc};
530    use collections::HashMap;
531    use gpui::TestAppContext;
532    use std::sync::Arc;
533
534    #[test]
535    fn test_shared_thread_roundtrip() {
536        let original = SharedThread {
537            title: "Test Thread".into(),
538            messages: vec![],
539            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
540            model: None,
541            version: SharedThread::VERSION.to_string(),
542        };
543
544        let bytes = original.to_bytes().expect("Failed to serialize");
545        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
546
547        assert_eq!(restored.title, original.title);
548        assert_eq!(restored.version, original.version);
549        assert_eq!(restored.updated_at, original.updated_at);
550    }
551
552    #[test]
553    fn test_imported_flag_defaults_to_false() {
554        // Simulate deserializing a thread without the imported field (backwards compatibility).
555        let json = r#"{
556            "title": "Old Thread",
557            "messages": [],
558            "updated_at": "2024-01-01T00:00:00Z"
559        }"#;
560
561        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
562
563        assert!(
564            !db_thread.imported,
565            "Legacy threads without imported field should default to false"
566        );
567    }
568
569    fn session_id(value: &str) -> acp::SessionId {
570        acp::SessionId::new(Arc::<str>::from(value))
571    }
572
573    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
574        DbThread {
575            title: title.to_string().into(),
576            messages: Vec::new(),
577            updated_at,
578            detailed_summary: None,
579            initial_project_snapshot: None,
580            cumulative_token_usage: Default::default(),
581            request_token_usage: HashMap::default(),
582            model: None,
583            profile: None,
584            imported: false,
585            subagent_context: None,
586            speed: None,
587            thinking_enabled: false,
588            thinking_effort: None,
589        }
590    }
591
592    #[gpui::test]
593    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
594        let database = ThreadsDatabase::new(cx.executor()).unwrap();
595
596        let older_id = session_id("thread-a");
597        let newer_id = session_id("thread-b");
598
599        let older_thread = make_thread(
600            "Thread A",
601            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
602        );
603        let newer_thread = make_thread(
604            "Thread B",
605            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
606        );
607
608        database
609            .save_thread(older_id.clone(), older_thread)
610            .await
611            .unwrap();
612        database
613            .save_thread(newer_id.clone(), newer_thread)
614            .await
615            .unwrap();
616
617        let entries = database.list_threads().await.unwrap();
618        assert_eq!(entries.len(), 2);
619        assert_eq!(entries[0].id, newer_id);
620        assert_eq!(entries[1].id, older_id);
621    }
622
623    #[gpui::test]
624    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
625        let database = ThreadsDatabase::new(cx.executor()).unwrap();
626
627        let thread_id = session_id("thread-a");
628        let original_thread = make_thread(
629            "Thread A",
630            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
631        );
632        let updated_thread = make_thread(
633            "Thread B",
634            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
635        );
636
637        database
638            .save_thread(thread_id.clone(), original_thread)
639            .await
640            .unwrap();
641        database
642            .save_thread(thread_id.clone(), updated_thread)
643            .await
644            .unwrap();
645
646        let entries = database.list_threads().await.unwrap();
647        assert_eq!(entries.len(), 1);
648        assert_eq!(entries[0].id, thread_id);
649        assert_eq!(entries[0].title.as_ref(), "Thread B");
650        assert_eq!(
651            entries[0].updated_at,
652            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
653        );
654    }
655
656    #[test]
657    fn test_subagent_context_defaults_to_none() {
658        let json = r#"{
659            "title": "Old Thread",
660            "messages": [],
661            "updated_at": "2024-01-01T00:00:00Z"
662        }"#;
663
664        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
665
666        assert!(
667            db_thread.subagent_context.is_none(),
668            "Legacy threads without subagent_context should default to None"
669        );
670    }
671
672    #[gpui::test]
673    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
674        let database = ThreadsDatabase::new(cx.executor()).unwrap();
675
676        let parent_id = session_id("parent-thread");
677        let child_id = session_id("child-thread");
678
679        let mut child_thread = make_thread(
680            "Subagent Thread",
681            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
682        );
683        child_thread.subagent_context = Some(crate::SubagentContext {
684            parent_thread_id: parent_id.clone(),
685            depth: 2,
686        });
687
688        database
689            .save_thread(child_id.clone(), child_thread)
690            .await
691            .unwrap();
692
693        let loaded = database
694            .load_thread(child_id)
695            .await
696            .unwrap()
697            .expect("thread should exist");
698
699        let context = loaded
700            .subagent_context
701            .expect("subagent_context should be restored");
702        assert_eq!(context.parent_thread_id, parent_id);
703        assert_eq!(context.depth, 2);
704    }
705
706    #[gpui::test]
707    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
708        let database = ThreadsDatabase::new(cx.executor()).unwrap();
709
710        let thread_id = session_id("regular-thread");
711        let thread = make_thread(
712            "Regular Thread",
713            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
714        );
715
716        database
717            .save_thread(thread_id.clone(), thread)
718            .await
719            .unwrap();
720
721        let loaded = database
722            .load_thread(thread_id)
723            .await
724            .unwrap()
725            .expect("thread should exist");
726
727        assert!(
728            loaded.subagent_context.is_none(),
729            "Regular threads should have no subagent_context"
730        );
731    }
732}