1import { getModel, type Model } from "@mariozechner/pi-ai";
2import {
3 type CustomModelConfig,
4 type RumiloConfig,
5} from "../config/schema.js";
6import { ConfigError } from "../util/errors.js";
7
8export function resolveModel(
9 modelString: string,
10 config: RumiloConfig,
11): Model<any> {
12 const colonIndex = modelString.indexOf(":");
13 if (colonIndex === -1) {
14 throw new ConfigError("Model must be in provider:model format");
15 }
16
17 const provider = modelString.slice(0, colonIndex);
18 const modelName = modelString.slice(colonIndex + 1);
19
20 if (!provider || !modelName) {
21 throw new ConfigError("Model must be in provider:model format");
22 }
23
24 // Handle custom models
25 if (provider === "custom") {
26 return resolveCustomModel(modelName, config);
27 }
28
29 // Handle built-in providers
30 return getModel(provider as any, modelName);
31}
32
33function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
34 if (!config.custom_models) {
35 throw new ConfigError(
36 `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
37 );
38 }
39
40 const customConfig = config.custom_models[modelName];
41 if (!customConfig) {
42 const available = Object.keys(config.custom_models).join(", ");
43 throw new ConfigError(
44 `Custom model '${modelName}' not found. Available custom models: ${available}`,
45 );
46 }
47
48 return buildCustomModel(customConfig);
49}
50
51function buildCustomModel(config: CustomModelConfig): Model<any> {
52 const api = config.api as any;
53
54 const cost: any = {
55 input: config.cost.input,
56 output: config.cost.output,
57 };
58
59 if (config.cost.cache_read !== undefined) {
60 cost.cacheRead = config.cost.cache_read;
61 }
62
63 if (config.cost.cache_write !== undefined) {
64 cost.cacheWrite = config.cost.cache_write;
65 }
66
67 const model: any = {
68 id: config.id,
69 name: config.name,
70 api,
71 provider: config.provider as any,
72 baseUrl: config.base_url,
73 reasoning: config.reasoning,
74 input: config.input,
75 cost,
76 contextWindow: config.context_window,
77 maxTokens: config.max_tokens,
78 };
79
80 if (config.headers) {
81 model.headers = config.headers;
82 }
83
84 if (config.compat) {
85 model.compat = convertCompatConfig(config.compat);
86 }
87
88 return model;
89}
90
91function convertCompatConfig(
92 compat: CustomModelConfig["compat"],
93): any {
94 if (!compat) {
95 throw new Error("Compat config is expected to be defined");
96 }
97
98 return {
99 supportsStore: compat.supports_store,
100 supportsDeveloperRole: compat.supports_developer_role,
101 supportsReasoningEffort: compat.supports_reasoning_effort,
102 supportsUsageInStreaming: compat.supports_usage_in_streaming,
103 maxTokensField: compat.max_tokens_field,
104 requiresToolResultName: compat.requires_tool_result_name,
105 requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
106 requiresThinkingAsText: compat.requires_thinking_as_text,
107 requiresMistralToolIds: compat.requires_mistral_tool_ids,
108 thinkingFormat: compat.thinking_format,
109 };
110}