db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub completion_mode: Option<CompletionMode>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53}
 54
 55impl DbThread {
 56    pub const VERSION: &'static str = "0.3.0";
 57
 58    pub fn from_json(json: &[u8]) -> Result<Self> {
 59        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 60        match saved_thread_json.get("version") {
 61            Some(serde_json::Value::String(version)) => match version.as_str() {
 62                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 63                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
 64                    json,
 65                )?),
 66            },
 67            _ => {
 68                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
 69            }
 70        }
 71    }
 72
 73    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
 74        let mut messages = Vec::new();
 75        let mut request_token_usage = HashMap::default();
 76
 77        let mut last_user_message_id = None;
 78        for (ix, msg) in thread.messages.into_iter().enumerate() {
 79            let message = match msg.role {
 80                language_model::Role::User => {
 81                    let mut content = Vec::new();
 82
 83                    // Convert segments to content
 84                    for segment in msg.segments {
 85                        match segment {
 86                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
 87                                content.push(UserMessageContent::Text(text));
 88                            }
 89                            crate::legacy_thread::SerializedMessageSegment::Thinking {
 90                                text,
 91                                ..
 92                            } => {
 93                                // User messages don't have thinking segments, but handle gracefully
 94                                content.push(UserMessageContent::Text(text));
 95                            }
 96                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
 97                                ..
 98                            } => {
 99                                // User messages don't have redacted thinking, skip.
100                            }
101                        }
102                    }
103
104                    // If no content was added, add context as text if available
105                    if content.is_empty() && !msg.context.is_empty() {
106                        content.push(UserMessageContent::Text(msg.context));
107                    }
108
109                    let id = UserMessageId::new();
110                    last_user_message_id = Some(id.clone());
111
112                    crate::Message::User(UserMessage {
113                        // MessageId from old format can't be meaningfully converted, so generate a new one
114                        id,
115                        content,
116                    })
117                }
118                language_model::Role::Assistant => {
119                    let mut content = Vec::new();
120
121                    // Convert segments to content
122                    for segment in msg.segments {
123                        match segment {
124                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125                                content.push(AgentMessageContent::Text(text));
126                            }
127                            crate::legacy_thread::SerializedMessageSegment::Thinking {
128                                text,
129                                signature,
130                            } => {
131                                content.push(AgentMessageContent::Thinking { text, signature });
132                            }
133                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134                                data,
135                            } => {
136                                content.push(AgentMessageContent::RedactedThinking(data));
137                            }
138                        }
139                    }
140
141                    // Convert tool uses
142                    let mut tool_names_by_id = HashMap::default();
143                    for tool_use in msg.tool_uses {
144                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145                        content.push(AgentMessageContent::ToolUse(
146                            language_model::LanguageModelToolUse {
147                                id: tool_use.id,
148                                name: tool_use.name.into(),
149                                raw_input: serde_json::to_string(&tool_use.input)
150                                    .unwrap_or_default(),
151                                input: tool_use.input,
152                                is_input_complete: true,
153                                thought_signature: None,
154                            },
155                        ));
156                    }
157
158                    // Convert tool results
159                    let mut tool_results = IndexMap::default();
160                    for tool_result in msg.tool_results {
161                        let name = tool_names_by_id
162                            .remove(&tool_result.tool_use_id)
163                            .unwrap_or_else(|| SharedString::from("unknown"));
164                        tool_results.insert(
165                            tool_result.tool_use_id.clone(),
166                            language_model::LanguageModelToolResult {
167                                tool_use_id: tool_result.tool_use_id,
168                                tool_name: name.into(),
169                                is_error: tool_result.is_error,
170                                content: tool_result.content,
171                                output: tool_result.output,
172                            },
173                        );
174                    }
175
176                    if let Some(last_user_message_id) = &last_user_message_id
177                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
178                    {
179                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
180                    }
181
182                    crate::Message::Agent(AgentMessage {
183                        content,
184                        tool_results,
185                    })
186                }
187                language_model::Role::System => {
188                    // Skip system messages as they're not supported in the new format
189                    continue;
190                }
191            };
192
193            messages.push(message);
194        }
195
196        Ok(Self {
197            title: thread.summary,
198            messages,
199            updated_at: thread.updated_at,
200            detailed_summary: match thread.detailed_summary_state {
201                crate::legacy_thread::DetailedSummaryState::NotGenerated
202                | crate::legacy_thread::DetailedSummaryState::Generating => None,
203                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
204            },
205            initial_project_snapshot: thread.initial_project_snapshot,
206            cumulative_token_usage: thread.cumulative_token_usage,
207            request_token_usage,
208            model: thread.model,
209            completion_mode: thread.completion_mode,
210            profile: thread.profile,
211        })
212    }
213}
214
215#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
216pub enum DataType {
217    #[serde(rename = "json")]
218    Json,
219    #[serde(rename = "zstd")]
220    Zstd,
221}
222
223impl Bind for DataType {
224    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
225        let value = match self {
226            DataType::Json => "json",
227            DataType::Zstd => "zstd",
228        };
229        value.bind(statement, start_index)
230    }
231}
232
233impl Column for DataType {
234    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
235        let (value, next_index) = String::column(statement, start_index)?;
236        let data_type = match value.as_str() {
237            "json" => DataType::Json,
238            "zstd" => DataType::Zstd,
239            _ => anyhow::bail!("Unknown data type: {}", value),
240        };
241        Ok((data_type, next_index))
242    }
243}
244
245pub(crate) struct ThreadsDatabase {
246    executor: BackgroundExecutor,
247    connection: Arc<Mutex<Connection>>,
248}
249
250struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
251
252impl Global for GlobalThreadsDatabase {}
253
254impl ThreadsDatabase {
255    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
256        if cx.has_global::<GlobalThreadsDatabase>() {
257            return cx.global::<GlobalThreadsDatabase>().0.clone();
258        }
259        let executor = cx.background_executor().clone();
260        let task = executor
261            .spawn({
262                let executor = executor.clone();
263                async move {
264                    match ThreadsDatabase::new(executor) {
265                        Ok(db) => Ok(Arc::new(db)),
266                        Err(err) => Err(Arc::new(err)),
267                    }
268                }
269            })
270            .shared();
271
272        cx.set_global(GlobalThreadsDatabase(task.clone()));
273        task
274    }
275
276    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
277        let connection = if *ZED_STATELESS {
278            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
279        } else if cfg!(any(feature = "test-support", test)) {
280            // rust stores the name of the test on the current thread.
281            // We use this to automatically create a database that will
282            // be shared within the test (for the test_retrieve_old_thread)
283            // but not with concurrent tests.
284            let thread = std::thread::current();
285            let test_name = thread.name();
286            Connection::open_memory(Some(&format!(
287                "THREAD_FALLBACK_{}",
288                test_name.unwrap_or_default()
289            )))
290        } else {
291            let threads_dir = paths::data_dir().join("threads");
292            std::fs::create_dir_all(&threads_dir)?;
293            let sqlite_path = threads_dir.join("threads.db");
294            Connection::open_file(&sqlite_path.to_string_lossy())
295        };
296
297        connection.exec(indoc! {"
298            CREATE TABLE IF NOT EXISTS threads (
299                id TEXT PRIMARY KEY,
300                summary TEXT NOT NULL,
301                updated_at TEXT NOT NULL,
302                data_type TEXT NOT NULL,
303                data BLOB NOT NULL
304            )
305        "})?()
306        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
307
308        let db = Self {
309            executor,
310            connection: Arc::new(Mutex::new(connection)),
311        };
312
313        Ok(db)
314    }
315
316    fn save_thread_sync(
317        connection: &Arc<Mutex<Connection>>,
318        id: acp::SessionId,
319        thread: DbThread,
320    ) -> Result<()> {
321        const COMPRESSION_LEVEL: i32 = 3;
322
323        #[derive(Serialize)]
324        struct SerializedThread {
325            #[serde(flatten)]
326            thread: DbThread,
327            version: &'static str,
328        }
329
330        let title = thread.title.to_string();
331        let updated_at = thread.updated_at.to_rfc3339();
332        let json_data = serde_json::to_string(&SerializedThread {
333            thread,
334            version: DbThread::VERSION,
335        })?;
336
337        let connection = connection.lock();
338
339        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
340        let data_type = DataType::Zstd;
341        let data = compressed;
342
343        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
344            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
345        "})?;
346
347        insert((id.0, title, updated_at, data_type, data))?;
348
349        Ok(())
350    }
351
352    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
353        let connection = self.connection.clone();
354
355        self.executor.spawn(async move {
356            let connection = connection.lock();
357
358            let mut select =
359                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
360                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
361            "})?;
362
363            let rows = select(())?;
364            let mut threads = Vec::new();
365
366            for (id, summary, updated_at) in rows {
367                threads.push(DbThreadMetadata {
368                    id: acp::SessionId(id),
369                    title: summary.into(),
370                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
371                });
372            }
373
374            Ok(threads)
375        })
376    }
377
378    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
379        let connection = self.connection.clone();
380
381        self.executor.spawn(async move {
382            let connection = connection.lock();
383            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
384                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
385            "})?;
386
387            let rows = select(id.0)?;
388            if let Some((data_type, data)) = rows.into_iter().next() {
389                let json_data = match data_type {
390                    DataType::Zstd => {
391                        let decompressed = zstd::decode_all(&data[..])?;
392                        String::from_utf8(decompressed)?
393                    }
394                    DataType::Json => String::from_utf8(data)?,
395                };
396                let thread = DbThread::from_json(json_data.as_bytes())?;
397                Ok(Some(thread))
398            } else {
399                Ok(None)
400            }
401        })
402    }
403
404    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
405        let connection = self.connection.clone();
406
407        self.executor
408            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
409    }
410
411    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
412        let connection = self.connection.clone();
413
414        self.executor.spawn(async move {
415            let connection = connection.lock();
416
417            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
418                DELETE FROM threads WHERE id = ?
419            "})?;
420
421            delete(id.0)?;
422
423            Ok(())
424        })
425    }
426}