1use anyhow::Result;
2use assistant_tooling::ToolFunctionDefinition;
3use client::{proto, Client};
4use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
5use gpui::{AppContext, Global};
6use std::sync::Arc;
7
8pub use open_ai::RequestMessage as CompletionMessage;
9
10#[derive(Clone)]
11pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
12
13impl CompletionProvider {
14 pub fn get(cx: &AppContext) -> &Self {
15 cx.global::<CompletionProvider>()
16 }
17
18 pub fn new(backend: impl CompletionProviderBackend) -> Self {
19 Self(Arc::new(backend))
20 }
21
22 pub fn default_model(&self) -> String {
23 self.0.default_model()
24 }
25
26 pub fn available_models(&self) -> Vec<String> {
27 self.0.available_models()
28 }
29
30 pub fn complete(
31 &self,
32 model: String,
33 messages: Vec<CompletionMessage>,
34 stop: Vec<String>,
35 temperature: f32,
36 tools: Vec<ToolFunctionDefinition>,
37 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
38 {
39 self.0.complete(model, messages, stop, temperature, tools)
40 }
41}
42
43impl Global for CompletionProvider {}
44
45pub trait CompletionProviderBackend: 'static {
46 fn default_model(&self) -> String;
47 fn available_models(&self) -> Vec<String>;
48 fn complete(
49 &self,
50 model: String,
51 messages: Vec<CompletionMessage>,
52 stop: Vec<String>,
53 temperature: f32,
54 tools: Vec<ToolFunctionDefinition>,
55 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
56}
57
58pub struct CloudCompletionProvider {
59 client: Arc<Client>,
60}
61
62impl CloudCompletionProvider {
63 pub fn new(client: Arc<Client>) -> Self {
64 Self { client }
65 }
66}
67
68impl CompletionProviderBackend for CloudCompletionProvider {
69 fn default_model(&self) -> String {
70 "gpt-4-turbo".into()
71 }
72
73 fn available_models(&self) -> Vec<String> {
74 vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
75 }
76
77 fn complete(
78 &self,
79 model: String,
80 messages: Vec<CompletionMessage>,
81 stop: Vec<String>,
82 temperature: f32,
83 tools: Vec<ToolFunctionDefinition>,
84 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
85 {
86 let client = self.client.clone();
87 let tools: Vec<proto::ChatCompletionTool> = tools
88 .iter()
89 .filter_map(|tool| {
90 Some(proto::ChatCompletionTool {
91 variant: Some(proto::chat_completion_tool::Variant::Function(
92 proto::chat_completion_tool::FunctionObject {
93 name: tool.name.clone(),
94 description: Some(tool.description.clone()),
95 parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
96 },
97 )),
98 })
99 })
100 .collect();
101
102 let tool_choice = match tools.is_empty() {
103 true => None,
104 false => Some("auto".into()),
105 };
106
107 async move {
108 let stream = client
109 .request_stream(proto::CompleteWithLanguageModel {
110 model,
111 messages: messages
112 .into_iter()
113 .map(|message| match message {
114 CompletionMessage::Assistant {
115 content,
116 tool_calls,
117 } => proto::LanguageModelRequestMessage {
118 role: proto::LanguageModelRole::LanguageModelAssistant as i32,
119 content: content.unwrap_or_default(),
120 tool_call_id: None,
121 tool_calls: tool_calls
122 .into_iter()
123 .map(|tool_call| match tool_call.content {
124 open_ai::ToolCallContent::Function { function } => {
125 proto::ToolCall {
126 id: tool_call.id,
127 variant: Some(proto::tool_call::Variant::Function(
128 proto::tool_call::FunctionCall {
129 name: function.name,
130 arguments: function.arguments,
131 },
132 )),
133 }
134 }
135 })
136 .collect(),
137 },
138 CompletionMessage::User { content } => {
139 proto::LanguageModelRequestMessage {
140 role: proto::LanguageModelRole::LanguageModelUser as i32,
141 content,
142 tool_call_id: None,
143 tool_calls: Vec::new(),
144 }
145 }
146 CompletionMessage::System { content } => {
147 proto::LanguageModelRequestMessage {
148 role: proto::LanguageModelRole::LanguageModelSystem as i32,
149 content,
150 tool_calls: Vec::new(),
151 tool_call_id: None,
152 }
153 }
154 CompletionMessage::Tool {
155 content,
156 tool_call_id,
157 } => proto::LanguageModelRequestMessage {
158 role: proto::LanguageModelRole::LanguageModelTool as i32,
159 content,
160 tool_call_id: Some(tool_call_id),
161 tool_calls: Vec::new(),
162 },
163 })
164 .collect(),
165 stop,
166 temperature,
167 tool_choice,
168 tools,
169 })
170 .await?;
171
172 Ok(stream
173 .filter_map(|response| async move {
174 match response {
175 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
176 Err(error) => Some(Err(error)),
177 }
178 })
179 .boxed())
180 }
181 .boxed()
182 }
183}