completion_provider.rs

  1use anyhow::Result;
  2use assistant_tooling::ToolFunctionDefinition;
  3use client::{proto, Client};
  4use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  5use gpui::Global;
  6use std::sync::Arc;
  7
  8pub use open_ai::RequestMessage as CompletionMessage;
  9
 10#[derive(Clone)]
 11pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
 12
 13impl CompletionProvider {
 14    pub fn new(backend: impl CompletionProviderBackend) -> Self {
 15        Self(Arc::new(backend))
 16    }
 17
 18    pub fn default_model(&self) -> String {
 19        self.0.default_model()
 20    }
 21
 22    pub fn available_models(&self) -> Vec<String> {
 23        self.0.available_models()
 24    }
 25
 26    pub fn complete(
 27        &self,
 28        model: String,
 29        messages: Vec<CompletionMessage>,
 30        stop: Vec<String>,
 31        temperature: f32,
 32        tools: &[ToolFunctionDefinition],
 33    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
 34    {
 35        self.0.complete(model, messages, stop, temperature, tools)
 36    }
 37}
 38
 39impl Global for CompletionProvider {}
 40
 41pub trait CompletionProviderBackend: 'static {
 42    fn default_model(&self) -> String;
 43    fn available_models(&self) -> Vec<String>;
 44    fn complete(
 45        &self,
 46        model: String,
 47        messages: Vec<CompletionMessage>,
 48        stop: Vec<String>,
 49        temperature: f32,
 50        tools: &[ToolFunctionDefinition],
 51    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
 52}
 53
 54pub struct CloudCompletionProvider {
 55    client: Arc<Client>,
 56}
 57
 58impl CloudCompletionProvider {
 59    pub fn new(client: Arc<Client>) -> Self {
 60        Self { client }
 61    }
 62}
 63
 64impl CompletionProviderBackend for CloudCompletionProvider {
 65    fn default_model(&self) -> String {
 66        "gpt-4-turbo".into()
 67    }
 68
 69    fn available_models(&self) -> Vec<String> {
 70        vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
 71    }
 72
 73    fn complete(
 74        &self,
 75        model: String,
 76        messages: Vec<CompletionMessage>,
 77        stop: Vec<String>,
 78        temperature: f32,
 79        tools: &[ToolFunctionDefinition],
 80    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
 81    {
 82        let client = self.client.clone();
 83        let tools: Vec<proto::ChatCompletionTool> = tools
 84            .iter()
 85            .filter_map(|tool| {
 86                Some(proto::ChatCompletionTool {
 87                    variant: Some(proto::chat_completion_tool::Variant::Function(
 88                        proto::chat_completion_tool::FunctionObject {
 89                            name: tool.name.clone(),
 90                            description: Some(tool.description.clone()),
 91                            parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
 92                        },
 93                    )),
 94                })
 95            })
 96            .collect();
 97
 98        let tool_choice = match tools.is_empty() {
 99            true => None,
100            false => Some("auto".into()),
101        };
102
103        async move {
104            let stream = client
105                .request_stream(proto::CompleteWithLanguageModel {
106                    model,
107                    messages: messages
108                        .into_iter()
109                        .map(|message| match message {
110                            CompletionMessage::Assistant {
111                                content,
112                                tool_calls,
113                            } => proto::LanguageModelRequestMessage {
114                                role: proto::LanguageModelRole::LanguageModelAssistant as i32,
115                                content: content.unwrap_or_default(),
116                                tool_call_id: None,
117                                tool_calls: tool_calls
118                                    .into_iter()
119                                    .map(|tool_call| match tool_call.content {
120                                        open_ai::ToolCallContent::Function { function } => {
121                                            proto::ToolCall {
122                                                id: tool_call.id,
123                                                variant: Some(proto::tool_call::Variant::Function(
124                                                    proto::tool_call::FunctionCall {
125                                                        name: function.name,
126                                                        arguments: function.arguments,
127                                                    },
128                                                )),
129                                            }
130                                        }
131                                    })
132                                    .collect(),
133                            },
134                            CompletionMessage::User { content } => {
135                                proto::LanguageModelRequestMessage {
136                                    role: proto::LanguageModelRole::LanguageModelUser as i32,
137                                    content,
138                                    tool_call_id: None,
139                                    tool_calls: Vec::new(),
140                                }
141                            }
142                            CompletionMessage::System { content } => {
143                                proto::LanguageModelRequestMessage {
144                                    role: proto::LanguageModelRole::LanguageModelSystem as i32,
145                                    content,
146                                    tool_calls: Vec::new(),
147                                    tool_call_id: None,
148                                }
149                            }
150                            CompletionMessage::Tool {
151                                content,
152                                tool_call_id,
153                            } => proto::LanguageModelRequestMessage {
154                                role: proto::LanguageModelRole::LanguageModelTool as i32,
155                                content,
156                                tool_call_id: Some(tool_call_id),
157                                tool_calls: Vec::new(),
158                            },
159                        })
160                        .collect(),
161                    stop,
162                    temperature,
163                    tool_choice,
164                    tools,
165                })
166                .await?;
167
168            Ok(stream
169                .filter_map(|response| async move {
170                    match response {
171                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
172                        Err(error) => Some(Err(error)),
173                    }
174                })
175                .boxed())
176        }
177        .boxed()
178    }
179}