1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "strings"
12
13 "github.com/anthropics/anthropic-sdk-go"
14 "github.com/anthropics/anthropic-sdk-go/option"
15 "github.com/anthropics/anthropic-sdk-go/packages/param"
16 "github.com/charmbracelet/crush/internal/ai"
17)
18
19type AnthropicThinking struct {
20 BudgetTokens int64 `json:"budget_tokens"`
21}
22
23type AnthropicProviderOptions struct {
24 SendReasoning *bool `json:"send_reasoning,omitempty"`
25 Thinking *AnthropicThinking `json:"thinking,omitempty"`
26 DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
27}
28
29type AnthropicReasoningMetadata struct {
30 Signature string `json:"signature"`
31 RedactedData string `json:"redacted_data"`
32}
33
34type AnthropicCacheControlProviderOptions struct {
35 Type string `json:"type"`
36}
37type AnthropicFilePartProviderOptions struct {
38 EnableCitations bool `json:"enable_citations"`
39 Title string `json:"title"`
40 Context string `json:"context"`
41}
42
43type anthropicProviderOptions struct {
44 baseURL string
45 apiKey string
46 name string
47 headers map[string]string
48 client option.HTTPClient
49}
50
51type anthropicProvider struct {
52 options anthropicProviderOptions
53}
54
55type AnthropicOption = func(*anthropicProviderOptions)
56
57func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider {
58 options := anthropicProviderOptions{
59 headers: map[string]string{},
60 }
61 for _, o := range opts {
62 o(&options)
63 }
64 if options.baseURL == "" {
65 options.baseURL = "https://api.anthropic.com/v1"
66 }
67
68 if options.name == "" {
69 options.name = "anthropic"
70 }
71
72 return &anthropicProvider{
73 options: options,
74 }
75}
76
77func WithAnthropicBaseURL(baseURL string) AnthropicOption {
78 return func(o *anthropicProviderOptions) {
79 o.baseURL = baseURL
80 }
81}
82
83func WithAnthropicAPIKey(apiKey string) AnthropicOption {
84 return func(o *anthropicProviderOptions) {
85 o.apiKey = apiKey
86 }
87}
88
89func WithAnthropicName(name string) AnthropicOption {
90 return func(o *anthropicProviderOptions) {
91 o.name = name
92 }
93}
94
95func WithAnthropicHeaders(headers map[string]string) AnthropicOption {
96 return func(o *anthropicProviderOptions) {
97 maps.Copy(o.headers, headers)
98 }
99}
100
101func WithAnthropicHTTPClient(client option.HTTPClient) AnthropicOption {
102 return func(o *anthropicProviderOptions) {
103 o.client = client
104 }
105}
106
107func (a *anthropicProvider) LanguageModel(modelID string) ai.LanguageModel {
108 anthropicClientOptions := []option.RequestOption{}
109 if a.options.apiKey != "" {
110 anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey))
111 }
112 if a.options.baseURL != "" {
113 anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(a.options.baseURL))
114 }
115
116 for key, value := range a.options.headers {
117 anthropicClientOptions = append(anthropicClientOptions, option.WithHeader(key, value))
118 }
119
120 if a.options.client != nil {
121 anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client))
122 }
123 return anthropicLanguageModel{
124 modelID: modelID,
125 provider: fmt.Sprintf("%s.messages", a.options.name),
126 providerOptions: a.options,
127 client: anthropic.NewClient(anthropicClientOptions...),
128 }
129}
130
131type anthropicLanguageModel struct {
132 provider string
133 modelID string
134 client anthropic.Client
135 providerOptions anthropicProviderOptions
136}
137
138// Model implements ai.LanguageModel.
139func (a anthropicLanguageModel) Model() string {
140 return a.modelID
141}
142
143// Provider implements ai.LanguageModel.
144func (a anthropicLanguageModel) Provider() string {
145 return a.provider
146}
147
148func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
149 params := &anthropic.MessageNewParams{}
150 providerOptions := &AnthropicProviderOptions{}
151 if v, ok := call.ProviderOptions["anthropic"]; ok {
152 err := ai.ParseOptions(v, providerOptions)
153 if err != nil {
154 return nil, nil, err
155 }
156 }
157 sendReasoning := true
158 if providerOptions.SendReasoning != nil {
159 sendReasoning = *providerOptions.SendReasoning
160 }
161 systemBlocks, messages, warnings := toAnthropicPrompt(call.Prompt, sendReasoning)
162
163 if call.FrequencyPenalty != nil {
164 warnings = append(warnings, ai.CallWarning{
165 Type: ai.CallWarningTypeUnsupportedSetting,
166 Setting: "FrequencyPenalty",
167 })
168 }
169 if call.PresencePenalty != nil {
170 warnings = append(warnings, ai.CallWarning{
171 Type: ai.CallWarningTypeUnsupportedSetting,
172 Setting: "PresencePenalty",
173 })
174 }
175
176 params.System = systemBlocks
177 params.Messages = messages
178 params.Model = anthropic.Model(a.modelID)
179
180 if call.MaxOutputTokens != nil {
181 params.MaxTokens = *call.MaxOutputTokens
182 }
183
184 if call.Temperature != nil {
185 params.Temperature = param.NewOpt(*call.Temperature)
186 }
187 if call.TopK != nil {
188 params.TopK = param.NewOpt(*call.TopK)
189 }
190 if call.TopP != nil {
191 params.TopP = param.NewOpt(*call.TopP)
192 }
193
194 isThinking := false
195 var thinkingBudget int64
196 if providerOptions.Thinking != nil {
197 isThinking = true
198 thinkingBudget = providerOptions.Thinking.BudgetTokens
199 }
200 if isThinking {
201 if thinkingBudget == 0 {
202 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
203 }
204 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
205 if call.Temperature != nil {
206 params.Temperature = param.Opt[float64]{}
207 warnings = append(warnings, ai.CallWarning{
208 Type: ai.CallWarningTypeUnsupportedSetting,
209 Setting: "temperature",
210 Details: "temperature is not supported when thinking is enabled",
211 })
212 }
213 if call.TopP != nil {
214 params.TopP = param.Opt[float64]{}
215 warnings = append(warnings, ai.CallWarning{
216 Type: ai.CallWarningTypeUnsupportedSetting,
217 Setting: "TopP",
218 Details: "TopP is not supported when thinking is enabled",
219 })
220 }
221 if call.TopK != nil {
222 params.TopK = param.Opt[int64]{}
223 warnings = append(warnings, ai.CallWarning{
224 Type: ai.CallWarningTypeUnsupportedSetting,
225 Setting: "TopK",
226 Details: "TopK is not supported when thinking is enabled",
227 })
228 }
229 params.MaxTokens = params.MaxTokens + thinkingBudget
230 }
231
232 if len(call.Tools) > 0 {
233 disableParallelToolUse := false
234 if providerOptions.DisableParallelToolUse != nil {
235 disableParallelToolUse = *providerOptions.DisableParallelToolUse
236 }
237 tools, toolChoice, toolWarnings := toAnthropicTools(call.Tools, call.ToolChoice, *&disableParallelToolUse)
238 params.Tools = tools
239 if toolChoice != nil {
240 params.ToolChoice = *toolChoice
241 }
242 warnings = append(warnings, toolWarnings...)
243 }
244
245 return params, warnings, nil
246}
247
248func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlProviderOptions {
249 if anthropicOptions, ok := providerOptions["anthropic"]; ok {
250 if cacheControl, ok := anthropicOptions["cache_control"]; ok {
251 if cc, ok := cacheControl.(map[string]any); ok {
252 cacheControlOption := &AnthropicCacheControlProviderOptions{}
253 err := ai.ParseOptions(cc, cacheControlOption)
254 if err != nil {
255 return cacheControlOption
256 }
257 }
258 } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
259 if cc, ok := cacheControl.(map[string]any); ok {
260 cacheControlOption := &AnthropicCacheControlProviderOptions{}
261 err := ai.ParseOptions(cc, cacheControlOption)
262 if err != nil {
263 return cacheControlOption
264 }
265 }
266 }
267 }
268 return nil
269}
270
271func getReasoningMetadata(providerOptions ai.ProviderOptions) *AnthropicReasoningMetadata {
272 if anthropicOptions, ok := providerOptions["anthropic"]; ok {
273 reasoningMetadata := &AnthropicReasoningMetadata{}
274 err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
275 if err != nil {
276 return reasoningMetadata
277 }
278 }
279 return nil
280}
281
282type messageBlock struct {
283 Role ai.MessageRole
284 Messages []ai.Message
285}
286
287func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
288 var blocks []*messageBlock
289
290 var currentBlock *messageBlock
291
292 for _, msg := range prompt {
293 switch msg.Role {
294 case ai.MessageRoleSystem:
295 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
296 currentBlock = &messageBlock{
297 Role: ai.MessageRoleSystem,
298 Messages: []ai.Message{},
299 }
300 blocks = append(blocks, currentBlock)
301 }
302 currentBlock.Messages = append(currentBlock.Messages, msg)
303 case ai.MessageRoleUser:
304 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
305 currentBlock = &messageBlock{
306 Role: ai.MessageRoleUser,
307 Messages: []ai.Message{},
308 }
309 blocks = append(blocks, currentBlock)
310 }
311 currentBlock.Messages = append(currentBlock.Messages, msg)
312 case ai.MessageRoleAssistant:
313 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
314 currentBlock = &messageBlock{
315 Role: ai.MessageRoleAssistant,
316 Messages: []ai.Message{},
317 }
318 blocks = append(blocks, currentBlock)
319 }
320 currentBlock.Messages = append(currentBlock.Messages, msg)
321 case ai.MessageRoleTool:
322 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
323 currentBlock = &messageBlock{
324 Role: ai.MessageRoleUser,
325 Messages: []ai.Message{},
326 }
327 blocks = append(blocks, currentBlock)
328 }
329 currentBlock.Messages = append(currentBlock.Messages, msg)
330 }
331 }
332 return blocks
333}
334
335func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
336 for _, tool := range tools {
337
338 if tool.GetType() == ai.ToolTypeFunction {
339 ft, ok := tool.(ai.FunctionTool)
340 if !ok {
341 continue
342 }
343 required := []string{}
344 var properties any
345 if props, ok := ft.InputSchema["properties"]; ok {
346 properties = props
347 }
348 if req, ok := ft.InputSchema["required"]; ok {
349 if reqArr, ok := req.([]string); ok {
350 required = reqArr
351 }
352 }
353 cacheControl := getCacheControl(ft.ProviderOptions)
354
355 anthropicTool := anthropic.ToolParam{
356 Name: ft.Name,
357 Description: anthropic.String(ft.Description),
358 InputSchema: anthropic.ToolInputSchemaParam{
359 Properties: properties,
360 Required: required,
361 },
362 }
363 if cacheControl != nil {
364 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
365 }
366 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
367
368 }
369 // TODO: handle provider tool calls
370 warnings = append(warnings, ai.CallWarning{
371 Type: ai.CallWarningTypeUnsupportedTool,
372 Tool: tool,
373 Message: "tool is not supported",
374 })
375 }
376 if toolChoice == nil {
377 return
378 }
379
380 switch *toolChoice {
381 case ai.ToolChoiceAuto:
382 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
383 OfAuto: &anthropic.ToolChoiceAutoParam{
384 Type: "auto",
385 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
386 },
387 }
388 case ai.ToolChoiceRequired:
389 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
390 OfAny: &anthropic.ToolChoiceAnyParam{
391 Type: "any",
392 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
393 },
394 }
395 default:
396 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
397 OfTool: &anthropic.ToolChoiceToolParam{
398 Type: "tool",
399 Name: string(*toolChoice),
400 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
401 },
402 }
403 }
404 return
405}
406
407func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
408 var systemBlocks []anthropic.TextBlockParam
409 var messages []anthropic.MessageParam
410 var warnings []ai.CallWarning
411
412 blocks := groupIntoBlocks(prompt)
413 finishedSystemBlock := false
414 for _, block := range blocks {
415 switch block.Role {
416 case ai.MessageRoleSystem:
417 if finishedSystemBlock {
418 // skip multiple system messages that are separated by user/assistant messages
419 // TODO: see if we need to send error here?
420 continue
421 }
422 finishedSystemBlock = true
423 for _, msg := range block.Messages {
424 for _, part := range msg.Content {
425 cacheControl := getCacheControl(part.Options())
426 text, ok := ai.AsMessagePart[ai.TextPart](part)
427 if !ok {
428 continue
429 }
430 textBlock := anthropic.TextBlockParam{
431 Text: text.Text,
432 }
433 if cacheControl != nil {
434 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
435 }
436 systemBlocks = append(systemBlocks, textBlock)
437 }
438 }
439
440 case ai.MessageRoleUser:
441 var anthropicContent []anthropic.ContentBlockParamUnion
442 for _, msg := range block.Messages {
443 if msg.Role == ai.MessageRoleUser {
444 for i, part := range msg.Content {
445 isLastPart := i == len(msg.Content)-1
446 cacheControl := getCacheControl(part.Options())
447 if cacheControl == nil && isLastPart {
448 cacheControl = getCacheControl(msg.ProviderOptions)
449 }
450 switch part.GetType() {
451 case ai.ContentTypeText:
452 text, ok := ai.AsMessagePart[ai.TextPart](part)
453 if !ok {
454 continue
455 }
456 textBlock := &anthropic.TextBlockParam{
457 Text: text.Text,
458 }
459 if cacheControl != nil {
460 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
461 }
462 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
463 OfText: textBlock,
464 })
465 case ai.ContentTypeFile:
466 file, ok := ai.AsMessagePart[ai.FilePart](part)
467 if !ok {
468 continue
469 }
470 // TODO: handle other file types
471 if !strings.HasPrefix(file.MediaType, "image/") {
472 continue
473 }
474
475 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
476 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
477 if cacheControl != nil {
478 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
479 }
480 anthropicContent = append(anthropicContent, imageBlock)
481 }
482 }
483 } else if msg.Role == ai.MessageRoleTool {
484 for i, part := range msg.Content {
485 isLastPart := i == len(msg.Content)-1
486 cacheControl := getCacheControl(part.Options())
487 if cacheControl == nil && isLastPart {
488 cacheControl = getCacheControl(msg.ProviderOptions)
489 }
490 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
491 if !ok {
492 continue
493 }
494 toolResultBlock := anthropic.ToolResultBlockParam{
495 ToolUseID: result.ToolCallID,
496 }
497 switch result.Output.GetType() {
498 case ai.ToolResultContentTypeText:
499 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
500 if !ok {
501 continue
502 }
503 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
504 {
505 OfText: &anthropic.TextBlockParam{
506 Text: content.Text,
507 },
508 },
509 }
510 case ai.ToolResultContentTypeMedia:
511 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
512 if !ok {
513 continue
514 }
515 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
516 {
517 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
518 },
519 }
520 case ai.ToolResultContentTypeError:
521 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
522 if !ok {
523 continue
524 }
525 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
526 {
527 OfText: &anthropic.TextBlockParam{
528 Text: content.Error.Error(),
529 },
530 },
531 }
532 toolResultBlock.IsError = param.NewOpt(true)
533 }
534 if cacheControl != nil {
535 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
536 }
537 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
538 OfToolResult: &toolResultBlock,
539 })
540 }
541 }
542 }
543 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
544 case ai.MessageRoleAssistant:
545 var anthropicContent []anthropic.ContentBlockParamUnion
546 for _, msg := range block.Messages {
547 for i, part := range msg.Content {
548 isLastPart := i == len(msg.Content)-1
549 cacheControl := getCacheControl(part.Options())
550 if cacheControl == nil && isLastPart {
551 cacheControl = getCacheControl(msg.ProviderOptions)
552 }
553 switch part.GetType() {
554 case ai.ContentTypeText:
555 text, ok := ai.AsMessagePart[ai.TextPart](part)
556 if !ok {
557 continue
558 }
559 textBlock := &anthropic.TextBlockParam{
560 Text: text.Text,
561 }
562 if cacheControl != nil {
563 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
564 }
565 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
566 OfText: textBlock,
567 })
568 case ai.ContentTypeReasoning:
569 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
570 if !ok {
571 continue
572 }
573 if !sendReasoningData {
574 warnings = append(warnings, ai.CallWarning{
575 Type: "other",
576 Message: "sending reasoning content is disabled for this model",
577 })
578 continue
579 }
580 reasoningMetadata := getReasoningMetadata(part.Options())
581 if reasoningMetadata == nil {
582 warnings = append(warnings, ai.CallWarning{
583 Type: "other",
584 Message: "unsupported reasoning metadata",
585 })
586 continue
587 }
588
589 if reasoningMetadata.Signature != "" {
590 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
591 } else if reasoningMetadata.RedactedData != "" {
592 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
593 } else {
594 warnings = append(warnings, ai.CallWarning{
595 Type: "other",
596 Message: "unsupported reasoning metadata",
597 })
598 continue
599 }
600 case ai.ContentTypeToolCall:
601 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
602 if !ok {
603 continue
604 }
605 if toolCall.ProviderExecuted {
606 // TODO: implement provider executed call
607 continue
608 }
609
610 var inputMap map[string]any
611 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
612 if err != nil {
613 continue
614 }
615 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
616 if cacheControl != nil {
617 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
618 }
619 anthropicContent = append(anthropicContent, toolUseBlock)
620 case ai.ContentTypeToolResult:
621 // TODO: implement provider executed tool result
622 }
623
624 }
625 }
626 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
627 }
628 }
629 return systemBlocks, messages, warnings
630}
631
632func (o anthropicLanguageModel) handleError(err error) error {
633 var apiErr *anthropic.Error
634 if errors.As(err, &apiErr) {
635 requestDump := apiErr.DumpRequest(true)
636 responseDump := apiErr.DumpResponse(true)
637 headers := map[string]string{}
638 for k, h := range apiErr.Response.Header {
639 v := h[len(h)-1]
640 headers[strings.ToLower(k)] = v
641 }
642 return ai.NewAPICallError(
643 apiErr.Error(),
644 apiErr.Request.URL.String(),
645 string(requestDump),
646 apiErr.StatusCode,
647 headers,
648 string(responseDump),
649 apiErr,
650 false,
651 )
652 }
653 return err
654}
655
656func mapAnthropicFinishReason(finishReason string) ai.FinishReason {
657 switch finishReason {
658 case "end", "stop_sequence":
659 return ai.FinishReasonStop
660 case "max_tokens":
661 return ai.FinishReasonLength
662 case "tool_use":
663 return ai.FinishReasonToolCalls
664 default:
665 return ai.FinishReasonUnknown
666 }
667}
668
669// Generate implements ai.LanguageModel.
670func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
671 params, warnings, err := a.prepareParams(call)
672 if err != nil {
673 return nil, err
674 }
675 response, err := a.client.Messages.New(ctx, *params)
676 if err != nil {
677 return nil, a.handleError(err)
678 }
679
680 var content []ai.Content
681 for _, block := range response.Content {
682 switch block.Type {
683 case "text":
684 text, ok := block.AsAny().(anthropic.TextBlock)
685 if !ok {
686 continue
687 }
688 content = append(content, ai.TextContent{
689 Text: text.Text,
690 })
691 case "thinking":
692 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
693 if !ok {
694 continue
695 }
696 content = append(content, ai.ReasoningContent{
697 Text: reasoning.Thinking,
698 ProviderMetadata: map[string]map[string]any{
699 "anthropic": {
700 "signature": reasoning.Signature,
701 },
702 },
703 })
704 case "redacted_thinking":
705 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
706 if !ok {
707 continue
708 }
709 content = append(content, ai.ReasoningContent{
710 Text: "",
711 ProviderMetadata: map[string]map[string]any{
712 "anthropic": {
713 "redacted_data": reasoning.Data,
714 },
715 },
716 })
717 case "tool_use":
718 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
719 if !ok {
720 continue
721 }
722 content = append(content, ai.ToolCallContent{
723 ToolCallID: toolUse.ID,
724 ToolName: toolUse.Name,
725 Input: string(toolUse.Input),
726 ProviderExecuted: false,
727 })
728 }
729 }
730
731 return &ai.Response{
732 Content: content,
733 Usage: ai.Usage{
734 InputTokens: response.Usage.InputTokens,
735 OutputTokens: response.Usage.OutputTokens,
736 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
737 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
738 CacheReadTokens: response.Usage.CacheReadInputTokens,
739 },
740 FinishReason: mapAnthropicFinishReason(string(response.StopReason)),
741 ProviderMetadata: ai.ProviderMetadata{
742 "anthropic": make(map[string]any),
743 },
744 Warnings: warnings,
745 }, nil
746}
747
748// Stream implements ai.LanguageModel.
749func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
750 params, warnings, err := a.prepareParams(call)
751 if err != nil {
752 return nil, err
753 }
754
755 stream := a.client.Messages.NewStreaming(ctx, *params)
756 acc := anthropic.Message{}
757 return func(yield func(ai.StreamPart) bool) {
758 if len(warnings) > 0 {
759 if !yield(ai.StreamPart{
760 Type: ai.StreamPartTypeWarnings,
761 Warnings: warnings,
762 }) {
763 return
764 }
765 }
766
767 for stream.Next() {
768 chunk := stream.Current()
769 acc.Accumulate(chunk)
770 switch chunk.Type {
771 case "content_block_start":
772 contentBlockType := chunk.ContentBlock.Type
773 switch contentBlockType {
774 case "text":
775 if !yield(ai.StreamPart{
776 Type: ai.StreamPartTypeTextStart,
777 ID: fmt.Sprintf("%d", chunk.Index),
778 }) {
779 return
780 }
781 case "thinking":
782 if !yield(ai.StreamPart{
783 Type: ai.StreamPartTypeReasoningStart,
784 ID: fmt.Sprintf("%d", chunk.Index),
785 }) {
786 return
787 }
788 case "redacted_thinking":
789 if !yield(ai.StreamPart{
790 Type: ai.StreamPartTypeReasoningStart,
791 ID: fmt.Sprintf("%d", chunk.Index),
792 ProviderMetadata: ai.ProviderMetadata{
793 "anthropic": {
794 "redacted_data": chunk.ContentBlock.Data,
795 },
796 },
797 }) {
798 return
799 }
800 case "tool_use":
801 if !yield(ai.StreamPart{
802 Type: ai.StreamPartTypeToolInputStart,
803 ID: chunk.ContentBlock.ID,
804 ToolCallName: chunk.ContentBlock.Name,
805 ToolCallInput: "",
806 }) {
807 return
808 }
809 }
810 case "content_block_stop":
811 if len(acc.Content)-1 < int(chunk.Index) {
812 continue
813 }
814 contentBlock := acc.Content[int(chunk.Index)]
815 switch contentBlock.Type {
816 case "text":
817 if !yield(ai.StreamPart{
818 Type: ai.StreamPartTypeTextEnd,
819 ID: fmt.Sprintf("%d", chunk.Index),
820 }) {
821 return
822 }
823 case "thinking":
824 if !yield(ai.StreamPart{
825 Type: ai.StreamPartTypeReasoningEnd,
826 ID: fmt.Sprintf("%d", chunk.Index),
827 }) {
828 return
829 }
830 case "tool_use":
831 if !yield(ai.StreamPart{
832 Type: ai.StreamPartTypeToolInputEnd,
833 ID: contentBlock.ID,
834 }) {
835 return
836 }
837 if !yield(ai.StreamPart{
838 Type: ai.StreamPartTypeToolCall,
839 ID: contentBlock.ID,
840 ToolCallName: contentBlock.Name,
841 ToolCallInput: string(contentBlock.Input),
842 }) {
843 return
844 }
845
846 }
847 case "content_block_delta":
848 switch chunk.Delta.Type {
849 case "text_delta":
850 if !yield(ai.StreamPart{
851 Type: ai.StreamPartTypeTextDelta,
852 ID: fmt.Sprintf("%d", chunk.Index),
853 Delta: chunk.Delta.Text,
854 }) {
855 return
856 }
857 case "thinking_delta":
858 if !yield(ai.StreamPart{
859 Type: ai.StreamPartTypeReasoningDelta,
860 ID: fmt.Sprintf("%d", chunk.Index),
861 Delta: chunk.Delta.Text,
862 }) {
863 return
864 }
865 case "signature_delta":
866 if !yield(ai.StreamPart{
867 Type: ai.StreamPartTypeReasoningDelta,
868 ID: fmt.Sprintf("%d", chunk.Index),
869 ProviderMetadata: ai.ProviderMetadata{
870 "anthropic": {
871 "signature": chunk.Delta.Signature,
872 },
873 },
874 }) {
875 return
876 }
877 case "input_json_delta":
878 if len(acc.Content)-1 < int(chunk.Index) {
879 continue
880 }
881 contentBlock := acc.Content[int(chunk.Index)]
882 if !yield(ai.StreamPart{
883 Type: ai.StreamPartTypeToolInputDelta,
884 ID: contentBlock.ID,
885 ToolCallInput: chunk.Delta.PartialJSON,
886 }) {
887 return
888 }
889
890 }
891 case "message_stop":
892 }
893 }
894
895 err := stream.Err()
896 if err == nil || errors.Is(err, io.EOF) {
897 yield(ai.StreamPart{
898 Type: ai.StreamPartTypeFinish,
899 ID: acc.ID,
900 FinishReason: mapAnthropicFinishReason(string(acc.StopReason)),
901 Usage: ai.Usage{
902 InputTokens: acc.Usage.InputTokens,
903 OutputTokens: acc.Usage.OutputTokens,
904 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
905 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
906 CacheReadTokens: acc.Usage.CacheReadInputTokens,
907 },
908 ProviderMetadata: ai.ProviderMetadata{
909 "anthropic": make(map[string]any),
910 },
911 })
912 return
913 } else {
914 yield(ai.StreamPart{
915 Type: ai.StreamPartTypeError,
916 Error: a.handleError(err),
917 })
918 return
919 }
920 }, nil
921}