legacy_thread.rs

  1use crate::ProjectSnapshot;
  2use agent_settings::AgentProfileId;
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use gpui::SharedString;
  6use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  7use serde::{Deserialize, Serialize};
  8use std::sync::Arc;
  9
 10#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 11pub enum DetailedSummaryState {
 12    #[default]
 13    NotGenerated,
 14    Generating,
 15    Generated {
 16        text: SharedString,
 17    },
 18}
 19
 20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 21pub struct MessageId(pub usize);
 22
 23#[derive(Serialize, Deserialize, Debug, PartialEq)]
 24pub struct SerializedThread {
 25    pub version: String,
 26    pub summary: SharedString,
 27    pub updated_at: DateTime<Utc>,
 28    pub messages: Vec<SerializedMessage>,
 29    #[serde(default)]
 30    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 31    #[serde(default)]
 32    pub cumulative_token_usage: TokenUsage,
 33    #[serde(default)]
 34    pub request_token_usage: Vec<TokenUsage>,
 35    #[serde(default)]
 36    pub detailed_summary_state: DetailedSummaryState,
 37    #[serde(default)]
 38    pub model: Option<SerializedLanguageModel>,
 39    #[serde(default)]
 40    pub tool_use_limit_reached: bool,
 41    #[serde(default)]
 42    pub profile: Option<AgentProfileId>,
 43}
 44
 45#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
 46pub struct SerializedLanguageModel {
 47    pub provider: String,
 48    pub model: String,
 49}
 50
 51impl SerializedThread {
 52    pub const VERSION: &'static str = "0.2.0";
 53
 54    pub fn from_json(json: &[u8]) -> Result<Self> {
 55        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 56        match saved_thread_json.get("version") {
 57            Some(serde_json::Value::String(version)) => match version.as_str() {
 58                SerializedThreadV0_1_0::VERSION => {
 59                    let saved_thread =
 60                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 61                    Ok(saved_thread.upgrade())
 62                }
 63                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 64                    saved_thread_json,
 65                )?),
 66                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 67            },
 68            None => {
 69                let saved_thread =
 70                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 71                Ok(saved_thread.upgrade())
 72            }
 73            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 74        }
 75    }
 76}
 77
 78#[derive(Serialize, Deserialize, Debug)]
 79pub struct SerializedThreadV0_1_0(
 80    // The structure did not change, so we are reusing the latest SerializedThread.
 81    // When making the next version, make sure this points to SerializedThreadV0_2_0
 82    SerializedThread,
 83);
 84
 85impl SerializedThreadV0_1_0 {
 86    pub const VERSION: &'static str = "0.1.0";
 87
 88    pub fn upgrade(self) -> SerializedThread {
 89        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 90
 91        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 92
 93        for message in self.0.messages {
 94            if message.role == Role::User
 95                && !message.tool_results.is_empty()
 96                && let Some(last_message) = messages.last_mut()
 97            {
 98                debug_assert!(last_message.role == Role::Assistant);
 99
100                last_message.tool_results = message.tool_results;
101                continue;
102            }
103
104            messages.push(message);
105        }
106
107        SerializedThread {
108            messages,
109            version: SerializedThread::VERSION.to_string(),
110            ..self.0
111        }
112    }
113}
114
115#[derive(Debug, Serialize, Deserialize, PartialEq)]
116pub struct SerializedMessage {
117    pub id: MessageId,
118    pub role: Role,
119    #[serde(default)]
120    pub segments: Vec<SerializedMessageSegment>,
121    #[serde(default)]
122    pub tool_uses: Vec<SerializedToolUse>,
123    #[serde(default)]
124    pub tool_results: Vec<SerializedToolResult>,
125    #[serde(default)]
126    pub context: String,
127    #[serde(default)]
128    pub creases: Vec<SerializedCrease>,
129    #[serde(default)]
130    pub is_hidden: bool,
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq)]
134#[serde(tag = "type")]
135pub enum SerializedMessageSegment {
136    #[serde(rename = "text")]
137    Text {
138        text: String,
139    },
140    #[serde(rename = "thinking")]
141    Thinking {
142        text: String,
143        #[serde(skip_serializing_if = "Option::is_none")]
144        signature: Option<String>,
145    },
146    RedactedThinking {
147        data: String,
148    },
149}
150
151#[derive(Debug, Serialize, Deserialize, PartialEq)]
152pub struct SerializedToolUse {
153    pub id: LanguageModelToolUseId,
154    pub name: SharedString,
155    pub input: serde_json::Value,
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq)]
159pub struct SerializedToolResult {
160    pub tool_use_id: LanguageModelToolUseId,
161    pub is_error: bool,
162    pub content: LanguageModelToolResultContent,
163    pub output: Option<serde_json::Value>,
164}
165
166#[derive(Serialize, Deserialize)]
167struct LegacySerializedThread {
168    pub summary: SharedString,
169    pub updated_at: DateTime<Utc>,
170    pub messages: Vec<LegacySerializedMessage>,
171    #[serde(default)]
172    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
173}
174
175impl LegacySerializedThread {
176    pub fn upgrade(self) -> SerializedThread {
177        SerializedThread {
178            version: SerializedThread::VERSION.to_string(),
179            summary: self.summary,
180            updated_at: self.updated_at,
181            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
182            initial_project_snapshot: self.initial_project_snapshot,
183            cumulative_token_usage: TokenUsage::default(),
184            request_token_usage: Vec::new(),
185            detailed_summary_state: DetailedSummaryState::default(),
186            model: None,
187            tool_use_limit_reached: false,
188            profile: None,
189        }
190    }
191}
192
193#[derive(Debug, Serialize, Deserialize)]
194struct LegacySerializedMessage {
195    pub id: MessageId,
196    pub role: Role,
197    pub text: String,
198    #[serde(default)]
199    pub tool_uses: Vec<SerializedToolUse>,
200    #[serde(default)]
201    pub tool_results: Vec<SerializedToolResult>,
202}
203
204impl LegacySerializedMessage {
205    fn upgrade(self) -> SerializedMessage {
206        SerializedMessage {
207            id: self.id,
208            role: self.role,
209            segments: vec![SerializedMessageSegment::Text { text: self.text }],
210            tool_uses: self.tool_uses,
211            tool_results: self.tool_results,
212            context: String::new(),
213            creases: Vec::new(),
214            is_hidden: false,
215        }
216    }
217}
218
219#[derive(Debug, Serialize, Deserialize, PartialEq)]
220pub struct SerializedCrease {
221    pub start: usize,
222    pub end: usize,
223    pub icon_path: SharedString,
224    pub label: SharedString,
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use chrono::Utc;
231    use language_model::{Role, TokenUsage};
232    use pretty_assertions::assert_eq;
233
234    #[test]
235    fn test_legacy_serialized_thread_upgrade() {
236        let updated_at = Utc::now();
237        let legacy_thread = LegacySerializedThread {
238            summary: "Test conversation".into(),
239            updated_at,
240            messages: vec![LegacySerializedMessage {
241                id: MessageId(1),
242                role: Role::User,
243                text: "Hello, world!".to_string(),
244                tool_uses: vec![],
245                tool_results: vec![],
246            }],
247            initial_project_snapshot: None,
248        };
249
250        let upgraded = legacy_thread.upgrade();
251
252        assert_eq!(
253            upgraded,
254            SerializedThread {
255                summary: "Test conversation".into(),
256                updated_at,
257                messages: vec![SerializedMessage {
258                    id: MessageId(1),
259                    role: Role::User,
260                    segments: vec![SerializedMessageSegment::Text {
261                        text: "Hello, world!".to_string()
262                    }],
263                    tool_uses: vec![],
264                    tool_results: vec![],
265                    context: "".to_string(),
266                    creases: vec![],
267                    is_hidden: false
268                }],
269                version: SerializedThread::VERSION.to_string(),
270                initial_project_snapshot: None,
271                cumulative_token_usage: TokenUsage::default(),
272                request_token_usage: vec![],
273                detailed_summary_state: DetailedSummaryState::default(),
274                model: None,
275                tool_use_limit_reached: false,
276                profile: None
277            }
278        )
279    }
280
281    #[test]
282    fn test_serialized_threadv0_1_0_upgrade() {
283        let updated_at = Utc::now();
284        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
285            summary: "Test conversation".into(),
286            updated_at,
287            messages: vec![
288                SerializedMessage {
289                    id: MessageId(1),
290                    role: Role::User,
291                    segments: vec![SerializedMessageSegment::Text {
292                        text: "Use tool_1".to_string(),
293                    }],
294                    tool_uses: vec![],
295                    tool_results: vec![],
296                    context: "".to_string(),
297                    creases: vec![],
298                    is_hidden: false,
299                },
300                SerializedMessage {
301                    id: MessageId(2),
302                    role: Role::Assistant,
303                    segments: vec![SerializedMessageSegment::Text {
304                        text: "I want to use a tool".to_string(),
305                    }],
306                    tool_uses: vec![SerializedToolUse {
307                        id: "abc".into(),
308                        name: "tool_1".into(),
309                        input: serde_json::Value::Null,
310                    }],
311                    tool_results: vec![],
312                    context: "".to_string(),
313                    creases: vec![],
314                    is_hidden: false,
315                },
316                SerializedMessage {
317                    id: MessageId(1),
318                    role: Role::User,
319                    segments: vec![SerializedMessageSegment::Text {
320                        text: "Here is the tool result".to_string(),
321                    }],
322                    tool_uses: vec![],
323                    tool_results: vec![SerializedToolResult {
324                        tool_use_id: "abc".into(),
325                        is_error: false,
326                        content: LanguageModelToolResultContent::Text("abcdef".into()),
327                        output: Some(serde_json::Value::Null),
328                    }],
329                    context: "".to_string(),
330                    creases: vec![],
331                    is_hidden: false,
332                },
333            ],
334            version: SerializedThreadV0_1_0::VERSION.to_string(),
335            initial_project_snapshot: None,
336            cumulative_token_usage: TokenUsage::default(),
337            request_token_usage: vec![],
338            detailed_summary_state: DetailedSummaryState::default(),
339            model: None,
340            tool_use_limit_reached: false,
341            profile: None,
342        });
343        let upgraded = thread_v0_1_0.upgrade();
344
345        assert_eq!(
346            upgraded,
347            SerializedThread {
348                summary: "Test conversation".into(),
349                updated_at,
350                messages: vec![
351                    SerializedMessage {
352                        id: MessageId(1),
353                        role: Role::User,
354                        segments: vec![SerializedMessageSegment::Text {
355                            text: "Use tool_1".to_string()
356                        }],
357                        tool_uses: vec![],
358                        tool_results: vec![],
359                        context: "".to_string(),
360                        creases: vec![],
361                        is_hidden: false
362                    },
363                    SerializedMessage {
364                        id: MessageId(2),
365                        role: Role::Assistant,
366                        segments: vec![SerializedMessageSegment::Text {
367                            text: "I want to use a tool".to_string(),
368                        }],
369                        tool_uses: vec![SerializedToolUse {
370                            id: "abc".into(),
371                            name: "tool_1".into(),
372                            input: serde_json::Value::Null,
373                        }],
374                        tool_results: vec![SerializedToolResult {
375                            tool_use_id: "abc".into(),
376                            is_error: false,
377                            content: LanguageModelToolResultContent::Text("abcdef".into()),
378                            output: Some(serde_json::Value::Null),
379                        }],
380                        context: "".to_string(),
381                        creases: vec![],
382                        is_hidden: false,
383                    },
384                ],
385                version: SerializedThread::VERSION.to_string(),
386                initial_project_snapshot: None,
387                cumulative_token_usage: TokenUsage::default(),
388                request_token_usage: vec![],
389                detailed_summary_state: DetailedSummaryState::default(),
390                model: None,
391                tool_use_limit_reached: false,
392                profile: None
393            }
394        )
395    }
396}