legacy_thread.rs

  1use crate::ProjectSnapshot;
  2use agent_settings::{AgentProfileId, CompletionMode};
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use gpui::SharedString;
  6use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  7use serde::{Deserialize, Serialize};
  8use std::sync::Arc;
  9
 10#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 11pub enum DetailedSummaryState {
 12    #[default]
 13    NotGenerated,
 14    Generating,
 15    Generated {
 16        text: SharedString,
 17    },
 18}
 19
 20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 21pub struct MessageId(pub usize);
 22
 23#[derive(Serialize, Deserialize, Debug, PartialEq)]
 24pub struct SerializedThread {
 25    pub version: String,
 26    pub summary: SharedString,
 27    pub updated_at: DateTime<Utc>,
 28    pub messages: Vec<SerializedMessage>,
 29    #[serde(default)]
 30    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 31    #[serde(default)]
 32    pub cumulative_token_usage: TokenUsage,
 33    #[serde(default)]
 34    pub request_token_usage: Vec<TokenUsage>,
 35    #[serde(default)]
 36    pub detailed_summary_state: DetailedSummaryState,
 37    #[serde(default)]
 38    pub model: Option<SerializedLanguageModel>,
 39    #[serde(default)]
 40    pub completion_mode: Option<CompletionMode>,
 41    #[serde(default)]
 42    pub tool_use_limit_reached: bool,
 43    #[serde(default)]
 44    pub profile: Option<AgentProfileId>,
 45}
 46
 47#[derive(Serialize, Deserialize, Debug, PartialEq)]
 48pub struct SerializedLanguageModel {
 49    pub provider: String,
 50    pub model: String,
 51}
 52
 53impl SerializedThread {
 54    pub const VERSION: &'static str = "0.2.0";
 55
 56    pub fn from_json(json: &[u8]) -> Result<Self> {
 57        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 58        match saved_thread_json.get("version") {
 59            Some(serde_json::Value::String(version)) => match version.as_str() {
 60                SerializedThreadV0_1_0::VERSION => {
 61                    let saved_thread =
 62                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 63                    Ok(saved_thread.upgrade())
 64                }
 65                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 66                    saved_thread_json,
 67                )?),
 68                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 69            },
 70            None => {
 71                let saved_thread =
 72                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 73                Ok(saved_thread.upgrade())
 74            }
 75            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 76        }
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81pub struct SerializedThreadV0_1_0(
 82    // The structure did not change, so we are reusing the latest SerializedThread.
 83    // When making the next version, make sure this points to SerializedThreadV0_2_0
 84    SerializedThread,
 85);
 86
 87impl SerializedThreadV0_1_0 {
 88    pub const VERSION: &'static str = "0.1.0";
 89
 90    pub fn upgrade(self) -> SerializedThread {
 91        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 92
 93        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 94
 95        for message in self.0.messages {
 96            if message.role == Role::User
 97                && !message.tool_results.is_empty()
 98                && let Some(last_message) = messages.last_mut()
 99            {
100                debug_assert!(last_message.role == Role::Assistant);
101
102                last_message.tool_results = message.tool_results;
103                continue;
104            }
105
106            messages.push(message);
107        }
108
109        SerializedThread {
110            messages,
111            version: SerializedThread::VERSION.to_string(),
112            ..self.0
113        }
114    }
115}
116
117#[derive(Debug, Serialize, Deserialize, PartialEq)]
118pub struct SerializedMessage {
119    pub id: MessageId,
120    pub role: Role,
121    #[serde(default)]
122    pub segments: Vec<SerializedMessageSegment>,
123    #[serde(default)]
124    pub tool_uses: Vec<SerializedToolUse>,
125    #[serde(default)]
126    pub tool_results: Vec<SerializedToolResult>,
127    #[serde(default)]
128    pub context: String,
129    #[serde(default)]
130    pub creases: Vec<SerializedCrease>,
131    #[serde(default)]
132    pub is_hidden: bool,
133}
134
135#[derive(Debug, Serialize, Deserialize, PartialEq)]
136#[serde(tag = "type")]
137pub enum SerializedMessageSegment {
138    #[serde(rename = "text")]
139    Text {
140        text: String,
141    },
142    #[serde(rename = "thinking")]
143    Thinking {
144        text: String,
145        #[serde(skip_serializing_if = "Option::is_none")]
146        signature: Option<String>,
147    },
148    RedactedThinking {
149        data: String,
150    },
151}
152
153#[derive(Debug, Serialize, Deserialize, PartialEq)]
154pub struct SerializedToolUse {
155    pub id: LanguageModelToolUseId,
156    pub name: SharedString,
157    pub input: serde_json::Value,
158}
159
160#[derive(Debug, Serialize, Deserialize, PartialEq)]
161pub struct SerializedToolResult {
162    pub tool_use_id: LanguageModelToolUseId,
163    pub is_error: bool,
164    pub content: LanguageModelToolResultContent,
165    pub output: Option<serde_json::Value>,
166}
167
168#[derive(Serialize, Deserialize)]
169struct LegacySerializedThread {
170    pub summary: SharedString,
171    pub updated_at: DateTime<Utc>,
172    pub messages: Vec<LegacySerializedMessage>,
173    #[serde(default)]
174    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
175}
176
177impl LegacySerializedThread {
178    pub fn upgrade(self) -> SerializedThread {
179        SerializedThread {
180            version: SerializedThread::VERSION.to_string(),
181            summary: self.summary,
182            updated_at: self.updated_at,
183            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
184            initial_project_snapshot: self.initial_project_snapshot,
185            cumulative_token_usage: TokenUsage::default(),
186            request_token_usage: Vec::new(),
187            detailed_summary_state: DetailedSummaryState::default(),
188            model: None,
189            completion_mode: None,
190            tool_use_limit_reached: false,
191            profile: None,
192        }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize)]
197struct LegacySerializedMessage {
198    pub id: MessageId,
199    pub role: Role,
200    pub text: String,
201    #[serde(default)]
202    pub tool_uses: Vec<SerializedToolUse>,
203    #[serde(default)]
204    pub tool_results: Vec<SerializedToolResult>,
205}
206
207impl LegacySerializedMessage {
208    fn upgrade(self) -> SerializedMessage {
209        SerializedMessage {
210            id: self.id,
211            role: self.role,
212            segments: vec![SerializedMessageSegment::Text { text: self.text }],
213            tool_uses: self.tool_uses,
214            tool_results: self.tool_results,
215            context: String::new(),
216            creases: Vec::new(),
217            is_hidden: false,
218        }
219    }
220}
221
222#[derive(Debug, Serialize, Deserialize, PartialEq)]
223pub struct SerializedCrease {
224    pub start: usize,
225    pub end: usize,
226    pub icon_path: SharedString,
227    pub label: SharedString,
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use chrono::Utc;
234    use language_model::{Role, TokenUsage};
235    use pretty_assertions::assert_eq;
236
237    #[test]
238    fn test_legacy_serialized_thread_upgrade() {
239        let updated_at = Utc::now();
240        let legacy_thread = LegacySerializedThread {
241            summary: "Test conversation".into(),
242            updated_at,
243            messages: vec![LegacySerializedMessage {
244                id: MessageId(1),
245                role: Role::User,
246                text: "Hello, world!".to_string(),
247                tool_uses: vec![],
248                tool_results: vec![],
249            }],
250            initial_project_snapshot: None,
251        };
252
253        let upgraded = legacy_thread.upgrade();
254
255        assert_eq!(
256            upgraded,
257            SerializedThread {
258                summary: "Test conversation".into(),
259                updated_at,
260                messages: vec![SerializedMessage {
261                    id: MessageId(1),
262                    role: Role::User,
263                    segments: vec![SerializedMessageSegment::Text {
264                        text: "Hello, world!".to_string()
265                    }],
266                    tool_uses: vec![],
267                    tool_results: vec![],
268                    context: "".to_string(),
269                    creases: vec![],
270                    is_hidden: false
271                }],
272                version: SerializedThread::VERSION.to_string(),
273                initial_project_snapshot: None,
274                cumulative_token_usage: TokenUsage::default(),
275                request_token_usage: vec![],
276                detailed_summary_state: DetailedSummaryState::default(),
277                model: None,
278                completion_mode: None,
279                tool_use_limit_reached: false,
280                profile: None
281            }
282        )
283    }
284
285    #[test]
286    fn test_serialized_threadv0_1_0_upgrade() {
287        let updated_at = Utc::now();
288        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
289            summary: "Test conversation".into(),
290            updated_at,
291            messages: vec![
292                SerializedMessage {
293                    id: MessageId(1),
294                    role: Role::User,
295                    segments: vec![SerializedMessageSegment::Text {
296                        text: "Use tool_1".to_string(),
297                    }],
298                    tool_uses: vec![],
299                    tool_results: vec![],
300                    context: "".to_string(),
301                    creases: vec![],
302                    is_hidden: false,
303                },
304                SerializedMessage {
305                    id: MessageId(2),
306                    role: Role::Assistant,
307                    segments: vec![SerializedMessageSegment::Text {
308                        text: "I want to use a tool".to_string(),
309                    }],
310                    tool_uses: vec![SerializedToolUse {
311                        id: "abc".into(),
312                        name: "tool_1".into(),
313                        input: serde_json::Value::Null,
314                    }],
315                    tool_results: vec![],
316                    context: "".to_string(),
317                    creases: vec![],
318                    is_hidden: false,
319                },
320                SerializedMessage {
321                    id: MessageId(1),
322                    role: Role::User,
323                    segments: vec![SerializedMessageSegment::Text {
324                        text: "Here is the tool result".to_string(),
325                    }],
326                    tool_uses: vec![],
327                    tool_results: vec![SerializedToolResult {
328                        tool_use_id: "abc".into(),
329                        is_error: false,
330                        content: LanguageModelToolResultContent::Text("abcdef".into()),
331                        output: Some(serde_json::Value::Null),
332                    }],
333                    context: "".to_string(),
334                    creases: vec![],
335                    is_hidden: false,
336                },
337            ],
338            version: SerializedThreadV0_1_0::VERSION.to_string(),
339            initial_project_snapshot: None,
340            cumulative_token_usage: TokenUsage::default(),
341            request_token_usage: vec![],
342            detailed_summary_state: DetailedSummaryState::default(),
343            model: None,
344            completion_mode: None,
345            tool_use_limit_reached: false,
346            profile: None,
347        });
348        let upgraded = thread_v0_1_0.upgrade();
349
350        assert_eq!(
351            upgraded,
352            SerializedThread {
353                summary: "Test conversation".into(),
354                updated_at,
355                messages: vec![
356                    SerializedMessage {
357                        id: MessageId(1),
358                        role: Role::User,
359                        segments: vec![SerializedMessageSegment::Text {
360                            text: "Use tool_1".to_string()
361                        }],
362                        tool_uses: vec![],
363                        tool_results: vec![],
364                        context: "".to_string(),
365                        creases: vec![],
366                        is_hidden: false
367                    },
368                    SerializedMessage {
369                        id: MessageId(2),
370                        role: Role::Assistant,
371                        segments: vec![SerializedMessageSegment::Text {
372                            text: "I want to use a tool".to_string(),
373                        }],
374                        tool_uses: vec![SerializedToolUse {
375                            id: "abc".into(),
376                            name: "tool_1".into(),
377                            input: serde_json::Value::Null,
378                        }],
379                        tool_results: vec![SerializedToolResult {
380                            tool_use_id: "abc".into(),
381                            is_error: false,
382                            content: LanguageModelToolResultContent::Text("abcdef".into()),
383                            output: Some(serde_json::Value::Null),
384                        }],
385                        context: "".to_string(),
386                        creases: vec![],
387                        is_hidden: false,
388                    },
389                ],
390                version: SerializedThread::VERSION.to_string(),
391                initial_project_snapshot: None,
392                cumulative_token_usage: TokenUsage::default(),
393                request_token_usage: vec![],
394                detailed_summary_state: DetailedSummaryState::default(),
395                model: None,
396                completion_mode: None,
397                tool_use_limit_reached: false,
398                profile: None
399            }
400        )
401    }
402}