1use crate::{
2 assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::Result;
5use futures::StreamExt as _;
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
7use gpui::{AnyView, AppContext, Task};
8use http::HttpClient;
9use ollama::{
10 get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
11};
12use std::sync::Arc;
13use std::time::Duration;
14use ui::{prelude::*, ButtonLike, ElevationIndex};
15
16const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
17
18pub struct OllamaCompletionProvider {
19 api_url: String,
20 model: OllamaModel,
21 http_client: Arc<dyn HttpClient>,
22 low_speed_timeout: Option<Duration>,
23 settings_version: usize,
24 available_models: Vec<OllamaModel>,
25}
26
27impl OllamaCompletionProvider {
28 pub fn new(
29 model: OllamaModel,
30 api_url: String,
31 http_client: Arc<dyn HttpClient>,
32 low_speed_timeout: Option<Duration>,
33 settings_version: usize,
34 ) -> Self {
35 Self {
36 api_url,
37 model,
38 http_client,
39 low_speed_timeout,
40 settings_version,
41 available_models: Default::default(),
42 }
43 }
44
45 pub fn update(
46 &mut self,
47 model: OllamaModel,
48 api_url: String,
49 low_speed_timeout: Option<Duration>,
50 settings_version: usize,
51 ) {
52 self.model = model;
53 self.api_url = api_url;
54 self.low_speed_timeout = low_speed_timeout;
55 self.settings_version = settings_version;
56 }
57
58 pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
59 self.available_models.iter()
60 }
61
62 pub fn settings_version(&self) -> usize {
63 self.settings_version
64 }
65
66 pub fn is_authenticated(&self) -> bool {
67 !self.available_models.is_empty()
68 }
69
70 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
71 if self.is_authenticated() {
72 Task::ready(Ok(()))
73 } else {
74 self.fetch_models(cx)
75 }
76 }
77
78 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
79 self.fetch_models(cx)
80 }
81
82 pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
83 let http_client = self.http_client.clone();
84 let api_url = self.api_url.clone();
85
86 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
87 cx.spawn(|mut cx| async move {
88 let models = get_models(http_client.as_ref(), &api_url, None).await?;
89
90 let mut models: Vec<OllamaModel> = models
91 .into_iter()
92 // Since there is no metadata from the Ollama API
93 // indicating which models are embedding models,
94 // simply filter out models with "-embed" in their name
95 .filter(|model| !model.name.contains("-embed"))
96 .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
97 .collect();
98
99 models.sort_by(|a, b| a.name.cmp(&b.name));
100
101 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
102 if let CompletionProvider::Ollama(provider) = provider {
103 provider.available_models = models;
104 }
105 })
106 })
107 }
108
109 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
110 cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into()
111 }
112
113 pub fn model(&self) -> OllamaModel {
114 self.model.clone()
115 }
116
117 pub fn count_tokens(
118 &self,
119 request: LanguageModelRequest,
120 _cx: &AppContext,
121 ) -> BoxFuture<'static, Result<usize>> {
122 // There is no endpoint for this _yet_ in Ollama
123 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
124 let token_count = request
125 .messages
126 .iter()
127 .map(|msg| msg.content.chars().count())
128 .sum::<usize>()
129 / 4;
130
131 async move { Ok(token_count) }.boxed()
132 }
133
134 pub fn complete(
135 &self,
136 request: LanguageModelRequest,
137 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
138 let request = self.to_ollama_request(request);
139
140 let http_client = self.http_client.clone();
141 let api_url = self.api_url.clone();
142 let low_speed_timeout = self.low_speed_timeout;
143 async move {
144 let request =
145 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
146 let response = request.await?;
147 let stream = response
148 .filter_map(|response| async move {
149 match response {
150 Ok(delta) => {
151 let content = match delta.message {
152 ChatMessage::User { content } => content,
153 ChatMessage::Assistant { content } => content,
154 ChatMessage::System { content } => content,
155 };
156 Some(Ok(content))
157 }
158 Err(error) => Some(Err(error)),
159 }
160 })
161 .boxed();
162 Ok(stream)
163 }
164 .boxed()
165 }
166
167 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
168 let model = match request.model {
169 LanguageModel::Ollama(model) => model,
170 _ => self.model(),
171 };
172
173 ChatRequest {
174 model: model.name,
175 messages: request
176 .messages
177 .into_iter()
178 .map(|msg| match msg.role {
179 Role::User => ChatMessage::User {
180 content: msg.content,
181 },
182 Role::Assistant => ChatMessage::Assistant {
183 content: msg.content,
184 },
185 Role::System => ChatMessage::System {
186 content: msg.content,
187 },
188 })
189 .collect(),
190 keep_alive: model.keep_alive,
191 stream: true,
192 options: Some(ChatOptions {
193 num_ctx: Some(model.max_tokens),
194 stop: Some(request.stop),
195 temperature: Some(request.temperature),
196 ..Default::default()
197 }),
198 }
199 }
200}
201
202impl From<Role> for ollama::Role {
203 fn from(val: Role) -> Self {
204 match val {
205 Role::User => OllamaRole::User,
206 Role::Assistant => OllamaRole::Assistant,
207 Role::System => OllamaRole::System,
208 }
209 }
210}
211
212struct DownloadOllamaMessage {}
213
214impl DownloadOllamaMessage {
215 pub fn new(_cx: &mut ViewContext<Self>) -> Self {
216 Self {}
217 }
218
219 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
220 ButtonLike::new("download_ollama_button")
221 .style(ButtonStyle::Filled)
222 .size(ButtonSize::Large)
223 .layer(ElevationIndex::ModalSurface)
224 .child(Label::new("Get Ollama"))
225 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
226 }
227}
228
229impl Render for DownloadOllamaMessage {
230 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
231 v_flex()
232 .p_4()
233 .size_full()
234 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large))
235 .child(
236 h_flex()
237 .w_full()
238 .p_4()
239 .justify_center()
240 .child(
241 self.render_download_button(cx)
242 )
243 )
244 .into_any()
245 }
246}