db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    pub parent_session_id: Option<acp::SessionId>,
 30    #[serde(alias = "summary")]
 31    pub title: SharedString,
 32    pub updated_at: DateTime<Utc>,
 33}
 34
 35#[derive(Debug, Serialize, Deserialize)]
 36pub struct DbThread {
 37    pub title: SharedString,
 38    pub messages: Vec<DbMessage>,
 39    pub updated_at: DateTime<Utc>,
 40    #[serde(default)]
 41    pub detailed_summary: Option<SharedString>,
 42    #[serde(default)]
 43    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 44    #[serde(default)]
 45    pub cumulative_token_usage: language_model::TokenUsage,
 46    #[serde(default)]
 47    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 48    #[serde(default)]
 49    pub model: Option<DbLanguageModel>,
 50    #[serde(default)]
 51    pub profile: Option<AgentProfileId>,
 52    #[serde(default)]
 53    pub imported: bool,
 54    #[serde(default)]
 55    pub subagent_context: Option<crate::SubagentContext>,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct SharedThread {
 60    pub title: SharedString,
 61    pub messages: Vec<DbMessage>,
 62    pub updated_at: DateTime<Utc>,
 63    #[serde(default)]
 64    pub model: Option<DbLanguageModel>,
 65    pub version: String,
 66}
 67
 68impl SharedThread {
 69    pub const VERSION: &'static str = "1.0.0";
 70
 71    pub fn from_db_thread(thread: &DbThread) -> Self {
 72        Self {
 73            title: thread.title.clone(),
 74            messages: thread.messages.clone(),
 75            updated_at: thread.updated_at,
 76            model: thread.model.clone(),
 77            version: Self::VERSION.to_string(),
 78        }
 79    }
 80
 81    pub fn to_db_thread(self) -> DbThread {
 82        DbThread {
 83            title: format!("🔗 {}", self.title).into(),
 84            messages: self.messages,
 85            updated_at: self.updated_at,
 86            detailed_summary: None,
 87            initial_project_snapshot: None,
 88            cumulative_token_usage: Default::default(),
 89            request_token_usage: Default::default(),
 90            model: self.model,
 91            profile: None,
 92            imported: true,
 93            subagent_context: None,
 94        }
 95    }
 96
 97    pub fn to_bytes(&self) -> Result<Vec<u8>> {
 98        const COMPRESSION_LEVEL: i32 = 3;
 99        let json = serde_json::to_vec(self)?;
100        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
101        Ok(compressed)
102    }
103
104    pub fn from_bytes(data: &[u8]) -> Result<Self> {
105        let decompressed = zstd::decode_all(data)?;
106        Ok(serde_json::from_slice(&decompressed)?)
107    }
108}
109
110impl DbThread {
111    pub const VERSION: &'static str = "0.3.0";
112
113    pub fn from_json(json: &[u8]) -> Result<Self> {
114        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
115        match saved_thread_json.get("version") {
116            Some(serde_json::Value::String(version)) => match version.as_str() {
117                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
118                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
119                    json,
120                )?),
121            },
122            _ => {
123                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
124            }
125        }
126    }
127
128    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
129        let mut messages = Vec::new();
130        let mut request_token_usage = HashMap::default();
131
132        let mut last_user_message_id = None;
133        for (ix, msg) in thread.messages.into_iter().enumerate() {
134            let message = match msg.role {
135                language_model::Role::User => {
136                    let mut content = Vec::new();
137
138                    // Convert segments to content
139                    for segment in msg.segments {
140                        match segment {
141                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
142                                content.push(UserMessageContent::Text(text));
143                            }
144                            crate::legacy_thread::SerializedMessageSegment::Thinking {
145                                text,
146                                ..
147                            } => {
148                                // User messages don't have thinking segments, but handle gracefully
149                                content.push(UserMessageContent::Text(text));
150                            }
151                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
152                                ..
153                            } => {
154                                // User messages don't have redacted thinking, skip.
155                            }
156                        }
157                    }
158
159                    // If no content was added, add context as text if available
160                    if content.is_empty() && !msg.context.is_empty() {
161                        content.push(UserMessageContent::Text(msg.context));
162                    }
163
164                    let id = UserMessageId::new();
165                    last_user_message_id = Some(id.clone());
166
167                    crate::Message::User(UserMessage {
168                        // MessageId from old format can't be meaningfully converted, so generate a new one
169                        id,
170                        content,
171                    })
172                }
173                language_model::Role::Assistant => {
174                    let mut content = Vec::new();
175
176                    // Convert segments to content
177                    for segment in msg.segments {
178                        match segment {
179                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
180                                content.push(AgentMessageContent::Text(text));
181                            }
182                            crate::legacy_thread::SerializedMessageSegment::Thinking {
183                                text,
184                                signature,
185                            } => {
186                                content.push(AgentMessageContent::Thinking { text, signature });
187                            }
188                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
189                                data,
190                            } => {
191                                content.push(AgentMessageContent::RedactedThinking(data));
192                            }
193                        }
194                    }
195
196                    // Convert tool uses
197                    let mut tool_names_by_id = HashMap::default();
198                    for tool_use in msg.tool_uses {
199                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
200                        content.push(AgentMessageContent::ToolUse(
201                            language_model::LanguageModelToolUse {
202                                id: tool_use.id,
203                                name: tool_use.name.into(),
204                                raw_input: serde_json::to_string(&tool_use.input)
205                                    .unwrap_or_default(),
206                                input: tool_use.input,
207                                is_input_complete: true,
208                                thought_signature: None,
209                            },
210                        ));
211                    }
212
213                    // Convert tool results
214                    let mut tool_results = IndexMap::default();
215                    for tool_result in msg.tool_results {
216                        let name = tool_names_by_id
217                            .remove(&tool_result.tool_use_id)
218                            .unwrap_or_else(|| SharedString::from("unknown"));
219                        tool_results.insert(
220                            tool_result.tool_use_id.clone(),
221                            language_model::LanguageModelToolResult {
222                                tool_use_id: tool_result.tool_use_id,
223                                tool_name: name.into(),
224                                is_error: tool_result.is_error,
225                                content: tool_result.content,
226                                output: tool_result.output,
227                            },
228                        );
229                    }
230
231                    if let Some(last_user_message_id) = &last_user_message_id
232                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
233                    {
234                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
235                    }
236
237                    crate::Message::Agent(AgentMessage {
238                        content,
239                        tool_results,
240                        reasoning_details: None,
241                    })
242                }
243                language_model::Role::System => {
244                    // Skip system messages as they're not supported in the new format
245                    continue;
246                }
247            };
248
249            messages.push(message);
250        }
251
252        Ok(Self {
253            title: thread.summary,
254            messages,
255            updated_at: thread.updated_at,
256            detailed_summary: match thread.detailed_summary_state {
257                crate::legacy_thread::DetailedSummaryState::NotGenerated
258                | crate::legacy_thread::DetailedSummaryState::Generating => None,
259                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
260            },
261            initial_project_snapshot: thread.initial_project_snapshot,
262            cumulative_token_usage: thread.cumulative_token_usage,
263            request_token_usage,
264            model: thread.model,
265            profile: thread.profile,
266            imported: false,
267            subagent_context: None,
268        })
269    }
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DataType {
274    #[serde(rename = "json")]
275    Json,
276    #[serde(rename = "zstd")]
277    Zstd,
278}
279
280impl Bind for DataType {
281    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
282        let value = match self {
283            DataType::Json => "json",
284            DataType::Zstd => "zstd",
285        };
286        value.bind(statement, start_index)
287    }
288}
289
290impl Column for DataType {
291    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
292        let (value, next_index) = String::column(statement, start_index)?;
293        let data_type = match value.as_str() {
294            "json" => DataType::Json,
295            "zstd" => DataType::Zstd,
296            _ => anyhow::bail!("Unknown data type: {}", value),
297        };
298        Ok((data_type, next_index))
299    }
300}
301
302pub(crate) struct ThreadsDatabase {
303    executor: BackgroundExecutor,
304    connection: Arc<Mutex<Connection>>,
305}
306
307struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
308
309impl Global for GlobalThreadsDatabase {}
310
311impl ThreadsDatabase {
312    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
313        if cx.has_global::<GlobalThreadsDatabase>() {
314            return cx.global::<GlobalThreadsDatabase>().0.clone();
315        }
316        let executor = cx.background_executor().clone();
317        let task = executor
318            .spawn({
319                let executor = executor.clone();
320                async move {
321                    match ThreadsDatabase::new(executor) {
322                        Ok(db) => Ok(Arc::new(db)),
323                        Err(err) => Err(Arc::new(err)),
324                    }
325                }
326            })
327            .shared();
328
329        cx.set_global(GlobalThreadsDatabase(task.clone()));
330        task
331    }
332
333    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
334        let connection = if *ZED_STATELESS {
335            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
336        } else if cfg!(any(feature = "test-support", test)) {
337            // rust stores the name of the test on the current thread.
338            // We use this to automatically create a database that will
339            // be shared within the test (for the test_retrieve_old_thread)
340            // but not with concurrent tests.
341            let thread = std::thread::current();
342            let test_name = thread.name();
343            Connection::open_memory(Some(&format!(
344                "THREAD_FALLBACK_{}",
345                test_name.unwrap_or_default()
346            )))
347        } else {
348            let threads_dir = paths::data_dir().join("threads");
349            std::fs::create_dir_all(&threads_dir)?;
350            let sqlite_path = threads_dir.join("threads.db");
351            Connection::open_file(&sqlite_path.to_string_lossy())
352        };
353
354        connection.exec(indoc! {"
355            CREATE TABLE IF NOT EXISTS threads (
356                id TEXT PRIMARY KEY,
357                summary TEXT NOT NULL,
358                updated_at TEXT NOT NULL,
359                data_type TEXT NOT NULL,
360                data BLOB NOT NULL
361            )
362        "})?()
363        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
364
365        if let Ok(mut s) = connection.exec(indoc! {"
366            ALTER TABLE threads ADD COLUMN parent_id TEXT
367        "})
368        {
369            s().ok();
370        }
371
372        let db = Self {
373            executor,
374            connection: Arc::new(Mutex::new(connection)),
375        };
376
377        Ok(db)
378    }
379
380    fn save_thread_sync(
381        connection: &Arc<Mutex<Connection>>,
382        id: acp::SessionId,
383        thread: DbThread,
384    ) -> Result<()> {
385        const COMPRESSION_LEVEL: i32 = 3;
386
387        #[derive(Serialize)]
388        struct SerializedThread {
389            #[serde(flatten)]
390            thread: DbThread,
391            version: &'static str,
392        }
393
394        let title = thread.title.to_string();
395        let updated_at = thread.updated_at.to_rfc3339();
396        let parent_id = thread
397            .subagent_context
398            .as_ref()
399            .map(|ctx| ctx.parent_thread_id.0.clone());
400        let json_data = serde_json::to_string(&SerializedThread {
401            thread,
402            version: DbThread::VERSION,
403        })?;
404
405        let connection = connection.lock();
406
407        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
408        let data_type = DataType::Zstd;
409        let data = compressed;
410
411        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, String, String, DataType, Vec<u8>)>(indoc! {"
412            INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?)
413        "})?;
414
415        insert((id.0, parent_id, title, updated_at, data_type, data))?;
416
417        Ok(())
418    }
419
420    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
421        let connection = self.connection.clone();
422
423        self.executor.spawn(async move {
424            let connection = connection.lock();
425
426            let mut select = connection
427                .select_bound::<(), (Arc<str>, Option<Arc<str>>, String, String)>(indoc! {"
428                SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC
429            "})?;
430
431            let rows = select(())?;
432            let mut threads = Vec::new();
433
434            for (id, parent_id, summary, updated_at) in rows {
435                threads.push(DbThreadMetadata {
436                    id: acp::SessionId::new(id),
437                    parent_session_id: parent_id.map(acp::SessionId::new),
438                    title: summary.into(),
439                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
440                });
441            }
442
443            Ok(threads)
444        })
445    }
446
447    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
448        let connection = self.connection.clone();
449
450        self.executor.spawn(async move {
451            let connection = connection.lock();
452            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
453                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
454            "})?;
455
456            let rows = select(id.0)?;
457            if let Some((data_type, data)) = rows.into_iter().next() {
458                let json_data = match data_type {
459                    DataType::Zstd => {
460                        let decompressed = zstd::decode_all(&data[..])?;
461                        String::from_utf8(decompressed)?
462                    }
463                    DataType::Json => String::from_utf8(data)?,
464                };
465                let thread = DbThread::from_json(json_data.as_bytes())?;
466                Ok(Some(thread))
467            } else {
468                Ok(None)
469            }
470        })
471    }
472
473    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
474        let connection = self.connection.clone();
475
476        self.executor
477            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
478    }
479
480    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
481        let connection = self.connection.clone();
482
483        self.executor.spawn(async move {
484            let connection = connection.lock();
485
486            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
487                DELETE FROM threads WHERE id = ?
488            "})?;
489
490            delete(id.0)?;
491
492            Ok(())
493        })
494    }
495
496    pub fn delete_threads(&self) -> Task<Result<()>> {
497        let connection = self.connection.clone();
498
499        self.executor.spawn(async move {
500            let connection = connection.lock();
501
502            let mut delete = connection.exec_bound::<()>(indoc! {"
503                DELETE FROM threads
504            "})?;
505
506            delete(())?;
507
508            Ok(())
509        })
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use chrono::{DateTime, TimeZone, Utc};
517    use collections::HashMap;
518    use gpui::TestAppContext;
519    use std::sync::Arc;
520
521    #[test]
522    fn test_shared_thread_roundtrip() {
523        let original = SharedThread {
524            title: "Test Thread".into(),
525            messages: vec![],
526            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
527            model: None,
528            version: SharedThread::VERSION.to_string(),
529        };
530
531        let bytes = original.to_bytes().expect("Failed to serialize");
532        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
533
534        assert_eq!(restored.title, original.title);
535        assert_eq!(restored.version, original.version);
536        assert_eq!(restored.updated_at, original.updated_at);
537    }
538
539    #[test]
540    fn test_imported_flag_defaults_to_false() {
541        // Simulate deserializing a thread without the imported field (backwards compatibility).
542        let json = r#"{
543            "title": "Old Thread",
544            "messages": [],
545            "updated_at": "2024-01-01T00:00:00Z"
546        }"#;
547
548        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
549
550        assert!(
551            !db_thread.imported,
552            "Legacy threads without imported field should default to false"
553        );
554    }
555
556    fn session_id(value: &str) -> acp::SessionId {
557        acp::SessionId::new(Arc::<str>::from(value))
558    }
559
560    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
561        DbThread {
562            title: title.to_string().into(),
563            messages: Vec::new(),
564            updated_at,
565            detailed_summary: None,
566            initial_project_snapshot: None,
567            cumulative_token_usage: Default::default(),
568            request_token_usage: HashMap::default(),
569            model: None,
570            profile: None,
571            imported: false,
572            subagent_context: None,
573        }
574    }
575
576    #[gpui::test]
577    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
578        let database = ThreadsDatabase::new(cx.executor()).unwrap();
579
580        let older_id = session_id("thread-a");
581        let newer_id = session_id("thread-b");
582
583        let older_thread = make_thread(
584            "Thread A",
585            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
586        );
587        let newer_thread = make_thread(
588            "Thread B",
589            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
590        );
591
592        database
593            .save_thread(older_id.clone(), older_thread)
594            .await
595            .unwrap();
596        database
597            .save_thread(newer_id.clone(), newer_thread)
598            .await
599            .unwrap();
600
601        let entries = database.list_threads().await.unwrap();
602        assert_eq!(entries.len(), 2);
603        assert_eq!(entries[0].id, newer_id);
604        assert_eq!(entries[1].id, older_id);
605    }
606
607    #[gpui::test]
608    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
609        let database = ThreadsDatabase::new(cx.executor()).unwrap();
610
611        let thread_id = session_id("thread-a");
612        let original_thread = make_thread(
613            "Thread A",
614            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
615        );
616        let updated_thread = make_thread(
617            "Thread B",
618            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
619        );
620
621        database
622            .save_thread(thread_id.clone(), original_thread)
623            .await
624            .unwrap();
625        database
626            .save_thread(thread_id.clone(), updated_thread)
627            .await
628            .unwrap();
629
630        let entries = database.list_threads().await.unwrap();
631        assert_eq!(entries.len(), 1);
632        assert_eq!(entries[0].id, thread_id);
633        assert_eq!(entries[0].title.as_ref(), "Thread B");
634        assert_eq!(
635            entries[0].updated_at,
636            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
637        );
638    }
639
640    #[test]
641    fn test_subagent_context_defaults_to_none() {
642        let json = r#"{
643            "title": "Old Thread",
644            "messages": [],
645            "updated_at": "2024-01-01T00:00:00Z"
646        }"#;
647
648        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
649
650        assert!(
651            db_thread.subagent_context.is_none(),
652            "Legacy threads without subagent_context should default to None"
653        );
654    }
655
656    #[gpui::test]
657    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
658        let database = ThreadsDatabase::new(cx.executor()).unwrap();
659
660        let parent_id = session_id("parent-thread");
661        let child_id = session_id("child-thread");
662
663        let mut child_thread = make_thread(
664            "Subagent Thread",
665            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
666        );
667        child_thread.subagent_context = Some(crate::SubagentContext {
668            parent_thread_id: parent_id.clone(),
669            depth: 2,
670        });
671
672        database
673            .save_thread(child_id.clone(), child_thread)
674            .await
675            .unwrap();
676
677        let loaded = database
678            .load_thread(child_id)
679            .await
680            .unwrap()
681            .expect("thread should exist");
682
683        let context = loaded
684            .subagent_context
685            .expect("subagent_context should be restored");
686        assert_eq!(context.parent_thread_id, parent_id);
687        assert_eq!(context.depth, 2);
688    }
689
690    #[gpui::test]
691    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
692        let database = ThreadsDatabase::new(cx.executor()).unwrap();
693
694        let thread_id = session_id("regular-thread");
695        let thread = make_thread(
696            "Regular Thread",
697            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
698        );
699
700        database
701            .save_thread(thread_id.clone(), thread)
702            .await
703            .unwrap();
704
705        let loaded = database
706            .load_thread(thread_id)
707            .await
708            .unwrap()
709            .expect("thread should exist");
710
711        assert!(
712            loaded.subagent_context.is_none(),
713            "Regular threads should have no subagent_context"
714        );
715    }
716}