db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub completion_mode: Option<CompletionMode>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53    #[serde(default)]
 54    pub imported: bool,
 55}
 56
 57#[derive(Debug, Clone, Serialize, Deserialize)]
 58pub struct SharedThread {
 59    pub title: SharedString,
 60    pub messages: Vec<DbMessage>,
 61    pub updated_at: DateTime<Utc>,
 62    #[serde(default)]
 63    pub model: Option<DbLanguageModel>,
 64    #[serde(default)]
 65    pub completion_mode: Option<CompletionMode>,
 66    pub version: String,
 67}
 68
 69impl SharedThread {
 70    pub const VERSION: &'static str = "1.0.0";
 71
 72    pub fn from_db_thread(thread: &DbThread) -> Self {
 73        Self {
 74            title: thread.title.clone(),
 75            messages: thread.messages.clone(),
 76            updated_at: thread.updated_at,
 77            model: thread.model.clone(),
 78            completion_mode: thread.completion_mode,
 79            version: Self::VERSION.to_string(),
 80        }
 81    }
 82
 83    pub fn to_db_thread(self) -> DbThread {
 84        DbThread {
 85            title: format!("🔗 {}", self.title).into(),
 86            messages: self.messages,
 87            updated_at: self.updated_at,
 88            detailed_summary: None,
 89            initial_project_snapshot: None,
 90            cumulative_token_usage: Default::default(),
 91            request_token_usage: Default::default(),
 92            model: self.model,
 93            completion_mode: self.completion_mode,
 94            profile: None,
 95            imported: true,
 96        }
 97    }
 98
 99    pub fn to_bytes(&self) -> Result<Vec<u8>> {
100        const COMPRESSION_LEVEL: i32 = 3;
101        let json = serde_json::to_vec(self)?;
102        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
103        Ok(compressed)
104    }
105
106    pub fn from_bytes(data: &[u8]) -> Result<Self> {
107        let decompressed = zstd::decode_all(data)?;
108        Ok(serde_json::from_slice(&decompressed)?)
109    }
110}
111
112impl DbThread {
113    pub const VERSION: &'static str = "0.3.0";
114
115    pub fn from_json(json: &[u8]) -> Result<Self> {
116        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
117        match saved_thread_json.get("version") {
118            Some(serde_json::Value::String(version)) => match version.as_str() {
119                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
120                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
121                    json,
122                )?),
123            },
124            _ => {
125                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
126            }
127        }
128    }
129
130    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
131        let mut messages = Vec::new();
132        let mut request_token_usage = HashMap::default();
133
134        let mut last_user_message_id = None;
135        for (ix, msg) in thread.messages.into_iter().enumerate() {
136            let message = match msg.role {
137                language_model::Role::User => {
138                    let mut content = Vec::new();
139
140                    // Convert segments to content
141                    for segment in msg.segments {
142                        match segment {
143                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
144                                content.push(UserMessageContent::Text(text));
145                            }
146                            crate::legacy_thread::SerializedMessageSegment::Thinking {
147                                text,
148                                ..
149                            } => {
150                                // User messages don't have thinking segments, but handle gracefully
151                                content.push(UserMessageContent::Text(text));
152                            }
153                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
154                                ..
155                            } => {
156                                // User messages don't have redacted thinking, skip.
157                            }
158                        }
159                    }
160
161                    // If no content was added, add context as text if available
162                    if content.is_empty() && !msg.context.is_empty() {
163                        content.push(UserMessageContent::Text(msg.context));
164                    }
165
166                    let id = UserMessageId::new();
167                    last_user_message_id = Some(id.clone());
168
169                    crate::Message::User(UserMessage {
170                        // MessageId from old format can't be meaningfully converted, so generate a new one
171                        id,
172                        content,
173                    })
174                }
175                language_model::Role::Assistant => {
176                    let mut content = Vec::new();
177
178                    // Convert segments to content
179                    for segment in msg.segments {
180                        match segment {
181                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
182                                content.push(AgentMessageContent::Text(text));
183                            }
184                            crate::legacy_thread::SerializedMessageSegment::Thinking {
185                                text,
186                                signature,
187                            } => {
188                                content.push(AgentMessageContent::Thinking { text, signature });
189                            }
190                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
191                                data,
192                            } => {
193                                content.push(AgentMessageContent::RedactedThinking(data));
194                            }
195                        }
196                    }
197
198                    // Convert tool uses
199                    let mut tool_names_by_id = HashMap::default();
200                    for tool_use in msg.tool_uses {
201                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
202                        content.push(AgentMessageContent::ToolUse(
203                            language_model::LanguageModelToolUse {
204                                id: tool_use.id,
205                                name: tool_use.name.into(),
206                                raw_input: serde_json::to_string(&tool_use.input)
207                                    .unwrap_or_default(),
208                                input: tool_use.input,
209                                is_input_complete: true,
210                                thought_signature: None,
211                            },
212                        ));
213                    }
214
215                    // Convert tool results
216                    let mut tool_results = IndexMap::default();
217                    for tool_result in msg.tool_results {
218                        let name = tool_names_by_id
219                            .remove(&tool_result.tool_use_id)
220                            .unwrap_or_else(|| SharedString::from("unknown"));
221                        tool_results.insert(
222                            tool_result.tool_use_id.clone(),
223                            language_model::LanguageModelToolResult {
224                                tool_use_id: tool_result.tool_use_id,
225                                tool_name: name.into(),
226                                is_error: tool_result.is_error,
227                                content: tool_result.content,
228                                output: tool_result.output,
229                            },
230                        );
231                    }
232
233                    if let Some(last_user_message_id) = &last_user_message_id
234                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
235                    {
236                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
237                    }
238
239                    crate::Message::Agent(AgentMessage {
240                        content,
241                        tool_results,
242                        reasoning_details: None,
243                    })
244                }
245                language_model::Role::System => {
246                    // Skip system messages as they're not supported in the new format
247                    continue;
248                }
249            };
250
251            messages.push(message);
252        }
253
254        Ok(Self {
255            title: thread.summary,
256            messages,
257            updated_at: thread.updated_at,
258            detailed_summary: match thread.detailed_summary_state {
259                crate::legacy_thread::DetailedSummaryState::NotGenerated
260                | crate::legacy_thread::DetailedSummaryState::Generating => None,
261                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
262            },
263            initial_project_snapshot: thread.initial_project_snapshot,
264            cumulative_token_usage: thread.cumulative_token_usage,
265            request_token_usage,
266            model: thread.model,
267            completion_mode: thread.completion_mode,
268            profile: thread.profile,
269            imported: false,
270        })
271    }
272}
273
274#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
275pub enum DataType {
276    #[serde(rename = "json")]
277    Json,
278    #[serde(rename = "zstd")]
279    Zstd,
280}
281
282impl Bind for DataType {
283    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
284        let value = match self {
285            DataType::Json => "json",
286            DataType::Zstd => "zstd",
287        };
288        value.bind(statement, start_index)
289    }
290}
291
292impl Column for DataType {
293    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
294        let (value, next_index) = String::column(statement, start_index)?;
295        let data_type = match value.as_str() {
296            "json" => DataType::Json,
297            "zstd" => DataType::Zstd,
298            _ => anyhow::bail!("Unknown data type: {}", value),
299        };
300        Ok((data_type, next_index))
301    }
302}
303
304pub(crate) struct ThreadsDatabase {
305    executor: BackgroundExecutor,
306    connection: Arc<Mutex<Connection>>,
307}
308
309struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
310
311impl Global for GlobalThreadsDatabase {}
312
313impl ThreadsDatabase {
314    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
315        if cx.has_global::<GlobalThreadsDatabase>() {
316            return cx.global::<GlobalThreadsDatabase>().0.clone();
317        }
318        let executor = cx.background_executor().clone();
319        let task = executor
320            .spawn({
321                let executor = executor.clone();
322                async move {
323                    match ThreadsDatabase::new(executor) {
324                        Ok(db) => Ok(Arc::new(db)),
325                        Err(err) => Err(Arc::new(err)),
326                    }
327                }
328            })
329            .shared();
330
331        cx.set_global(GlobalThreadsDatabase(task.clone()));
332        task
333    }
334
335    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
336        let connection = if *ZED_STATELESS {
337            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
338        } else if cfg!(any(feature = "test-support", test)) {
339            // rust stores the name of the test on the current thread.
340            // We use this to automatically create a database that will
341            // be shared within the test (for the test_retrieve_old_thread)
342            // but not with concurrent tests.
343            let thread = std::thread::current();
344            let test_name = thread.name();
345            Connection::open_memory(Some(&format!(
346                "THREAD_FALLBACK_{}",
347                test_name.unwrap_or_default()
348            )))
349        } else {
350            let threads_dir = paths::data_dir().join("threads");
351            std::fs::create_dir_all(&threads_dir)?;
352            let sqlite_path = threads_dir.join("threads.db");
353            Connection::open_file(&sqlite_path.to_string_lossy())
354        };
355
356        connection.exec(indoc! {"
357            CREATE TABLE IF NOT EXISTS threads (
358                id TEXT PRIMARY KEY,
359                summary TEXT NOT NULL,
360                updated_at TEXT NOT NULL,
361                data_type TEXT NOT NULL,
362                data BLOB NOT NULL
363            )
364        "})?()
365        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
366
367        let db = Self {
368            executor,
369            connection: Arc::new(Mutex::new(connection)),
370        };
371
372        Ok(db)
373    }
374
375    fn save_thread_sync(
376        connection: &Arc<Mutex<Connection>>,
377        id: acp::SessionId,
378        thread: DbThread,
379    ) -> Result<()> {
380        const COMPRESSION_LEVEL: i32 = 3;
381
382        #[derive(Serialize)]
383        struct SerializedThread {
384            #[serde(flatten)]
385            thread: DbThread,
386            version: &'static str,
387        }
388
389        let title = thread.title.to_string();
390        let updated_at = thread.updated_at.to_rfc3339();
391        let json_data = serde_json::to_string(&SerializedThread {
392            thread,
393            version: DbThread::VERSION,
394        })?;
395
396        let connection = connection.lock();
397
398        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
399        let data_type = DataType::Zstd;
400        let data = compressed;
401
402        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
403            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
404        "})?;
405
406        insert((id.0, title, updated_at, data_type, data))?;
407
408        Ok(())
409    }
410
411    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
412        let connection = self.connection.clone();
413
414        self.executor.spawn(async move {
415            let connection = connection.lock();
416
417            let mut select =
418                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
419                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
420            "})?;
421
422            let rows = select(())?;
423            let mut threads = Vec::new();
424
425            for (id, summary, updated_at) in rows {
426                threads.push(DbThreadMetadata {
427                    id: acp::SessionId::new(id),
428                    title: summary.into(),
429                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
430                });
431            }
432
433            Ok(threads)
434        })
435    }
436
437    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
438        let connection = self.connection.clone();
439
440        self.executor.spawn(async move {
441            let connection = connection.lock();
442            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
443                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
444            "})?;
445
446            let rows = select(id.0)?;
447            if let Some((data_type, data)) = rows.into_iter().next() {
448                let json_data = match data_type {
449                    DataType::Zstd => {
450                        let decompressed = zstd::decode_all(&data[..])?;
451                        String::from_utf8(decompressed)?
452                    }
453                    DataType::Json => String::from_utf8(data)?,
454                };
455                let thread = DbThread::from_json(json_data.as_bytes())?;
456                Ok(Some(thread))
457            } else {
458                Ok(None)
459            }
460        })
461    }
462
463    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
464        let connection = self.connection.clone();
465
466        self.executor
467            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
468    }
469
470    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
471        let connection = self.connection.clone();
472
473        self.executor.spawn(async move {
474            let connection = connection.lock();
475
476            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
477                DELETE FROM threads WHERE id = ?
478            "})?;
479
480            delete(id.0)?;
481
482            Ok(())
483        })
484    }
485
486    pub fn delete_threads(&self) -> Task<Result<()>> {
487        let connection = self.connection.clone();
488
489        self.executor.spawn(async move {
490            let connection = connection.lock();
491
492            let mut delete = connection.exec_bound::<()>(indoc! {"
493                DELETE FROM threads
494            "})?;
495
496            delete(())?;
497
498            Ok(())
499        })
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use chrono::{DateTime, TimeZone, Utc};
507    use collections::HashMap;
508    use gpui::TestAppContext;
509    use std::sync::Arc;
510
511    #[test]
512    fn test_shared_thread_roundtrip() {
513        let original = SharedThread {
514            title: "Test Thread".into(),
515            messages: vec![],
516            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
517            model: None,
518            completion_mode: None,
519            version: SharedThread::VERSION.to_string(),
520        };
521
522        let bytes = original.to_bytes().expect("Failed to serialize");
523        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
524
525        assert_eq!(restored.title, original.title);
526        assert_eq!(restored.version, original.version);
527        assert_eq!(restored.updated_at, original.updated_at);
528    }
529
530    #[test]
531    fn test_imported_flag_defaults_to_false() {
532        // Simulate deserializing a thread without the imported field (backwards compatibility).
533        let json = r#"{
534            "title": "Old Thread",
535            "messages": [],
536            "updated_at": "2024-01-01T00:00:00Z"
537        }"#;
538
539        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
540
541        assert!(
542            !db_thread.imported,
543            "Legacy threads without imported field should default to false"
544        );
545    }
546
547    fn session_id(value: &str) -> acp::SessionId {
548        acp::SessionId::new(Arc::<str>::from(value))
549    }
550
551    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
552        DbThread {
553            title: title.to_string().into(),
554            messages: Vec::new(),
555            updated_at,
556            detailed_summary: None,
557            initial_project_snapshot: None,
558            cumulative_token_usage: Default::default(),
559            request_token_usage: HashMap::default(),
560            model: None,
561            completion_mode: None,
562            profile: None,
563            imported: false,
564        }
565    }
566
567    #[gpui::test]
568    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
569        let database = ThreadsDatabase::new(cx.executor()).unwrap();
570
571        let older_id = session_id("thread-a");
572        let newer_id = session_id("thread-b");
573
574        let older_thread = make_thread(
575            "Thread A",
576            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
577        );
578        let newer_thread = make_thread(
579            "Thread B",
580            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
581        );
582
583        database
584            .save_thread(older_id.clone(), older_thread)
585            .await
586            .unwrap();
587        database
588            .save_thread(newer_id.clone(), newer_thread)
589            .await
590            .unwrap();
591
592        let entries = database.list_threads().await.unwrap();
593        assert_eq!(entries.len(), 2);
594        assert_eq!(entries[0].id, newer_id);
595        assert_eq!(entries[1].id, older_id);
596    }
597
598    #[gpui::test]
599    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
600        let database = ThreadsDatabase::new(cx.executor()).unwrap();
601
602        let thread_id = session_id("thread-a");
603        let original_thread = make_thread(
604            "Thread A",
605            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
606        );
607        let updated_thread = make_thread(
608            "Thread B",
609            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
610        );
611
612        database
613            .save_thread(thread_id.clone(), original_thread)
614            .await
615            .unwrap();
616        database
617            .save_thread(thread_id.clone(), updated_thread)
618            .await
619            .unwrap();
620
621        let entries = database.list_threads().await.unwrap();
622        assert_eq!(entries.len(), 1);
623        assert_eq!(entries[0].id, thread_id);
624        assert_eq!(entries[0].title.as_ref(), "Thread B");
625        assert_eq!(
626            entries[0].updated_at,
627            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
628        );
629    }
630}