1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3package ssestream
4
5import (
6 "bufio"
7 "bytes"
8 "encoding/json"
9 "fmt"
10 "io"
11 "net/http"
12 "strings"
13
14 "github.com/tidwall/gjson"
15)
16
17type Decoder interface {
18 Event() Event
19 Next() bool
20 Close() error
21 Err() error
22}
23
24func NewDecoder(res *http.Response) Decoder {
25 if res == nil || res.Body == nil {
26 return nil
27 }
28
29 var decoder Decoder
30 contentType := res.Header.Get("content-type")
31 if t, ok := decoderTypes[contentType]; ok {
32 decoder = t(res.Body)
33 } else {
34 scanner := bufio.NewScanner(res.Body)
35 decoder = &eventStreamDecoder{rc: res.Body, scn: scanner}
36 }
37 return decoder
38}
39
40var decoderTypes = map[string](func(io.ReadCloser) Decoder){}
41
42func RegisterDecoder(contentType string, decoder func(io.ReadCloser) Decoder) {
43 decoderTypes[strings.ToLower(contentType)] = decoder
44}
45
46type Event struct {
47 Type string
48 Data []byte
49}
50
51// A base implementation of a Decoder for text/event-stream.
52type eventStreamDecoder struct {
53 evt Event
54 rc io.ReadCloser
55 scn *bufio.Scanner
56 err error
57}
58
59func (s *eventStreamDecoder) Next() bool {
60 if s.err != nil {
61 return false
62 }
63
64 event := ""
65 data := bytes.NewBuffer(nil)
66
67 for s.scn.Scan() {
68 txt := s.scn.Bytes()
69
70 // Dispatch event on an empty line
71 if len(txt) == 0 {
72 s.evt = Event{
73 Type: event,
74 Data: data.Bytes(),
75 }
76 return true
77 }
78
79 // Split a string like "event: bar" into name="event" and value=" bar".
80 name, value, _ := bytes.Cut(txt, []byte(":"))
81
82 // Consume an optional space after the colon if it exists.
83 if len(value) > 0 && value[0] == ' ' {
84 value = value[1:]
85 }
86
87 switch string(name) {
88 case "":
89 // An empty line in the for ": something" is a comment and should be ignored.
90 continue
91 case "event":
92 event = string(value)
93 case "data":
94 _, s.err = data.Write(value)
95 if s.err != nil {
96 break
97 }
98 _, s.err = data.WriteRune('\n')
99 if s.err != nil {
100 break
101 }
102 }
103 }
104
105 if s.scn.Err() != nil {
106 s.err = s.scn.Err()
107 }
108
109 return false
110}
111
112func (s *eventStreamDecoder) Event() Event {
113 return s.evt
114}
115
116func (s *eventStreamDecoder) Close() error {
117 return s.rc.Close()
118}
119
120func (s *eventStreamDecoder) Err() error {
121 return s.err
122}
123
124type Stream[T any] struct {
125 decoder Decoder
126 cur T
127 err error
128 done bool
129}
130
131func NewStream[T any](decoder Decoder, err error) *Stream[T] {
132 return &Stream[T]{
133 decoder: decoder,
134 err: err,
135 }
136}
137
138// Next returns false if the stream has ended or an error occurred.
139// Call Stream.Current() to get the current value.
140// Call Stream.Err() to get the error.
141//
142// for stream.Next() {
143// data := stream.Current()
144// }
145//
146// if stream.Err() != nil {
147// ...
148// }
149func (s *Stream[T]) Next() bool {
150 if s.err != nil {
151 return false
152 }
153
154 for s.decoder.Next() {
155 if s.done {
156 continue
157 }
158
159 if bytes.HasPrefix(s.decoder.Event().Data, []byte("[DONE]")) {
160 // In this case we don't break because we still want to iterate through the full stream.
161 s.done = true
162 continue
163 }
164
165 if s.decoder.Event().Type == "" || strings.HasPrefix(s.decoder.Event().Type, "response.") {
166 ep := gjson.GetBytes(s.decoder.Event().Data, "error")
167 if ep.Exists() {
168 s.err = fmt.Errorf("received error while streaming: %s", ep.String())
169 return false
170 }
171 s.err = json.Unmarshal(s.decoder.Event().Data, &s.cur)
172 if s.err != nil {
173 return false
174 }
175 return true
176 } else {
177 ep := gjson.GetBytes(s.decoder.Event().Data, "error")
178 if ep.Exists() {
179 s.err = fmt.Errorf("received error while streaming: %s", ep.String())
180 return false
181 }
182 event := s.decoder.Event().Type
183 data := s.decoder.Event().Data
184 s.err = json.Unmarshal([]byte(fmt.Sprintf(`{ "event": %q, "data": %s }`, event, data)), &s.cur)
185 if s.err != nil {
186 return false
187 }
188 return true
189 }
190 }
191
192 // decoder.Next() may be false because of an error
193 s.err = s.decoder.Err()
194
195 return false
196}
197
198func (s *Stream[T]) Current() T {
199 return s.cur
200}
201
202func (s *Stream[T]) Err() error {
203 return s.err
204}
205
206func (s *Stream[T]) Close() error {
207 if s.decoder == nil {
208 // already closed
209 return nil
210 }
211 return s.decoder.Close()
212}