Detailed changes
@@ -24,6 +24,16 @@ pub struct AnthropicModelCacheConfiguration {
pub max_cache_anchors: usize,
}
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum AnthropicModelMode {
+ #[default]
+ Default,
+ Thinking {
+ budget_tokens: Option<u32>,
+ },
+}
+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
@@ -32,6 +42,11 @@ pub enum Model {
Claude3_5Sonnet,
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet,
+ #[serde(
+ rename = "claude-3-7-sonnet-thinking",
+ alias = "claude-3-7-sonnet-thinking-latest"
+ )]
+ Claude3_7SonnetThinking,
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
Claude3_5Haiku,
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
@@ -54,6 +69,8 @@ pub enum Model {
default_temperature: Option<f32>,
#[serde(default)]
extra_beta_headers: Vec<String>,
+ #[serde(default)]
+ mode: AnthropicModelMode,
},
}
@@ -61,6 +78,8 @@ impl Model {
pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-5-sonnet") {
Ok(Self::Claude3_5Sonnet)
+ } else if id.starts_with("claude-3-7-sonnet-thinking") {
+ Ok(Self::Claude3_7SonnetThinking)
} else if id.starts_with("claude-3-7-sonnet") {
Ok(Self::Claude3_7Sonnet)
} else if id.starts_with("claude-3-5-haiku") {
@@ -80,6 +99,20 @@ impl Model {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
Model::Claude3_7Sonnet => "claude-3-7-sonnet-latest",
+ Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-thinking-latest",
+ Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
+ Model::Claude3Opus => "claude-3-opus-latest",
+ Model::Claude3Sonnet => "claude-3-sonnet-20240229",
+ Model::Claude3Haiku => "claude-3-haiku-20240307",
+ Self::Custom { name, .. } => name,
+ }
+ }
+
+ /// The id of the model that should be used for making API requests
+ pub fn request_id(&self) -> &str {
+ match self {
+ Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
+ Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest",
Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
Model::Claude3Opus => "claude-3-opus-latest",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
@@ -92,6 +125,7 @@ impl Model {
match self {
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
+ Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::Claude3_5Haiku => "Claude 3.5 Haiku",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
@@ -107,6 +141,7 @@ impl Model {
Self::Claude3_5Sonnet
| Self::Claude3_5Haiku
| Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
| Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
min_total_token: 2_048,
should_speculate: true,
@@ -125,6 +160,7 @@ impl Model {
Self::Claude3_5Sonnet
| Self::Claude3_5Haiku
| Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000,
@@ -135,7 +171,10 @@ impl Model {
pub fn max_output_tokens(&self) -> u32 {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
- Self::Claude3_5Sonnet | Self::Claude3_7Sonnet | Self::Claude3_5Haiku => 8_192,
+ Self::Claude3_5Sonnet
+ | Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
+ | Self::Claude3_5Haiku => 8_192,
Self::Custom {
max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096),
@@ -146,6 +185,7 @@ impl Model {
match self {
Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
| Self::Claude3_5Haiku
| Self::Claude3Opus
| Self::Claude3Sonnet
@@ -157,6 +197,21 @@ impl Model {
}
}
+ pub fn mode(&self) -> AnthropicModelMode {
+ match self {
+ Self::Claude3_5Sonnet
+ | Self::Claude3_7Sonnet
+ | Self::Claude3_5Haiku
+ | Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3Haiku => AnthropicModelMode::Default,
+ Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking {
+ budget_tokens: Some(4_096),
+ },
+ Self::Custom { mode, .. } => mode.clone(),
+ }
+ }
+
pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"];
pub fn beta_headers(&self) -> String {
@@ -188,7 +243,7 @@ impl Model {
{
tool_override
} else {
- self.id()
+ self.request_id()
}
}
}
@@ -409,6 +464,8 @@ pub async fn extract_tool_args_from_events(
Err(error) => Some(Err(error)),
Ok(Event::ContentBlockDelta { index, delta }) => match delta {
ContentDelta::TextDelta { .. } => None,
+ ContentDelta::ThinkingDelta { .. } => None,
+ ContentDelta::SignatureDelta { .. } => None,
ContentDelta::InputJsonDelta { partial_json } => {
if index == tool_use_index {
Some(Ok(partial_json))
@@ -487,6 +544,10 @@ pub enum RequestContent {
pub enum ResponseContent {
#[serde(rename = "text")]
Text { text: String },
+ #[serde(rename = "thinking")]
+ Thinking { thinking: String },
+ #[serde(rename = "redacted_thinking")]
+ RedactedThinking { data: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
@@ -518,6 +579,12 @@ pub enum ToolChoice {
Tool { name: String },
}
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum Thinking {
+ Enabled { budget_tokens: Option<u32> },
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
@@ -526,6 +593,8 @@ pub struct Request {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
+ pub thinking: Option<Thinking>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
@@ -609,6 +678,10 @@ pub enum Event {
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
+ #[serde(rename = "thinking_delta")]
+ ThinkingDelta { thinking: String },
+ #[serde(rename = "signature_delta")]
+ SignatureDelta { signature: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
@@ -1,5 +1,6 @@
use crate::thread::{
- LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback,
+ LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
+ ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
@@ -7,10 +8,10 @@ use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
- list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
- ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
- ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
- Transformation, UnderlineStyle, WeakEntity,
+ linear_color_stop, linear_gradient, list, percentage, pulsating_between, AbsoluteLength,
+ Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
+ Entity, Focusable, Length, ListAlignment, ListOffset, ListState, ScrollHandle, StyleRefinement,
+ Subscription, Task, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@@ -35,15 +36,175 @@ pub struct ActiveThread {
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
list_state: ListState,
- rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+ rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
+ expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
+struct RenderedMessage {
+ language_registry: Arc<LanguageRegistry>,
+ segments: Vec<RenderedMessageSegment>,
+}
+
+impl RenderedMessage {
+ fn from_segments(
+ segments: &[MessageSegment],
+ language_registry: Arc<LanguageRegistry>,
+ window: &Window,
+ cx: &mut App,
+ ) -> Self {
+ let mut this = Self {
+ language_registry,
+ segments: Vec::with_capacity(segments.len()),
+ };
+ for segment in segments {
+ this.push_segment(segment, window, cx);
+ }
+ this
+ }
+
+ fn append_thinking(&mut self, text: &String, window: &Window, cx: &mut App) {
+ if let Some(RenderedMessageSegment::Thinking {
+ content,
+ scroll_handle,
+ }) = self.segments.last_mut()
+ {
+ content.update(cx, |markdown, cx| {
+ markdown.append(text, cx);
+ });
+ scroll_handle.scroll_to_bottom();
+ } else {
+ self.segments.push(RenderedMessageSegment::Thinking {
+ content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
+ scroll_handle: ScrollHandle::default(),
+ });
+ }
+ }
+
+ fn append_text(&mut self, text: &String, window: &Window, cx: &mut App) {
+ if let Some(RenderedMessageSegment::Text(markdown)) = self.segments.last_mut() {
+ markdown.update(cx, |markdown, cx| markdown.append(text, cx));
+ } else {
+ self.segments
+ .push(RenderedMessageSegment::Text(render_markdown(
+ SharedString::from(text),
+ self.language_registry.clone(),
+ window,
+ cx,
+ )));
+ }
+ }
+
+ fn push_segment(&mut self, segment: &MessageSegment, window: &Window, cx: &mut App) {
+ let rendered_segment = match segment {
+ MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking {
+ content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
+ scroll_handle: ScrollHandle::default(),
+ },
+ MessageSegment::Text(text) => RenderedMessageSegment::Text(render_markdown(
+ text.into(),
+ self.language_registry.clone(),
+ window,
+ cx,
+ )),
+ };
+ self.segments.push(rendered_segment);
+ }
+}
+
+enum RenderedMessageSegment {
+ Thinking {
+ content: Entity<Markdown>,
+ scroll_handle: ScrollHandle,
+ },
+ Text(Entity<Markdown>),
+}
+
+fn render_markdown(
+ text: SharedString,
+ language_registry: Arc<LanguageRegistry>,
+ window: &Window,
+ cx: &mut App,
+) -> Entity<Markdown> {
+ let theme_settings = ThemeSettings::get_global(cx);
+ let colors = cx.theme().colors();
+ let ui_font_size = TextSize::Default.rems(cx);
+ let buffer_font_size = TextSize::Small.rems(cx);
+ let mut text_style = window.text_style();
+
+ text_style.refine(&TextStyleRefinement {
+ font_family: Some(theme_settings.ui_font.family.clone()),
+ font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
+ font_features: Some(theme_settings.ui_font.features.clone()),
+ font_size: Some(ui_font_size.into()),
+ color: Some(cx.theme().colors().text),
+ ..Default::default()
+ });
+
+ let markdown_style = MarkdownStyle {
+ base_text_style: text_style,
+ syntax: cx.theme().syntax().clone(),
+ selection_background_color: cx.theme().players().local().selection,
+ code_block_overflow_x_scroll: true,
+ table_overflow_x_scroll: true,
+ code_block: StyleRefinement {
+ margin: EdgesRefinement {
+ top: Some(Length::Definite(rems(0.).into())),
+ left: Some(Length::Definite(rems(0.).into())),
+ right: Some(Length::Definite(rems(0.).into())),
+ bottom: Some(Length::Definite(rems(0.5).into())),
+ },
+ padding: EdgesRefinement {
+ top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
+ left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
+ right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
+ bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
+ },
+ background: Some(colors.editor_background.into()),
+ border_color: Some(colors.border_variant),
+ border_widths: EdgesRefinement {
+ top: Some(AbsoluteLength::Pixels(Pixels(1.))),
+ left: Some(AbsoluteLength::Pixels(Pixels(1.))),
+ right: Some(AbsoluteLength::Pixels(Pixels(1.))),
+ bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
+ },
+ text: Some(TextStyleRefinement {
+ font_family: Some(theme_settings.buffer_font.family.clone()),
+ font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
+ font_features: Some(theme_settings.buffer_font.features.clone()),
+ font_size: Some(buffer_font_size.into()),
+ ..Default::default()
+ }),
+ ..Default::default()
+ },
+ inline_code: TextStyleRefinement {
+ font_family: Some(theme_settings.buffer_font.family.clone()),
+ font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
+ font_features: Some(theme_settings.buffer_font.features.clone()),
+ font_size: Some(buffer_font_size.into()),
+ background_color: Some(colors.editor_foreground.opacity(0.1)),
+ ..Default::default()
+ },
+ link: TextStyleRefinement {
+ background_color: Some(colors.editor_foreground.opacity(0.025)),
+ underline: Some(UnderlineStyle {
+ color: Some(colors.text_accent.opacity(0.5)),
+ thickness: px(1.),
+ ..Default::default()
+ }),
+ ..Default::default()
+ },
+ ..Default::default()
+ };
+
+ cx.new(|cx| Markdown::new(text, markdown_style, Some(language_registry), None, cx))
+}
+
struct EditMessageState {
editor: Entity<Editor>,
}
@@ -75,6 +236,7 @@ impl ActiveThread {
rendered_scripting_tool_uses: HashMap::default(),
rendered_tool_use_labels: HashMap::default(),
expanded_tool_uses: HashMap::default(),
+ expanded_thinking_segments: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
@@ -88,7 +250,7 @@ impl ActiveThread {
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
- this.push_message(&message.id, message.text.clone(), window, cx);
+ this.push_message(&message.id, &message.segments, window, cx);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_label_markdown(
@@ -156,7 +318,7 @@ impl ActiveThread {
fn push_message(
&mut self,
id: &MessageId,
- text: String,
+ segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -164,8 +326,9 @@ impl ActiveThread {
self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1);
- let markdown = self.render_markdown(text.into(), window, cx);
- self.rendered_messages_by_id.insert(*id, markdown);
+ let rendered_message =
+ RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
+ self.rendered_messages_by_id.insert(*id, rendered_message);
self.list_state.scroll_to(ListOffset {
item_ix: old_len,
offset_in_item: Pixels(0.0),
@@ -175,7 +338,7 @@ impl ActiveThread {
fn edited_message(
&mut self,
id: &MessageId,
- text: String,
+ segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -183,8 +346,9 @@ impl ActiveThread {
return;
};
self.list_state.splice(index..index + 1, 1);
- let markdown = self.render_markdown(text.into(), window, cx);
- self.rendered_messages_by_id.insert(*id, markdown);
+ let rendered_message =
+ RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
+ self.rendered_messages_by_id.insert(*id, rendered_message);
}
fn deleted_message(&mut self, id: &MessageId) {
@@ -196,94 +360,6 @@ impl ActiveThread {
self.rendered_messages_by_id.remove(id);
}
- fn render_markdown(
- &self,
- text: SharedString,
- window: &Window,
- cx: &mut Context<Self>,
- ) -> Entity<Markdown> {
- let theme_settings = ThemeSettings::get_global(cx);
- let colors = cx.theme().colors();
- let ui_font_size = TextSize::Default.rems(cx);
- let buffer_font_size = TextSize::Small.rems(cx);
- let mut text_style = window.text_style();
-
- text_style.refine(&TextStyleRefinement {
- font_family: Some(theme_settings.ui_font.family.clone()),
- font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
- font_features: Some(theme_settings.ui_font.features.clone()),
- font_size: Some(ui_font_size.into()),
- color: Some(cx.theme().colors().text),
- ..Default::default()
- });
-
- let markdown_style = MarkdownStyle {
- base_text_style: text_style,
- syntax: cx.theme().syntax().clone(),
- selection_background_color: cx.theme().players().local().selection,
- code_block_overflow_x_scroll: true,
- table_overflow_x_scroll: true,
- code_block: StyleRefinement {
- margin: EdgesRefinement {
- top: Some(Length::Definite(rems(0.).into())),
- left: Some(Length::Definite(rems(0.).into())),
- right: Some(Length::Definite(rems(0.).into())),
- bottom: Some(Length::Definite(rems(0.5).into())),
- },
- padding: EdgesRefinement {
- top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
- left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
- right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
- bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
- },
- background: Some(colors.editor_background.into()),
- border_color: Some(colors.border_variant),
- border_widths: EdgesRefinement {
- top: Some(AbsoluteLength::Pixels(Pixels(1.))),
- left: Some(AbsoluteLength::Pixels(Pixels(1.))),
- right: Some(AbsoluteLength::Pixels(Pixels(1.))),
- bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
- },
- text: Some(TextStyleRefinement {
- font_family: Some(theme_settings.buffer_font.family.clone()),
- font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
- font_features: Some(theme_settings.buffer_font.features.clone()),
- font_size: Some(buffer_font_size.into()),
- ..Default::default()
- }),
- ..Default::default()
- },
- inline_code: TextStyleRefinement {
- font_family: Some(theme_settings.buffer_font.family.clone()),
- font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
- font_features: Some(theme_settings.buffer_font.features.clone()),
- font_size: Some(buffer_font_size.into()),
- background_color: Some(colors.editor_foreground.opacity(0.1)),
- ..Default::default()
- },
- link: TextStyleRefinement {
- background_color: Some(colors.editor_foreground.opacity(0.025)),
- underline: Some(UnderlineStyle {
- color: Some(colors.text_accent.opacity(0.5)),
- thickness: px(1.),
- ..Default::default()
- }),
- ..Default::default()
- },
- ..Default::default()
- };
-
- cx.new(|cx| {
- Markdown::new(
- text,
- markdown_style,
- Some(self.language_registry.clone()),
- None,
- cx,
- )
- })
- }
-
/// Renders the input of a scripting tool use to Markdown.
///
/// Does nothing if the tool use does not correspond to the scripting tool.
@@ -303,8 +379,12 @@ impl ActiveThread {
.map(|input| input.lua_script)
.unwrap_or_default();
- let lua_script =
- self.render_markdown(format!("```lua\n{lua_script}\n```").into(), window, cx);
+ let lua_script = render_markdown(
+ format!("```lua\n{lua_script}\n```").into(),
+ self.language_registry.clone(),
+ window,
+ cx,
+ );
self.rendered_scripting_tool_uses
.insert(tool_use_id, lua_script);
@@ -319,7 +399,12 @@ impl ActiveThread {
) {
self.rendered_tool_use_labels.insert(
tool_use_id,
- self.render_markdown(tool_label.into(), window, cx),
+ render_markdown(
+ tool_label.into(),
+ self.language_registry.clone(),
+ window,
+ cx,
+ ),
);
}
@@ -339,33 +424,36 @@ impl ActiveThread {
}
ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
- if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
- markdown.update(cx, |markdown, cx| {
- markdown.append(text, cx);
- });
+ if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
+ rendered_message.append_text(text, window, cx);
+ }
+ }
+ ThreadEvent::StreamedAssistantThinking(message_id, text) => {
+ if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
+ rendered_message.append_thinking(text, window, cx);
}
}
ThreadEvent::MessageAdded(message_id) => {
- if let Some(message_text) = self
+ if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
- .map(|message| message.text.clone())
+ .map(|message| message.segments.clone())
{
- self.push_message(message_id, message_text, window, cx);
+ self.push_message(message_id, &message_segments, window, cx);
}
self.save_thread(cx);
cx.notify();
}
ThreadEvent::MessageEdited(message_id) => {
- if let Some(message_text) = self
+ if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
- .map(|message| message.text.clone())
+ .map(|message| message.segments.clone())
{
- self.edited_message(message_id, message_text, window, cx);
+ self.edited_message(message_id, &message_segments, window, cx);
}
self.save_thread(cx);
@@ -490,10 +578,16 @@ impl ActiveThread {
fn start_editing_message(
&mut self,
message_id: MessageId,
- message_text: String,
+ message_segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
+ // User message should always consist of a single text segment,
+ // therefore we can skip returning early if it's not a text segment.
+ let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
+ return;
+ };
+
let buffer = cx.new(|cx| {
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
});
@@ -534,7 +628,12 @@ impl ActiveThread {
};
let edited_text = state.editor.read(cx).text(cx);
self.thread.update(cx, |thread, cx| {
- thread.edit_message(message_id, Role::User, edited_text, cx);
+ thread.edit_message(
+ message_id,
+ Role::User,
+ vec![MessageSegment::Text(edited_text)],
+ cx,
+ );
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
@@ -617,7 +716,7 @@ impl ActiveThread {
return Empty.into_any();
};
- let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
+ let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
@@ -759,7 +858,10 @@ impl ActiveThread {
.min_h_6()
.child(edit_message_editor)
} else {
- div().min_h_6().text_ui(cx).child(markdown.clone())
+ div()
+ .min_h_6()
+ .text_ui(cx)
+ .child(self.render_message_content(message_id, rendered_message, cx))
},
)
.when_some(context, |parent, context| {
@@ -869,11 +971,12 @@ impl ActiveThread {
Button::new("edit-message", "Edit")
.label_size(LabelSize::Small)
.on_click(cx.listener({
- let message_text = message.text.clone();
+ let message_segments =
+ message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
- message_text.clone(),
+ &message_segments,
window,
cx,
);
@@ -995,6 +1098,190 @@ impl ActiveThread {
.into_any()
}
+ fn render_message_content(
+ &self,
+ message_id: MessageId,
+ rendered_message: &RenderedMessage,
+ cx: &Context<Self>,
+ ) -> impl IntoElement {
+ let pending_thinking_segment_index = rendered_message
+ .segments
+ .iter()
+ .enumerate()
+ .last()
+ .filter(|(_, segment)| matches!(segment, RenderedMessageSegment::Thinking { .. }))
+ .map(|(index, _)| index);
+
+ div()
+ .text_ui(cx)
+ .gap_2()
+ .children(
+ rendered_message.segments.iter().enumerate().map(
+ |(index, segment)| match segment {
+ RenderedMessageSegment::Thinking {
+ content,
+ scroll_handle,
+ } => self
+ .render_message_thinking_segment(
+ message_id,
+ index,
+ content.clone(),
+ &scroll_handle,
+ Some(index) == pending_thinking_segment_index,
+ cx,
+ )
+ .into_any_element(),
+ RenderedMessageSegment::Text(markdown) => {
+ div().p_2p5().child(markdown.clone()).into_any_element()
+ }
+ },
+ ),
+ )
+ }
+
+ fn render_message_thinking_segment(
+ &self,
+ message_id: MessageId,
+ ix: usize,
+ markdown: Entity<Markdown>,
+ scroll_handle: &ScrollHandle,
+ pending: bool,
+ cx: &Context<Self>,
+ ) -> impl IntoElement {
+ let is_open = self
+ .expanded_thinking_segments
+ .get(&(message_id, ix))
+ .copied()
+ .unwrap_or_default();
+
+ let lighter_border = cx.theme().colors().border.opacity(0.5);
+ let editor_bg = cx.theme().colors().editor_background;
+
+ v_flex()
+ .rounded_lg()
+ .border_1()
+ .border_color(lighter_border)
+ .child(
+ h_flex()
+ .justify_between()
+ .py_1()
+ .pl_1()
+ .pr_2()
+ .bg(cx.theme().colors().editor_foreground.opacity(0.025))
+ .map(|this| {
+ if is_open {
+ this.rounded_t_md()
+ .border_b_1()
+ .border_color(lighter_border)
+ } else {
+ this.rounded_md()
+ }
+ })
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Disclosure::new("thinking-disclosure", is_open).on_click(
+ cx.listener({
+ move |this, _event, _window, _cx| {
+ let is_open = this
+ .expanded_thinking_segments
+ .entry((message_id, ix))
+ .or_insert(false);
+
+ *is_open = !*is_open;
+ }
+ }),
+ ))
+ .child({
+ if pending {
+ Label::new("Thinkingβ¦")
+ .size(LabelSize::Small)
+ .buffer_font(cx)
+ .with_animation(
+ "pulsating-label",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else {
+ Label::new("Thought Process")
+ .size(LabelSize::Small)
+ .buffer_font(cx)
+ .into_any_element()
+ }
+ }),
+ )
+ .child({
+ let (icon_name, color, animated) = if pending {
+ (IconName::ArrowCircle, Color::Accent, true)
+ } else {
+ (IconName::Check, Color::Success, false)
+ };
+
+ let icon = Icon::new(icon_name).color(color).size(IconSize::Small);
+
+ if animated {
+ icon.with_animation(
+ "arrow-circle",
+ Animation::new(Duration::from_secs(2)).repeat(),
+ |icon, delta| {
+ icon.transform(Transformation::rotate(percentage(delta)))
+ },
+ )
+ .into_any_element()
+ } else {
+ icon.into_any_element()
+ }
+ }),
+ )
+ .when(pending && !is_open, |this| {
+ let gradient_overlay = div()
+ .rounded_b_lg()
+ .h_20()
+ .absolute()
+ .w_full()
+ .bottom_0()
+ .left_0()
+ .bg(linear_gradient(
+ 180.,
+ linear_color_stop(editor_bg, 1.),
+ linear_color_stop(editor_bg.opacity(0.2), 0.),
+ ));
+
+ this.child(
+ div()
+ .relative()
+ .bg(editor_bg)
+ .rounded_b_lg()
+ .text_ui_sm(cx)
+ .child(
+ div()
+ .id(("thinking-content", ix))
+ .p_2()
+ .h_20()
+ .track_scroll(scroll_handle)
+ .child(markdown.clone())
+ .overflow_hidden(),
+ )
+ .child(gradient_overlay),
+ )
+ })
+ .when(is_open, |this| {
+ this.child(
+ div()
+ .id(("thinking-content", ix))
+ .h_full()
+ .p_2()
+ .rounded_b_lg()
+ .bg(editor_bg)
+ .text_ui_sm(cx)
+ .child(markdown.clone()),
+ )
+ })
+ }
+
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
let is_open = self
.expanded_tool_uses
@@ -1258,8 +1545,9 @@ impl ActiveThread {
}
}),
))
- .child(div().text_ui_sm(cx).child(self.render_markdown(
+ .child(div().text_ui_sm(cx).child(render_markdown(
tool_use.ui_text.clone(),
+ self.language_registry.clone(),
window,
cx,
)))
@@ -29,7 +29,8 @@ use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::{
- SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
+ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
+ SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
@@ -69,7 +70,47 @@ impl MessageId {
pub struct Message {
pub id: MessageId,
pub role: Role,
- pub text: String,
+ pub segments: Vec<MessageSegment>,
+}
+
+impl Message {
+ pub fn push_thinking(&mut self, text: &str) {
+ if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
+ segment.push_str(text);
+ } else {
+ self.segments
+ .push(MessageSegment::Thinking(text.to_string()));
+ }
+ }
+
+ pub fn push_text(&mut self, text: &str) {
+ if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
+ segment.push_str(text);
+ } else {
+ self.segments.push(MessageSegment::Text(text.to_string()));
+ }
+ }
+
+ pub fn to_string(&self) -> String {
+ let mut result = String::new();
+ for segment in &self.segments {
+ match segment {
+ MessageSegment::Text(text) => result.push_str(text),
+ MessageSegment::Thinking(text) => {
+ result.push_str("<think>");
+ result.push_str(text);
+ result.push_str("</think>");
+ }
+ }
+ }
+ result
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum MessageSegment {
+ Text(String),
+ Thinking(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -226,7 +267,16 @@ impl Thread {
.map(|message| Message {
id: message.id,
role: message.role,
- text: message.text,
+ segments: message
+ .segments
+ .into_iter()
+ .map(|segment| match segment {
+ SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
+ SerializedMessageSegment::Thinking { text } => {
+ MessageSegment::Thinking(text)
+ }
+ })
+ .collect(),
})
.collect(),
next_message_id,
@@ -419,7 +469,8 @@ impl Thread {
checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
- let message_id = self.insert_message(Role::User, text, cx);
+ let message_id =
+ self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
@@ -433,15 +484,11 @@ impl Thread {
pub fn insert_message(
&mut self,
role: Role,
- text: impl Into<String>,
+ segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
- self.messages.push(Message {
- id,
- role,
- text: text.into(),
- });
+ self.messages.push(Message { id, role, segments });
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
id
@@ -451,14 +498,14 @@ impl Thread {
&mut self,
id: MessageId,
new_role: Role,
- new_text: String,
+ new_segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
return false;
};
message.role = new_role;
- message.text = new_text;
+ message.segments = new_segments;
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
@@ -489,7 +536,14 @@ impl Thread {
});
text.push('\n');
- text.push_str(&message.text);
+ for segment in &message.segments {
+ match segment {
+ MessageSegment::Text(content) => text.push_str(content),
+ MessageSegment::Thinking(content) => {
+ text.push_str(&format!("<think>{}</think>", content))
+ }
+ }
+ }
text.push('\n');
}
@@ -502,6 +556,7 @@ impl Thread {
cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread {
+ version: SerializedThread::VERSION.to_string(),
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
@@ -509,7 +564,18 @@ impl Thread {
.map(|message| SerializedMessage {
id: message.id,
role: message.role,
- text: message.text.clone(),
+ segments: message
+ .segments
+ .iter()
+ .map(|segment| match segment {
+ MessageSegment::Text(text) => {
+ SerializedMessageSegment::Text { text: text.clone() }
+ }
+ MessageSegment::Thinking(text) => {
+ SerializedMessageSegment::Thinking { text: text.clone() }
+ }
+ })
+ .collect(),
tool_uses: this
.tool_uses_for_message(message.id, cx)
.into_iter()
@@ -733,10 +799,10 @@ impl Thread {
}
}
- if !message.text.is_empty() {
+ if !message.segments.is_empty() {
request_message
.content
- .push(MessageContent::Text(message.text.clone()));
+ .push(MessageContent::Text(message.to_string()));
}
match request_kind {
@@ -826,7 +892,11 @@ impl Thread {
thread.update(cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
- thread.insert_message(Role::Assistant, String::new(), cx);
+ thread.insert_message(
+ Role::Assistant,
+ vec![MessageSegment::Text(String::new())],
+ cx,
+ );
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@@ -840,7 +910,7 @@ impl Thread {
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
- last_message.text.push_str(&chunk);
+ last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
@@ -851,7 +921,33 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
- thread.insert_message(Role::Assistant, chunk, cx);
+ thread.insert_message(
+ Role::Assistant,
+ vec![MessageSegment::Text(chunk.to_string())],
+ cx,
+ );
+ };
+ }
+ }
+ LanguageModelCompletionEvent::Thinking(chunk) => {
+ if let Some(last_message) = thread.messages.last_mut() {
+ if last_message.role == Role::Assistant {
+ last_message.push_thinking(&chunk);
+ cx.emit(ThreadEvent::StreamedAssistantThinking(
+ last_message.id,
+ chunk,
+ ));
+ } else {
+ // If we won't have an Assistant message yet, assume this chunk marks the beginning
+ // of a new Assistant response.
+ //
+ // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
+ // will result in duplicating the text of the chunk in the rendered Markdown.
+ thread.insert_message(
+ Role::Assistant,
+ vec![MessageSegment::Thinking(chunk.to_string())],
+ cx,
+ );
};
}
}
@@ -1357,7 +1453,14 @@ impl Thread {
Role::System => "System",
}
)?;
- writeln!(markdown, "{}\n", message.text)?;
+ for segment in &message.segments {
+ match segment {
+ MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
+ MessageSegment::Thinking(text) => {
+ writeln!(markdown, "<think>{}</think>\n", text)?
+ }
+ }
+ }
for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!(
@@ -1416,6 +1519,7 @@ pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
+ StreamedAssistantThinking(MessageId, String),
DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),
@@ -1,3 +1,4 @@
+use std::borrow::Cow;
use std::path::PathBuf;
use std::sync::Arc;
@@ -12,7 +13,7 @@ use futures::FutureExt as _;
use gpui::{
prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
};
-use heed::types::{SerdeBincode, SerdeJson};
+use heed::types::SerdeBincode;
use heed::Database;
use language_model::{LanguageModelToolUseId, Role};
use project::Project;
@@ -259,6 +260,7 @@ pub struct SerializedThreadMetadata {
#[derive(Serialize, Deserialize)]
pub struct SerializedThread {
+ pub version: String,
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
pub messages: Vec<SerializedMessage>,
@@ -266,17 +268,55 @@ pub struct SerializedThread {
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
}
+impl SerializedThread {
+ pub const VERSION: &'static str = "0.1.0";
+
+ pub fn from_json(json: &[u8]) -> Result<Self> {
+ let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
+ match saved_thread_json.get("version") {
+ Some(serde_json::Value::String(version)) => match version.as_str() {
+ SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
+ saved_thread_json,
+ )?),
+ _ => Err(anyhow!(
+ "unrecognized serialized thread version: {}",
+ version
+ )),
+ },
+ None => {
+ let saved_thread =
+ serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
+ Ok(saved_thread.upgrade())
+ }
+ version => Err(anyhow!(
+ "unrecognized serialized thread version: {:?}",
+ version
+ )),
+ }
+ }
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
- pub text: String,
+ #[serde(default)]
+ pub segments: Vec<SerializedMessageSegment>,
#[serde(default)]
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
}
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum SerializedMessageSegment {
+ #[serde(rename = "text")]
+ Text { text: String },
+ #[serde(rename = "thinking")]
+ Thinking { text: String },
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
@@ -291,6 +331,50 @@ pub struct SerializedToolResult {
pub content: Arc<str>,
}
+#[derive(Serialize, Deserialize)]
+struct LegacySerializedThread {
+ pub summary: SharedString,
+ pub updated_at: DateTime<Utc>,
+ pub messages: Vec<LegacySerializedMessage>,
+ #[serde(default)]
+ pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
+}
+
+impl LegacySerializedThread {
+ pub fn upgrade(self) -> SerializedThread {
+ SerializedThread {
+ version: SerializedThread::VERSION.to_string(),
+ summary: self.summary,
+ updated_at: self.updated_at,
+ messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
+ initial_project_snapshot: self.initial_project_snapshot,
+ }
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct LegacySerializedMessage {
+ pub id: MessageId,
+ pub role: Role,
+ pub text: String,
+ #[serde(default)]
+ pub tool_uses: Vec<SerializedToolUse>,
+ #[serde(default)]
+ pub tool_results: Vec<SerializedToolResult>,
+}
+
+impl LegacySerializedMessage {
+ fn upgrade(self) -> SerializedMessage {
+ SerializedMessage {
+ id: self.id,
+ role: self.role,
+ segments: vec![SerializedMessageSegment::Text { text: self.text }],
+ tool_uses: self.tool_uses,
+ tool_results: self.tool_results,
+ }
+ }
+}
+
struct GlobalThreadsDatabase(
Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
);
@@ -300,7 +384,25 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
env: heed::Env,
- threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
+ threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
+}
+
+impl heed::BytesEncode<'_> for SerializedThread {
+ type EItem = SerializedThread;
+
+ fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
+ serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
+ }
+}
+
+impl<'a> heed::BytesDecode<'a> for SerializedThread {
+ type DItem = SerializedThread;
+
+ fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
+ // We implement this type manually because we want to call `SerializedThread::from_json`,
+ // instead of the Deserialize trait implementation for `SerializedThread`.
+ SerializedThread::from_json(bytes).map_err(Into::into)
+ }
}
impl ThreadsDatabase {
@@ -162,6 +162,11 @@ pub enum ContextOperation {
section: SlashCommandOutputSection<language::Anchor>,
version: clock::Global,
},
+ ThoughtProcessOutputSectionAdded {
+ timestamp: clock::Lamport,
+ section: ThoughtProcessOutputSection<language::Anchor>,
+ version: clock::Global,
+ },
BufferOperation(language::Operation),
}
@@ -259,6 +264,20 @@ impl ContextOperation {
version: language::proto::deserialize_version(&message.version),
})
}
+ proto::context_operation::Variant::ThoughtProcessOutputSectionAdded(message) => {
+ let section = message.section.context("missing section")?;
+ Ok(Self::ThoughtProcessOutputSectionAdded {
+ timestamp: language::proto::deserialize_timestamp(
+ message.timestamp.context("missing timestamp")?,
+ ),
+ section: ThoughtProcessOutputSection {
+ range: language::proto::deserialize_anchor_range(
+ section.range.context("invalid range")?,
+ )?,
+ },
+ version: language::proto::deserialize_version(&message.version),
+ })
+ }
proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
language::proto::deserialize_operation(
op.operation.context("invalid buffer operation")?,
@@ -370,6 +389,27 @@ impl ContextOperation {
},
)),
},
+ Self::ThoughtProcessOutputSectionAdded {
+ timestamp,
+ section,
+ version,
+ } => proto::ContextOperation {
+ variant: Some(
+ proto::context_operation::Variant::ThoughtProcessOutputSectionAdded(
+ proto::context_operation::ThoughtProcessOutputSectionAdded {
+ timestamp: Some(language::proto::serialize_timestamp(*timestamp)),
+ section: Some({
+ proto::ThoughtProcessOutputSection {
+ range: Some(language::proto::serialize_anchor_range(
+ section.range.clone(),
+ )),
+ }
+ }),
+ version: language::proto::serialize_version(version),
+ },
+ ),
+ ),
+ },
Self::BufferOperation(operation) => proto::ContextOperation {
variant: Some(proto::context_operation::Variant::BufferOperation(
proto::context_operation::BufferOperation {
@@ -387,7 +427,8 @@ impl ContextOperation {
Self::UpdateSummary { summary, .. } => summary.timestamp,
Self::SlashCommandStarted { id, .. } => id.0,
Self::SlashCommandOutputSectionAdded { timestamp, .. }
- | Self::SlashCommandFinished { timestamp, .. } => *timestamp,
+ | Self::SlashCommandFinished { timestamp, .. }
+ | Self::ThoughtProcessOutputSectionAdded { timestamp, .. } => *timestamp,
Self::BufferOperation(_) => {
panic!("reading the timestamp of a buffer operation is not supported")
}
@@ -402,7 +443,8 @@ impl ContextOperation {
| Self::UpdateSummary { version, .. }
| Self::SlashCommandStarted { version, .. }
| Self::SlashCommandOutputSectionAdded { version, .. }
- | Self::SlashCommandFinished { version, .. } => version,
+ | Self::SlashCommandFinished { version, .. }
+ | Self::ThoughtProcessOutputSectionAdded { version, .. } => version,
Self::BufferOperation(_) => {
panic!("reading the version of a buffer operation is not supported")
}
@@ -418,6 +460,8 @@ pub enum ContextEvent {
MessagesEdited,
SummaryChanged,
StreamedCompletion,
+ StartedThoughtProcess(Range<language::Anchor>),
+ EndedThoughtProcess(language::Anchor),
PatchesUpdated {
removed: Vec<Range<language::Anchor>>,
updated: Vec<Range<language::Anchor>>,
@@ -498,6 +542,17 @@ impl MessageMetadata {
}
}
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+pub struct ThoughtProcessOutputSection<T> {
+ pub range: Range<T>,
+}
+
+impl ThoughtProcessOutputSection<language::Anchor> {
+ pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
+ self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
+ }
+}
+
#[derive(Clone, Debug)]
pub struct Message {
pub offset_range: Range<usize>,
@@ -580,6 +635,7 @@ pub struct AssistantContext {
edits_since_last_parse: language::Subscription,
slash_commands: Arc<SlashCommandWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
+ thought_process_output_sections: Vec<ThoughtProcessOutputSection<language::Anchor>>,
message_anchors: Vec<MessageAnchor>,
contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>,
@@ -682,6 +738,7 @@ impl AssistantContext {
parsed_slash_commands: Vec::new(),
invoked_slash_commands: HashMap::default(),
slash_command_output_sections: Vec::new(),
+ thought_process_output_sections: Vec::new(),
edits_since_last_parse: edits_since_last_slash_command_parse,
summary: None,
pending_summary: Task::ready(None),
@@ -764,6 +821,18 @@ impl AssistantContext {
}
})
.collect(),
+ thought_process_output_sections: self
+ .thought_process_output_sections
+ .iter()
+ .filter_map(|section| {
+ if section.is_valid(buffer) {
+ let range = section.range.to_offset(buffer);
+ Some(ThoughtProcessOutputSection { range })
+ } else {
+ None
+ }
+ })
+ .collect(),
}
}
@@ -957,6 +1026,16 @@ impl AssistantContext {
cx.emit(ContextEvent::SlashCommandOutputSectionAdded { section });
}
}
+ ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => {
+ let buffer = self.buffer.read(cx);
+ if let Err(ix) = self
+ .thought_process_output_sections
+ .binary_search_by(|probe| probe.range.cmp(§ion.range, buffer))
+ {
+ self.thought_process_output_sections
+ .insert(ix, section.clone());
+ }
+ }
ContextOperation::SlashCommandFinished {
id,
error_message,
@@ -1020,6 +1099,9 @@ impl AssistantContext {
ContextOperation::SlashCommandOutputSectionAdded { section, .. } => {
self.has_received_operations_for_anchor_range(section.range.clone(), cx)
}
+ ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => {
+ self.has_received_operations_for_anchor_range(section.range.clone(), cx)
+ }
ContextOperation::SlashCommandFinished { .. } => true,
ContextOperation::BufferOperation(_) => {
panic!("buffer operations should always be applied")
@@ -1128,6 +1210,12 @@ impl AssistantContext {
&self.slash_command_output_sections
}
+ pub fn thought_process_output_sections(
+ &self,
+ ) -> &[ThoughtProcessOutputSection<language::Anchor>] {
+ &self.thought_process_output_sections
+ }
+
pub fn contains_files(&self, cx: &App) -> bool {
let buffer = self.buffer.read(cx);
self.slash_command_output_sections.iter().any(|section| {
@@ -2168,6 +2256,35 @@ impl AssistantContext {
);
}
+ fn insert_thought_process_output_section(
+ &mut self,
+ section: ThoughtProcessOutputSection<language::Anchor>,
+ cx: &mut Context<Self>,
+ ) {
+ let buffer = self.buffer.read(cx);
+ let insertion_ix = match self
+ .thought_process_output_sections
+ .binary_search_by(|probe| probe.range.cmp(§ion.range, buffer))
+ {
+ Ok(ix) | Err(ix) => ix,
+ };
+ self.thought_process_output_sections
+ .insert(insertion_ix, section.clone());
+ // cx.emit(ContextEvent::ThoughtProcessOutputSectionAdded {
+ // section: section.clone(),
+ // });
+ let version = self.version.clone();
+ let timestamp = self.next_timestamp();
+ self.push_op(
+ ContextOperation::ThoughtProcessOutputSectionAdded {
+ timestamp,
+ section,
+ version,
+ },
+ cx,
+ );
+ }
+
pub fn completion_provider_changed(&mut self, cx: &mut Context<Self>) {
self.count_remaining_tokens(cx);
}
@@ -2220,6 +2337,10 @@ impl AssistantContext {
let request_start = Instant::now();
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
+ let mut thought_process_stack = Vec::new();
+
+ const THOUGHT_PROCESS_START_MARKER: &str = "<think>\n";
+ const THOUGHT_PROCESS_END_MARKER: &str = "\n</think>";
while let Some(event) = events.next().await {
if response_latency.is_none() {
@@ -2227,6 +2348,9 @@ impl AssistantContext {
}
let event = event?;
+ let mut context_event = None;
+ let mut thought_process_output_section = None;
+
this.update(cx, |this, cx| {
let message_ix = this
.message_anchors
@@ -2245,7 +2369,50 @@ impl AssistantContext {
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
}
- LanguageModelCompletionEvent::Text(chunk) => {
+ LanguageModelCompletionEvent::Thinking(chunk) => {
+ if thought_process_stack.is_empty() {
+ let start =
+ buffer.anchor_before(message_old_end_offset);
+ thought_process_stack.push(start);
+ let chunk =
+ format!("{THOUGHT_PROCESS_START_MARKER}{chunk}{THOUGHT_PROCESS_END_MARKER}");
+ let chunk_len = chunk.len();
+ buffer.edit(
+ [(
+ message_old_end_offset..message_old_end_offset,
+ chunk,
+ )],
+ None,
+ cx,
+ );
+ let end = buffer
+ .anchor_before(message_old_end_offset + chunk_len);
+ context_event = Some(
+ ContextEvent::StartedThoughtProcess(start..end),
+ );
+ } else {
+ // This ensures that all the thinking chunks are inserted inside the thinking tag
+ let insertion_position =
+ message_old_end_offset - THOUGHT_PROCESS_END_MARKER.len();
+ buffer.edit(
+ [(insertion_position..insertion_position, chunk)],
+ None,
+ cx,
+ );
+ }
+ }
+ LanguageModelCompletionEvent::Text(mut chunk) => {
+ if let Some(start) = thought_process_stack.pop() {
+ let end = buffer.anchor_before(message_old_end_offset);
+ context_event =
+ Some(ContextEvent::EndedThoughtProcess(end));
+ thought_process_output_section =
+ Some(ThoughtProcessOutputSection {
+ range: start..end,
+ });
+ chunk.insert_str(0, "\n\n");
+ }
+
buffer.edit(
[(
message_old_end_offset..message_old_end_offset,
@@ -2260,6 +2427,13 @@ impl AssistantContext {
}
});
+ if let Some(section) = thought_process_output_section.take() {
+ this.insert_thought_process_output_section(section, cx);
+ }
+ if let Some(context_event) = context_event.take() {
+ cx.emit(context_event);
+ }
+
cx.emit(ContextEvent::StreamedCompletion);
Some(())
@@ -3127,6 +3301,8 @@ pub struct SavedContext {
pub summary: String,
pub slash_command_output_sections:
Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
+ #[serde(default)]
+ pub thought_process_output_sections: Vec<ThoughtProcessOutputSection<usize>>,
}
impl SavedContext {
@@ -3228,6 +3404,20 @@ impl SavedContext {
version.observe(timestamp);
}
+ for section in self.thought_process_output_sections {
+ let timestamp = next_timestamp.tick();
+ operations.push(ContextOperation::ThoughtProcessOutputSectionAdded {
+ timestamp,
+ section: ThoughtProcessOutputSection {
+ range: buffer.anchor_after(section.range.start)
+ ..buffer.anchor_before(section.range.end),
+ },
+ version: version.clone(),
+ });
+
+ version.observe(timestamp);
+ }
+
let timestamp = next_timestamp.tick();
operations.push(ContextOperation::UpdateSummary {
summary: ContextSummary {
@@ -3302,6 +3492,7 @@ impl SavedContextV0_3_0 {
.collect(),
summary: self.summary,
slash_command_output_sections: self.slash_command_output_sections,
+ thought_process_output_sections: Vec::new(),
}
}
}
@@ -64,7 +64,10 @@ use workspace::{
Workspace,
};
-use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker};
+use crate::{
+ slash_command::SlashCommandCompletionProvider, slash_command_picker,
+ ThoughtProcessOutputSection,
+};
use crate::{
AssistantContext, AssistantPatch, AssistantPatchStatus, CacheStatus, Content, ContextEvent,
ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId,
@@ -120,6 +123,11 @@ enum AssistError {
Message(SharedString),
}
+pub enum ThoughtProcessStatus {
+ Pending,
+ Completed,
+}
+
pub trait AssistantPanelDelegate {
fn active_context_editor(
&self,
@@ -178,6 +186,7 @@ pub struct ContextEditor {
project: Entity<Project>,
lsp_adapter_delegate: Option<Arc<dyn LspAdapterDelegate>>,
editor: Entity<Editor>,
+ pending_thought_process: Option<(CreaseId, language::Anchor)>,
blocks: HashMap<MessageId, (MessageHeader, CustomBlockId)>,
image_blocks: HashSet<CustomBlockId>,
scroll_position: Option<ScrollPosition>,
@@ -253,7 +262,8 @@ impl ContextEditor {
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
];
- let sections = context.read(cx).slash_command_output_sections().to_vec();
+ let slash_command_sections = context.read(cx).slash_command_output_sections().to_vec();
+ let thought_process_sections = context.read(cx).thought_process_output_sections().to_vec();
let patch_ranges = context.read(cx).patch_ranges().collect::<Vec<_>>();
let slash_commands = context.read(cx).slash_commands().clone();
let mut this = Self {
@@ -265,6 +275,7 @@ impl ContextEditor {
image_blocks: Default::default(),
scroll_position: None,
remote_id: None,
+ pending_thought_process: None,
fs: fs.clone(),
workspace,
project,
@@ -294,7 +305,14 @@ impl ContextEditor {
};
this.update_message_headers(cx);
this.update_image_blocks(cx);
- this.insert_slash_command_output_sections(sections, false, window, cx);
+ this.insert_slash_command_output_sections(slash_command_sections, false, window, cx);
+ this.insert_thought_process_output_sections(
+ thought_process_sections
+ .into_iter()
+ .map(|section| (section, ThoughtProcessStatus::Completed)),
+ window,
+ cx,
+ );
this.patches_updated(&Vec::new(), &patch_ranges, window, cx);
this
}
@@ -599,6 +617,47 @@ impl ContextEditor {
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
});
}
+ ContextEvent::StartedThoughtProcess(range) => {
+ let creases = self.insert_thought_process_output_sections(
+ [(
+ ThoughtProcessOutputSection {
+ range: range.clone(),
+ },
+ ThoughtProcessStatus::Pending,
+ )],
+ window,
+ cx,
+ );
+ self.pending_thought_process = Some((creases[0], range.start));
+ }
+ ContextEvent::EndedThoughtProcess(end) => {
+ if let Some((crease_id, start)) = self.pending_thought_process.take() {
+ self.editor.update(cx, |editor, cx| {
+ let multi_buffer_snapshot = editor.buffer().read(cx).snapshot(cx);
+ let (excerpt_id, _, _) = multi_buffer_snapshot.as_singleton().unwrap();
+ let start_anchor = multi_buffer_snapshot
+ .anchor_in_excerpt(*excerpt_id, start)
+ .unwrap();
+
+ editor.display_map.update(cx, |display_map, cx| {
+ display_map.unfold_intersecting(
+ vec![start_anchor..start_anchor],
+ true,
+ cx,
+ );
+ });
+ editor.remove_creases(vec![crease_id], cx);
+ });
+ self.insert_thought_process_output_sections(
+ [(
+ ThoughtProcessOutputSection { range: start..*end },
+ ThoughtProcessStatus::Completed,
+ )],
+ window,
+ cx,
+ );
+ }
+ }
ContextEvent::StreamedCompletion => {
self.editor.update(cx, |editor, cx| {
if let Some(scroll_position) = self.scroll_position {
@@ -946,6 +1005,62 @@ impl ContextEditor {
self.update_active_patch(window, cx);
}
+ fn insert_thought_process_output_sections(
+ &mut self,
+ sections: impl IntoIterator<
+ Item = (
+ ThoughtProcessOutputSection<language::Anchor>,
+ ThoughtProcessStatus,
+ ),
+ >,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Vec<CreaseId> {
+ self.editor.update(cx, |editor, cx| {
+ let buffer = editor.buffer().read(cx).snapshot(cx);
+ let excerpt_id = *buffer.as_singleton().unwrap().0;
+ let mut buffer_rows_to_fold = BTreeSet::new();
+ let mut creases = Vec::new();
+ for (section, status) in sections {
+ let start = buffer
+ .anchor_in_excerpt(excerpt_id, section.range.start)
+ .unwrap();
+ let end = buffer
+ .anchor_in_excerpt(excerpt_id, section.range.end)
+ .unwrap();
+ let buffer_row = MultiBufferRow(start.to_point(&buffer).row);
+ buffer_rows_to_fold.insert(buffer_row);
+ creases.push(
+ Crease::inline(
+ start..end,
+ FoldPlaceholder {
+ render: render_thought_process_fold_icon_button(
+ cx.entity().downgrade(),
+ status,
+ ),
+ merge_adjacent: false,
+ ..Default::default()
+ },
+ render_slash_command_output_toggle,
+ |_, _, _, _| Empty.into_any_element(),
+ )
+ .with_metadata(CreaseMetadata {
+ icon: IconName::Ai,
+ label: "Thinking Process".into(),
+ }),
+ );
+ }
+
+ let creases = editor.insert_creases(creases, cx);
+
+ for buffer_row in buffer_rows_to_fold.into_iter().rev() {
+ editor.fold_at(&FoldAt { buffer_row }, window, cx);
+ }
+
+ creases
+ })
+ }
+
fn insert_slash_command_output_sections(
&mut self,
sections: impl IntoIterator<Item = SlashCommandOutputSection<language::Anchor>>,
@@ -2652,6 +2767,52 @@ fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Opti
None
}
+fn render_thought_process_fold_icon_button(
+ editor: WeakEntity<Editor>,
+ status: ThoughtProcessStatus,
+) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
+ Arc::new(move |fold_id, fold_range, _cx| {
+ let editor = editor.clone();
+
+ let button = ButtonLike::new(fold_id).layer(ElevationIndex::ElevatedSurface);
+ let button = match status {
+ ThoughtProcessStatus::Pending => button
+ .child(
+ Icon::new(IconName::Brain)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Thinkingβ¦").color(Color::Muted).with_animation(
+ "pulsating-label",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ ),
+ ),
+ ThoughtProcessStatus::Completed => button
+ .style(ButtonStyle::Filled)
+ .child(Icon::new(IconName::Brain).size(IconSize::Small))
+ .child(Label::new("Thought Process").single_line()),
+ };
+
+ button
+ .on_click(move |_, window, cx| {
+ editor
+ .update(cx, |editor, cx| {
+ let buffer_start = fold_range
+ .start
+ .to_point(&editor.buffer().read(cx).read(cx));
+ let buffer_row = MultiBufferRow(buffer_start.row);
+ editor.unfold_at(&UnfoldAt { buffer_row }, window, cx);
+ })
+ .ok();
+ })
+ .into_any_element()
+ })
+}
+
fn render_fold_icon_button(
editor: WeakEntity<Editor>,
icon: IconName,
@@ -120,7 +120,7 @@ impl Eval {
.count();
Ok(EvalOutput {
diff,
- last_message: last_message.text.clone(),
+ last_message: last_message.to_string(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
@@ -89,7 +89,7 @@ impl HeadlessAssistant {
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
- println!("Message: {}", message.text,);
+ println!("Message: {}", message.to_string());
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
@@ -1240,20 +1240,11 @@ impl Element for Div {
let mut state = scroll_handle.0.borrow_mut();
state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len());
state.bounds = bounds;
- let requested = state.requested_scroll_top.take();
-
- for (ix, child_layout_id) in request_layout.child_layout_ids.iter().enumerate() {
+ for child_layout_id in &request_layout.child_layout_ids {
let child_bounds = window.layout_bounds(*child_layout_id);
child_min = child_min.min(&child_bounds.origin);
child_max = child_max.max(&child_bounds.bottom_right());
state.child_bounds.push(child_bounds);
-
- if let Some(requested) = requested.as_ref() {
- if requested.0 == ix {
- *state.offset.borrow_mut() =
- bounds.origin - (child_bounds.origin - point(px(0.), requested.1));
- }
- }
}
(child_max - child_min).into()
} else {
@@ -1528,8 +1519,11 @@ impl Interactivity {
_cx: &mut App,
) -> Point<Pixels> {
if let Some(scroll_offset) = self.scroll_offset.as_ref() {
+ let mut scroll_to_bottom = false;
if let Some(scroll_handle) = &self.tracked_scroll_handle {
- scroll_handle.0.borrow_mut().overflow = style.overflow;
+ let mut state = scroll_handle.0.borrow_mut();
+ state.overflow = style.overflow;
+ scroll_to_bottom = mem::take(&mut state.scroll_to_bottom);
}
let rem_size = window.rem_size();
@@ -1555,8 +1549,14 @@ impl Interactivity {
// Clamp scroll offset in case scroll max is smaller now (e.g., if children
// were removed or the bounds became larger).
let mut scroll_offset = scroll_offset.borrow_mut();
+
scroll_offset.x = scroll_offset.x.clamp(-scroll_max.width, px(0.));
- scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.));
+ if scroll_to_bottom {
+ scroll_offset.y = -scroll_max.height;
+ } else {
+ scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.));
+ }
+
*scroll_offset
} else {
Point::default()
@@ -2861,12 +2861,13 @@ impl ScrollAnchor {
});
}
}
+
#[derive(Default, Debug)]
struct ScrollHandleState {
offset: Rc<RefCell<Point<Pixels>>>,
bounds: Bounds<Pixels>,
child_bounds: Vec<Bounds<Pixels>>,
- requested_scroll_top: Option<(usize, Pixels)>,
+ scroll_to_bottom: bool,
overflow: Point<Overflow>,
}
@@ -2955,6 +2956,12 @@ impl ScrollHandle {
}
}
+ /// Scrolls to the bottom.
+ pub fn scroll_to_bottom(&self) {
+ let mut state = self.0.borrow_mut();
+ state.scroll_to_bottom = true;
+ }
+
/// Set the offset explicitly. The offset is the distance from the top left of the
/// parent container to the top left of the first child.
/// As you scroll further down the offset becomes more negative.
@@ -2978,11 +2985,6 @@ impl ScrollHandle {
}
}
- /// Set the logical scroll top, based on a child index and a pixel offset.
- pub fn set_logical_scroll_top(&self, ix: usize, px: Pixels) {
- self.0.borrow_mut().requested_scroll_top = Some((ix, px));
- }
-
/// Get the count of children for scrollable item.
pub fn children_count(&self) -> usize {
self.0.borrow().child_bounds.len()
@@ -59,6 +59,7 @@ pub struct LanguageModelCacheConfiguration {
pub enum LanguageModelCompletionEvent {
Stop(StopReason),
Text(String),
+ Thinking(String),
ToolUse(LanguageModelToolUse),
StartMessage { message_id: String },
UsageUpdate(TokenUsage),
@@ -217,6 +218,7 @@ pub trait LanguageModel: Send + Sync {
match result {
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
+ Ok(LanguageModelCompletionEvent::Thinking(_)) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,
@@ -72,7 +72,9 @@ impl CloudModel {
pub fn availability(&self) -> LanguageModelAvailability {
match self {
Self::Anthropic(model) => match model {
- anthropic::Model::Claude3_5Sonnet | anthropic::Model::Claude3_7Sonnet => {
+ anthropic::Model::Claude3_5Sonnet
+ | anthropic::Model::Claude3_7Sonnet
+ | anthropic::Model::Claude3_7SonnetThinking => {
LanguageModelAvailability::RequiresPlan(Plan::Free)
}
anthropic::Model::Claude3Opus
@@ -1,6 +1,6 @@
use crate::ui::InstructionListItem;
use crate::AllLanguageModelSettings;
-use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
+use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage};
use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
@@ -55,6 +55,37 @@ pub struct AvailableModel {
pub default_temperature: Option<f32>,
#[serde(default)]
pub extra_beta_headers: Vec<String>,
+ /// The model's mode (e.g. thinking)
+ pub mode: Option<ModelMode>,
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+ budget_tokens: Option<u32>,
+ },
+}
+
+impl From<ModelMode> for AnthropicModelMode {
+ fn from(value: ModelMode) -> Self {
+ match value {
+ ModelMode::Default => AnthropicModelMode::Default,
+ ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
+ }
+ }
+}
+
+impl From<AnthropicModelMode> for ModelMode {
+ fn from(value: AnthropicModelMode) -> Self {
+ match value {
+ AnthropicModelMode::Default => ModelMode::Default,
+ AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
+ }
+ }
}
pub struct AnthropicLanguageModelProvider {
@@ -228,6 +259,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
extra_beta_headers: model.extra_beta_headers.clone(),
+ mode: model.mode.clone().unwrap_or_default().into(),
},
);
}
@@ -399,9 +431,10 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = into_anthropic(
request,
- self.model.id().into(),
+ self.model.request_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
+ self.model.mode(),
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
@@ -434,6 +467,7 @@ impl LanguageModel for AnthropicModel {
self.model.tool_model_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
+ self.model.mode(),
);
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@@ -464,6 +498,7 @@ pub fn into_anthropic(
model: String,
default_temperature: f32,
max_output_tokens: u32,
+ mode: AnthropicModelMode,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
@@ -552,6 +587,11 @@ pub fn into_anthropic(
messages: new_messages,
max_tokens: max_output_tokens,
system: Some(system_message),
+ thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
+ Some(anthropic::Thinking::Enabled { budget_tokens })
+ } else {
+ None
+ },
tools: request
.tools
.into_iter()
@@ -607,6 +647,16 @@ pub fn map_to_language_model_completion_events(
state,
));
}
+ ResponseContent::Thinking { thinking } => {
+ return Some((
+ vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
+ state,
+ ));
+ }
+ ResponseContent::RedactedThinking { .. } => {
+ // Redacted thinking is encrypted and not accessible to the user, see:
+ // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
+ }
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
@@ -625,6 +675,13 @@ pub fn map_to_language_model_completion_events(
state,
));
}
+ ContentDelta::ThinkingDelta { thinking } => {
+ return Some((
+ vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
+ state,
+ ));
+ }
+ ContentDelta::SignatureDelta { .. } => {}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
@@ -1,4 +1,4 @@
-use anthropic::AnthropicError;
+use anthropic::{AnthropicError, AnthropicModelMode};
use anyhow::{anyhow, Result};
use client::{
zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
@@ -91,6 +91,28 @@ pub struct AvailableModel {
/// Any extra beta headers to provide when using the model.
#[serde(default)]
pub extra_beta_headers: Vec<String>,
+ /// The model's mode (e.g. thinking)
+ pub mode: Option<ModelMode>,
+}
+
+#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+ budget_tokens: Option<u32>,
+ },
+}
+
+impl From<ModelMode> for AnthropicModelMode {
+ fn from(value: ModelMode) -> Self {
+ match value {
+ ModelMode::Default => AnthropicModelMode::Default,
+ ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
+ }
+ }
}
pub struct CloudLanguageModelProvider {
@@ -299,6 +321,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
default_temperature: model.default_temperature,
max_output_tokens: model.max_output_tokens,
extra_beta_headers: model.extra_beta_headers.clone(),
+ mode: model.mode.unwrap_or_default().into(),
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
@@ -567,9 +590,10 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => {
let request = into_anthropic(
request,
- model.id().into(),
+ model.request_id().into(),
model.default_temperature(),
model.max_output_tokens(),
+ model.mode(),
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
@@ -669,6 +693,7 @@ impl LanguageModel for CloudLanguageModel {
model.tool_model_id().into(),
model.default_temperature(),
model.max_output_tokens(),
+ model.mode(),
);
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@@ -109,6 +109,7 @@ impl AnthropicSettingsContent {
max_output_tokens,
default_temperature,
extra_beta_headers,
+ mode,
} => Some(provider::anthropic::AvailableModel {
name,
display_name,
@@ -124,6 +125,7 @@ impl AnthropicSettingsContent {
max_output_tokens,
default_temperature,
extra_beta_headers,
+ mode: Some(mode.into()),
}),
_ => None,
})
@@ -2503,6 +2503,10 @@ message SlashCommandOutputSection {
optional string metadata = 4;
}
+message ThoughtProcessOutputSection {
+ AnchorRange range = 1;
+}
+
message ContextOperation {
oneof variant {
InsertMessage insert_message = 1;
@@ -2512,6 +2516,7 @@ message ContextOperation {
SlashCommandStarted slash_command_started = 6;
SlashCommandOutputSectionAdded slash_command_output_section_added = 7;
SlashCommandCompleted slash_command_completed = 8;
+ ThoughtProcessOutputSectionAdded thought_process_output_section_added = 9;
}
reserved 4;
@@ -2556,6 +2561,12 @@ message ContextOperation {
repeated VectorClockEntry version = 5;
}
+ message ThoughtProcessOutputSectionAdded {
+ LamportTimestamp timestamp = 1;
+ ThoughtProcessOutputSection section = 2;
+ repeated VectorClockEntry version = 3;
+ }
+
message BufferOperation {
Operation operation = 1;
}
@@ -68,6 +68,21 @@ You can add custom models to the Anthropic provider by adding the following to y
Custom models will be listed in the model dropdown in the assistant panel.
+You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it),
+by changing the mode in of your models configuration to `thinking`, for example:
+
+```json
+{
+ "name": "claude-3-7-sonnet-latest",
+ "display_name": "claude-3-7-sonnet-thinking",
+ "max_tokens": 200000,
+ "mode": {
+ "type": "thinking",
+ "budget_tokens": 4_096
+ }
+}
+```
+
### GitHub Copilot Chat {#github-copilot-chat}
You can use GitHub Copilot chat with the Zed assistant by choosing it via the model dropdown in the assistant panel.