1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, KeepAlive, OllamaToolCall,
8};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc, time::Duration};
13use ui::{prelude::*, ButtonLike, Indicator};
14use util::ResultExt;
15
16use crate::LanguageModelCompletionEvent;
17use crate::{
18 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
19 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
20 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
21};
22
23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
25const OLLAMA_SITE: &str = "https://ollama.com/";
26
27const PROVIDER_ID: &str = "ollama";
28const PROVIDER_NAME: &str = "Ollama";
29
30#[derive(Default, Debug, Clone, PartialEq)]
31pub struct OllamaSettings {
32 pub api_url: String,
33 pub low_speed_timeout: Option<Duration>,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.1:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45 /// The number of seconds to keep the connection open after the last request
46 pub keep_alive: Option<KeepAlive>,
47}
48
49pub struct OllamaLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Model<State>,
52}
53
54pub struct State {
55 http_client: Arc<dyn HttpClient>,
56 available_models: Vec<ollama::Model>,
57 _subscription: Subscription,
58}
59
60impl State {
61 fn is_authenticated(&self) -> bool {
62 !self.available_models.is_empty()
63 }
64
65 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
66 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
67 let http_client = self.http_client.clone();
68 let api_url = settings.api_url.clone();
69
70 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
71 cx.spawn(|this, mut cx| async move {
72 let models = get_models(http_client.as_ref(), &api_url, None).await?;
73
74 let mut models: Vec<ollama::Model> = models
75 .into_iter()
76 // Since there is no metadata from the Ollama API
77 // indicating which models are embedding models,
78 // simply filter out models with "-embed" in their name
79 .filter(|model| !model.name.contains("-embed"))
80 .map(|model| ollama::Model::new(&model.name, None, None))
81 .collect();
82
83 models.sort_by(|a, b| a.name.cmp(&b.name));
84
85 this.update(&mut cx, |this, cx| {
86 this.available_models = models;
87 cx.notify();
88 })
89 })
90 }
91
92 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
93 if self.is_authenticated() {
94 Task::ready(Ok(()))
95 } else {
96 self.fetch_models(cx)
97 }
98 }
99}
100
101impl OllamaLanguageModelProvider {
102 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
103 let this = Self {
104 http_client: http_client.clone(),
105 state: cx.new_model(|cx| State {
106 http_client,
107 available_models: Default::default(),
108 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
109 this.fetch_models(cx).detach();
110 cx.notify();
111 }),
112 }),
113 };
114 this.state
115 .update(cx, |state, cx| state.fetch_models(cx).detach());
116 this
117 }
118}
119
120impl LanguageModelProviderState for OllamaLanguageModelProvider {
121 type ObservableEntity = State;
122
123 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
124 Some(self.state.clone())
125 }
126}
127
128impl LanguageModelProvider for OllamaLanguageModelProvider {
129 fn id(&self) -> LanguageModelProviderId {
130 LanguageModelProviderId(PROVIDER_ID.into())
131 }
132
133 fn name(&self) -> LanguageModelProviderName {
134 LanguageModelProviderName(PROVIDER_NAME.into())
135 }
136
137 fn icon(&self) -> IconName {
138 IconName::AiOllama
139 }
140
141 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
142 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
143
144 // Add models from the Ollama API
145 for model in self.state.read(cx).available_models.iter() {
146 models.insert(model.name.clone(), model.clone());
147 }
148
149 // Override with available models from settings
150 for model in AllLanguageModelSettings::get_global(cx)
151 .ollama
152 .available_models
153 .iter()
154 {
155 models.insert(
156 model.name.clone(),
157 ollama::Model {
158 name: model.name.clone(),
159 display_name: model.display_name.clone(),
160 max_tokens: model.max_tokens,
161 keep_alive: model.keep_alive.clone(),
162 },
163 );
164 }
165
166 models
167 .into_values()
168 .map(|model| {
169 Arc::new(OllamaLanguageModel {
170 id: LanguageModelId::from(model.name.clone()),
171 model: model.clone(),
172 http_client: self.http_client.clone(),
173 request_limiter: RateLimiter::new(4),
174 }) as Arc<dyn LanguageModel>
175 })
176 .collect()
177 }
178
179 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
180 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
181 let http_client = self.http_client.clone();
182 let api_url = settings.api_url.clone();
183 let id = model.id().0.to_string();
184 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
185 .detach_and_log_err(cx);
186 }
187
188 fn is_authenticated(&self, cx: &AppContext) -> bool {
189 self.state.read(cx).is_authenticated()
190 }
191
192 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
193 self.state.update(cx, |state, cx| state.authenticate(cx))
194 }
195
196 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
197 let state = self.state.clone();
198 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
199 }
200
201 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
202 self.state.update(cx, |state, cx| state.fetch_models(cx))
203 }
204}
205
206pub struct OllamaLanguageModel {
207 id: LanguageModelId,
208 model: ollama::Model,
209 http_client: Arc<dyn HttpClient>,
210 request_limiter: RateLimiter,
211}
212
213impl OllamaLanguageModel {
214 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
215 ChatRequest {
216 model: self.model.name.clone(),
217 messages: request
218 .messages
219 .into_iter()
220 .map(|msg| match msg.role {
221 Role::User => ChatMessage::User {
222 content: msg.string_contents(),
223 },
224 Role::Assistant => ChatMessage::Assistant {
225 content: msg.string_contents(),
226 tool_calls: None,
227 },
228 Role::System => ChatMessage::System {
229 content: msg.string_contents(),
230 },
231 })
232 .collect(),
233 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
234 stream: true,
235 options: Some(ChatOptions {
236 num_ctx: Some(self.model.max_tokens),
237 stop: Some(request.stop),
238 temperature: request.temperature.or(Some(1.0)),
239 ..Default::default()
240 }),
241 tools: vec![],
242 }
243 }
244 fn request_completion(
245 &self,
246 request: ChatRequest,
247 cx: &AsyncAppContext,
248 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
249 let http_client = self.http_client.clone();
250
251 let Ok(api_url) = cx.update(|cx| {
252 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
253 settings.api_url.clone()
254 }) else {
255 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
256 };
257
258 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
259 }
260}
261
262impl LanguageModel for OllamaLanguageModel {
263 fn id(&self) -> LanguageModelId {
264 self.id.clone()
265 }
266
267 fn name(&self) -> LanguageModelName {
268 LanguageModelName::from(self.model.display_name().to_string())
269 }
270
271 fn provider_id(&self) -> LanguageModelProviderId {
272 LanguageModelProviderId(PROVIDER_ID.into())
273 }
274
275 fn provider_name(&self) -> LanguageModelProviderName {
276 LanguageModelProviderName(PROVIDER_NAME.into())
277 }
278
279 fn telemetry_id(&self) -> String {
280 format!("ollama/{}", self.model.id())
281 }
282
283 fn max_token_count(&self) -> usize {
284 self.model.max_token_count()
285 }
286
287 fn count_tokens(
288 &self,
289 request: LanguageModelRequest,
290 _cx: &AppContext,
291 ) -> BoxFuture<'static, Result<usize>> {
292 // There is no endpoint for this _yet_ in Ollama
293 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
294 let token_count = request
295 .messages
296 .iter()
297 .map(|msg| msg.string_contents().chars().count())
298 .sum::<usize>()
299 / 4;
300
301 async move { Ok(token_count) }.boxed()
302 }
303
304 fn stream_completion(
305 &self,
306 request: LanguageModelRequest,
307 cx: &AsyncAppContext,
308 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309 let request = self.to_ollama_request(request);
310
311 let http_client = self.http_client.clone();
312 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
313 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
314 (settings.api_url.clone(), settings.low_speed_timeout)
315 }) else {
316 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317 };
318
319 let future = self.request_limiter.stream(async move {
320 let response =
321 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
322 .await?;
323 let stream = response
324 .filter_map(|response| async move {
325 match response {
326 Ok(delta) => {
327 let content = match delta.message {
328 ChatMessage::User { content } => content,
329 ChatMessage::Assistant { content, .. } => content,
330 ChatMessage::System { content } => content,
331 };
332 Some(Ok(content))
333 }
334 Err(error) => Some(Err(error)),
335 }
336 })
337 .boxed();
338 Ok(stream)
339 });
340
341 async move {
342 Ok(future
343 .await?
344 .map(|result| result.map(LanguageModelCompletionEvent::Text))
345 .boxed())
346 }
347 .boxed()
348 }
349
350 fn use_any_tool(
351 &self,
352 request: LanguageModelRequest,
353 tool_name: String,
354 tool_description: String,
355 schema: serde_json::Value,
356 cx: &AsyncAppContext,
357 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
358 use ollama::{OllamaFunctionTool, OllamaTool};
359 let function = OllamaFunctionTool {
360 name: tool_name.clone(),
361 description: Some(tool_description),
362 parameters: Some(schema),
363 };
364 let tools = vec![OllamaTool::Function { function }];
365 let request = self.to_ollama_request(request).with_tools(tools);
366 let response = self.request_completion(request, cx);
367 self.request_limiter
368 .run(async move {
369 let response = response.await?;
370 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
371 bail!("message does not have an assistant role");
372 };
373 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
374 for call in tool_calls {
375 let OllamaToolCall::Function(function) = call;
376 if function.name == tool_name {
377 return Ok(futures::stream::once(async move {
378 Ok(function.arguments.to_string())
379 })
380 .boxed());
381 }
382 }
383 } else {
384 bail!("assistant message does not have any tool calls");
385 };
386
387 bail!("tool not used")
388 })
389 .boxed()
390 }
391}
392
393struct ConfigurationView {
394 state: gpui::Model<State>,
395 loading_models_task: Option<Task<()>>,
396}
397
398impl ConfigurationView {
399 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
400 let loading_models_task = Some(cx.spawn({
401 let state = state.clone();
402 |this, mut cx| async move {
403 if let Some(task) = state
404 .update(&mut cx, |state, cx| state.authenticate(cx))
405 .log_err()
406 {
407 task.await.log_err();
408 }
409 this.update(&mut cx, |this, cx| {
410 this.loading_models_task = None;
411 cx.notify();
412 })
413 .log_err();
414 }
415 }));
416
417 Self {
418 state,
419 loading_models_task,
420 }
421 }
422
423 fn retry_connection(&self, cx: &mut WindowContext) {
424 self.state
425 .update(cx, |state, cx| state.fetch_models(cx))
426 .detach_and_log_err(cx);
427 }
428}
429
430impl Render for ConfigurationView {
431 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
432 let is_authenticated = self.state.read(cx).is_authenticated();
433
434 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
435 let ollama_reqs =
436 "Ollama must be running with at least one model installed to use it in the assistant.";
437
438 let mut inline_code_bg = cx.theme().colors().editor_background;
439 inline_code_bg.fade_out(0.5);
440
441 if self.loading_models_task.is_some() {
442 div().child(Label::new("Loading models...")).into_any()
443 } else {
444 v_flex()
445 .size_full()
446 .gap_3()
447 .child(
448 v_flex()
449 .size_full()
450 .gap_2()
451 .p_1()
452 .child(Label::new(ollama_intro))
453 .child(Label::new(ollama_reqs))
454 .child(
455 h_flex()
456 .gap_0p5()
457 .child(Label::new("Once installed, try "))
458 .child(
459 div()
460 .bg(inline_code_bg)
461 .px_1p5()
462 .rounded_md()
463 .child(Label::new("ollama run llama3.1")),
464 ),
465 ),
466 )
467 .child(
468 h_flex()
469 .w_full()
470 .pt_2()
471 .justify_between()
472 .gap_2()
473 .child(
474 h_flex()
475 .w_full()
476 .gap_2()
477 .map(|this| {
478 if is_authenticated {
479 this.child(
480 Button::new("ollama-site", "Ollama")
481 .style(ButtonStyle::Subtle)
482 .icon(IconName::ExternalLink)
483 .icon_size(IconSize::XSmall)
484 .icon_color(Color::Muted)
485 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
486 .into_any_element(),
487 )
488 } else {
489 this.child(
490 Button::new(
491 "download_ollama_button",
492 "Download Ollama",
493 )
494 .style(ButtonStyle::Subtle)
495 .icon(IconName::ExternalLink)
496 .icon_size(IconSize::XSmall)
497 .icon_color(Color::Muted)
498 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
499 .into_any_element(),
500 )
501 }
502 })
503 .child(
504 Button::new("view-models", "All Models")
505 .style(ButtonStyle::Subtle)
506 .icon(IconName::ExternalLink)
507 .icon_size(IconSize::XSmall)
508 .icon_color(Color::Muted)
509 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
510 ),
511 )
512 .child(if is_authenticated {
513 // This is only a button to ensure the spacing is correct
514 // it should stay disabled
515 ButtonLike::new("connected")
516 .disabled(true)
517 // Since this won't ever be clickable, we can use the arrow cursor
518 .cursor_style(gpui::CursorStyle::Arrow)
519 .child(
520 h_flex()
521 .gap_2()
522 .child(Indicator::dot().color(Color::Success))
523 .child(Label::new("Connected"))
524 .into_any_element(),
525 )
526 .into_any_element()
527 } else {
528 Button::new("retry_ollama_models", "Connect")
529 .icon_position(IconPosition::Start)
530 .icon(IconName::ArrowCircle)
531 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
532 .into_any_element()
533 }),
534 )
535 .into_any()
536 }
537 }
538}