reader.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. package wsutil
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "io"
  6. "io/ioutil"
  7. "github.com/gobwas/ws"
  8. )
  9. // ErrNoFrameAdvance means that Reader's Read() method was called without
  10. // preceding NextFrame() call.
  11. var ErrNoFrameAdvance = errors.New("no frame advance")
  12. // ErrFrameTooLarge indicates that a message of length higher than
  13. // MaxFrameSize was being read.
  14. var ErrFrameTooLarge = errors.New("frame too large")
  15. // FrameHandlerFunc handles parsed frame header and its body represented by
  16. // io.Reader.
  17. //
  18. // Note that reader represents already unmasked body.
  19. type FrameHandlerFunc func(ws.Header, io.Reader) error
  20. // Reader is a wrapper around source io.Reader which represents WebSocket
  21. // connection. It contains options for reading messages from source.
  22. //
  23. // Reader implements io.Reader, which Read() method reads payload of incoming
  24. // WebSocket frames. It also takes care on fragmented frames and possibly
  25. // intermediate control frames between them.
  26. //
  27. // Note that Reader's methods are not goroutine safe.
  28. type Reader struct {
  29. Source io.Reader
  30. State ws.State
  31. // SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
  32. SkipHeaderCheck bool
  33. // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
  34. // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
  35. CheckUTF8 bool
  36. // Extensions is a list of negotiated extensions for reader Source.
  37. // It is used to meet the specs and clear appropriate bits in fragment
  38. // header RSV segment.
  39. Extensions []RecvExtension
  40. // MaxFrameSize controls the maximum frame size in bytes
  41. // that can be read. A message exceeding that size will return
  42. // a ErrFrameTooLarge to the application.
  43. //
  44. // Not setting this field means there is no limit.
  45. MaxFrameSize int64
  46. OnContinuation FrameHandlerFunc
  47. OnIntermediate FrameHandlerFunc
  48. opCode ws.OpCode // Used to store message op code on fragmentation.
  49. frame io.Reader // Used to as frame reader.
  50. raw io.LimitedReader // Used to discard frames without cipher.
  51. utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
  52. tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers.
  53. cr *CipherReader // Used by NextFrame() to unmask frame payload.
  54. }
  55. // NewReader creates new frame reader that reads from r keeping given state to
  56. // make some protocol validity checks when it needed.
  57. func NewReader(r io.Reader, s ws.State) *Reader {
  58. return &Reader{
  59. Source: r,
  60. State: s,
  61. }
  62. }
  63. // NewClientSideReader is a helper function that calls NewReader with r and
  64. // ws.StateClientSide.
  65. func NewClientSideReader(r io.Reader) *Reader {
  66. return NewReader(r, ws.StateClientSide)
  67. }
  68. // NewServerSideReader is a helper function that calls NewReader with r and
  69. // ws.StateServerSide.
  70. func NewServerSideReader(r io.Reader) *Reader {
  71. return NewReader(r, ws.StateServerSide)
  72. }
  73. // Read implements io.Reader. It reads the next message payload into p.
  74. // It takes care on fragmented messages.
  75. //
  76. // The error is io.EOF only if all of message bytes were read.
  77. // If an io.EOF happens during reading some but not all the message bytes
  78. // Read() returns io.ErrUnexpectedEOF.
  79. //
  80. // The error is ErrNoFrameAdvance if no NextFrame() call was made before
  81. // reading next message bytes.
  82. func (r *Reader) Read(p []byte) (n int, err error) {
  83. if r.frame == nil {
  84. if !r.fragmented() {
  85. // Every new Read() must be preceded by NextFrame() call.
  86. return 0, ErrNoFrameAdvance
  87. }
  88. // Read next continuation or intermediate control frame.
  89. _, err := r.NextFrame()
  90. if err != nil {
  91. return 0, err
  92. }
  93. if r.frame == nil {
  94. // We handled intermediate control and now got nothing to read.
  95. return 0, nil
  96. }
  97. }
  98. n, err = r.frame.Read(p)
  99. if err != nil && err != io.EOF {
  100. return n, err
  101. }
  102. if err == nil && r.raw.N != 0 {
  103. return n, nil
  104. }
  105. // EOF condition (either err is io.EOF or r.raw.N is zero).
  106. switch {
  107. case r.raw.N != 0:
  108. err = io.ErrUnexpectedEOF
  109. case r.fragmented():
  110. err = nil
  111. r.resetFragment()
  112. case r.CheckUTF8 && !r.utf8.Valid():
  113. // NOTE: check utf8 only when full message received, since partial
  114. // reads may be invalid.
  115. n = r.utf8.Accepted()
  116. err = ErrInvalidUTF8
  117. default:
  118. r.reset()
  119. err = io.EOF
  120. }
  121. return n, err
  122. }
  123. // Discard discards current message unread bytes.
  124. // It discards all frames of fragmented message.
  125. func (r *Reader) Discard() (err error) {
  126. for {
  127. _, err = io.Copy(ioutil.Discard, &r.raw)
  128. if err != nil {
  129. break
  130. }
  131. if !r.fragmented() {
  132. break
  133. }
  134. if _, err = r.NextFrame(); err != nil {
  135. break
  136. }
  137. }
  138. r.reset()
  139. return err
  140. }
  141. // NextFrame prepares r to read next message. It returns received frame header
  142. // and non-nil error on failure.
  143. //
  144. // Note that next NextFrame() call must be done after receiving or discarding
  145. // all current message bytes.
  146. func (r *Reader) NextFrame() (hdr ws.Header, err error) {
  147. hdr, err = r.readHeader(r.Source)
  148. if err == io.EOF && r.fragmented() {
  149. // If we are in fragmented state EOF means that is was totally
  150. // unexpected.
  151. //
  152. // NOTE: This is necessary to prevent callers such that
  153. // ioutil.ReadAll to receive some amount of bytes without an error.
  154. // ReadAll() ignores an io.EOF error, thus caller may think that
  155. // whole message fetched, but actually only part of it.
  156. err = io.ErrUnexpectedEOF
  157. }
  158. if err == nil && !r.SkipHeaderCheck {
  159. err = ws.CheckHeader(hdr, r.State)
  160. }
  161. if err != nil {
  162. return hdr, err
  163. }
  164. if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
  165. return hdr, ErrFrameTooLarge
  166. }
  167. // Save raw reader to use it on discarding frame without ciphering and
  168. // other streaming checks.
  169. r.raw = io.LimitedReader{
  170. R: r.Source,
  171. N: hdr.Length,
  172. }
  173. frame := io.Reader(&r.raw)
  174. if hdr.Masked {
  175. if r.cr == nil {
  176. r.cr = NewCipherReader(frame, hdr.Mask)
  177. } else {
  178. r.cr.Reset(frame, hdr.Mask)
  179. }
  180. frame = r.cr
  181. }
  182. for _, x := range r.Extensions {
  183. hdr, err = x.UnsetBits(hdr)
  184. if err != nil {
  185. return hdr, err
  186. }
  187. }
  188. if r.fragmented() {
  189. if hdr.OpCode.IsControl() {
  190. if cb := r.OnIntermediate; cb != nil {
  191. err = cb(hdr, frame)
  192. }
  193. if err == nil {
  194. // Ensure that src is empty.
  195. _, err = io.Copy(ioutil.Discard, &r.raw)
  196. }
  197. return hdr, err
  198. }
  199. } else {
  200. r.opCode = hdr.OpCode
  201. }
  202. if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
  203. r.utf8.Source = frame
  204. frame = &r.utf8
  205. }
  206. // Save reader with ciphering and other streaming checks.
  207. r.frame = frame
  208. if hdr.OpCode == ws.OpContinuation {
  209. if cb := r.OnContinuation; cb != nil {
  210. err = cb(hdr, frame)
  211. }
  212. }
  213. if hdr.Fin {
  214. r.State = r.State.Clear(ws.StateFragmented)
  215. } else {
  216. r.State = r.State.Set(ws.StateFragmented)
  217. }
  218. return hdr, err
  219. }
  220. func (r *Reader) fragmented() bool {
  221. return r.State.Fragmented()
  222. }
  223. func (r *Reader) resetFragment() {
  224. r.raw = io.LimitedReader{}
  225. r.frame = nil
  226. // Reset source of the UTF8Reader, but not the state.
  227. r.utf8.Source = nil
  228. }
  229. func (r *Reader) reset() {
  230. r.raw = io.LimitedReader{}
  231. r.frame = nil
  232. r.utf8 = UTF8Reader{}
  233. r.opCode = 0
  234. }
  235. // readHeader reads a frame header from in.
  236. func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) {
  237. // Make slice of bytes with capacity 12 that could hold any header.
  238. //
  239. // The maximum header size is 14, but due to the 2 hop reads,
  240. // after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
  241. // So 14 - 2 = 12.
  242. bts := r.tmp[:2]
  243. // Prepare to hold first 2 bytes to choose size of next read.
  244. _, err = io.ReadFull(in, bts)
  245. if err != nil {
  246. return h, err
  247. }
  248. const bit0 = 0x80
  249. h.Fin = bts[0]&bit0 != 0
  250. h.Rsv = (bts[0] & 0x70) >> 4
  251. h.OpCode = ws.OpCode(bts[0] & 0x0f)
  252. var extra int
  253. if bts[1]&bit0 != 0 {
  254. h.Masked = true
  255. extra += 4
  256. }
  257. length := bts[1] & 0x7f
  258. switch {
  259. case length < 126:
  260. h.Length = int64(length)
  261. case length == 126:
  262. extra += 2
  263. case length == 127:
  264. extra += 8
  265. default:
  266. err = ws.ErrHeaderLengthUnexpected
  267. return h, err
  268. }
  269. if extra == 0 {
  270. return h, err
  271. }
  272. // Increase len of bts to extra bytes need to read.
  273. // Overwrite first 2 bytes that was read before.
  274. bts = bts[:extra]
  275. _, err = io.ReadFull(in, bts)
  276. if err != nil {
  277. return h, err
  278. }
  279. switch {
  280. case length == 126:
  281. h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
  282. bts = bts[2:]
  283. case length == 127:
  284. if bts[0]&0x80 != 0 {
  285. err = ws.ErrHeaderLengthMSB
  286. return h, err
  287. }
  288. h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
  289. bts = bts[8:]
  290. }
  291. if h.Masked {
  292. copy(h.Mask[:], bts)
  293. }
  294. return h, nil
  295. }
  296. // NextReader prepares next message read from r. It returns header that
  297. // describes the message and io.Reader to read message's payload. It returns
  298. // non-nil error when it is not possible to read message's initial frame.
  299. //
  300. // Note that next NextReader() on the same r should be done after reading all
  301. // bytes from previously returned io.Reader. For more performant way to discard
  302. // message use Reader and its Discard() method.
  303. //
  304. // Note that it will not handle any "intermediate" frames, that possibly could
  305. // be received between text/binary continuation frames. That is, if peer sent
  306. // text/binary frame with fin flag "false", then it could send ping frame, and
  307. // eventually remaining part of text/binary frame with fin "true" – with
  308. // NextReader() the ping frame will be dropped without any notice. To handle
  309. // this rare, but possible situation (and if you do not know exactly which
  310. // frames peer could send), you could use Reader with OnIntermediate field set.
  311. func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
  312. rd := &Reader{
  313. Source: r,
  314. State: s,
  315. }
  316. header, err := rd.NextFrame()
  317. if err != nil {
  318. return header, nil, err
  319. }
  320. return header, rd, nil
  321. }