0
Fork 0
mirror of https://github.com/ninenines/cowboy.git synced 2025-07-14 20:30:23 +00:00

Fix websocket unmasking when compression is enabled

The unmasking logic was based on the length of inflated data instead
of the length of the deflated data. This meant data would get corrupted
when we receive a websocket frame split across multiple TCP packets.
This commit is contained in:
Ali Sabil 2013-07-02 11:09:27 +02:00
parent b0d0cabf12
commit a3b9438d16

View file

@ -329,45 +329,49 @@ websocket_data(State, Req, HandlerState, Data) ->
websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
Opcode, Len, MaskKey, Data, Rsv, 0) ->
websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}},
Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
%% Subsequent frame fragments.
websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
0, Len, MaskKey, Data, Rsv, 0) ->
websocket_payload(State, Req, HandlerState,
0, Len, MaskKey, <<>>, Data, Rsv);
0, Len, MaskKey, <<>>, 0, Data, Rsv);
%% Final frame fragment.
websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}},
Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) ->
websocket_payload(State#state{frag_state={fin, Opcode, SoFar}},
Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
%% Unfragmented frame.
websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) ->
websocket_payload(State, Req, HandlerState,
Opcode, Len, MaskKey, <<>>, Data, Rsv).
Opcode, Len, MaskKey, <<>>, 0, Data, Rsv).
-spec websocket_payload(#state{}, Req, any(),
opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv())
opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(),
binary(), rsv())
-> {ok, Req, cowboy_middleware:env()}
| {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
%% Close control frames with a payload MUST contain a valid close code.
websocket_payload(State, Req, HandlerState,
Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) ->
Opcode=8, Len, MaskKey, <<>>, 0,
<< MaskedCode:2/binary, Rest/bits >>, Rsv) ->
Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>),
if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006;
(Code > 1011) and (Code < 3000); Code > 4999 ->
websocket_close(State, Req, HandlerState, {error, badframe});
true ->
websocket_payload(State, Req, HandlerState,
Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv)
Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode),
Rest, Rsv)
end;
%% Text frames and close control frames MUST have a payload that is valid UTF-8.
websocket_payload(State=#state{utf8_state=Incomplete},
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
Data, Rsv)
when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse
((Opcode =:= 8) andalso (Unmasked =/= <<>>))) ->
Unmasked2 = websocket_unmask(Data,
rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
false ->
@ -375,14 +379,16 @@ websocket_payload(State=#state{utf8_state=Incomplete},
Utf8State ->
websocket_payload_loop(State2#state{utf8_state=Utf8State},
Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
<< Unmasked/binary, Unmasked3/binary >>, Rsv)
<< Unmasked/binary, Unmasked3/binary >>,
UnmaskedLen + byte_size(Data), Rsv)
end;
websocket_payload(State=#state{utf8_state=Incomplete},
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
Data, Rsv)
when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) ->
<< End:Len/binary, Rest/bits >> = Data,
Unmasked2 = websocket_unmask(End,
rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
<<>> ->
@ -394,10 +400,11 @@ websocket_payload(State=#state{utf8_state=Incomplete},
end;
%% Fragmented text frames may cut payload in the middle of UTF-8 codepoints.
websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv)
Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
Data, Rsv)
when byte_size(Data) < Len ->
Unmasked2 = websocket_unmask(Data,
rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
false ->
@ -405,13 +412,15 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
Utf8State ->
websocket_payload_loop(State2#state{utf8_state=Utf8State},
Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
<< Unmasked/binary, Unmasked3/binary >>, Rsv)
<< Unmasked/binary, Unmasked3/binary >>,
UnmaskedLen + byte_size(Data), Rsv)
end;
websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) ->
Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
Data, Rsv) ->
<< End:Len/binary, Rest/bits >> = Data,
Unmasked2 = websocket_unmask(End,
rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
<<>> ->
@ -427,20 +436,23 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
end;
%% Other frames have a binary payload.
websocket_payload(State, Req, HandlerState,
Opcode, Len, MaskKey, Unmasked, Data, Rsv)
Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv)
when byte_size(Data) < Len ->
Unmasked2 = websocket_unmask(Data,
rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
websocket_payload_loop(State2, Req, HandlerState,
Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv);
Opcode, Len - byte_size(Data), MaskKey,
<< Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data),
Rsv);
websocket_payload(State, Req, HandlerState,
Opcode, Len, MaskKey, Unmasked, Data, Rsv) ->
Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) ->
<< End:Len/binary, Rest/bits >> = Data,
Unmasked2 = websocket_unmask(End,
rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3).
websocket_dispatch(State2, Req, HandlerState, Rest, Opcode,
<< Unmasked/binary, Unmasked3/binary >>).
-spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) ->
{binary(), #state{}}.
@ -513,19 +525,20 @@ is_utf8(_) ->
false.
-spec websocket_payload_loop(#state{}, Req, any(),
opcode(), non_neg_integer(), mask_key(), binary(), rsv())
opcode(), non_neg_integer(), mask_key(), binary(),
non_neg_integer(), rsv())
-> {ok, Req, cowboy_middleware:env()}
| {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
messages={OK, Closed, Error}, timeout_ref=TRef},
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) ->
Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) ->
Transport:setopts(Socket, [{active, once}]),
receive
{OK, Socket, Data} ->
State2 = handler_loop_timeout(State),
websocket_payload(State2, Req, HandlerState,
Opcode, Len, MaskKey, Unmasked, Data, Rsv);
Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv);
{Closed, Socket} ->
handler_terminate(State, Req, HandlerState, {error, closed});
{Error, Socket, Reason} ->
@ -534,13 +547,13 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
websocket_close(State, Req, HandlerState, {normal, timeout});
{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
websocket_payload_loop(State, Req, HandlerState,
Opcode, Len, MaskKey, Unmasked, Rsv);
Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv);
Message ->
handler_call(State, Req, HandlerState,
<<>>, websocket_info, Message,
fun (State2, Req2, HandlerState2, _) ->
websocket_payload_loop(State2, Req2, HandlerState2,
Opcode, Len, MaskKey, Unmasked, Rsv)
Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv)
end)
end.