llm.go

  1// Package llm provides a unified interface for interacting with LLMs.
  2package llm
  3
  4import (
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"net/http"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14	"time"
 15)
 16
 17type Service interface {
 18	// Do sends a request to an LLM.
 19	Do(context.Context, *Request) (*Response, error)
 20	// TokenContextWindow returns the maximum token context window size for this service
 21	TokenContextWindow() int
 22	// MaxImageDimension returns the maximum allowed dimension (width or height) for images.
 23	// For multi-image requests, some providers enforce stricter limits.
 24	// Returns 0 if there is no limit.
 25	MaxImageDimension() int
 26}
 27
 28type SimplifiedPatcher interface {
 29	// UseSimplifiedPatch reports whether the service should use the simplified patch input schema.
 30	UseSimplifiedPatch() bool
 31}
 32
 33func UseSimplifiedPatch(svc Service) bool {
 34	if sp, ok := svc.(SimplifiedPatcher); ok {
 35		return sp.UseSimplifiedPatch()
 36	}
 37	return false
 38}
 39
 40// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
 41// It panics if the schema is invalid.
 42// The schema must have at least type="object" and a properties key.
 43func MustSchema(schema string) json.RawMessage {
 44	schema = strings.TrimSpace(schema)
 45	bytes := []byte(schema)
 46	var obj map[string]any
 47	if err := json.Unmarshal(bytes, &obj); err != nil {
 48		panic("failed to parse JSON schema: " + schema + ": " + err.Error())
 49	}
 50	if typ, ok := obj["type"]; !ok || typ != "object" {
 51		panic("JSON schema must have type='object': " + schema)
 52	}
 53	if _, ok := obj["properties"]; !ok {
 54		panic("JSON schema must have 'properties' key: " + schema)
 55	}
 56	return json.RawMessage(bytes)
 57}
 58
 59func EmptySchema() json.RawMessage {
 60	return MustSchema(`{"type": "object", "properties": {}}`)
 61}
 62
 63// ErrorType identifies system-generated error messages (not LLM content).
 64type ErrorType string
 65
 66const (
 67	ErrorTypeNone       ErrorType = ""            // Not an error
 68	ErrorTypeTruncation ErrorType = "truncation"  // Response truncated due to max tokens
 69	ErrorTypeLLMRequest ErrorType = "llm_request" // LLM request failed
 70)
 71
 72type Request struct {
 73	Messages   []Message
 74	ToolChoice *ToolChoice
 75	Tools      []*Tool
 76	System     []SystemContent
 77}
 78
 79// Message represents a message in the conversation.
 80type Message struct {
 81	Role      MessageRole `json:"Role"`
 82	Content   []Content   `json:"Content"`
 83	ToolUse   *ToolUse    `json:"ToolUse,omitempty"` // use to control whether/which tool to use
 84	EndOfTurn bool        `json:"EndOfTurn"`         // true if this message completes the agent's turn (no tool calls to make)
 85
 86	// ExcludedFromContext indicates this message should be stored but not sent back to the LLM.
 87	// Used for truncated responses we want to keep for cost tracking but that would confuse the LLM.
 88	ExcludedFromContext bool `json:"ExcludedFromContext,omitempty"`
 89
 90	// ErrorType indicates this is a system-generated error message (not LLM content).
 91	// Empty string means not an error. Values: "truncation", "llm_request".
 92	ErrorType ErrorType `json:"ErrorType,omitempty"`
 93}
 94
 95// ToolUse represents a tool use in the message content.
 96type ToolUse struct {
 97	ID   string
 98	Name string
 99}
100
101type ToolChoice struct {
102	Type ToolChoiceType
103	Name string
104}
105
106type SystemContent struct {
107	Text  string
108	Type  string
109	Cache bool
110}
111
112// Tool represents a tool available to an LLM.
113type Tool struct {
114	Name string
115	// Type is used by the text editor tool; see
116	// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
117	Type        string
118	Description string
119	InputSchema json.RawMessage
120	// EndsTurn indicates that this tool should cause the model to end its turn when used
121	EndsTurn bool
122	// Cache indicates whether to use prompt caching for this tool
123	Cache bool
124
125	// The Run function is automatically called when the tool is used.
126	// Run functions may be called concurrently with each other and themselves.
127	// The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
128	// The outputs from Run will be sent back to Claude.
129	// If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
130	// ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
131	Run func(ctx context.Context, input json.RawMessage) ToolOut `json:"-"`
132}
133
134// ToolOut represents the output of a tool run.
135type ToolOut struct {
136	// LLMContent is the output of the tool to be sent back to the LLM.
137	// May be nil on error.
138	LLMContent []Content
139	// Display is content to be displayed to the user.
140	// The type of content is set by the tool and coordinated with the UIs.
141	// It should be JSON-serializable.
142	Display any
143	// Error is the error (if any) that occurred during the tool run.
144	// The text contents of the error will be sent back to the LLM.
145	// If non-nil, LLMContent will be ignored.
146	Error error
147}
148
149type Content struct {
150	ID   string
151	Type ContentType
152	Text string
153
154	// Media type for image content
155	MediaType string
156
157	// for thinking
158	Thinking  string
159	Data      string
160	Signature string
161
162	// for tool_use
163	ToolName  string
164	ToolInput json.RawMessage
165
166	// for tool_result
167	ToolUseID  string
168	ToolError  bool
169	ToolResult []Content
170
171	// timing information for tool_result; added externally; not sent to the LLM
172	ToolUseStartTime *time.Time
173	ToolUseEndTime   *time.Time
174
175	// Display is content to be displayed to the user, copied from ToolOut
176	Display any
177
178	Cache bool
179}
180
181func StringContent(s string) Content {
182	return Content{Type: ContentTypeText, Text: s}
183}
184
185// ContentsAttr returns contents as a slog.Attr.
186// It is meant for logging.
187func ContentsAttr(contents []Content) slog.Attr {
188	var contentAttrs []any // slog.Attr
189	for _, content := range contents {
190		var attrs []any // slog.Attr
191		switch content.Type {
192		case ContentTypeText:
193			attrs = append(attrs, slog.String("text", content.Text))
194		case ContentTypeToolUse:
195			attrs = append(attrs, slog.String("tool_name", content.ToolName))
196			attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
197		case ContentTypeToolResult:
198			attrs = append(attrs, slog.Any("tool_result", content.ToolResult))
199			attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
200		case ContentTypeThinking:
201			attrs = append(attrs, slog.String("thinking", content.Thinking))
202		default:
203			attrs = append(attrs, slog.String("unknown_content_type", content.Type.String()))
204			attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
205		}
206		contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
207	}
208	return slog.Group("contents", contentAttrs...)
209}
210
211type (
212	MessageRole    int
213	ContentType    int
214	ToolChoiceType int
215	StopReason     int
216	ThinkingLevel  int
217)
218
219//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason,ThinkingLevel -output=llm_string.go
220
221const (
222	MessageRoleUser MessageRole = iota
223	MessageRoleAssistant
224
225	ContentTypeText ContentType = iota
226	ContentTypeThinking
227	ContentTypeRedactedThinking
228	ContentTypeToolUse
229	ContentTypeToolResult
230
231	ToolChoiceTypeAuto ToolChoiceType = iota // default
232	ToolChoiceTypeAny                        // any tool, but must use one
233	ToolChoiceTypeNone                       // no tools allowed
234	ToolChoiceTypeTool                       // must use the tool specified in the Name field
235
236	StopReasonStopSequence StopReason = iota
237	StopReasonMaxTokens
238	StopReasonEndTurn
239	StopReasonToolUse
240	StopReasonRefusal
241)
242
243// ThinkingLevel controls how much thinking/reasoning the model does.
244// ThinkingLevelOff is the zero value and disables thinking.
245const (
246	ThinkingLevelOff     ThinkingLevel = iota // No thinking (zero value)
247	ThinkingLevelMinimal                      // Minimal thinking (1024 tokens / "minimal")
248	ThinkingLevelLow                          // Low thinking (2048 tokens / "low")
249	ThinkingLevelMedium                       // Medium thinking (8192 tokens / "medium")
250	ThinkingLevelHigh                         // High thinking (16384 tokens / "high")
251)
252
253// ThinkingBudgetTokens returns the recommended budget_tokens for Anthropic's extended thinking.
254func (t ThinkingLevel) ThinkingBudgetTokens() int {
255	switch t {
256	case ThinkingLevelMinimal:
257		return 1024
258	case ThinkingLevelLow:
259		return 2048
260	case ThinkingLevelMedium:
261		return 8192
262	case ThinkingLevelHigh:
263		return 16384
264	default:
265		return 0
266	}
267}
268
269// ThinkingEffort returns the reasoning effort string for OpenAI's reasoning API.
270func (t ThinkingLevel) ThinkingEffort() string {
271	switch t {
272	case ThinkingLevelMinimal:
273		return "minimal"
274	case ThinkingLevelLow:
275		return "low"
276	case ThinkingLevelMedium:
277		return "medium"
278	case ThinkingLevelHigh:
279		return "high"
280	default:
281		return ""
282	}
283}
284
285type Response struct {
286	ID           string
287	Type         string
288	Role         MessageRole
289	Model        string
290	Content      []Content
291	StopReason   StopReason
292	StopSequence *string
293	Usage        Usage
294	StartTime    *time.Time
295	EndTime      *time.Time
296}
297
298func (m *Response) ToMessage() Message {
299	return Message{
300		Role:      m.Role,
301		Content:   m.Content,
302		EndOfTurn: m.StopReason != StopReasonToolUse, // End of turn unless there are tools to call
303	}
304}
305
306func CostUSDFromResponse(headers http.Header) float64 {
307	h := headers.Get("Skaband-Cost-Microcents")
308	if h == "" {
309		return 0
310	}
311	uc, err := strconv.ParseUint(h, 10, 64)
312	if err != nil {
313		slog.Warn("failed to parse cost header", "header", h)
314		return 0
315	}
316	return float64(uc) / 100_000_000
317}
318
319// Usage represents the billing and rate-limit usage.
320// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
321// However, the front-end uses this struct, and it relies on its JSON serialization.
322// Do NOT use this struct directly when implementing an llm.Service.
323type Usage struct {
324	InputTokens              uint64     `json:"input_tokens"`
325	CacheCreationInputTokens uint64     `json:"cache_creation_input_tokens"`
326	CacheReadInputTokens     uint64     `json:"cache_read_input_tokens"`
327	OutputTokens             uint64     `json:"output_tokens"`
328	CostUSD                  float64    `json:"cost_usd"`
329	Model                    string     `json:"model,omitempty"`
330	StartTime                *time.Time `json:"start_time,omitempty"`
331	EndTime                  *time.Time `json:"end_time,omitempty"`
332}
333
334func (u *Usage) Add(other Usage) {
335	u.InputTokens += other.InputTokens
336	u.CacheCreationInputTokens += other.CacheCreationInputTokens
337	u.CacheReadInputTokens += other.CacheReadInputTokens
338	u.OutputTokens += other.OutputTokens
339	u.CostUSD += other.CostUSD
340}
341
342func (u *Usage) String() string {
343	return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
344}
345
346// TotalInputTokens returns the total number of input tokens including cached tokens.
347// This represents the full context that was sent to the model:
348// - InputTokens: tokens processed (not from cache)
349// - CacheCreationInputTokens: tokens written to cache (also part of input)
350// - CacheReadInputTokens: tokens read from cache (also part of input)
351func (u *Usage) TotalInputTokens() uint64 {
352	return u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens
353}
354
355// ContextWindowUsed returns the total context window usage after this response.
356// This is the size of the conversation that would be sent to the model for the next turn:
357// total input tokens + output tokens (which become part of the conversation).
358func (u *Usage) ContextWindowUsed() uint64 {
359	return u.TotalInputTokens() + u.OutputTokens
360}
361
362func (u *Usage) IsZero() bool {
363	return *u == Usage{}
364}
365
366func (u *Usage) Attr() slog.Attr {
367	return slog.Group("usage",
368		slog.Uint64("input_tokens", u.InputTokens),
369		slog.Uint64("output_tokens", u.OutputTokens),
370		slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
371		slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
372		slog.Float64("cost_usd", u.CostUSD),
373	)
374}
375
376// UserStringMessage creates a user message with a single text content item.
377func UserStringMessage(text string) Message {
378	return Message{
379		Role:    MessageRoleUser,
380		Content: []Content{StringContent(text)},
381	}
382}
383
384// TextContent creates a simple text content for tool results.
385// This is a helper function to create the most common type of tool result content.
386func TextContent(text string) []Content {
387	return []Content{{
388		Type: ContentTypeText,
389		Text: text,
390	}}
391}
392
393func ErrorToolOut(err error) ToolOut {
394	if err == nil {
395		panic("ErrorToolOut called with nil error")
396	}
397	return ToolOut{
398		Error: err,
399	}
400}
401
402func ErrorfToolOut(format string, args ...any) ToolOut {
403	return ErrorToolOut(fmt.Errorf(format, args...))
404}
405
406// DumpToFile writes LLM communication content to a timestamped file in ~/.cache/sketch/.
407// For requests, it includes the URL followed by the content. For responses, it only includes the content.
408// The typ parameter is used as a prefix in the filename ("request", "response").
409func DumpToFile(typ, url string, content []byte) error {
410	homeDir, err := os.UserHomeDir()
411	if err != nil {
412		return err
413	}
414	cacheDir := filepath.Join(homeDir, ".cache", "sketch")
415	err = os.MkdirAll(cacheDir, 0o700)
416	if err != nil {
417		return err
418	}
419	now := time.Now()
420	filename := fmt.Sprintf("%s_%d.txt", typ, now.UnixMilli())
421	filePath := filepath.Join(cacheDir, filename)
422
423	// For requests, start with the URL; for responses, just write the content
424	data := []byte(url)
425	if url != "" {
426		data = append(data, "\n\n"...)
427	}
428	data = append(data, content...)
429
430	return os.WriteFile(filePath, data, 0o600)
431}