@@ -1,3 +1,4 @@
+use super::create_label_for_command;
use anyhow::{anyhow, Result};
use assistant_slash_command::{
AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
@@ -6,9 +7,9 @@ use assistant_slash_command::{
use collections::HashMap;
use context_servers::{
manager::{ContextServer, ContextServerManager},
- protocol::PromptInfo,
+ types::Prompt,
};
-use gpui::{Task, WeakView, WindowContext};
+use gpui::{AppContext, Task, WeakView, WindowContext};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
@@ -18,11 +19,11 @@ use workspace::Workspace;
pub struct ContextServerSlashCommand {
server_id: String,
- prompt: PromptInfo,
+ prompt: Prompt,
}
impl ContextServerSlashCommand {
- pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
+ pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
Self {
server_id: server.id.clone(),
prompt,
@@ -35,12 +36,28 @@ impl SlashCommand for ContextServerSlashCommand {
self.prompt.name.clone()
}
+ fn label(&self, cx: &AppContext) -> language::CodeLabel {
+ let mut parts = vec![self.prompt.name.as_str()];
+ if let Some(args) = &self.prompt.arguments {
+ if let Some(arg) = args.first() {
+ parts.push(arg.name.as_str());
+ }
+ }
+ create_label_for_command(&parts[0], &parts[1..], cx)
+ }
+
fn description(&self) -> String {
- format!("Run context server command: {}", self.prompt.name)
+ match &self.prompt.description {
+ Some(desc) => desc.clone(),
+ None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
+ }
}
fn menu_text(&self) -> String {
- format!("Run '{}' from {}", self.prompt.name, self.server_id)
+ match &self.prompt.description {
+ Some(desc) => desc.clone(),
+ None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
+ }
}
fn requires_argument(&self) -> bool {
@@ -154,7 +171,7 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
-fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
+fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
if arguments.is_empty() {
return Err(anyhow!("No arguments given"));
}
@@ -170,7 +187,7 @@ fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(Str
}
}
-fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
+fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
match &prompt.arguments {
Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported"
@@ -199,7 +216,7 @@ fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap
/// MCP servers can return prompts with multiple arguments. Since we only
/// support one argument, we ignore all others. This is the necessary predicate
/// for this.
-pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
+pub fn acceptable_prompt(prompt: &Prompt) -> bool {
match &prompt.arguments {
None => true,
Some(args) if args.len() <= 1 => true,
@@ -26,7 +26,7 @@ const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
-type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
+type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
@@ -94,7 +94,6 @@ enum CspResult<T> {
#[derive(Serialize, Deserialize)]
struct Notification<'a, T> {
jsonrpc: &'static str,
- id: RequestId,
#[serde(borrow)]
method: &'a str,
params: T,
@@ -103,7 +102,6 @@ struct Notification<'a, T> {
#[derive(Debug, Clone, Deserialize)]
struct AnyNotification<'a> {
jsonrpc: &'a str,
- id: RequestId,
method: String,
#[serde(default)]
params: Option<Value>,
@@ -246,11 +244,7 @@ impl Client {
if let Some(handler) =
notification_handlers.get_mut(notification.method.as_str())
{
- handler(
- notification.id,
- notification.params.unwrap_or(Value::Null),
- cx.clone(),
- );
+ handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
}
}
@@ -378,10 +372,8 @@ impl Client {
/// Sends a notification to the context server without expecting a response.
/// This function serializes the notification and sends it through the outbound channel.
pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
- let id = self.next_id.fetch_add(1, SeqCst);
let notification = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION,
- id: RequestId::Int(id),
method,
params,
})
@@ -390,13 +382,13 @@ impl Client {
Ok(())
}
- pub fn on_notification<F>(&self, method: &'static str, mut f: F)
+ pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncAppContext),
{
self.notification_handlers
.lock()
- .insert(method, Box::new(move |_, params, cx| f(params, cx)));
+ .insert(method, Box::new(f));
}
pub fn name(&self) -> &str {
@@ -15,6 +15,7 @@ pub enum RequestType {
PromptsGet,
PromptsList,
CompletionComplete,
+ Ping,
}
impl RequestType {
@@ -30,6 +31,7 @@ impl RequestType {
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
+ RequestType::Ping => "ping",
}
}
}
@@ -39,14 +41,15 @@ impl RequestType {
pub struct InitializeParams {
pub protocol_version: u32,
pub capabilities: ClientCapabilities,
- pub client_info: EntityInfo,
+ pub client_info: Implementation,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolParams {
pub name: String,
- pub arguments: Option<serde_json::Value>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub arguments: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
@@ -77,6 +80,7 @@ pub struct LoggingSetLevelParams {
#[serde(rename_all = "camelCase")]
pub struct PromptsGetParams {
pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, String>>,
}
@@ -101,6 +105,13 @@ pub struct PromptReference {
pub name: String,
}
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceReference {
+ pub r#type: PromptReferenceType,
+ pub uri: Url,
+}
+
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PromptReferenceType {
@@ -110,13 +121,6 @@ pub enum PromptReferenceType {
Resource,
}
-#[derive(Debug, Serialize)]
-#[serde(rename_all = "camelCase")]
-pub struct ResourceReference {
- pub r#type: String,
- pub uri: String,
-}
-
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionArgument {
@@ -129,7 +133,7 @@ pub struct CompletionArgument {
pub struct InitializeResponse {
pub protocol_version: u32,
pub capabilities: ServerCapabilities,
- pub server_info: EntityInfo,
+ pub server_info: Implementation,
}
#[derive(Debug, Deserialize)]
@@ -141,13 +145,39 @@ pub struct ResourcesReadResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub resource_templates: Option<Vec<ResourceTemplate>>,
- pub resources: Vec<Resource>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub resources: Option<Vec<Resource>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SamplingMessage {
+ pub role: SamplingRole,
+ pub content: SamplingContent,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum SamplingRole {
+ User,
+ Assistant,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum SamplingContent {
+ #[serde(rename = "text")]
+ Text { text: String },
+ #[serde(rename = "image")]
+ Image { data: String, mime_type: String },
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetResponse {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub prompt: String,
}
@@ -155,7 +185,7 @@ pub struct PromptsGetResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
- pub prompts: Vec<PromptInfo>,
+ pub prompts: Vec<Prompt>,
}
#[derive(Debug, Deserialize)]
@@ -168,61 +198,91 @@ pub struct CompletionCompleteResponse {
#[serde(rename_all = "camelCase")]
pub struct CompletionResult {
pub values: Vec<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub has_more: Option<bool>,
}
-#[derive(Debug, Deserialize, Clone)]
+#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct PromptInfo {
+pub struct Prompt {
pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub description: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Vec<PromptArgument>>,
}
-#[derive(Debug, Deserialize, Clone)]
+#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptArgument {
pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
}
-// Shared Types
-
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientCapabilities {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
- pub sampling: Option<HashMap<String, serde_json::Value>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub sampling: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
- pub logging: Option<HashMap<String, serde_json::Value>>,
- pub prompts: Option<HashMap<String, serde_json::Value>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub logging: Option<serde_json::Value>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompts: Option<PromptsCapabilities>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapabilities>,
- pub tools: Option<HashMap<String, serde_json::Value>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tools: Option<ToolsCapabilities>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsCapabilities {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesCapabilities {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub subscribe: Option<bool>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub list_changed: Option<bool>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ToolsCapabilities {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-pub struct EntityInfo {
+pub struct Implementation {
pub name: String,
pub version: String,
}
@@ -231,6 +291,10 @@ pub struct EntityInfo {
#[serde(rename_all = "camelCase")]
pub struct Resource {
pub uri: Url,
+ pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub description: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
@@ -238,17 +302,23 @@ pub struct Resource {
#[serde(rename_all = "camelCase")]
pub struct ResourceContent {
pub uri: Url,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
- pub data: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub blob: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceTemplate {
pub uri_template: String,
- pub name: Option<String>,
+ pub name: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -260,13 +330,16 @@ pub enum LoggingLevel {
Error,
}
-// Client Notifications
-
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
Progress,
+ Message,
+ ResourcesUpdated,
+ ResourcesListChanged,
+ ToolsListChanged,
+ PromptsListChanged,
}
impl NotificationType {
@@ -274,6 +347,11 @@ impl NotificationType {
match self {
NotificationType::Initialized => "notifications/initialized",
NotificationType::Progress => "notifications/progress",
+ NotificationType::Message => "notifications/message",
+ NotificationType::ResourcesUpdated => "notifications/resources/updated",
+ NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
+ NotificationType::ToolsListChanged => "notifications/tools/list_changed",
+ NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
}
}
}
@@ -288,12 +366,13 @@ pub enum ClientNotification {
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ProgressParams {
- pub progress_token: String,
+ pub progress_token: ProgressToken,
pub progress: f64,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>,
}
-// Helper Types that don't map directly to the protocol
+pub type ProgressToken = String;
pub enum CompletionTotal {
Exact(u32),