1package anthropic
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/anthropics/anthropic-sdk-go"
15 "github.com/anthropics/anthropic-sdk-go/bedrock"
16 "github.com/anthropics/anthropic-sdk-go/option"
17 "github.com/anthropics/anthropic-sdk-go/packages/param"
18 "github.com/anthropics/anthropic-sdk-go/vertex"
19 "github.com/charmbracelet/fantasy/ai"
20 "golang.org/x/oauth2/google"
21)
22
23const (
24 Name = "anthropic"
25 DefaultURL = "https://api.anthropic.com"
26)
27
28type options struct {
29 baseURL string
30 apiKey string
31 name string
32 headers map[string]string
33 client option.HTTPClient
34
35 vertexProject string
36 vertexLocation string
37 skipAuth bool
38
39 useBedrock bool
40}
41
42type provider struct {
43 options options
44}
45
46type Option = func(*options)
47
48func New(opts ...Option) ai.Provider {
49 providerOptions := options{
50 headers: map[string]string{},
51 }
52 for _, o := range opts {
53 o(&providerOptions)
54 }
55
56 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
57 providerOptions.name = cmp.Or(providerOptions.name, Name)
58 return &provider{options: providerOptions}
59}
60
61func WithBaseURL(baseURL string) Option {
62 return func(o *options) {
63 o.baseURL = baseURL
64 }
65}
66
67func WithAPIKey(apiKey string) Option {
68 return func(o *options) {
69 o.apiKey = apiKey
70 }
71}
72
73func WithVertex(project, location string) Option {
74 return func(o *options) {
75 o.vertexProject = project
76 o.vertexLocation = location
77 }
78}
79
80func WithSkipAuth(skip bool) Option {
81 return func(o *options) {
82 o.skipAuth = skip
83 }
84}
85
86func WithBedrock() Option {
87 return func(o *options) {
88 o.useBedrock = true
89 }
90}
91
92func WithName(name string) Option {
93 return func(o *options) {
94 o.name = name
95 }
96}
97
98func WithHeaders(headers map[string]string) Option {
99 return func(o *options) {
100 maps.Copy(o.headers, headers)
101 }
102}
103
104func WithHTTPClient(client option.HTTPClient) Option {
105 return func(o *options) {
106 o.client = client
107 }
108}
109
110func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
111 clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers))
112 if a.options.apiKey != "" {
113 clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey))
114 }
115 if a.options.baseURL != "" {
116 clientOptions = append(clientOptions, option.WithBaseURL(a.options.baseURL))
117 }
118 for key, value := range a.options.headers {
119 clientOptions = append(clientOptions, option.WithHeader(key, value))
120 }
121 if a.options.client != nil {
122 clientOptions = append(clientOptions, option.WithHTTPClient(a.options.client))
123 }
124 if a.options.vertexProject != "" && a.options.vertexLocation != "" {
125 var credentials *google.Credentials
126 if a.options.skipAuth {
127 credentials = &google.Credentials{TokenSource: &googleDummyTokenSource{}}
128 } else {
129 var err error
130 credentials, err = google.FindDefaultCredentials(context.TODO())
131 if err != nil {
132 return nil, err
133 }
134 }
135
136 clientOptions = append(
137 clientOptions,
138 vertex.WithCredentials(
139 context.TODO(),
140 a.options.vertexLocation,
141 a.options.vertexProject,
142 credentials,
143 ),
144 )
145 }
146 if a.options.useBedrock {
147 if a.options.skipAuth {
148 clientOptions = append(
149 clientOptions,
150 bedrock.WithConfig(dummyBedrockConfig),
151 )
152 } else {
153 clientOptions = append(
154 clientOptions,
155 bedrock.WithLoadDefaultConfig(context.TODO()),
156 )
157 }
158 }
159 return languageModel{
160 modelID: modelID,
161 provider: a.options.name,
162 options: a.options,
163 client: anthropic.NewClient(clientOptions...),
164 }, nil
165}
166
167type languageModel struct {
168 provider string
169 modelID string
170 client anthropic.Client
171 options options
172}
173
174// Model implements ai.LanguageModel.
175func (a languageModel) Model() string {
176 return a.modelID
177}
178
179// Provider implements ai.LanguageModel.
180func (a languageModel) Provider() string {
181 return a.provider
182}
183
184func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
185 params := &anthropic.MessageNewParams{}
186 providerOptions := &ProviderOptions{}
187 if v, ok := call.ProviderOptions[Name]; ok {
188 providerOptions, ok = v.(*ProviderOptions)
189 if !ok {
190 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
191 }
192 }
193 sendReasoning := true
194 if providerOptions.SendReasoning != nil {
195 sendReasoning = *providerOptions.SendReasoning
196 }
197 systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
198
199 if call.FrequencyPenalty != nil {
200 warnings = append(warnings, ai.CallWarning{
201 Type: ai.CallWarningTypeUnsupportedSetting,
202 Setting: "FrequencyPenalty",
203 })
204 }
205 if call.PresencePenalty != nil {
206 warnings = append(warnings, ai.CallWarning{
207 Type: ai.CallWarningTypeUnsupportedSetting,
208 Setting: "PresencePenalty",
209 })
210 }
211
212 params.System = systemBlocks
213 params.Messages = messages
214 params.Model = anthropic.Model(a.modelID)
215 params.MaxTokens = 4096
216
217 if call.MaxOutputTokens != nil {
218 params.MaxTokens = *call.MaxOutputTokens
219 }
220
221 if call.Temperature != nil {
222 params.Temperature = param.NewOpt(*call.Temperature)
223 }
224 if call.TopK != nil {
225 params.TopK = param.NewOpt(*call.TopK)
226 }
227 if call.TopP != nil {
228 params.TopP = param.NewOpt(*call.TopP)
229 }
230
231 isThinking := false
232 var thinkingBudget int64
233 if providerOptions.Thinking != nil {
234 isThinking = true
235 thinkingBudget = providerOptions.Thinking.BudgetTokens
236 }
237 if isThinking {
238 if thinkingBudget == 0 {
239 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
240 }
241 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
242 if call.Temperature != nil {
243 params.Temperature = param.Opt[float64]{}
244 warnings = append(warnings, ai.CallWarning{
245 Type: ai.CallWarningTypeUnsupportedSetting,
246 Setting: "temperature",
247 Details: "temperature is not supported when thinking is enabled",
248 })
249 }
250 if call.TopP != nil {
251 params.TopP = param.Opt[float64]{}
252 warnings = append(warnings, ai.CallWarning{
253 Type: ai.CallWarningTypeUnsupportedSetting,
254 Setting: "TopP",
255 Details: "TopP is not supported when thinking is enabled",
256 })
257 }
258 if call.TopK != nil {
259 params.TopK = param.Opt[int64]{}
260 warnings = append(warnings, ai.CallWarning{
261 Type: ai.CallWarningTypeUnsupportedSetting,
262 Setting: "TopK",
263 Details: "TopK is not supported when thinking is enabled",
264 })
265 }
266 params.MaxTokens = params.MaxTokens + thinkingBudget
267 }
268
269 if len(call.Tools) > 0 {
270 disableParallelToolUse := false
271 if providerOptions.DisableParallelToolUse != nil {
272 disableParallelToolUse = *providerOptions.DisableParallelToolUse
273 }
274 tools, toolChoice, toolWarnings := a.toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
275 params.Tools = tools
276 if toolChoice != nil {
277 params.ToolChoice = *toolChoice
278 }
279 warnings = append(warnings, toolWarnings...)
280 }
281
282 return params, warnings, nil
283}
284
285func (a *provider) Name() string {
286 return Name
287}
288
289func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
290 if anthropicOptions, ok := providerOptions[Name]; ok {
291 if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
292 return &options.CacheControl
293 }
294 }
295 return nil
296}
297
298func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
299 if anthropicOptions, ok := providerOptions[Name]; ok {
300 if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
301 return reasoning
302 }
303 }
304 return nil
305}
306
307type messageBlock struct {
308 Role ai.MessageRole
309 Messages []ai.Message
310}
311
312func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
313 var blocks []*messageBlock
314
315 var currentBlock *messageBlock
316
317 for _, msg := range prompt {
318 switch msg.Role {
319 case ai.MessageRoleSystem:
320 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
321 currentBlock = &messageBlock{
322 Role: ai.MessageRoleSystem,
323 Messages: []ai.Message{},
324 }
325 blocks = append(blocks, currentBlock)
326 }
327 currentBlock.Messages = append(currentBlock.Messages, msg)
328 case ai.MessageRoleUser:
329 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
330 currentBlock = &messageBlock{
331 Role: ai.MessageRoleUser,
332 Messages: []ai.Message{},
333 }
334 blocks = append(blocks, currentBlock)
335 }
336 currentBlock.Messages = append(currentBlock.Messages, msg)
337 case ai.MessageRoleAssistant:
338 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
339 currentBlock = &messageBlock{
340 Role: ai.MessageRoleAssistant,
341 Messages: []ai.Message{},
342 }
343 blocks = append(blocks, currentBlock)
344 }
345 currentBlock.Messages = append(currentBlock.Messages, msg)
346 case ai.MessageRoleTool:
347 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
348 currentBlock = &messageBlock{
349 Role: ai.MessageRoleUser,
350 Messages: []ai.Message{},
351 }
352 blocks = append(blocks, currentBlock)
353 }
354 currentBlock.Messages = append(currentBlock.Messages, msg)
355 }
356 }
357 return blocks
358}
359
360func (a languageModel) toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
361 for _, tool := range tools {
362 if tool.GetType() == ai.ToolTypeFunction {
363 ft, ok := tool.(ai.FunctionTool)
364 if !ok {
365 continue
366 }
367 required := []string{}
368 var properties any
369 if props, ok := ft.InputSchema["properties"]; ok {
370 properties = props
371 }
372 if req, ok := ft.InputSchema["required"]; ok {
373 if reqArr, ok := req.([]string); ok {
374 required = reqArr
375 }
376 }
377 cacheControl := getCacheControl(ft.ProviderOptions)
378
379 anthropicTool := anthropic.ToolParam{
380 Name: ft.Name,
381 Description: anthropic.String(ft.Description),
382 InputSchema: anthropic.ToolInputSchemaParam{
383 Properties: properties,
384 Required: required,
385 },
386 }
387 if cacheControl != nil {
388 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
389 }
390 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
391 continue
392 }
393 // TODO: handle provider tool calls
394 warnings = append(warnings, ai.CallWarning{
395 Type: ai.CallWarningTypeUnsupportedTool,
396 Tool: tool,
397 Message: "tool is not supported",
398 })
399 }
400
401 // NOTE: Bedrock does not support this attribute.
402 var disableParallelToolUse param.Opt[bool]
403 if !a.options.useBedrock {
404 disableParallelToolUse = param.NewOpt(disableParallelToolCalls)
405 }
406
407 if toolChoice == nil {
408 if disableParallelToolCalls {
409 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
410 OfAuto: &anthropic.ToolChoiceAutoParam{
411 Type: "auto",
412 DisableParallelToolUse: disableParallelToolUse,
413 },
414 }
415 }
416 return anthropicTools, anthropicToolChoice, warnings
417 }
418
419 switch *toolChoice {
420 case ai.ToolChoiceAuto:
421 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
422 OfAuto: &anthropic.ToolChoiceAutoParam{
423 Type: "auto",
424 DisableParallelToolUse: disableParallelToolUse,
425 },
426 }
427 case ai.ToolChoiceRequired:
428 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
429 OfAny: &anthropic.ToolChoiceAnyParam{
430 Type: "any",
431 DisableParallelToolUse: disableParallelToolUse,
432 },
433 }
434 case ai.ToolChoiceNone:
435 return anthropicTools, anthropicToolChoice, warnings
436 default:
437 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
438 OfTool: &anthropic.ToolChoiceToolParam{
439 Type: "tool",
440 Name: string(*toolChoice),
441 DisableParallelToolUse: disableParallelToolUse,
442 },
443 }
444 }
445 return anthropicTools, anthropicToolChoice, warnings
446}
447
448func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
449 var systemBlocks []anthropic.TextBlockParam
450 var messages []anthropic.MessageParam
451 var warnings []ai.CallWarning
452
453 blocks := groupIntoBlocks(prompt)
454 finishedSystemBlock := false
455 for _, block := range blocks {
456 switch block.Role {
457 case ai.MessageRoleSystem:
458 if finishedSystemBlock {
459 // skip multiple system messages that are separated by user/assistant messages
460 // TODO: see if we need to send error here?
461 continue
462 }
463 finishedSystemBlock = true
464 for _, msg := range block.Messages {
465 for i, part := range msg.Content {
466 isLastPart := i == len(msg.Content)-1
467 cacheControl := getCacheControl(part.Options())
468 if cacheControl == nil && isLastPart {
469 cacheControl = getCacheControl(msg.ProviderOptions)
470 }
471 text, ok := ai.AsMessagePart[ai.TextPart](part)
472 if !ok {
473 continue
474 }
475 textBlock := anthropic.TextBlockParam{
476 Text: text.Text,
477 }
478 if cacheControl != nil {
479 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
480 }
481 systemBlocks = append(systemBlocks, textBlock)
482 }
483 }
484
485 case ai.MessageRoleUser:
486 var anthropicContent []anthropic.ContentBlockParamUnion
487 for _, msg := range block.Messages {
488 if msg.Role == ai.MessageRoleUser {
489 for i, part := range msg.Content {
490 isLastPart := i == len(msg.Content)-1
491 cacheControl := getCacheControl(part.Options())
492 if cacheControl == nil && isLastPart {
493 cacheControl = getCacheControl(msg.ProviderOptions)
494 }
495 switch part.GetType() {
496 case ai.ContentTypeText:
497 text, ok := ai.AsMessagePart[ai.TextPart](part)
498 if !ok {
499 continue
500 }
501 textBlock := &anthropic.TextBlockParam{
502 Text: text.Text,
503 }
504 if cacheControl != nil {
505 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
506 }
507 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
508 OfText: textBlock,
509 })
510 case ai.ContentTypeFile:
511 file, ok := ai.AsMessagePart[ai.FilePart](part)
512 if !ok {
513 continue
514 }
515 // TODO: handle other file types
516 if !strings.HasPrefix(file.MediaType, "image/") {
517 continue
518 }
519
520 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
521 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
522 if cacheControl != nil {
523 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
524 }
525 anthropicContent = append(anthropicContent, imageBlock)
526 }
527 }
528 } else if msg.Role == ai.MessageRoleTool {
529 for i, part := range msg.Content {
530 isLastPart := i == len(msg.Content)-1
531 cacheControl := getCacheControl(part.Options())
532 if cacheControl == nil && isLastPart {
533 cacheControl = getCacheControl(msg.ProviderOptions)
534 }
535 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
536 if !ok {
537 continue
538 }
539 toolResultBlock := anthropic.ToolResultBlockParam{
540 ToolUseID: result.ToolCallID,
541 }
542 switch result.Output.GetType() {
543 case ai.ToolResultContentTypeText:
544 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
545 if !ok {
546 continue
547 }
548 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
549 {
550 OfText: &anthropic.TextBlockParam{
551 Text: content.Text,
552 },
553 },
554 }
555 case ai.ToolResultContentTypeMedia:
556 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
557 if !ok {
558 continue
559 }
560 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
561 {
562 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
563 },
564 }
565 case ai.ToolResultContentTypeError:
566 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
567 if !ok {
568 continue
569 }
570 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
571 {
572 OfText: &anthropic.TextBlockParam{
573 Text: content.Error.Error(),
574 },
575 },
576 }
577 toolResultBlock.IsError = param.NewOpt(true)
578 }
579 if cacheControl != nil {
580 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
581 }
582 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
583 OfToolResult: &toolResultBlock,
584 })
585 }
586 }
587 }
588 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
589 case ai.MessageRoleAssistant:
590 var anthropicContent []anthropic.ContentBlockParamUnion
591 for _, msg := range block.Messages {
592 for i, part := range msg.Content {
593 isLastPart := i == len(msg.Content)-1
594 cacheControl := getCacheControl(part.Options())
595 if cacheControl == nil && isLastPart {
596 cacheControl = getCacheControl(msg.ProviderOptions)
597 }
598 switch part.GetType() {
599 case ai.ContentTypeText:
600 text, ok := ai.AsMessagePart[ai.TextPart](part)
601 if !ok {
602 continue
603 }
604 textBlock := &anthropic.TextBlockParam{
605 Text: text.Text,
606 }
607 if cacheControl != nil {
608 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
609 }
610 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
611 OfText: textBlock,
612 })
613 case ai.ContentTypeReasoning:
614 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
615 if !ok {
616 continue
617 }
618 if !sendReasoningData {
619 warnings = append(warnings, ai.CallWarning{
620 Type: "other",
621 Message: "sending reasoning content is disabled for this model",
622 })
623 continue
624 }
625 reasoningMetadata := getReasoningMetadata(part.Options())
626 if reasoningMetadata == nil {
627 warnings = append(warnings, ai.CallWarning{
628 Type: "other",
629 Message: "unsupported reasoning metadata",
630 })
631 continue
632 }
633
634 if reasoningMetadata.Signature != "" {
635 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
636 } else if reasoningMetadata.RedactedData != "" {
637 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
638 } else {
639 warnings = append(warnings, ai.CallWarning{
640 Type: "other",
641 Message: "unsupported reasoning metadata",
642 })
643 continue
644 }
645 case ai.ContentTypeToolCall:
646 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
647 if !ok {
648 continue
649 }
650 if toolCall.ProviderExecuted {
651 // TODO: implement provider executed call
652 continue
653 }
654
655 var inputMap map[string]any
656 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
657 if err != nil {
658 continue
659 }
660 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
661 if cacheControl != nil {
662 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
663 }
664 anthropicContent = append(anthropicContent, toolUseBlock)
665 case ai.ContentTypeToolResult:
666 // TODO: implement provider executed tool result
667 }
668 }
669 }
670 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
671 }
672 }
673 return systemBlocks, messages, warnings
674}
675
676func (o languageModel) handleError(err error) error {
677 var apiErr *anthropic.Error
678 if errors.As(err, &apiErr) {
679 requestDump := apiErr.DumpRequest(true)
680 responseDump := apiErr.DumpResponse(true)
681 headers := map[string]string{}
682 for k, h := range apiErr.Response.Header {
683 v := h[len(h)-1]
684 headers[strings.ToLower(k)] = v
685 }
686 return ai.NewAPICallError(
687 apiErr.Error(),
688 apiErr.Request.URL.String(),
689 string(requestDump),
690 apiErr.StatusCode,
691 headers,
692 string(responseDump),
693 apiErr,
694 false,
695 )
696 }
697 return err
698}
699
700func mapFinishReason(finishReason string) ai.FinishReason {
701 switch finishReason {
702 case "end_turn", "pause_turn", "stop_sequence":
703 return ai.FinishReasonStop
704 case "max_tokens":
705 return ai.FinishReasonLength
706 case "tool_use":
707 return ai.FinishReasonToolCalls
708 default:
709 return ai.FinishReasonUnknown
710 }
711}
712
713// Generate implements ai.LanguageModel.
714func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
715 params, warnings, err := a.prepareParams(call)
716 if err != nil {
717 return nil, err
718 }
719 response, err := a.client.Messages.New(ctx, *params)
720 if err != nil {
721 return nil, a.handleError(err)
722 }
723
724 var content []ai.Content
725 for _, block := range response.Content {
726 switch block.Type {
727 case "text":
728 text, ok := block.AsAny().(anthropic.TextBlock)
729 if !ok {
730 continue
731 }
732 content = append(content, ai.TextContent{
733 Text: text.Text,
734 })
735 case "thinking":
736 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
737 if !ok {
738 continue
739 }
740 content = append(content, ai.ReasoningContent{
741 Text: reasoning.Thinking,
742 ProviderMetadata: ai.ProviderMetadata{
743 Name: &ReasoningOptionMetadata{
744 Signature: reasoning.Signature,
745 },
746 },
747 })
748 case "redacted_thinking":
749 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
750 if !ok {
751 continue
752 }
753 content = append(content, ai.ReasoningContent{
754 Text: "",
755 ProviderMetadata: ai.ProviderMetadata{
756 Name: &ReasoningOptionMetadata{
757 RedactedData: reasoning.Data,
758 },
759 },
760 })
761 case "tool_use":
762 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
763 if !ok {
764 continue
765 }
766 content = append(content, ai.ToolCallContent{
767 ToolCallID: toolUse.ID,
768 ToolName: toolUse.Name,
769 Input: string(toolUse.Input),
770 ProviderExecuted: false,
771 })
772 }
773 }
774
775 return &ai.Response{
776 Content: content,
777 Usage: ai.Usage{
778 InputTokens: response.Usage.InputTokens,
779 OutputTokens: response.Usage.OutputTokens,
780 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
781 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
782 CacheReadTokens: response.Usage.CacheReadInputTokens,
783 },
784 FinishReason: mapFinishReason(string(response.StopReason)),
785 ProviderMetadata: ai.ProviderMetadata{},
786 Warnings: warnings,
787 }, nil
788}
789
790// Stream implements ai.LanguageModel.
791func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
792 params, warnings, err := a.prepareParams(call)
793 if err != nil {
794 return nil, err
795 }
796
797 stream := a.client.Messages.NewStreaming(ctx, *params)
798 acc := anthropic.Message{}
799 return func(yield func(ai.StreamPart) bool) {
800 if len(warnings) > 0 {
801 if !yield(ai.StreamPart{
802 Type: ai.StreamPartTypeWarnings,
803 Warnings: warnings,
804 }) {
805 return
806 }
807 }
808
809 for stream.Next() {
810 chunk := stream.Current()
811 _ = acc.Accumulate(chunk)
812 switch chunk.Type {
813 case "content_block_start":
814 contentBlockType := chunk.ContentBlock.Type
815 switch contentBlockType {
816 case "text":
817 if !yield(ai.StreamPart{
818 Type: ai.StreamPartTypeTextStart,
819 ID: fmt.Sprintf("%d", chunk.Index),
820 }) {
821 return
822 }
823 case "thinking":
824 if !yield(ai.StreamPart{
825 Type: ai.StreamPartTypeReasoningStart,
826 ID: fmt.Sprintf("%d", chunk.Index),
827 }) {
828 return
829 }
830 case "redacted_thinking":
831 if !yield(ai.StreamPart{
832 Type: ai.StreamPartTypeReasoningStart,
833 ID: fmt.Sprintf("%d", chunk.Index),
834 ProviderMetadata: ai.ProviderMetadata{
835 Name: &ReasoningOptionMetadata{
836 RedactedData: chunk.ContentBlock.Data,
837 },
838 },
839 }) {
840 return
841 }
842 case "tool_use":
843 if !yield(ai.StreamPart{
844 Type: ai.StreamPartTypeToolInputStart,
845 ID: chunk.ContentBlock.ID,
846 ToolCallName: chunk.ContentBlock.Name,
847 ToolCallInput: "",
848 }) {
849 return
850 }
851 }
852 case "content_block_stop":
853 if len(acc.Content)-1 < int(chunk.Index) {
854 continue
855 }
856 contentBlock := acc.Content[int(chunk.Index)]
857 switch contentBlock.Type {
858 case "text":
859 if !yield(ai.StreamPart{
860 Type: ai.StreamPartTypeTextEnd,
861 ID: fmt.Sprintf("%d", chunk.Index),
862 }) {
863 return
864 }
865 case "thinking":
866 if !yield(ai.StreamPart{
867 Type: ai.StreamPartTypeReasoningEnd,
868 ID: fmt.Sprintf("%d", chunk.Index),
869 }) {
870 return
871 }
872 case "tool_use":
873 if !yield(ai.StreamPart{
874 Type: ai.StreamPartTypeToolInputEnd,
875 ID: contentBlock.ID,
876 }) {
877 return
878 }
879 if !yield(ai.StreamPart{
880 Type: ai.StreamPartTypeToolCall,
881 ID: contentBlock.ID,
882 ToolCallName: contentBlock.Name,
883 ToolCallInput: string(contentBlock.Input),
884 }) {
885 return
886 }
887 }
888 case "content_block_delta":
889 switch chunk.Delta.Type {
890 case "text_delta":
891 if !yield(ai.StreamPart{
892 Type: ai.StreamPartTypeTextDelta,
893 ID: fmt.Sprintf("%d", chunk.Index),
894 Delta: chunk.Delta.Text,
895 }) {
896 return
897 }
898 case "thinking_delta":
899 if !yield(ai.StreamPart{
900 Type: ai.StreamPartTypeReasoningDelta,
901 ID: fmt.Sprintf("%d", chunk.Index),
902 Delta: chunk.Delta.Thinking,
903 }) {
904 return
905 }
906 case "signature_delta":
907 if !yield(ai.StreamPart{
908 Type: ai.StreamPartTypeReasoningDelta,
909 ID: fmt.Sprintf("%d", chunk.Index),
910 ProviderMetadata: ai.ProviderMetadata{
911 Name: &ReasoningOptionMetadata{
912 Signature: chunk.Delta.Signature,
913 },
914 },
915 }) {
916 return
917 }
918 case "input_json_delta":
919 if len(acc.Content)-1 < int(chunk.Index) {
920 continue
921 }
922 contentBlock := acc.Content[int(chunk.Index)]
923 if !yield(ai.StreamPart{
924 Type: ai.StreamPartTypeToolInputDelta,
925 ID: contentBlock.ID,
926 ToolCallInput: chunk.Delta.PartialJSON,
927 }) {
928 return
929 }
930 }
931 case "message_stop":
932 }
933 }
934
935 err := stream.Err()
936 if err == nil || errors.Is(err, io.EOF) {
937 yield(ai.StreamPart{
938 Type: ai.StreamPartTypeFinish,
939 ID: acc.ID,
940 FinishReason: mapFinishReason(string(acc.StopReason)),
941 Usage: ai.Usage{
942 InputTokens: acc.Usage.InputTokens,
943 OutputTokens: acc.Usage.OutputTokens,
944 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
945 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
946 CacheReadTokens: acc.Usage.CacheReadInputTokens,
947 },
948 ProviderMetadata: ai.ProviderMetadata{},
949 })
950 return
951 } else {
952 yield(ai.StreamPart{
953 Type: ai.StreamPartTypeError,
954 Error: a.handleError(err),
955 })
956 return
957 }
958 }, nil
959}