http.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "github.com/gobwas/httphead"
  10. )
  11. const (
  12. crlf = "\r\n"
  13. colonAndSpace = ": "
  14. commaAndSpace = ", "
  15. )
  16. const (
  17. textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
  18. )
  19. var (
  20. textHeadBadRequest = statusText(http.StatusBadRequest)
  21. textHeadInternalServerError = statusText(http.StatusInternalServerError)
  22. textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
  23. textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
  24. textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
  25. textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
  26. textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
  27. textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
  28. textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
  29. textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
  30. textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
  31. textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
  32. )
  33. const (
  34. // Every new header must be added to TestHeaderNames test.
  35. headerHost = "Host"
  36. headerUpgrade = "Upgrade"
  37. headerConnection = "Connection"
  38. headerSecVersion = "Sec-WebSocket-Version"
  39. headerSecProtocol = "Sec-WebSocket-Protocol"
  40. headerSecExtensions = "Sec-WebSocket-Extensions"
  41. headerSecKey = "Sec-WebSocket-Key"
  42. headerSecAccept = "Sec-WebSocket-Accept"
  43. headerHostCanonical = headerHost
  44. headerUpgradeCanonical = headerUpgrade
  45. headerConnectionCanonical = headerConnection
  46. headerSecVersionCanonical = "Sec-Websocket-Version"
  47. headerSecProtocolCanonical = "Sec-Websocket-Protocol"
  48. headerSecExtensionsCanonical = "Sec-Websocket-Extensions"
  49. headerSecKeyCanonical = "Sec-Websocket-Key"
  50. headerSecAcceptCanonical = "Sec-Websocket-Accept"
  51. )
  52. var (
  53. specHeaderValueUpgrade = []byte("websocket")
  54. specHeaderValueConnection = []byte("Upgrade")
  55. specHeaderValueConnectionLower = []byte("upgrade")
  56. specHeaderValueSecVersion = []byte("13")
  57. )
  58. var (
  59. httpVersion1_0 = []byte("HTTP/1.0")
  60. httpVersion1_1 = []byte("HTTP/1.1")
  61. httpVersionPrefix = []byte("HTTP/")
  62. )
  63. type httpRequestLine struct {
  64. method, uri []byte
  65. major, minor int
  66. }
  67. type httpResponseLine struct {
  68. major, minor int
  69. status int
  70. reason []byte
  71. }
  72. // httpParseRequestLine parses http request line like "GET / HTTP/1.0".
  73. func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
  74. var proto []byte
  75. req.method, req.uri, proto = bsplit3(line, ' ')
  76. var ok bool
  77. req.major, req.minor, ok = httpParseVersion(proto)
  78. if !ok {
  79. err = ErrMalformedRequest
  80. }
  81. return req, err
  82. }
  83. func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
  84. var (
  85. proto []byte
  86. status []byte
  87. )
  88. proto, status, resp.reason = bsplit3(line, ' ')
  89. var ok bool
  90. resp.major, resp.minor, ok = httpParseVersion(proto)
  91. if !ok {
  92. return resp, ErrMalformedResponse
  93. }
  94. var convErr error
  95. resp.status, convErr = asciiToInt(status)
  96. if convErr != nil {
  97. return resp, ErrMalformedResponse
  98. }
  99. return resp, nil
  100. }
  101. // httpParseVersion parses major and minor version of HTTP protocol. It returns
  102. // parsed values and true if parse is ok.
  103. func httpParseVersion(bts []byte) (major, minor int, ok bool) {
  104. switch {
  105. case bytes.Equal(bts, httpVersion1_0):
  106. return 1, 0, true
  107. case bytes.Equal(bts, httpVersion1_1):
  108. return 1, 1, true
  109. case len(bts) < 8:
  110. return 0, 0, false
  111. case !bytes.Equal(bts[:5], httpVersionPrefix):
  112. return 0, 0, false
  113. }
  114. bts = bts[5:]
  115. dot := bytes.IndexByte(bts, '.')
  116. if dot == -1 {
  117. return 0, 0, false
  118. }
  119. var err error
  120. major, err = asciiToInt(bts[:dot])
  121. if err != nil {
  122. return major, 0, false
  123. }
  124. minor, err = asciiToInt(bts[dot+1:])
  125. if err != nil {
  126. return major, minor, false
  127. }
  128. return major, minor, true
  129. }
  130. // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
  131. // values and true if parse is ok.
  132. func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
  133. colon := bytes.IndexByte(line, ':')
  134. if colon == -1 {
  135. return nil, nil, false
  136. }
  137. k = btrim(line[:colon])
  138. // TODO(gobwas): maybe use just lower here?
  139. canonicalizeHeaderKey(k)
  140. v = btrim(line[colon+1:])
  141. return k, v, true
  142. }
  143. // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
  144. // that key is already canonical. This helps to increase performance.
  145. func httpGetHeader(h http.Header, key string) string {
  146. if h == nil {
  147. return ""
  148. }
  149. v := h[key]
  150. if len(v) == 0 {
  151. return ""
  152. }
  153. return v[0]
  154. }
  155. // The request MAY include a header field with the name
  156. // |Sec-WebSocket-Protocol|. If present, this value indicates one or more
  157. // comma-separated subprotocol the client wishes to speak, ordered by
  158. // preference. The elements that comprise this value MUST be non-empty strings
  159. // with characters in the range U+0021 to U+007E not including separator
  160. // characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
  161. // for the value of this header field is 1#token, where the definitions of
  162. // constructs and rules are as given in [RFC2616].
  163. func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
  164. ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
  165. if check(btsToString(v)) {
  166. ret = string(v)
  167. return false
  168. }
  169. return true
  170. })
  171. return ret, ok
  172. }
  173. func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
  174. var selected []byte
  175. ok = httphead.ScanTokens(h, func(v []byte) bool {
  176. if check(v) {
  177. selected = v
  178. return false
  179. }
  180. return true
  181. })
  182. if ok && selected != nil {
  183. return string(selected), true
  184. }
  185. return ret, ok
  186. }
  187. func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
  188. s := httphead.OptionSelector{
  189. Flags: httphead.SelectCopy,
  190. Check: check,
  191. }
  192. return s.Select(h, selected)
  193. }
  194. func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
  195. if in.Size() == 0 {
  196. return dest, nil
  197. }
  198. opt, err := f(in)
  199. if err != nil {
  200. return nil, err
  201. }
  202. if opt.Size() > 0 {
  203. dest = append(dest, opt)
  204. }
  205. return dest, nil
  206. }
  207. func negotiateExtensions(
  208. h []byte, dest []httphead.Option,
  209. f func(httphead.Option) (httphead.Option, error),
  210. ) (_ []httphead.Option, err error) {
  211. index := -1
  212. var current httphead.Option
  213. ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
  214. if i != index {
  215. dest, err = negotiateMaybe(current, dest, f)
  216. if err != nil {
  217. return httphead.ControlBreak
  218. }
  219. index = i
  220. current = httphead.Option{Name: name}
  221. }
  222. if attr != nil {
  223. current.Parameters.Set(attr, val)
  224. }
  225. return httphead.ControlContinue
  226. })
  227. if !ok {
  228. return nil, ErrMalformedRequest
  229. }
  230. return negotiateMaybe(current, dest, f)
  231. }
  232. func httpWriteHeader(bw *bufio.Writer, key, value string) {
  233. httpWriteHeaderKey(bw, key)
  234. bw.WriteString(value)
  235. bw.WriteString(crlf)
  236. }
  237. func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
  238. httpWriteHeaderKey(bw, key)
  239. bw.Write(value)
  240. bw.WriteString(crlf)
  241. }
  242. func httpWriteHeaderKey(bw *bufio.Writer, key string) {
  243. bw.WriteString(key)
  244. bw.WriteString(colonAndSpace)
  245. }
  246. func httpWriteUpgradeRequest(
  247. bw *bufio.Writer,
  248. u *url.URL,
  249. nonce []byte,
  250. protocols []string,
  251. extensions []httphead.Option,
  252. header HandshakeHeader,
  253. host string,
  254. ) {
  255. bw.WriteString("GET ")
  256. bw.WriteString(u.RequestURI())
  257. bw.WriteString(" HTTP/1.1\r\n")
  258. if host == "" {
  259. host = u.Host
  260. }
  261. httpWriteHeader(bw, headerHost, host)
  262. httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
  263. httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
  264. httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
  265. // NOTE: write nonce bytes as a string to prevent heap allocation –
  266. // WriteString() copy given string into its inner buffer, unlike Write()
  267. // which may write p directly to the underlying io.Writer – which in turn
  268. // will lead to p escape.
  269. httpWriteHeader(bw, headerSecKey, btsToString(nonce))
  270. if len(protocols) > 0 {
  271. httpWriteHeaderKey(bw, headerSecProtocol)
  272. for i, p := range protocols {
  273. if i > 0 {
  274. bw.WriteString(commaAndSpace)
  275. }
  276. bw.WriteString(p)
  277. }
  278. bw.WriteString(crlf)
  279. }
  280. if len(extensions) > 0 {
  281. httpWriteHeaderKey(bw, headerSecExtensions)
  282. httphead.WriteOptions(bw, extensions)
  283. bw.WriteString(crlf)
  284. }
  285. if header != nil {
  286. header.WriteTo(bw)
  287. }
  288. bw.WriteString(crlf)
  289. }
  290. func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
  291. bw.WriteString(textHeadUpgrade)
  292. httpWriteHeaderKey(bw, headerSecAccept)
  293. writeAccept(bw, nonce)
  294. bw.WriteString(crlf)
  295. if hs.Protocol != "" {
  296. httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
  297. }
  298. if len(hs.Extensions) > 0 {
  299. httpWriteHeaderKey(bw, headerSecExtensions)
  300. httphead.WriteOptions(bw, hs.Extensions)
  301. bw.WriteString(crlf)
  302. }
  303. if header != nil {
  304. header(bw)
  305. }
  306. bw.WriteString(crlf)
  307. }
  308. func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
  309. switch code {
  310. case http.StatusBadRequest:
  311. bw.WriteString(textHeadBadRequest)
  312. case http.StatusInternalServerError:
  313. bw.WriteString(textHeadInternalServerError)
  314. case http.StatusUpgradeRequired:
  315. bw.WriteString(textHeadUpgradeRequired)
  316. default:
  317. writeStatusText(bw, code)
  318. }
  319. // Write custom headers.
  320. if header != nil {
  321. header(bw)
  322. }
  323. switch err {
  324. case ErrHandshakeBadProtocol:
  325. bw.WriteString(textTailErrHandshakeBadProtocol)
  326. case ErrHandshakeBadMethod:
  327. bw.WriteString(textTailErrHandshakeBadMethod)
  328. case ErrHandshakeBadHost:
  329. bw.WriteString(textTailErrHandshakeBadHost)
  330. case ErrHandshakeBadUpgrade:
  331. bw.WriteString(textTailErrHandshakeBadUpgrade)
  332. case ErrHandshakeBadConnection:
  333. bw.WriteString(textTailErrHandshakeBadConnection)
  334. case ErrHandshakeBadSecAccept:
  335. bw.WriteString(textTailErrHandshakeBadSecAccept)
  336. case ErrHandshakeBadSecKey:
  337. bw.WriteString(textTailErrHandshakeBadSecKey)
  338. case ErrHandshakeBadSecVersion:
  339. bw.WriteString(textTailErrHandshakeBadSecVersion)
  340. case ErrHandshakeUpgradeRequired:
  341. bw.WriteString(textTailErrUpgradeRequired)
  342. case nil:
  343. bw.WriteString(crlf)
  344. default:
  345. writeErrorText(bw, err)
  346. }
  347. }
  348. func writeStatusText(bw *bufio.Writer, code int) {
  349. bw.WriteString("HTTP/1.1 ")
  350. bw.WriteString(strconv.Itoa(code))
  351. bw.WriteByte(' ')
  352. bw.WriteString(http.StatusText(code))
  353. bw.WriteString(crlf)
  354. bw.WriteString("Content-Type: text/plain; charset=utf-8")
  355. bw.WriteString(crlf)
  356. }
  357. func writeErrorText(bw *bufio.Writer, err error) {
  358. body := err.Error()
  359. bw.WriteString("Content-Length: ")
  360. bw.WriteString(strconv.Itoa(len(body)))
  361. bw.WriteString(crlf)
  362. bw.WriteString(crlf)
  363. bw.WriteString(body)
  364. }
  365. // httpError is like the http.Error with WebSocket context exception.
  366. func httpError(w http.ResponseWriter, body string, code int) {
  367. w.Header().Set("Content-Type", "text/plain; charset=utf-8")
  368. w.Header().Set("Content-Length", strconv.Itoa(len(body)))
  369. w.WriteHeader(code)
  370. w.Write([]byte(body))
  371. }
  372. // statusText is a non-performant status text generator.
  373. // NOTE: Used only to generate constants.
  374. func statusText(code int) string {
  375. var buf bytes.Buffer
  376. bw := bufio.NewWriter(&buf)
  377. writeStatusText(bw, code)
  378. bw.Flush()
  379. return buf.String()
  380. }
  381. // errorText is a non-performant error text generator.
  382. // NOTE: Used only to generate constants.
  383. func errorText(err error) string {
  384. var buf bytes.Buffer
  385. bw := bufio.NewWriter(&buf)
  386. writeErrorText(bw, err)
  387. bw.Flush()
  388. return buf.String()
  389. }
  390. // HandshakeHeader is the interface that writes both upgrade request or
  391. // response headers into a given io.Writer.
  392. type HandshakeHeader interface {
  393. io.WriterTo
  394. }
  395. // HandshakeHeaderString is an adapter to allow the use of headers represented
  396. // by ordinary string as HandshakeHeader.
  397. type HandshakeHeaderString string
  398. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  399. func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
  400. n, err := io.WriteString(w, string(s))
  401. return int64(n), err
  402. }
  403. // HandshakeHeaderBytes is an adapter to allow the use of headers represented
  404. // by ordinary slice of bytes as HandshakeHeader.
  405. type HandshakeHeaderBytes []byte
  406. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  407. func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
  408. n, err := w.Write(b)
  409. return int64(n), err
  410. }
  411. // HandshakeHeaderFunc is an adapter to allow the use of headers represented by
  412. // ordinary function as HandshakeHeader.
  413. type HandshakeHeaderFunc func(io.Writer) (int64, error)
  414. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  415. func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
  416. return f(w)
  417. }
  418. // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
  419. // HandshakeHeader.
  420. type HandshakeHeaderHTTP http.Header
  421. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  422. func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
  423. wr := writer{w: w}
  424. err := http.Header(h).Write(&wr)
  425. return wr.n, err
  426. }
  427. type writer struct {
  428. n int64
  429. w io.Writer
  430. }
  431. func (w *writer) WriteString(s string) (int, error) {
  432. n, err := io.WriteString(w.w, s)
  433. w.n += int64(n)
  434. return n, err
  435. }
  436. func (w *writer) Write(p []byte) (int, error) {
  437. n, err := w.w.Write(p)
  438. w.n += int64(n)
  439. return n, err
  440. }