@@ -40,6 +40,7 @@ async fn test_echo(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
})
+ .unwrap()
.collect()
.await;
thread.update(cx, |thread, _cx| {
@@ -73,6 +74,7 @@ async fn test_thinking(cx: &mut TestAppContext) {
cx,
)
})
+ .unwrap()
.collect()
.await;
thread.update(cx, |thread, _cx| {
@@ -101,9 +103,11 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["abc"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
@@ -136,9 +140,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
let fake_model = model.as_fake();
// Send initial user message and verify it's cached
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Message 1"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 1"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
@@ -157,9 +163,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
cx.run_until_parked();
// Send another user message and verify only the latest is cached
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Message 2"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 2"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
@@ -191,9 +199,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
// Simulate a tool call and verify that the latest tool result is cached
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
@@ -273,6 +283,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
cx,
)
})
+ .unwrap()
.collect()
.await;
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -291,6 +302,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
cx,
)
})
+ .unwrap()
.collect()
.await;
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -322,10 +334,12 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
// Test a tool call that's likely to complete *before* streaming stops.
- let mut events = thread.update(cx, |thread, cx| {
- thread.add_tool(WordListTool);
- thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
- });
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(WordListTool);
+ thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
+ })
+ .unwrap();
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
@@ -371,10 +385,12 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let mut events = thread.update(cx, |thread, cx| {
- thread.add_tool(ToolRequiringPermission);
- thread.send(UserMessageId::new(), ["abc"], cx)
- });
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(ToolRequiringPermission);
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -501,9 +517,11 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let mut events = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["abc"], cx)
- });
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -528,10 +546,12 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events = thread.update(cx, |thread, cx| {
- thread.add_tool(EchoTool);
- thread.send(UserMessageId::new(), ["abc"], cx)
- });
+ let events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
@@ -644,10 +664,12 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events = thread.update(cx, |thread, cx| {
- thread.add_tool(EchoTool);
- thread.send(UserMessageId::new(), ["abc"], cx)
- });
+ let events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
@@ -677,9 +699,11 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
.is::<language_model::ToolUseLimitReachedError>()
);
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), vec!["ghi"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), vec!["ghi"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
@@ -790,6 +814,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
cx,
)
})
+ .unwrap()
.collect()
.await;
@@ -857,10 +882,12 @@ async fn test_profiles(cx: &mut TestAppContext) {
cx.run_until_parked();
// Test that test-1 profile (default) has echo and delay tools
- thread.update(cx, |thread, cx| {
- thread.set_profile(AgentProfileId("test-1".into()));
- thread.send(UserMessageId::new(), ["test"], cx);
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test-1".into()));
+ thread.send(UserMessageId::new(), ["test"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
@@ -875,10 +902,12 @@ async fn test_profiles(cx: &mut TestAppContext) {
fake_model.end_last_completion_stream();
// Switch to test-2 profile, and verify that it has only the infinite tool.
- thread.update(cx, |thread, cx| {
- thread.set_profile(AgentProfileId("test-2".into()));
- thread.send(UserMessageId::new(), ["test2"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test-2".into()));
+ thread.send(UserMessageId::new(), ["test2"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(pending_completions.len(), 1);
@@ -896,15 +925,17 @@ async fn test_profiles(cx: &mut TestAppContext) {
async fn test_cancellation(cx: &mut TestAppContext) {
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
- let mut events = thread.update(cx, |thread, cx| {
- thread.add_tool(InfiniteTool);
- thread.add_tool(EchoTool);
- thread.send(
- UserMessageId::new(),
- ["Call the echo tool, then call the infinite tool, then explain their output"],
- cx,
- )
- });
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(InfiniteTool);
+ thread.add_tool(EchoTool);
+ thread.send(
+ UserMessageId::new(),
+ ["Call the echo tool, then call the infinite tool, then explain their output"],
+ cx,
+ )
+ })
+ .unwrap();
// Wait until both tools are called.
let mut expected_tools = vec!["Echo", "Infinite Tool"];
@@ -960,6 +991,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
cx,
)
})
+ .unwrap()
.collect::<Vec<_>>()
.await;
thread.update(cx, |thread, _cx| {
@@ -978,16 +1010,20 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events_1 = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hello 1"], cx)
- });
+ let events_1 = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 1"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
cx.run_until_parked();
- let events_2 = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hello 2"], cx)
- });
+ let events_2 = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 2"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
fake_model
@@ -1005,9 +1041,11 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events_1 = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hello 1"], cx)
- });
+ let events_1 = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 1"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
fake_model
@@ -1015,9 +1053,11 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
fake_model.end_last_completion_stream();
let events_1 = events_1.collect::<Vec<_>>().await;
- let events_2 = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hello 2"], cx)
- });
+ let events_2 = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 2"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
fake_model
@@ -1034,9 +1074,11 @@ async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hello"], cx)
- });
+ let events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1082,9 +1124,11 @@ async fn test_truncate(cx: &mut TestAppContext) {
let fake_model = model.as_fake();
let message_id = UserMessageId::new();
- thread.update(cx, |thread, cx| {
- thread.send(message_id.clone(), ["Hello"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_id.clone(), ["Hello"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1123,9 +1167,11 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
// Ensure we can still send a new message after truncation.
- thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Hi"], cx)
- });
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hi"], cx)
+ })
+ .unwrap();
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
@@ -1291,9 +1337,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
- let mut events = thread.update(cx, |thread, cx| {
- thread.send(UserMessageId::new(), ["Think"], cx)
- });
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Think"], cx)
+ })
+ .unwrap();
cx.run_until_parked();
// Simulate streaming partial input.
@@ -1506,7 +1554,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
context_server_registry,
action_log,
templates,
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -469,7 +469,7 @@ pub struct Thread {
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
- model: Arc<dyn LanguageModel>,
+ model: Option<Arc<dyn LanguageModel>>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@@ -481,7 +481,7 @@ impl Thread {
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
- model: Arc<dyn LanguageModel>,
+ model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -512,12 +512,12 @@ impl Thread {
&self.action_log
}
- pub fn model(&self) -> &Arc<dyn LanguageModel> {
- &self.model
+ pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
+ self.model.as_ref()
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
- self.model = model;
+ self.model = Some(model);
}
pub fn completion_mode(&self) -> CompletionMode {
@@ -575,6 +575,7 @@ impl Thread {
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
+ anyhow::ensure!(self.model.is_some(), "Model not set");
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
@@ -584,7 +585,7 @@ impl Thread {
cx.notify();
log::info!("Total messages in thread: {}", self.messages.len());
- Ok(self.run_turn(cx))
+ self.run_turn(cx)
}
/// Sending a message results in the model streaming a response, which could include tool calls.
@@ -595,11 +596,13 @@ impl Thread {
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
+ ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
where
T: Into<UserMessageContent>,
{
- log::info!("Thread::send called with model: {:?}", self.model.name());
+ let model = self.model().context("No language model configured")?;
+
+ log::info!("Thread::send called with model: {:?}", model.name());
self.advance_prompt_id();
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
@@ -616,10 +619,10 @@ impl Thread {
fn run_turn(
&mut self,
cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+ ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
self.cancel();
- let model = self.model.clone();
+ let model = self.model.clone().context("No language model configured")?;
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
@@ -637,7 +640,7 @@ impl Thread {
);
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
- })?;
+ })??;
log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?;
@@ -729,7 +732,7 @@ impl Thread {
.ok();
}),
});
- events_rx
+ Ok(events_rx)
}
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
@@ -917,7 +920,7 @@ impl Thread {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
- let supports_images = self.model.supports_images();
+ let supports_images = self.model().map_or(false, |model| model.supports_images());
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
log::info!("Running tool {}", tool_use.name);
Some(cx.foreground_executor().spawn(async move {
@@ -1005,7 +1008,9 @@ impl Thread {
&self,
completion_intent: CompletionIntent,
cx: &mut App,
- ) -> LanguageModelRequest {
+ ) -> Result<LanguageModelRequest> {
+ let model = self.model().context("No language model configured")?;
+
log::debug!("Building completion request");
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
@@ -1021,9 +1026,7 @@ impl Thread {
Some(LanguageModelRequestTool {
name: tool_name,
description: tool.description().to_string(),
- input_schema: tool
- .input_schema(self.model.tool_input_format())
- .log_err()?,
+ input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
})
})
.collect()
@@ -1042,20 +1045,22 @@ impl Thread {
tools,
tool_choice: None,
stop: Vec::new(),
- temperature: AgentSettings::temperature_for_model(self.model(), cx),
+ temperature: AgentSettings::temperature_for_model(&model, cx),
thinking_allowed: true,
};
log::debug!("Completion request built successfully");
- request
+ Ok(request)
}
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
+ let model = self.model().context("No language model configured")?;
+
let profile = AgentSettings::get_global(cx)
.profiles
.get(&self.profile_id)
.context("profile not found")?;
- let provider_id = self.model.provider_id();
+ let provider_id = model.provider_id();
Ok(self
.tools
@@ -237,11 +237,17 @@ impl AgentTool for EditFileTool {
});
}
- let request = self.thread.update(cx, |thread, cx| {
- thread.build_completion_request(CompletionIntent::ToolResults, cx)
- });
+ let Some(request) = self.thread.update(cx, |thread, cx| {
+ thread
+ .build_completion_request(CompletionIntent::ToolResults, cx)
+ .ok()
+ }) else {
+ return Task::ready(Err(anyhow!("Failed to build completion request")));
+ };
let thread = self.thread.read(cx);
- let model = thread.model().clone();
+ let Some(model) = thread.model().cloned() else {
+ return Task::ready(Err(anyhow!("No language model configured")));
+ };
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
@@ -520,7 +526,7 @@ mod tests {
context_server_registry,
action_log,
Templates::new(),
- model,
+ Some(model),
cx,
)
});
@@ -717,7 +723,7 @@ mod tests {
context_server_registry,
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -853,7 +859,7 @@ mod tests {
context_server_registry,
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -979,7 +985,7 @@ mod tests {
context_server_registry,
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -1116,7 +1122,7 @@ mod tests {
context_server_registry,
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -1226,7 +1232,7 @@ mod tests {
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -1307,7 +1313,7 @@ mod tests {
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -1391,7 +1397,7 @@ mod tests {
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});
@@ -1472,7 +1478,7 @@ mod tests {
context_server_registry,
action_log.clone(),
Templates::new(),
- model.clone(),
+ Some(model.clone()),
cx,
)
});