1package anthropic
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/anthropics/anthropic-sdk-go"
15 "github.com/anthropics/anthropic-sdk-go/option"
16 "github.com/anthropics/anthropic-sdk-go/packages/param"
17 "github.com/charmbracelet/ai/ai"
18)
19
20const (
21 ProviderName = "anthropic"
22 DefaultURL = "https://api.anthropic.com"
23)
24
25type options struct {
26 baseURL string
27 apiKey string
28 name string
29 headers map[string]string
30 client option.HTTPClient
31}
32
33type provider struct {
34 options options
35}
36
37type Option = func(*options)
38
39func New(opts ...Option) ai.Provider {
40 providerOptions := options{
41 headers: map[string]string{},
42 }
43 for _, o := range opts {
44 o(&providerOptions)
45 }
46
47 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
48 providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
49 return &provider{options: providerOptions}
50}
51
52func WithBaseURL(baseURL string) Option {
53 return func(o *options) {
54 o.baseURL = baseURL
55 }
56}
57
58func WithAPIKey(apiKey string) Option {
59 return func(o *options) {
60 o.apiKey = apiKey
61 }
62}
63
64func WithName(name string) Option {
65 return func(o *options) {
66 o.name = name
67 }
68}
69
70func WithHeaders(headers map[string]string) Option {
71 return func(o *options) {
72 maps.Copy(o.headers, headers)
73 }
74}
75
76func WithHTTPClient(client option.HTTPClient) Option {
77 return func(o *options) {
78 o.client = client
79 }
80}
81
82func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
83 anthropicClientOptions := []option.RequestOption{}
84 if a.options.apiKey != "" {
85 anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey))
86 }
87 if a.options.baseURL != "" {
88 anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(a.options.baseURL))
89 }
90
91 for key, value := range a.options.headers {
92 anthropicClientOptions = append(anthropicClientOptions, option.WithHeader(key, value))
93 }
94
95 if a.options.client != nil {
96 anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client))
97 }
98 return languageModel{
99 modelID: modelID,
100 provider: fmt.Sprintf("%s.messages", a.options.name),
101 options: a.options,
102 client: anthropic.NewClient(anthropicClientOptions...),
103 }, nil
104}
105
106type languageModel struct {
107 provider string
108 modelID string
109 client anthropic.Client
110 options options
111}
112
113// Model implements ai.LanguageModel.
114func (a languageModel) Model() string {
115 return a.modelID
116}
117
118// Provider implements ai.LanguageModel.
119func (a languageModel) Provider() string {
120 return a.provider
121}
122
123func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
124 params := &anthropic.MessageNewParams{}
125 providerOptions := &ProviderOptions{}
126 if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok {
127 providerOptions, ok = v.(*ProviderOptions)
128 if !ok {
129 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
130 }
131 }
132 sendReasoning := true
133 if providerOptions.SendReasoning != nil {
134 sendReasoning = *providerOptions.SendReasoning
135 }
136 systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
137
138 if call.FrequencyPenalty != nil {
139 warnings = append(warnings, ai.CallWarning{
140 Type: ai.CallWarningTypeUnsupportedSetting,
141 Setting: "FrequencyPenalty",
142 })
143 }
144 if call.PresencePenalty != nil {
145 warnings = append(warnings, ai.CallWarning{
146 Type: ai.CallWarningTypeUnsupportedSetting,
147 Setting: "PresencePenalty",
148 })
149 }
150
151 params.System = systemBlocks
152 params.Messages = messages
153 params.Model = anthropic.Model(a.modelID)
154 params.MaxTokens = 4096
155
156 if call.MaxOutputTokens != nil {
157 params.MaxTokens = *call.MaxOutputTokens
158 }
159
160 if call.Temperature != nil {
161 params.Temperature = param.NewOpt(*call.Temperature)
162 }
163 if call.TopK != nil {
164 params.TopK = param.NewOpt(*call.TopK)
165 }
166 if call.TopP != nil {
167 params.TopP = param.NewOpt(*call.TopP)
168 }
169
170 isThinking := false
171 var thinkingBudget int64
172 if providerOptions.Thinking != nil {
173 isThinking = true
174 thinkingBudget = providerOptions.Thinking.BudgetTokens
175 }
176 if isThinking {
177 if thinkingBudget == 0 {
178 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
179 }
180 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
181 if call.Temperature != nil {
182 params.Temperature = param.Opt[float64]{}
183 warnings = append(warnings, ai.CallWarning{
184 Type: ai.CallWarningTypeUnsupportedSetting,
185 Setting: "temperature",
186 Details: "temperature is not supported when thinking is enabled",
187 })
188 }
189 if call.TopP != nil {
190 params.TopP = param.Opt[float64]{}
191 warnings = append(warnings, ai.CallWarning{
192 Type: ai.CallWarningTypeUnsupportedSetting,
193 Setting: "TopP",
194 Details: "TopP is not supported when thinking is enabled",
195 })
196 }
197 if call.TopK != nil {
198 params.TopK = param.Opt[int64]{}
199 warnings = append(warnings, ai.CallWarning{
200 Type: ai.CallWarningTypeUnsupportedSetting,
201 Setting: "TopK",
202 Details: "TopK is not supported when thinking is enabled",
203 })
204 }
205 params.MaxTokens = params.MaxTokens + thinkingBudget
206 }
207
208 if len(call.Tools) > 0 {
209 disableParallelToolUse := false
210 if providerOptions.DisableParallelToolUse != nil {
211 disableParallelToolUse = *providerOptions.DisableParallelToolUse
212 }
213 tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
214 params.Tools = tools
215 if toolChoice != nil {
216 params.ToolChoice = *toolChoice
217 }
218 warnings = append(warnings, toolWarnings...)
219 }
220
221 return params, warnings, nil
222}
223
224func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
225 if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok {
226 if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
227 return &options.CacheControl
228 }
229 }
230 return nil
231}
232
233func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
234 if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok {
235 if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
236 return reasoning
237 }
238 }
239 return nil
240}
241
242type messageBlock struct {
243 Role ai.MessageRole
244 Messages []ai.Message
245}
246
247func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
248 var blocks []*messageBlock
249
250 var currentBlock *messageBlock
251
252 for _, msg := range prompt {
253 switch msg.Role {
254 case ai.MessageRoleSystem:
255 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
256 currentBlock = &messageBlock{
257 Role: ai.MessageRoleSystem,
258 Messages: []ai.Message{},
259 }
260 blocks = append(blocks, currentBlock)
261 }
262 currentBlock.Messages = append(currentBlock.Messages, msg)
263 case ai.MessageRoleUser:
264 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
265 currentBlock = &messageBlock{
266 Role: ai.MessageRoleUser,
267 Messages: []ai.Message{},
268 }
269 blocks = append(blocks, currentBlock)
270 }
271 currentBlock.Messages = append(currentBlock.Messages, msg)
272 case ai.MessageRoleAssistant:
273 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
274 currentBlock = &messageBlock{
275 Role: ai.MessageRoleAssistant,
276 Messages: []ai.Message{},
277 }
278 blocks = append(blocks, currentBlock)
279 }
280 currentBlock.Messages = append(currentBlock.Messages, msg)
281 case ai.MessageRoleTool:
282 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
283 currentBlock = &messageBlock{
284 Role: ai.MessageRoleUser,
285 Messages: []ai.Message{},
286 }
287 blocks = append(blocks, currentBlock)
288 }
289 currentBlock.Messages = append(currentBlock.Messages, msg)
290 }
291 }
292 return blocks
293}
294
295func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
296 for _, tool := range tools {
297 if tool.GetType() == ai.ToolTypeFunction {
298 ft, ok := tool.(ai.FunctionTool)
299 if !ok {
300 continue
301 }
302 required := []string{}
303 var properties any
304 if props, ok := ft.InputSchema["properties"]; ok {
305 properties = props
306 }
307 if req, ok := ft.InputSchema["required"]; ok {
308 if reqArr, ok := req.([]string); ok {
309 required = reqArr
310 }
311 }
312 cacheControl := getCacheControl(ft.ProviderOptions)
313
314 anthropicTool := anthropic.ToolParam{
315 Name: ft.Name,
316 Description: anthropic.String(ft.Description),
317 InputSchema: anthropic.ToolInputSchemaParam{
318 Properties: properties,
319 Required: required,
320 },
321 }
322 if cacheControl != nil {
323 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
324 }
325 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
326 continue
327 }
328 // TODO: handle provider tool calls
329 warnings = append(warnings, ai.CallWarning{
330 Type: ai.CallWarningTypeUnsupportedTool,
331 Tool: tool,
332 Message: "tool is not supported",
333 })
334 }
335 if toolChoice == nil {
336 if disableParallelToolCalls {
337 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
338 OfAuto: &anthropic.ToolChoiceAutoParam{
339 Type: "auto",
340 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
341 },
342 }
343 }
344 return anthropicTools, anthropicToolChoice, warnings
345 }
346
347 switch *toolChoice {
348 case ai.ToolChoiceAuto:
349 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
350 OfAuto: &anthropic.ToolChoiceAutoParam{
351 Type: "auto",
352 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
353 },
354 }
355 case ai.ToolChoiceRequired:
356 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
357 OfAny: &anthropic.ToolChoiceAnyParam{
358 Type: "any",
359 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
360 },
361 }
362 case ai.ToolChoiceNone:
363 return anthropicTools, anthropicToolChoice, warnings
364 default:
365 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
366 OfTool: &anthropic.ToolChoiceToolParam{
367 Type: "tool",
368 Name: string(*toolChoice),
369 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
370 },
371 }
372 }
373 return anthropicTools, anthropicToolChoice, warnings
374}
375
376func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
377 var systemBlocks []anthropic.TextBlockParam
378 var messages []anthropic.MessageParam
379 var warnings []ai.CallWarning
380
381 blocks := groupIntoBlocks(prompt)
382 finishedSystemBlock := false
383 for _, block := range blocks {
384 switch block.Role {
385 case ai.MessageRoleSystem:
386 if finishedSystemBlock {
387 // skip multiple system messages that are separated by user/assistant messages
388 // TODO: see if we need to send error here?
389 continue
390 }
391 finishedSystemBlock = true
392 for _, msg := range block.Messages {
393 for _, part := range msg.Content {
394 cacheControl := getCacheControl(part.Options())
395 text, ok := ai.AsMessagePart[ai.TextPart](part)
396 if !ok {
397 continue
398 }
399 textBlock := anthropic.TextBlockParam{
400 Text: text.Text,
401 }
402 if cacheControl != nil {
403 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
404 }
405 systemBlocks = append(systemBlocks, textBlock)
406 }
407 }
408
409 case ai.MessageRoleUser:
410 var anthropicContent []anthropic.ContentBlockParamUnion
411 for _, msg := range block.Messages {
412 if msg.Role == ai.MessageRoleUser {
413 for i, part := range msg.Content {
414 isLastPart := i == len(msg.Content)-1
415 cacheControl := getCacheControl(part.Options())
416 if cacheControl == nil && isLastPart {
417 cacheControl = getCacheControl(msg.ProviderOptions)
418 }
419 switch part.GetType() {
420 case ai.ContentTypeText:
421 text, ok := ai.AsMessagePart[ai.TextPart](part)
422 if !ok {
423 continue
424 }
425 textBlock := &anthropic.TextBlockParam{
426 Text: text.Text,
427 }
428 if cacheControl != nil {
429 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
430 }
431 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
432 OfText: textBlock,
433 })
434 case ai.ContentTypeFile:
435 file, ok := ai.AsMessagePart[ai.FilePart](part)
436 if !ok {
437 continue
438 }
439 // TODO: handle other file types
440 if !strings.HasPrefix(file.MediaType, "image/") {
441 continue
442 }
443
444 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
445 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
446 if cacheControl != nil {
447 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
448 }
449 anthropicContent = append(anthropicContent, imageBlock)
450 }
451 }
452 } else if msg.Role == ai.MessageRoleTool {
453 for i, part := range msg.Content {
454 isLastPart := i == len(msg.Content)-1
455 cacheControl := getCacheControl(part.Options())
456 if cacheControl == nil && isLastPart {
457 cacheControl = getCacheControl(msg.ProviderOptions)
458 }
459 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
460 if !ok {
461 continue
462 }
463 toolResultBlock := anthropic.ToolResultBlockParam{
464 ToolUseID: result.ToolCallID,
465 }
466 switch result.Output.GetType() {
467 case ai.ToolResultContentTypeText:
468 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
469 if !ok {
470 continue
471 }
472 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
473 {
474 OfText: &anthropic.TextBlockParam{
475 Text: content.Text,
476 },
477 },
478 }
479 case ai.ToolResultContentTypeMedia:
480 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
481 if !ok {
482 continue
483 }
484 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
485 {
486 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
487 },
488 }
489 case ai.ToolResultContentTypeError:
490 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
491 if !ok {
492 continue
493 }
494 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
495 {
496 OfText: &anthropic.TextBlockParam{
497 Text: content.Error.Error(),
498 },
499 },
500 }
501 toolResultBlock.IsError = param.NewOpt(true)
502 }
503 if cacheControl != nil {
504 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
505 }
506 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
507 OfToolResult: &toolResultBlock,
508 })
509 }
510 }
511 }
512 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
513 case ai.MessageRoleAssistant:
514 var anthropicContent []anthropic.ContentBlockParamUnion
515 for _, msg := range block.Messages {
516 for i, part := range msg.Content {
517 isLastPart := i == len(msg.Content)-1
518 cacheControl := getCacheControl(part.Options())
519 if cacheControl == nil && isLastPart {
520 cacheControl = getCacheControl(msg.ProviderOptions)
521 }
522 switch part.GetType() {
523 case ai.ContentTypeText:
524 text, ok := ai.AsMessagePart[ai.TextPart](part)
525 if !ok {
526 continue
527 }
528 textBlock := &anthropic.TextBlockParam{
529 Text: text.Text,
530 }
531 if cacheControl != nil {
532 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
533 }
534 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
535 OfText: textBlock,
536 })
537 case ai.ContentTypeReasoning:
538 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
539 if !ok {
540 continue
541 }
542 if !sendReasoningData {
543 warnings = append(warnings, ai.CallWarning{
544 Type: "other",
545 Message: "sending reasoning content is disabled for this model",
546 })
547 continue
548 }
549 reasoningMetadata := getReasoningMetadata(part.Options())
550 if reasoningMetadata == nil {
551 warnings = append(warnings, ai.CallWarning{
552 Type: "other",
553 Message: "unsupported reasoning metadata",
554 })
555 continue
556 }
557
558 if reasoningMetadata.Signature != "" {
559 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
560 } else if reasoningMetadata.RedactedData != "" {
561 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
562 } else {
563 warnings = append(warnings, ai.CallWarning{
564 Type: "other",
565 Message: "unsupported reasoning metadata",
566 })
567 continue
568 }
569 case ai.ContentTypeToolCall:
570 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
571 if !ok {
572 continue
573 }
574 if toolCall.ProviderExecuted {
575 // TODO: implement provider executed call
576 continue
577 }
578
579 var inputMap map[string]any
580 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
581 if err != nil {
582 continue
583 }
584 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
585 if cacheControl != nil {
586 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
587 }
588 anthropicContent = append(anthropicContent, toolUseBlock)
589 case ai.ContentTypeToolResult:
590 // TODO: implement provider executed tool result
591 }
592 }
593 }
594 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
595 }
596 }
597 return systemBlocks, messages, warnings
598}
599
600func (o languageModel) handleError(err error) error {
601 var apiErr *anthropic.Error
602 if errors.As(err, &apiErr) {
603 requestDump := apiErr.DumpRequest(true)
604 responseDump := apiErr.DumpResponse(true)
605 headers := map[string]string{}
606 for k, h := range apiErr.Response.Header {
607 v := h[len(h)-1]
608 headers[strings.ToLower(k)] = v
609 }
610 return ai.NewAPICallError(
611 apiErr.Error(),
612 apiErr.Request.URL.String(),
613 string(requestDump),
614 apiErr.StatusCode,
615 headers,
616 string(responseDump),
617 apiErr,
618 false,
619 )
620 }
621 return err
622}
623
624func mapFinishReason(finishReason string) ai.FinishReason {
625 switch finishReason {
626 case "end_turn", "pause_turn", "stop_sequence":
627 return ai.FinishReasonStop
628 case "max_tokens":
629 return ai.FinishReasonLength
630 case "tool_use":
631 return ai.FinishReasonToolCalls
632 default:
633 return ai.FinishReasonUnknown
634 }
635}
636
637// Generate implements ai.LanguageModel.
638func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
639 params, warnings, err := a.prepareParams(call)
640 if err != nil {
641 return nil, err
642 }
643 response, err := a.client.Messages.New(ctx, *params)
644 if err != nil {
645 return nil, a.handleError(err)
646 }
647
648 var content []ai.Content
649 for _, block := range response.Content {
650 switch block.Type {
651 case "text":
652 text, ok := block.AsAny().(anthropic.TextBlock)
653 if !ok {
654 continue
655 }
656 content = append(content, ai.TextContent{
657 Text: text.Text,
658 })
659 case "thinking":
660 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
661 if !ok {
662 continue
663 }
664 content = append(content, ai.ReasoningContent{
665 Text: reasoning.Thinking,
666 ProviderMetadata: ai.ProviderMetadata{
667 ProviderOptionsKey: &ReasoningOptionMetadata{
668 Signature: reasoning.Signature,
669 },
670 },
671 })
672 case "redacted_thinking":
673 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
674 if !ok {
675 continue
676 }
677 content = append(content, ai.ReasoningContent{
678 Text: "",
679 ProviderMetadata: ai.ProviderMetadata{
680 ProviderOptionsKey: &ReasoningOptionMetadata{
681 RedactedData: reasoning.Data,
682 },
683 },
684 })
685 case "tool_use":
686 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
687 if !ok {
688 continue
689 }
690 content = append(content, ai.ToolCallContent{
691 ToolCallID: toolUse.ID,
692 ToolName: toolUse.Name,
693 Input: string(toolUse.Input),
694 ProviderExecuted: false,
695 })
696 }
697 }
698
699 return &ai.Response{
700 Content: content,
701 Usage: ai.Usage{
702 InputTokens: response.Usage.InputTokens,
703 OutputTokens: response.Usage.OutputTokens,
704 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
705 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
706 CacheReadTokens: response.Usage.CacheReadInputTokens,
707 },
708 FinishReason: mapFinishReason(string(response.StopReason)),
709 ProviderMetadata: ai.ProviderMetadata{},
710 Warnings: warnings,
711 }, nil
712}
713
714// Stream implements ai.LanguageModel.
715func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
716 params, warnings, err := a.prepareParams(call)
717 if err != nil {
718 return nil, err
719 }
720
721 stream := a.client.Messages.NewStreaming(ctx, *params)
722 acc := anthropic.Message{}
723 return func(yield func(ai.StreamPart) bool) {
724 if len(warnings) > 0 {
725 if !yield(ai.StreamPart{
726 Type: ai.StreamPartTypeWarnings,
727 Warnings: warnings,
728 }) {
729 return
730 }
731 }
732
733 for stream.Next() {
734 chunk := stream.Current()
735 _ = acc.Accumulate(chunk)
736 switch chunk.Type {
737 case "content_block_start":
738 contentBlockType := chunk.ContentBlock.Type
739 switch contentBlockType {
740 case "text":
741 if !yield(ai.StreamPart{
742 Type: ai.StreamPartTypeTextStart,
743 ID: fmt.Sprintf("%d", chunk.Index),
744 }) {
745 return
746 }
747 case "thinking":
748 if !yield(ai.StreamPart{
749 Type: ai.StreamPartTypeReasoningStart,
750 ID: fmt.Sprintf("%d", chunk.Index),
751 }) {
752 return
753 }
754 case "redacted_thinking":
755 if !yield(ai.StreamPart{
756 Type: ai.StreamPartTypeReasoningStart,
757 ID: fmt.Sprintf("%d", chunk.Index),
758 ProviderMetadata: ai.ProviderMetadata{
759 ProviderOptionsKey: &ReasoningOptionMetadata{
760 RedactedData: chunk.ContentBlock.Data,
761 },
762 },
763 }) {
764 return
765 }
766 case "tool_use":
767 if !yield(ai.StreamPart{
768 Type: ai.StreamPartTypeToolInputStart,
769 ID: chunk.ContentBlock.ID,
770 ToolCallName: chunk.ContentBlock.Name,
771 ToolCallInput: "",
772 }) {
773 return
774 }
775 }
776 case "content_block_stop":
777 if len(acc.Content)-1 < int(chunk.Index) {
778 continue
779 }
780 contentBlock := acc.Content[int(chunk.Index)]
781 switch contentBlock.Type {
782 case "text":
783 if !yield(ai.StreamPart{
784 Type: ai.StreamPartTypeTextEnd,
785 ID: fmt.Sprintf("%d", chunk.Index),
786 }) {
787 return
788 }
789 case "thinking":
790 if !yield(ai.StreamPart{
791 Type: ai.StreamPartTypeReasoningEnd,
792 ID: fmt.Sprintf("%d", chunk.Index),
793 }) {
794 return
795 }
796 case "tool_use":
797 if !yield(ai.StreamPart{
798 Type: ai.StreamPartTypeToolInputEnd,
799 ID: contentBlock.ID,
800 }) {
801 return
802 }
803 if !yield(ai.StreamPart{
804 Type: ai.StreamPartTypeToolCall,
805 ID: contentBlock.ID,
806 ToolCallName: contentBlock.Name,
807 ToolCallInput: string(contentBlock.Input),
808 }) {
809 return
810 }
811 }
812 case "content_block_delta":
813 switch chunk.Delta.Type {
814 case "text_delta":
815 if !yield(ai.StreamPart{
816 Type: ai.StreamPartTypeTextDelta,
817 ID: fmt.Sprintf("%d", chunk.Index),
818 Delta: chunk.Delta.Text,
819 }) {
820 return
821 }
822 case "thinking_delta":
823 if !yield(ai.StreamPart{
824 Type: ai.StreamPartTypeReasoningDelta,
825 ID: fmt.Sprintf("%d", chunk.Index),
826 Delta: chunk.Delta.Thinking,
827 }) {
828 return
829 }
830 case "signature_delta":
831 if !yield(ai.StreamPart{
832 Type: ai.StreamPartTypeReasoningDelta,
833 ID: fmt.Sprintf("%d", chunk.Index),
834 ProviderMetadata: ai.ProviderMetadata{
835 ProviderOptionsKey: &ReasoningOptionMetadata{
836 Signature: chunk.Delta.Signature,
837 },
838 },
839 }) {
840 return
841 }
842 case "input_json_delta":
843 if len(acc.Content)-1 < int(chunk.Index) {
844 continue
845 }
846 contentBlock := acc.Content[int(chunk.Index)]
847 if !yield(ai.StreamPart{
848 Type: ai.StreamPartTypeToolInputDelta,
849 ID: contentBlock.ID,
850 ToolCallInput: chunk.Delta.PartialJSON,
851 }) {
852 return
853 }
854 }
855 case "message_stop":
856 }
857 }
858
859 err := stream.Err()
860 if err == nil || errors.Is(err, io.EOF) {
861 yield(ai.StreamPart{
862 Type: ai.StreamPartTypeFinish,
863 ID: acc.ID,
864 FinishReason: mapFinishReason(string(acc.StopReason)),
865 Usage: ai.Usage{
866 InputTokens: acc.Usage.InputTokens,
867 OutputTokens: acc.Usage.OutputTokens,
868 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
869 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
870 CacheReadTokens: acc.Usage.CacheReadInputTokens,
871 },
872 ProviderMetadata: ai.ProviderMetadata{},
873 })
874 return
875 } else {
876 yield(ai.StreamPart{
877 Type: ai.StreamPartTypeError,
878 Error: a.handleError(err),
879 })
880 return
881 }
882 }, nil
883}